aboutsummaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/__init__.py0
-rw-r--r--tests/accounts/__init__.py0
-rw-r--r--tests/accounts/models.py8
-rw-r--r--tests/accounts/serializers.py11
-rw-r--r--tests/description.py26
-rw-r--r--tests/extras/__init__.py0
-rw-r--r--tests/extras/bad_import.py1
-rw-r--r--tests/models.py170
-rw-r--r--tests/records/__init__.py0
-rw-r--r--tests/records/models.py6
-rw-r--r--tests/serializers.py8
-rw-r--r--tests/settings.py169
-rw-r--r--tests/test_authentication.py637
-rw-r--r--tests/test_breadcrumbs.py73
-rw-r--r--tests/test_decorators.py157
-rw-r--r--tests/test_description.py108
-rw-r--r--tests/test_fields.py984
-rw-r--r--tests/test_files.py95
-rw-r--r--tests/test_filters.py615
-rw-r--r--tests/test_genericrelations.py129
-rw-r--r--tests/test_generics.py609
-rw-r--r--tests/test_htmlrenderer.py118
-rw-r--r--tests/test_hyperlinkedserializers.py379
-rw-r--r--tests/test_multitable_inheritance.py67
-rw-r--r--tests/test_negotiation.py45
-rw-r--r--tests/test_nullable_fields.py30
-rw-r--r--tests/test_pagination.py517
-rw-r--r--tests/test_parsers.py115
-rw-r--r--tests/test_permissions.py291
-rw-r--r--tests/test_relations.py120
-rw-r--r--tests/test_relations_hyperlink.py524
-rw-r--r--tests/test_relations_nested.py328
-rw-r--r--tests/test_relations_pk.py551
-rw-r--r--tests/test_relations_slug.py257
-rw-r--r--tests/test_renderers.py651
-rw-r--r--tests/test_request.py347
-rw-r--r--tests/test_response.py278
-rw-r--r--tests/test_reverse.py27
-rw-r--r--tests/test_routers.py216
-rw-r--r--tests/test_serializer.py1857
-rw-r--r--tests/test_serializer_bulk_update.py278
-rw-r--r--tests/test_serializer_empty.py15
-rw-r--r--tests/test_serializer_import.py19
-rw-r--r--tests/test_serializer_nested.py347
-rw-r--r--tests/test_serializers.py28
-rw-r--r--tests/test_settings.py22
-rw-r--r--tests/test_status.py33
-rw-r--r--tests/test_templatetags.py51
-rw-r--r--tests/test_testing.py154
-rw-r--r--tests/test_throttling.py277
-rw-r--r--tests/test_urlpatterns.py76
-rw-r--r--tests/test_validation.py104
-rw-r--r--tests/test_views.py142
-rw-r--r--tests/test_write_only_fields.py42
-rw-r--r--tests/urls.py6
-rw-r--r--tests/users/__init__.py0
-rw-r--r--tests/users/models.py6
-rw-r--r--tests/users/serializers.py8
-rw-r--r--tests/views.py8
59 files changed, 12140 insertions, 0 deletions
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/tests/__init__.py
diff --git a/tests/accounts/__init__.py b/tests/accounts/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/tests/accounts/__init__.py
diff --git a/tests/accounts/models.py b/tests/accounts/models.py
new file mode 100644
index 00000000..3bf4a0c3
--- /dev/null
+++ b/tests/accounts/models.py
@@ -0,0 +1,8 @@
+from django.db import models
+
+from tests.users.models import User
+
+
+class Account(models.Model):
+ owner = models.ForeignKey(User, related_name='accounts_owned')
+ admins = models.ManyToManyField(User, blank=True, null=True, related_name='accounts_administered')
diff --git a/tests/accounts/serializers.py b/tests/accounts/serializers.py
new file mode 100644
index 00000000..57a91b92
--- /dev/null
+++ b/tests/accounts/serializers.py
@@ -0,0 +1,11 @@
+from rest_framework import serializers
+
+from tests.accounts.models import Account
+from tests.users.serializers import UserSerializer
+
+
+class AccountSerializer(serializers.ModelSerializer):
+ admins = UserSerializer(many=True)
+
+ class Meta:
+ model = Account
diff --git a/tests/description.py b/tests/description.py
new file mode 100644
index 00000000..b46d7f54
--- /dev/null
+++ b/tests/description.py
@@ -0,0 +1,26 @@
+# -- coding: utf-8 --
+
+# Apparently there is a python 2.6 issue where docstrings of imported view classes
+# do not retain their encoding information even if a module has a proper
+# encoding declaration at the top of its source file. Therefore for tests
+# to catch unicode related errors, a mock view has to be declared in a separate
+# module.
+
+from rest_framework.views import APIView
+
+
+# test strings snatched from http://www.columbia.edu/~fdc/utf8/,
+# http://winrus.com/utf8-jap.htm and memory
+UTF8_TEST_DOCSTRING = (
+ 'zażółć gęślą jaźń'
+ 'Sîne klâwen durh die wolken sint geslagen'
+ 'Τη γλώσσα μου έδωσαν ελληνική'
+ 'யாமறிந்த மொழிகளிலே தமிழ்மொழி'
+ 'На берегу пустынных волн'
+ 'てすと'
+ 'アイウエオカキクケコサシスセソタチツテ'
+)
+
+
+class ViewWithNonASCIICharactersInDocstring(APIView):
+ __doc__ = UTF8_TEST_DOCSTRING
diff --git a/tests/extras/__init__.py b/tests/extras/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/tests/extras/__init__.py
diff --git a/tests/extras/bad_import.py b/tests/extras/bad_import.py
new file mode 100644
index 00000000..68263d94
--- /dev/null
+++ b/tests/extras/bad_import.py
@@ -0,0 +1 @@
+raise ValueError
diff --git a/tests/models.py b/tests/models.py
new file mode 100644
index 00000000..32a726c0
--- /dev/null
+++ b/tests/models.py
@@ -0,0 +1,170 @@
+from __future__ import unicode_literals
+from django.db import models
+from django.utils.translation import ugettext_lazy as _
+from rest_framework import serializers
+
+
+def foobar():
+ return 'foobar'
+
+
+class CustomField(models.CharField):
+
+ def __init__(self, *args, **kwargs):
+ kwargs['max_length'] = 12
+ super(CustomField, self).__init__(*args, **kwargs)
+
+
+class RESTFrameworkModel(models.Model):
+ """
+ Base for test models that sets app_label, so they play nicely.
+ """
+ class Meta:
+ app_label = 'tests'
+ abstract = True
+
+
+class HasPositiveIntegerAsChoice(RESTFrameworkModel):
+ some_choices = ((1, 'A'), (2, 'B'), (3, 'C'))
+ some_integer = models.PositiveIntegerField(choices=some_choices)
+
+
+class Anchor(RESTFrameworkModel):
+ text = models.CharField(max_length=100, default='anchor')
+
+
+class BasicModel(RESTFrameworkModel):
+ text = models.CharField(max_length=100, verbose_name=_("Text comes here"), help_text=_("Text description."))
+
+
+class SlugBasedModel(RESTFrameworkModel):
+ text = models.CharField(max_length=100)
+ slug = models.SlugField(max_length=32)
+
+
+class DefaultValueModel(RESTFrameworkModel):
+ text = models.CharField(default='foobar', max_length=100)
+ extra = models.CharField(blank=True, null=True, max_length=100)
+
+
+class CallableDefaultValueModel(RESTFrameworkModel):
+ text = models.CharField(default=foobar, max_length=100)
+
+
+class ManyToManyModel(RESTFrameworkModel):
+ rel = models.ManyToManyField(Anchor, help_text='Some help text.')
+
+
+class ReadOnlyManyToManyModel(RESTFrameworkModel):
+ text = models.CharField(max_length=100, default='anchor')
+ rel = models.ManyToManyField(Anchor)
+
+
+# Model for regression test for #285
+
+class Comment(RESTFrameworkModel):
+ email = models.EmailField()
+ content = models.CharField(max_length=200)
+ created = models.DateTimeField(auto_now_add=True)
+
+
+class ActionItem(RESTFrameworkModel):
+ title = models.CharField(max_length=200)
+ started = models.NullBooleanField(default=False)
+ done = models.BooleanField(default=False)
+ info = CustomField(default='---', max_length=12)
+
+
+# Models for reverse relations
+class Person(RESTFrameworkModel):
+ name = models.CharField(max_length=10)
+ age = models.IntegerField(null=True, blank=True)
+
+ @property
+ def info(self):
+ return {
+ 'name': self.name,
+ 'age': self.age,
+ }
+
+
+class BlogPost(RESTFrameworkModel):
+ title = models.CharField(max_length=100)
+ writer = models.ForeignKey(Person, null=True, blank=True)
+
+ def get_first_comment(self):
+ return self.blogpostcomment_set.all()[0]
+
+
+class BlogPostComment(RESTFrameworkModel):
+ text = models.TextField()
+ blog_post = models.ForeignKey(BlogPost)
+
+
+class Album(RESTFrameworkModel):
+ title = models.CharField(max_length=100, unique=True)
+
+
+class Photo(RESTFrameworkModel):
+ description = models.TextField()
+ album = models.ForeignKey(Album)
+
+
+# Model for issue #324
+class BlankFieldModel(RESTFrameworkModel):
+ title = models.CharField(max_length=100, blank=True, null=False)
+
+
+# Model for issue #380
+class OptionalRelationModel(RESTFrameworkModel):
+ other = models.ForeignKey('OptionalRelationModel', blank=True, null=True)
+
+
+# Model for RegexField
+class Book(RESTFrameworkModel):
+ isbn = models.CharField(max_length=13)
+
+
+# Models for relations tests
+# ManyToMany
+class ManyToManyTarget(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+
+
+class ManyToManySource(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+ targets = models.ManyToManyField(ManyToManyTarget, related_name='sources')
+
+
+# ForeignKey
+class ForeignKeyTarget(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+
+
+class ForeignKeySource(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+ target = models.ForeignKey(ForeignKeyTarget, related_name='sources')
+
+
+# Nullable ForeignKey
+class NullableForeignKeySource(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+ target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True,
+ related_name='nullable_sources')
+
+
+# OneToOne
+class OneToOneTarget(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+
+
+class NullableOneToOneSource(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+ target = models.OneToOneField(OneToOneTarget, null=True, blank=True,
+ related_name='nullable_source')
+
+
+# Serializer used to test BasicModel
+class BasicModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BasicModel
diff --git a/tests/records/__init__.py b/tests/records/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/tests/records/__init__.py
diff --git a/tests/records/models.py b/tests/records/models.py
new file mode 100644
index 00000000..76954807
--- /dev/null
+++ b/tests/records/models.py
@@ -0,0 +1,6 @@
+from django.db import models
+
+
+class Record(models.Model):
+ account = models.ForeignKey('accounts.Account', blank=True, null=True)
+ owner = models.ForeignKey('users.User', blank=True, null=True)
diff --git a/tests/serializers.py b/tests/serializers.py
new file mode 100644
index 00000000..f2f85b6e
--- /dev/null
+++ b/tests/serializers.py
@@ -0,0 +1,8 @@
+from rest_framework import serializers
+
+from tests.models import NullableForeignKeySource
+
+
+class NullableFKSourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = NullableForeignKeySource
diff --git a/tests/settings.py b/tests/settings.py
new file mode 100644
index 00000000..75f7c54b
--- /dev/null
+++ b/tests/settings.py
@@ -0,0 +1,169 @@
+# Django settings for testproject project.
+
+DEBUG = True
+TEMPLATE_DEBUG = DEBUG
+DEBUG_PROPAGATE_EXCEPTIONS = True
+
+ALLOWED_HOSTS = ['*']
+
+ADMINS = (
+ # ('Your Name', 'your_email@domain.com'),
+)
+
+MANAGERS = ADMINS
+
+DATABASES = {
+ 'default': {
+ 'ENGINE': 'django.db.backends.sqlite3', # Add 'postgresql_psycopg2', 'postgresql', 'mysql', 'sqlite3' or 'oracle'.
+ 'NAME': 'sqlite.db', # Or path to database file if using sqlite3.
+ 'USER': '', # Not used with sqlite3.
+ 'PASSWORD': '', # Not used with sqlite3.
+ 'HOST': '', # Set to empty string for localhost. Not used with sqlite3.
+ 'PORT': '', # Set to empty string for default. Not used with sqlite3.
+ }
+}
+
+CACHES = {
+ 'default': {
+ 'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
+ }
+}
+
+# Local time zone for this installation. Choices can be found here:
+# http://en.wikipedia.org/wiki/List_of_tz_zones_by_name
+# although not all choices may be available on all operating systems.
+# On Unix systems, a value of None will cause Django to use the same
+# timezone as the operating system.
+# If running in a Windows environment this must be set to the same as your
+# system time zone.
+TIME_ZONE = 'Europe/London'
+
+# Language code for this installation. All choices can be found here:
+# http://www.i18nguy.com/unicode/language-identifiers.html
+LANGUAGE_CODE = 'en-uk'
+
+SITE_ID = 1
+
+# If you set this to False, Django will make some optimizations so as not
+# to load the internationalization machinery.
+USE_I18N = True
+
+# If you set this to False, Django will not format dates, numbers and
+# calendars according to the current locale
+USE_L10N = True
+
+# Absolute filesystem path to the directory that will hold user-uploaded files.
+# Example: "/home/media/media.lawrence.com/"
+MEDIA_ROOT = ''
+
+# URL that handles the media served from MEDIA_ROOT. Make sure to use a
+# trailing slash if there is a path component (optional in other cases).
+# Examples: "http://media.lawrence.com", "http://example.com/media/"
+MEDIA_URL = ''
+
+# Make this unique, and don't share it with anybody.
+SECRET_KEY = 'u@x-aj9(hoh#rb-^ymf#g2jx_hp0vj7u5#b@ag1n^seu9e!%cy'
+
+# List of callables that know how to import templates from various sources.
+TEMPLATE_LOADERS = (
+ 'django.template.loaders.filesystem.Loader',
+ 'django.template.loaders.app_directories.Loader',
+# 'django.template.loaders.eggs.Loader',
+)
+
+MIDDLEWARE_CLASSES = (
+ 'django.middleware.common.CommonMiddleware',
+ 'django.contrib.sessions.middleware.SessionMiddleware',
+ 'django.middleware.csrf.CsrfViewMiddleware',
+ 'django.contrib.auth.middleware.AuthenticationMiddleware',
+ 'django.contrib.messages.middleware.MessageMiddleware',
+)
+
+ROOT_URLCONF = 'tests.urls'
+
+TEMPLATE_DIRS = (
+ # Put strings here, like "/home/html/django_templates" or "C:/www/django/templates".
+ # Always use forward slashes, even on Windows.
+ # Don't forget to use absolute paths, not relative paths.
+)
+
+INSTALLED_APPS = (
+ 'django.contrib.auth',
+ 'django.contrib.contenttypes',
+ 'django.contrib.sessions',
+ 'django.contrib.sites',
+ 'django.contrib.messages',
+ # Uncomment the next line to enable the admin:
+ # 'django.contrib.admin',
+ # Uncomment the next line to enable admin documentation:
+ # 'django.contrib.admindocs',
+ 'rest_framework',
+ 'rest_framework.authtoken',
+ 'tests',
+ 'tests.accounts',
+ 'tests.records',
+ 'tests.users',
+)
+
+# OAuth is optional and won't work if there is no oauth_provider & oauth2
+try:
+ import oauth_provider
+ import oauth2
+except ImportError:
+ pass
+else:
+ INSTALLED_APPS += (
+ 'oauth_provider',
+ )
+
+try:
+ import provider
+except ImportError:
+ pass
+else:
+ INSTALLED_APPS += (
+ 'provider',
+ 'provider.oauth2',
+ )
+
+# guardian is optional
+try:
+ import guardian
+except ImportError:
+ pass
+else:
+ ANONYMOUS_USER_ID = -1
+ AUTHENTICATION_BACKENDS = (
+ 'django.contrib.auth.backends.ModelBackend', # default
+ 'guardian.backends.ObjectPermissionBackend',
+ )
+ INSTALLED_APPS += (
+ 'guardian',
+ )
+
+STATIC_URL = '/static/'
+
+PASSWORD_HASHERS = (
+ 'django.contrib.auth.hashers.SHA1PasswordHasher',
+ 'django.contrib.auth.hashers.PBKDF2PasswordHasher',
+ 'django.contrib.auth.hashers.PBKDF2SHA1PasswordHasher',
+ 'django.contrib.auth.hashers.BCryptPasswordHasher',
+ 'django.contrib.auth.hashers.MD5PasswordHasher',
+ 'django.contrib.auth.hashers.CryptPasswordHasher',
+)
+
+AUTH_USER_MODEL = 'auth.User'
+
+import django
+
+if django.VERSION < (1, 3):
+ INSTALLED_APPS += ('staticfiles',)
+
+
+# If we're running on the Jenkins server we want to archive the coverage reports as XML.
+import os
+if os.environ.get('HUDSON_URL', None):
+ TEST_RUNNER = 'xmlrunner.extra.djangotestrunner.XMLTestRunner'
+ TEST_OUTPUT_VERBOSE = True
+ TEST_OUTPUT_DESCRIPTIONS = True
+ TEST_OUTPUT_DIR = 'xmlrunner'
diff --git a/tests/test_authentication.py b/tests/test_authentication.py
new file mode 100644
index 00000000..4ecfef44
--- /dev/null
+++ b/tests/test_authentication.py
@@ -0,0 +1,637 @@
+from __future__ import unicode_literals
+from django.contrib.auth.models import User
+from django.http import HttpResponse
+from django.test import TestCase
+from django.utils import unittest
+from rest_framework import HTTP_HEADER_ENCODING
+from rest_framework import exceptions
+from rest_framework import permissions
+from rest_framework import renderers
+from rest_framework.response import Response
+from rest_framework import status
+from rest_framework.authentication import (
+ BaseAuthentication,
+ TokenAuthentication,
+ BasicAuthentication,
+ SessionAuthentication,
+ OAuthAuthentication,
+ OAuth2Authentication
+)
+from rest_framework.authtoken.models import Token
+from rest_framework.compat import patterns, url, include
+from rest_framework.compat import oauth2_provider, oauth2_provider_models, oauth2_provider_scope
+from rest_framework.compat import oauth, oauth_provider
+from rest_framework.test import APIRequestFactory, APIClient
+from rest_framework.views import APIView
+import base64
+import time
+import datetime
+
+factory = APIRequestFactory()
+
+
+class MockView(APIView):
+ permission_classes = (permissions.IsAuthenticated,)
+
+ def get(self, request):
+ return HttpResponse({'a': 1, 'b': 2, 'c': 3})
+
+ def post(self, request):
+ return HttpResponse({'a': 1, 'b': 2, 'c': 3})
+
+ def put(self, request):
+ return HttpResponse({'a': 1, 'b': 2, 'c': 3})
+
+
+urlpatterns = patterns('',
+ (r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])),
+ (r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])),
+ (r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),
+ (r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),
+ (r'^oauth/$', MockView.as_view(authentication_classes=[OAuthAuthentication])),
+ (r'^oauth-with-scope/$', MockView.as_view(authentication_classes=[OAuthAuthentication],
+ permission_classes=[permissions.TokenHasReadWriteScope]))
+)
+
+if oauth2_provider is not None:
+ urlpatterns += patterns('',
+ url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')),
+ url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])),
+ url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication],
+ permission_classes=[permissions.TokenHasReadWriteScope])),
+ )
+
+
+class BasicAuthTests(TestCase):
+ """Basic authentication"""
+ urls = 'tests.test_authentication'
+
+ def setUp(self):
+ self.csrf_client = APIClient(enforce_csrf_checks=True)
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ def test_post_form_passing_basic_auth(self):
+ """Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF"""
+ credentials = ('%s:%s' % (self.username, self.password))
+ base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING)
+ auth = 'Basic %s' % base64_credentials
+ response = self.csrf_client.post('/basic/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ def test_post_json_passing_basic_auth(self):
+ """Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF"""
+ credentials = ('%s:%s' % (self.username, self.password))
+ base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING)
+ auth = 'Basic %s' % base64_credentials
+ response = self.csrf_client.post('/basic/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ def test_post_form_failing_basic_auth(self):
+ """Ensure POSTing form over basic auth without correct credentials fails"""
+ response = self.csrf_client.post('/basic/', {'example': 'example'})
+ self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
+
+ def test_post_json_failing_basic_auth(self):
+ """Ensure POSTing json over basic auth without correct credentials fails"""
+ response = self.csrf_client.post('/basic/', {'example': 'example'}, format='json')
+ self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
+ self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"')
+
+
+class SessionAuthTests(TestCase):
+ """User session authentication"""
+ urls = 'tests.test_authentication'
+
+ def setUp(self):
+ self.csrf_client = APIClient(enforce_csrf_checks=True)
+ self.non_csrf_client = APIClient(enforce_csrf_checks=False)
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ def tearDown(self):
+ self.csrf_client.logout()
+
+ def test_post_form_session_auth_failing_csrf(self):
+ """
+ Ensure POSTing form over session authentication without CSRF token fails.
+ """
+ self.csrf_client.login(username=self.username, password=self.password)
+ response = self.csrf_client.post('/session/', {'example': 'example'})
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+ def test_post_form_session_auth_passing(self):
+ """
+ Ensure POSTing form over session authentication with logged in user and CSRF token passes.
+ """
+ self.non_csrf_client.login(username=self.username, password=self.password)
+ response = self.non_csrf_client.post('/session/', {'example': 'example'})
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ def test_put_form_session_auth_passing(self):
+ """
+ Ensure PUTting form over session authentication with logged in user and CSRF token passes.
+ """
+ self.non_csrf_client.login(username=self.username, password=self.password)
+ response = self.non_csrf_client.put('/session/', {'example': 'example'})
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ def test_post_form_session_auth_failing(self):
+ """
+ Ensure POSTing form over session authentication without logged in user fails.
+ """
+ response = self.csrf_client.post('/session/', {'example': 'example'})
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+
+class TokenAuthTests(TestCase):
+ """Token authentication"""
+ urls = 'tests.test_authentication'
+
+ def setUp(self):
+ self.csrf_client = APIClient(enforce_csrf_checks=True)
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ self.key = 'abcd1234'
+ self.token = Token.objects.create(key=self.key, user=self.user)
+
+ def test_post_form_passing_token_auth(self):
+ """Ensure POSTing json over token auth with correct credentials passes and does not require CSRF"""
+ auth = 'Token ' + self.key
+ response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ def test_post_json_passing_token_auth(self):
+ """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF"""
+ auth = "Token " + self.key
+ response = self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ def test_post_form_failing_token_auth(self):
+ """Ensure POSTing form over token auth without correct credentials fails"""
+ response = self.csrf_client.post('/token/', {'example': 'example'})
+ self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
+
+ def test_post_json_failing_token_auth(self):
+ """Ensure POSTing json over token auth without correct credentials fails"""
+ response = self.csrf_client.post('/token/', {'example': 'example'}, format='json')
+ self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
+
+ def test_token_has_auto_assigned_key_if_none_provided(self):
+ """Ensure creating a token with no key will auto-assign a key"""
+ self.token.delete()
+ token = Token.objects.create(user=self.user)
+ self.assertTrue(bool(token.key))
+
+ def test_token_login_json(self):
+ """Ensure token login view using JSON POST works."""
+ client = APIClient(enforce_csrf_checks=True)
+ response = client.post('/auth-token/',
+ {'username': self.username, 'password': self.password}, format='json')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['token'], self.key)
+
+ def test_token_login_json_bad_creds(self):
+ """Ensure token login view using JSON POST fails if bad credentials are used."""
+ client = APIClient(enforce_csrf_checks=True)
+ response = client.post('/auth-token/',
+ {'username': self.username, 'password': "badpass"}, format='json')
+ self.assertEqual(response.status_code, 400)
+
+ def test_token_login_json_missing_fields(self):
+ """Ensure token login view using JSON POST fails if missing fields."""
+ client = APIClient(enforce_csrf_checks=True)
+ response = client.post('/auth-token/',
+ {'username': self.username}, format='json')
+ self.assertEqual(response.status_code, 400)
+
+ def test_token_login_form(self):
+ """Ensure token login view using form POST works."""
+ client = APIClient(enforce_csrf_checks=True)
+ response = client.post('/auth-token/',
+ {'username': self.username, 'password': self.password})
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['token'], self.key)
+
+
+class IncorrectCredentialsTests(TestCase):
+ def test_incorrect_credentials(self):
+ """
+ If a request contains bad authentication credentials, then
+ authentication should run and error, even if no permissions
+ are set on the view.
+ """
+ class IncorrectCredentialsAuth(BaseAuthentication):
+ def authenticate(self, request):
+ raise exceptions.AuthenticationFailed('Bad credentials')
+
+ request = factory.get('/')
+ view = MockView.as_view(
+ authentication_classes=(IncorrectCredentialsAuth,),
+ permission_classes=()
+ )
+ response = view(request)
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+ self.assertEqual(response.data, {'detail': 'Bad credentials'})
+
+
+class OAuthTests(TestCase):
+ """OAuth 1.0a authentication"""
+ urls = 'tests.test_authentication'
+
+ def setUp(self):
+ # these imports are here because oauth is optional and hiding them in try..except block or compat
+ # could obscure problems if something breaks
+ from oauth_provider.models import Consumer, Scope
+ from oauth_provider.models import Token as OAuthToken
+ from oauth_provider import consts
+
+ self.consts = consts
+
+ self.csrf_client = APIClient(enforce_csrf_checks=True)
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ self.CONSUMER_KEY = 'consumer_key'
+ self.CONSUMER_SECRET = 'consumer_secret'
+ self.TOKEN_KEY = "token_key"
+ self.TOKEN_SECRET = "token_secret"
+
+ self.consumer = Consumer.objects.create(key=self.CONSUMER_KEY, secret=self.CONSUMER_SECRET,
+ name='example', user=self.user, status=self.consts.ACCEPTED)
+
+ self.scope = Scope.objects.create(name="resource name", url="api/")
+ self.token = OAuthToken.objects.create(user=self.user, consumer=self.consumer, scope=self.scope,
+ token_type=OAuthToken.ACCESS, key=self.TOKEN_KEY, secret=self.TOKEN_SECRET, is_approved=True
+ )
+
+ def _create_authorization_header(self):
+ params = {
+ 'oauth_version': "1.0",
+ 'oauth_nonce': oauth.generate_nonce(),
+ 'oauth_timestamp': int(time.time()),
+ 'oauth_token': self.token.key,
+ 'oauth_consumer_key': self.consumer.key
+ }
+
+ req = oauth.Request(method="GET", url="http://example.com", parameters=params)
+
+ signature_method = oauth.SignatureMethod_PLAINTEXT()
+ req.sign_request(signature_method, self.consumer, self.token)
+
+ return req.to_header()["Authorization"]
+
+ def _create_authorization_url_parameters(self):
+ params = {
+ 'oauth_version': "1.0",
+ 'oauth_nonce': oauth.generate_nonce(),
+ 'oauth_timestamp': int(time.time()),
+ 'oauth_token': self.token.key,
+ 'oauth_consumer_key': self.consumer.key
+ }
+
+ req = oauth.Request(method="GET", url="http://example.com", parameters=params)
+
+ signature_method = oauth.SignatureMethod_PLAINTEXT()
+ req.sign_request(signature_method, self.consumer, self.token)
+ return dict(req)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_passing_oauth(self):
+ """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_repeated_nonce_failing_oauth(self):
+ """Ensure POSTing form over OAuth with repeated auth (same nonces and timestamp) credentials fails"""
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ # simulate reply attack auth header containes already used (nonce, timestamp) pair
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_token_removed_failing_oauth(self):
+ """Ensure POSTing when there is no OAuth access token in db fails"""
+ self.token.delete()
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_consumer_status_not_accepted_failing_oauth(self):
+ """Ensure POSTing when consumer status is anything other than ACCEPTED fails"""
+ for consumer_status in (self.consts.CANCELED, self.consts.PENDING, self.consts.REJECTED):
+ self.consumer.status = consumer_status
+ self.consumer.save()
+
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_with_request_token_failing_oauth(self):
+ """Ensure POSTing with unauthorized request token instead of access token fails"""
+ self.token.token_type = self.token.REQUEST
+ self.token.save()
+
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_with_urlencoded_parameters(self):
+ """Ensure POSTing with x-www-form-urlencoded auth parameters passes"""
+ params = self._create_authorization_url_parameters()
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', params, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_get_form_with_url_parameters(self):
+ """Ensure GETing with auth in url parameters passes"""
+ params = self._create_authorization_url_parameters()
+ response = self.csrf_client.get('/oauth/', params)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_hmac_sha1_signature_passes(self):
+ """Ensure POSTing using HMAC_SHA1 signature method passes"""
+ params = {
+ 'oauth_version': "1.0",
+ 'oauth_nonce': oauth.generate_nonce(),
+ 'oauth_timestamp': int(time.time()),
+ 'oauth_token': self.token.key,
+ 'oauth_consumer_key': self.consumer.key
+ }
+
+ req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params)
+
+ signature_method = oauth.SignatureMethod_HMAC_SHA1()
+ req.sign_request(signature_method, self.consumer, self.token)
+ auth = req.to_header()["Authorization"]
+
+ response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_get_form_with_readonly_resource_passing_auth(self):
+ """Ensure POSTing with a readonly scope instead of a write scope fails"""
+ read_only_access_token = self.token
+ read_only_access_token.scope.is_readonly = True
+ read_only_access_token.scope.save()
+ params = self._create_authorization_url_parameters()
+ response = self.csrf_client.get('/oauth-with-scope/', params)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_with_readonly_resource_failing_auth(self):
+ """Ensure POSTing with a readonly resource instead of a write scope fails"""
+ read_only_access_token = self.token
+ read_only_access_token.scope.is_readonly = True
+ read_only_access_token.scope.save()
+ params = self._create_authorization_url_parameters()
+ response = self.csrf_client.post('/oauth-with-scope/', params)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_with_write_resource_passing_auth(self):
+ """Ensure POSTing with a write resource succeed"""
+ read_write_access_token = self.token
+ read_write_access_token.scope.is_readonly = False
+ read_write_access_token.scope.save()
+ params = self._create_authorization_url_parameters()
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth-with-scope/', params, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_bad_consumer_key(self):
+ """Ensure POSTing using HMAC_SHA1 signature method passes"""
+ params = {
+ 'oauth_version': "1.0",
+ 'oauth_nonce': oauth.generate_nonce(),
+ 'oauth_timestamp': int(time.time()),
+ 'oauth_token': self.token.key,
+ 'oauth_consumer_key': 'badconsumerkey'
+ }
+
+ req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params)
+
+ signature_method = oauth.SignatureMethod_HMAC_SHA1()
+ req.sign_request(signature_method, self.consumer, self.token)
+ auth = req.to_header()["Authorization"]
+
+ response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_bad_token_key(self):
+ """Ensure POSTing using HMAC_SHA1 signature method passes"""
+ params = {
+ 'oauth_version': "1.0",
+ 'oauth_nonce': oauth.generate_nonce(),
+ 'oauth_timestamp': int(time.time()),
+ 'oauth_token': 'badtokenkey',
+ 'oauth_consumer_key': self.consumer.key
+ }
+
+ req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params)
+
+ signature_method = oauth.SignatureMethod_HMAC_SHA1()
+ req.sign_request(signature_method, self.consumer, self.token)
+ auth = req.to_header()["Authorization"]
+
+ response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+
+
+class OAuth2Tests(TestCase):
+ """OAuth 2.0 authentication"""
+ urls = 'tests.test_authentication'
+
+ def setUp(self):
+ self.csrf_client = APIClient(enforce_csrf_checks=True)
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ self.CLIENT_ID = 'client_key'
+ self.CLIENT_SECRET = 'client_secret'
+ self.ACCESS_TOKEN = "access_token"
+ self.REFRESH_TOKEN = "refresh_token"
+
+ self.oauth2_client = oauth2_provider_models.Client.objects.create(
+ client_id=self.CLIENT_ID,
+ client_secret=self.CLIENT_SECRET,
+ redirect_uri='',
+ client_type=0,
+ name='example',
+ user=None,
+ )
+
+ self.access_token = oauth2_provider_models.AccessToken.objects.create(
+ token=self.ACCESS_TOKEN,
+ client=self.oauth2_client,
+ user=self.user,
+ )
+ self.refresh_token = oauth2_provider_models.RefreshToken.objects.create(
+ user=self.user,
+ access_token=self.access_token,
+ client=self.oauth2_client
+ )
+
+ def _create_authorization_header(self, token=None):
+ return "Bearer {0}".format(token or self.access_token.token)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_with_wrong_authorization_header_token_type_failing(self):
+ """Ensure that a wrong token type lead to the correct HTTP error status code"""
+ auth = "Wrong token-type-obsviously"
+ response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+ response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_with_wrong_authorization_header_token_format_failing(self):
+ """Ensure that a wrong token format lead to the correct HTTP error status code"""
+ auth = "Bearer wrong token format"
+ response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+ response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_with_wrong_authorization_header_token_failing(self):
+ """Ensure that a wrong token lead to the correct HTTP error status code"""
+ auth = "Bearer wrong-token"
+ response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+ response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_passing_auth(self):
+ """Ensure GETing form over OAuth with correct client credentials succeed"""
+ auth = self._create_authorization_header()
+ response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_passing_auth(self):
+ """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_token_removed_failing_auth(self):
+ """Ensure POSTing when there is no OAuth access token in db fails"""
+ self.access_token.delete()
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_with_refresh_token_failing_auth(self):
+ """Ensure POSTing with refresh token instead of access token fails"""
+ auth = self._create_authorization_header(token=self.refresh_token.token)
+ response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_with_expired_access_token_failing_auth(self):
+ """Ensure POSTing with expired access token fails with an 'Invalid token' error"""
+ self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10) # 10 seconds late
+ self.access_token.save()
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+ self.assertIn('Invalid token', response.content)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_with_invalid_scope_failing_auth(self):
+ """Ensure POSTing with a readonly scope instead of a write scope fails"""
+ read_only_access_token = self.access_token
+ read_only_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['read']
+ read_only_access_token.save()
+ auth = self._create_authorization_header(token=read_only_access_token.token)
+ response = self.csrf_client.get('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+ response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_with_valid_scope_passing_auth(self):
+ """Ensure POSTing with a write scope succeed"""
+ read_write_access_token = self.access_token
+ read_write_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['write']
+ read_write_access_token.save()
+ auth = self._create_authorization_header(token=read_write_access_token.token)
+ response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+
+class FailingAuthAccessedInRenderer(TestCase):
+ def setUp(self):
+ class AuthAccessingRenderer(renderers.BaseRenderer):
+ media_type = 'text/plain'
+ format = 'txt'
+
+ def render(self, data, media_type=None, renderer_context=None):
+ request = renderer_context['request']
+ if request.user.is_authenticated():
+ return b'authenticated'
+ return b'not authenticated'
+
+ class FailingAuth(BaseAuthentication):
+ def authenticate(self, request):
+ raise exceptions.AuthenticationFailed('authentication failed')
+
+ class ExampleView(APIView):
+ authentication_classes = (FailingAuth,)
+ renderer_classes = (AuthAccessingRenderer,)
+
+ def get(self, request):
+ return Response({'foo': 'bar'})
+
+ self.view = ExampleView.as_view()
+
+ def test_failing_auth_accessed_in_renderer(self):
+ """
+ When authentication fails the renderer should still be able to access
+ `request.user` without raising an exception. Particularly relevant
+ to HTML responses that might reasonably access `request.user`.
+ """
+ request = factory.get('/')
+ response = self.view(request)
+ content = response.render().content
+ self.assertEqual(content, b'not authenticated')
diff --git a/tests/test_breadcrumbs.py b/tests/test_breadcrumbs.py
new file mode 100644
index 00000000..78edc603
--- /dev/null
+++ b/tests/test_breadcrumbs.py
@@ -0,0 +1,73 @@
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework.compat import patterns, url
+from rest_framework.utils.breadcrumbs import get_breadcrumbs
+from rest_framework.views import APIView
+
+
+class Root(APIView):
+ pass
+
+
+class ResourceRoot(APIView):
+ pass
+
+
+class ResourceInstance(APIView):
+ pass
+
+
+class NestedResourceRoot(APIView):
+ pass
+
+
+class NestedResourceInstance(APIView):
+ pass
+
+urlpatterns = patterns('',
+ url(r'^$', Root.as_view()),
+ url(r'^resource/$', ResourceRoot.as_view()),
+ url(r'^resource/(?P<key>[0-9]+)$', ResourceInstance.as_view()),
+ url(r'^resource/(?P<key>[0-9]+)/$', NestedResourceRoot.as_view()),
+ url(r'^resource/(?P<key>[0-9]+)/(?P<other>[A-Za-z]+)$', NestedResourceInstance.as_view()),
+)
+
+
+class BreadcrumbTests(TestCase):
+ """Tests the breadcrumb functionality used by the HTML renderer."""
+
+ urls = 'tests.test_breadcrumbs'
+
+ def test_root_breadcrumbs(self):
+ url = '/'
+ self.assertEqual(get_breadcrumbs(url), [('Root', '/')])
+
+ def test_resource_root_breadcrumbs(self):
+ url = '/resource/'
+ self.assertEqual(get_breadcrumbs(url), [('Root', '/'),
+ ('Resource Root', '/resource/')])
+
+ def test_resource_instance_breadcrumbs(self):
+ url = '/resource/123'
+ self.assertEqual(get_breadcrumbs(url), [('Root', '/'),
+ ('Resource Root', '/resource/'),
+ ('Resource Instance', '/resource/123')])
+
+ def test_nested_resource_breadcrumbs(self):
+ url = '/resource/123/'
+ self.assertEqual(get_breadcrumbs(url), [('Root', '/'),
+ ('Resource Root', '/resource/'),
+ ('Resource Instance', '/resource/123'),
+ ('Nested Resource Root', '/resource/123/')])
+
+ def test_nested_resource_instance_breadcrumbs(self):
+ url = '/resource/123/abc'
+ self.assertEqual(get_breadcrumbs(url), [('Root', '/'),
+ ('Resource Root', '/resource/'),
+ ('Resource Instance', '/resource/123'),
+ ('Nested Resource Root', '/resource/123/'),
+ ('Nested Resource Instance', '/resource/123/abc')])
+
+ def test_broken_url_breadcrumbs_handled_gracefully(self):
+ url = '/foobar'
+ self.assertEqual(get_breadcrumbs(url), [('Root', '/')])
diff --git a/tests/test_decorators.py b/tests/test_decorators.py
new file mode 100644
index 00000000..195f0ba3
--- /dev/null
+++ b/tests/test_decorators.py
@@ -0,0 +1,157 @@
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework import status
+from rest_framework.authentication import BasicAuthentication
+from rest_framework.parsers import JSONParser
+from rest_framework.permissions import IsAuthenticated
+from rest_framework.response import Response
+from rest_framework.renderers import JSONRenderer
+from rest_framework.test import APIRequestFactory
+from rest_framework.throttling import UserRateThrottle
+from rest_framework.views import APIView
+from rest_framework.decorators import (
+ api_view,
+ renderer_classes,
+ parser_classes,
+ authentication_classes,
+ throttle_classes,
+ permission_classes,
+)
+
+
+class DecoratorTestCase(TestCase):
+
+ def setUp(self):
+ self.factory = APIRequestFactory()
+
+ def _finalize_response(self, request, response, *args, **kwargs):
+ response.request = request
+ return APIView.finalize_response(self, request, response, *args, **kwargs)
+
+ def test_api_view_incorrect(self):
+ """
+ If @api_view is not applied correct, we should raise an assertion.
+ """
+
+ @api_view
+ def view(request):
+ return Response()
+
+ request = self.factory.get('/')
+ self.assertRaises(AssertionError, view, request)
+
+ def test_api_view_incorrect_arguments(self):
+ """
+ If @api_view is missing arguments, we should raise an assertion.
+ """
+
+ with self.assertRaises(AssertionError):
+ @api_view('GET')
+ def view(request):
+ return Response()
+
+ def test_calling_method(self):
+
+ @api_view(['GET'])
+ def view(request):
+ return Response({})
+
+ request = self.factory.get('/')
+ response = view(request)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ request = self.factory.post('/')
+ response = view(request)
+ self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
+
+ def test_calling_put_method(self):
+
+ @api_view(['GET', 'PUT'])
+ def view(request):
+ return Response({})
+
+ request = self.factory.put('/')
+ response = view(request)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ request = self.factory.post('/')
+ response = view(request)
+ self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
+
+ def test_calling_patch_method(self):
+
+ @api_view(['GET', 'PATCH'])
+ def view(request):
+ return Response({})
+
+ request = self.factory.patch('/')
+ response = view(request)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ request = self.factory.post('/')
+ response = view(request)
+ self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
+
+ def test_renderer_classes(self):
+
+ @api_view(['GET'])
+ @renderer_classes([JSONRenderer])
+ def view(request):
+ return Response({})
+
+ request = self.factory.get('/')
+ response = view(request)
+ self.assertTrue(isinstance(response.accepted_renderer, JSONRenderer))
+
+ def test_parser_classes(self):
+
+ @api_view(['GET'])
+ @parser_classes([JSONParser])
+ def view(request):
+ self.assertEqual(len(request.parsers), 1)
+ self.assertTrue(isinstance(request.parsers[0],
+ JSONParser))
+ return Response({})
+
+ request = self.factory.get('/')
+ view(request)
+
+ def test_authentication_classes(self):
+
+ @api_view(['GET'])
+ @authentication_classes([BasicAuthentication])
+ def view(request):
+ self.assertEqual(len(request.authenticators), 1)
+ self.assertTrue(isinstance(request.authenticators[0],
+ BasicAuthentication))
+ return Response({})
+
+ request = self.factory.get('/')
+ view(request)
+
+ def test_permission_classes(self):
+
+ @api_view(['GET'])
+ @permission_classes([IsAuthenticated])
+ def view(request):
+ return Response({})
+
+ request = self.factory.get('/')
+ response = view(request)
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+ def test_throttle_classes(self):
+ class OncePerDayUserThrottle(UserRateThrottle):
+ rate = '1/day'
+
+ @api_view(['GET'])
+ @throttle_classes([OncePerDayUserThrottle])
+ def view(request):
+ return Response({})
+
+ request = self.factory.get('/')
+ response = view(request)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ response = view(request)
+ self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS)
diff --git a/tests/test_description.py b/tests/test_description.py
new file mode 100644
index 00000000..1e481f06
--- /dev/null
+++ b/tests/test_description.py
@@ -0,0 +1,108 @@
+# -- coding: utf-8 --
+
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework.compat import apply_markdown, smart_text
+from rest_framework.views import APIView
+from .description import ViewWithNonASCIICharactersInDocstring
+from .description import UTF8_TEST_DOCSTRING
+
+# We check that docstrings get nicely un-indented.
+DESCRIPTION = """an example docstring
+====================
+
+* list
+* list
+
+another header
+--------------
+
+ code block
+
+indented
+
+# hash style header #"""
+
+# If markdown is installed we also test it's working
+# (and that our wrapped forces '=' to h2 and '-' to h3)
+
+# We support markdown < 2.1 and markdown >= 2.1
+MARKED_DOWN_lt_21 = """<h2>an example docstring</h2>
+<ul>
+<li>list</li>
+<li>list</li>
+</ul>
+<h3>another header</h3>
+<pre><code>code block
+</code></pre>
+<p>indented</p>
+<h2 id="hash_style_header">hash style header</h2>"""
+
+MARKED_DOWN_gte_21 = """<h2 id="an-example-docstring">an example docstring</h2>
+<ul>
+<li>list</li>
+<li>list</li>
+</ul>
+<h3 id="another-header">another header</h3>
+<pre><code>code block
+</code></pre>
+<p>indented</p>
+<h2 id="hash-style-header">hash style header</h2>"""
+
+
+class TestViewNamesAndDescriptions(TestCase):
+ def test_view_name_uses_class_name(self):
+ """
+ Ensure view names are based on the class name.
+ """
+ class MockView(APIView):
+ pass
+ self.assertEqual(MockView().get_view_name(), 'Mock')
+
+ def test_view_description_uses_docstring(self):
+ """Ensure view descriptions are based on the docstring."""
+ class MockView(APIView):
+ """an example docstring
+ ====================
+
+ * list
+ * list
+
+ another header
+ --------------
+
+ code block
+
+ indented
+
+ # hash style header #"""
+
+ self.assertEqual(MockView().get_view_description(), DESCRIPTION)
+
+ def test_view_description_supports_unicode(self):
+ """
+ Unicode in docstrings should be respected.
+ """
+
+ self.assertEqual(
+ ViewWithNonASCIICharactersInDocstring().get_view_description(),
+ smart_text(UTF8_TEST_DOCSTRING)
+ )
+
+ def test_view_description_can_be_empty(self):
+ """
+ Ensure that if a view has no docstring,
+ then it's description is the empty string.
+ """
+ class MockView(APIView):
+ pass
+ self.assertEqual(MockView().get_view_description(), '')
+
+ def test_markdown(self):
+ """
+ Ensure markdown to HTML works as expected.
+ """
+ if apply_markdown:
+ gte_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_gte_21
+ lt_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_lt_21
+ self.assertTrue(gte_21_match or lt_21_match)
diff --git a/tests/test_fields.py b/tests/test_fields.py
new file mode 100644
index 00000000..e65a2fb3
--- /dev/null
+++ b/tests/test_fields.py
@@ -0,0 +1,984 @@
+"""
+General serializer field tests.
+"""
+from __future__ import unicode_literals
+
+import datetime
+from decimal import Decimal
+from uuid import uuid4
+from django.core import validators
+from django.db import models
+from django.test import TestCase
+from django.utils.datastructures import SortedDict
+from rest_framework import serializers
+from tests.models import RESTFrameworkModel
+
+
+class TimestampedModel(models.Model):
+ added = models.DateTimeField(auto_now_add=True)
+ updated = models.DateTimeField(auto_now=True)
+
+
+class CharPrimaryKeyModel(models.Model):
+ id = models.CharField(max_length=20, primary_key=True)
+
+
+class TimestampedModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = TimestampedModel
+
+
+class CharPrimaryKeyModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = CharPrimaryKeyModel
+
+
+class TimeFieldModel(models.Model):
+ clock = models.TimeField()
+
+
+class TimeFieldModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = TimeFieldModel
+
+
+SAMPLE_CHOICES = [
+ ('red', 'Red'),
+ ('green', 'Green'),
+ ('blue', 'Blue'),
+]
+
+
+class ChoiceFieldModel(models.Model):
+ choice = models.CharField(choices=SAMPLE_CHOICES, blank=True, max_length=255)
+
+
+class ChoiceFieldModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ChoiceFieldModel
+
+
+class ChoiceFieldModelWithNull(models.Model):
+ choice = models.CharField(choices=SAMPLE_CHOICES, blank=True, null=True, max_length=255)
+
+
+class ChoiceFieldModelWithNullSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ChoiceFieldModelWithNull
+
+
+class BasicFieldTests(TestCase):
+ def test_auto_now_fields_read_only(self):
+ """
+ auto_now and auto_now_add fields should be read_only by default.
+ """
+ serializer = TimestampedModelSerializer()
+ self.assertEqual(serializer.fields['added'].read_only, True)
+
+ def test_auto_pk_fields_read_only(self):
+ """
+ AutoField fields should be read_only by default.
+ """
+ serializer = TimestampedModelSerializer()
+ self.assertEqual(serializer.fields['id'].read_only, True)
+
+ def test_non_auto_pk_fields_not_read_only(self):
+ """
+ PK fields other than AutoField fields should not be read_only by default.
+ """
+ serializer = CharPrimaryKeyModelSerializer()
+ self.assertEqual(serializer.fields['id'].read_only, False)
+
+ def test_dict_field_ordering(self):
+ """
+ Field should preserve dictionary ordering, if it exists.
+ See: https://github.com/tomchristie/django-rest-framework/issues/832
+ """
+ ret = SortedDict()
+ ret['c'] = 1
+ ret['b'] = 1
+ ret['a'] = 1
+ ret['z'] = 1
+ field = serializers.Field()
+ keys = list(field.to_native(ret).keys())
+ self.assertEqual(keys, ['c', 'b', 'a', 'z'])
+
+
+class DateFieldTest(TestCase):
+ """
+ Tests for the DateFieldTest from_native() and to_native() behavior
+ """
+
+ def test_from_native_string(self):
+ """
+ Make sure from_native() accepts default iso input formats.
+ """
+ f = serializers.DateField()
+ result_1 = f.from_native('1984-07-31')
+
+ self.assertEqual(datetime.date(1984, 7, 31), result_1)
+
+ def test_from_native_datetime_date(self):
+ """
+ Make sure from_native() accepts a datetime.date instance.
+ """
+ f = serializers.DateField()
+ result_1 = f.from_native(datetime.date(1984, 7, 31))
+
+ self.assertEqual(result_1, datetime.date(1984, 7, 31))
+
+ def test_from_native_custom_format(self):
+ """
+ Make sure from_native() accepts custom input formats.
+ """
+ f = serializers.DateField(input_formats=['%Y -- %d'])
+ result = f.from_native('1984 -- 31')
+
+ self.assertEqual(datetime.date(1984, 1, 31), result)
+
+ def test_from_native_invalid_default_on_custom_format(self):
+ """
+ Make sure from_native() don't accept default formats if custom format is preset
+ """
+ f = serializers.DateField(input_formats=['%Y -- %d'])
+
+ try:
+ f.from_native('1984-07-31')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY -- DD"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns None on empty param.
+ """
+ f = serializers.DateField()
+ result = f.from_native('')
+
+ self.assertEqual(result, None)
+
+ def test_from_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DateField()
+ result = f.from_native(None)
+
+ self.assertEqual(result, None)
+
+ def test_from_native_invalid_date(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid date.
+ """
+ f = serializers.DateField()
+
+ try:
+ f.from_native('1984-13-31')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_invalid_format(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid format.
+ """
+ f = serializers.DateField()
+
+ try:
+ f.from_native('1984 -- 31')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_to_native(self):
+ """
+ Make sure to_native() returns datetime as default.
+ """
+ f = serializers.DateField()
+
+ result_1 = f.to_native(datetime.date(1984, 7, 31))
+
+ self.assertEqual(datetime.date(1984, 7, 31), result_1)
+
+ def test_to_native_iso(self):
+ """
+ Make sure to_native() with 'iso-8601' returns iso formated date.
+ """
+ f = serializers.DateField(format='iso-8601')
+
+ result_1 = f.to_native(datetime.date(1984, 7, 31))
+
+ self.assertEqual('1984-07-31', result_1)
+
+ def test_to_native_custom_format(self):
+ """
+ Make sure to_native() returns correct custom format.
+ """
+ f = serializers.DateField(format="%Y - %m.%d")
+
+ result_1 = f.to_native(datetime.date(1984, 7, 31))
+
+ self.assertEqual('1984 - 07.31', result_1)
+
+ def test_to_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DateField(required=False)
+ self.assertEqual(None, f.to_native(None))
+
+
+class DateTimeFieldTest(TestCase):
+ """
+ Tests for the DateTimeField from_native() and to_native() behavior
+ """
+
+ def test_from_native_string(self):
+ """
+ Make sure from_native() accepts default iso input formats.
+ """
+ f = serializers.DateTimeField()
+ result_1 = f.from_native('1984-07-31 04:31')
+ result_2 = f.from_native('1984-07-31 04:31:59')
+ result_3 = f.from_native('1984-07-31 04:31:59.000200')
+
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_1)
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_2)
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_3)
+
+ def test_from_native_datetime_datetime(self):
+ """
+ Make sure from_native() accepts a datetime.datetime instance.
+ """
+ f = serializers.DateTimeField()
+ result_1 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31))
+ result_2 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
+ result_3 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+
+ self.assertEqual(result_1, datetime.datetime(1984, 7, 31, 4, 31))
+ self.assertEqual(result_2, datetime.datetime(1984, 7, 31, 4, 31, 59))
+ self.assertEqual(result_3, datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+
+ def test_from_native_custom_format(self):
+ """
+ Make sure from_native() accepts custom input formats.
+ """
+ f = serializers.DateTimeField(input_formats=['%Y -- %H:%M'])
+ result = f.from_native('1984 -- 04:59')
+
+ self.assertEqual(datetime.datetime(1984, 1, 1, 4, 59), result)
+
+ def test_from_native_invalid_default_on_custom_format(self):
+ """
+ Make sure from_native() don't accept default formats if custom format is preset
+ """
+ f = serializers.DateTimeField(input_formats=['%Y -- %H:%M'])
+
+ try:
+ f.from_native('1984-07-31 04:31:59')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: YYYY -- hh:mm"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns None on empty param.
+ """
+ f = serializers.DateTimeField()
+ result = f.from_native('')
+
+ self.assertEqual(result, None)
+
+ def test_from_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DateTimeField()
+ result = f.from_native(None)
+
+ self.assertEqual(result, None)
+
+ def test_from_native_invalid_datetime(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid datetime.
+ """
+ f = serializers.DateTimeField()
+
+ try:
+ f.from_native('04:61:59')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: "
+ "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_invalid_format(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid format.
+ """
+ f = serializers.DateTimeField()
+
+ try:
+ f.from_native('04 -- 31')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: "
+ "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_to_native(self):
+ """
+ Make sure to_native() returns isoformat as default.
+ """
+ f = serializers.DateTimeField()
+
+ result_1 = f.to_native(datetime.datetime(1984, 7, 31))
+ result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31))
+ result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
+ result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+
+ self.assertEqual(datetime.datetime(1984, 7, 31), result_1)
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_2)
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_3)
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_4)
+
+ def test_to_native_iso(self):
+ """
+ Make sure to_native() with format=iso-8601 returns iso formatted datetime.
+ """
+ f = serializers.DateTimeField(format='iso-8601')
+
+ result_1 = f.to_native(datetime.datetime(1984, 7, 31))
+ result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31))
+ result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
+ result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+
+ self.assertEqual('1984-07-31T00:00:00', result_1)
+ self.assertEqual('1984-07-31T04:31:00', result_2)
+ self.assertEqual('1984-07-31T04:31:59', result_3)
+ self.assertEqual('1984-07-31T04:31:59.000200', result_4)
+
+ def test_to_native_custom_format(self):
+ """
+ Make sure to_native() returns correct custom format.
+ """
+ f = serializers.DateTimeField(format="%Y - %H:%M")
+
+ result_1 = f.to_native(datetime.datetime(1984, 7, 31))
+ result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31))
+ result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
+ result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+
+ self.assertEqual('1984 - 00:00', result_1)
+ self.assertEqual('1984 - 04:31', result_2)
+ self.assertEqual('1984 - 04:31', result_3)
+ self.assertEqual('1984 - 04:31', result_4)
+
+ def test_to_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DateTimeField(required=False)
+ self.assertEqual(None, f.to_native(None))
+
+
+class TimeFieldTest(TestCase):
+ """
+ Tests for the TimeField from_native() and to_native() behavior
+ """
+
+ def test_from_native_string(self):
+ """
+ Make sure from_native() accepts default iso input formats.
+ """
+ f = serializers.TimeField()
+ result_1 = f.from_native('04:31')
+ result_2 = f.from_native('04:31:59')
+ result_3 = f.from_native('04:31:59.000200')
+
+ self.assertEqual(datetime.time(4, 31), result_1)
+ self.assertEqual(datetime.time(4, 31, 59), result_2)
+ self.assertEqual(datetime.time(4, 31, 59, 200), result_3)
+
+ def test_from_native_datetime_time(self):
+ """
+ Make sure from_native() accepts a datetime.time instance.
+ """
+ f = serializers.TimeField()
+ result_1 = f.from_native(datetime.time(4, 31))
+ result_2 = f.from_native(datetime.time(4, 31, 59))
+ result_3 = f.from_native(datetime.time(4, 31, 59, 200))
+
+ self.assertEqual(result_1, datetime.time(4, 31))
+ self.assertEqual(result_2, datetime.time(4, 31, 59))
+ self.assertEqual(result_3, datetime.time(4, 31, 59, 200))
+
+ def test_from_native_custom_format(self):
+ """
+ Make sure from_native() accepts custom input formats.
+ """
+ f = serializers.TimeField(input_formats=['%H -- %M'])
+ result = f.from_native('04 -- 31')
+
+ self.assertEqual(datetime.time(4, 31), result)
+
+ def test_from_native_invalid_default_on_custom_format(self):
+ """
+ Make sure from_native() don't accept default formats if custom format is preset
+ """
+ f = serializers.TimeField(input_formats=['%H -- %M'])
+
+ try:
+ f.from_native('04:31:59')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: hh -- mm"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns None on empty param.
+ """
+ f = serializers.TimeField()
+ result = f.from_native('')
+
+ self.assertEqual(result, None)
+
+ def test_from_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.TimeField()
+ result = f.from_native(None)
+
+ self.assertEqual(result, None)
+
+ def test_from_native_invalid_time(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid time.
+ """
+ f = serializers.TimeField()
+
+ try:
+ f.from_native('04:61:59')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: "
+ "hh:mm[:ss[.uuuuuu]]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_invalid_format(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid format.
+ """
+ f = serializers.TimeField()
+
+ try:
+ f.from_native('04 -- 31')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: "
+ "hh:mm[:ss[.uuuuuu]]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_to_native(self):
+ """
+ Make sure to_native() returns time object as default.
+ """
+ f = serializers.TimeField()
+ result_1 = f.to_native(datetime.time(4, 31))
+ result_2 = f.to_native(datetime.time(4, 31, 59))
+ result_3 = f.to_native(datetime.time(4, 31, 59, 200))
+
+ self.assertEqual(datetime.time(4, 31), result_1)
+ self.assertEqual(datetime.time(4, 31, 59), result_2)
+ self.assertEqual(datetime.time(4, 31, 59, 200), result_3)
+
+ def test_to_native_iso(self):
+ """
+ Make sure to_native() with format='iso-8601' returns iso formatted time.
+ """
+ f = serializers.TimeField(format='iso-8601')
+ result_1 = f.to_native(datetime.time(4, 31))
+ result_2 = f.to_native(datetime.time(4, 31, 59))
+ result_3 = f.to_native(datetime.time(4, 31, 59, 200))
+
+ self.assertEqual('04:31:00', result_1)
+ self.assertEqual('04:31:59', result_2)
+ self.assertEqual('04:31:59.000200', result_3)
+
+ def test_to_native_custom_format(self):
+ """
+ Make sure to_native() returns correct custom format.
+ """
+ f = serializers.TimeField(format="%H - %S [%f]")
+ result_1 = f.to_native(datetime.time(4, 31))
+ result_2 = f.to_native(datetime.time(4, 31, 59))
+ result_3 = f.to_native(datetime.time(4, 31, 59, 200))
+
+ self.assertEqual('04 - 00 [000000]', result_1)
+ self.assertEqual('04 - 59 [000000]', result_2)
+ self.assertEqual('04 - 59 [000200]', result_3)
+
+
+class DecimalFieldTest(TestCase):
+ """
+ Tests for the DecimalField from_native() and to_native() behavior
+ """
+
+ def test_from_native_string(self):
+ """
+ Make sure from_native() accepts string values
+ """
+ f = serializers.DecimalField()
+ result_1 = f.from_native('9000')
+ result_2 = f.from_native('1.00000001')
+
+ self.assertEqual(Decimal('9000'), result_1)
+ self.assertEqual(Decimal('1.00000001'), result_2)
+
+ def test_from_native_invalid_string(self):
+ """
+ Make sure from_native() raises ValidationError on passing invalid string
+ """
+ f = serializers.DecimalField()
+
+ try:
+ f.from_native('123.45.6')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Enter a number."])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_integer(self):
+ """
+ Make sure from_native() accepts integer values
+ """
+ f = serializers.DecimalField()
+ result = f.from_native(9000)
+
+ self.assertEqual(Decimal('9000'), result)
+
+ def test_from_native_float(self):
+ """
+ Make sure from_native() accepts float values
+ """
+ f = serializers.DecimalField()
+ result = f.from_native(1.00000001)
+
+ self.assertEqual(Decimal('1.00000001'), result)
+
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns None on empty param.
+ """
+ f = serializers.DecimalField()
+ result = f.from_native('')
+
+ self.assertEqual(result, None)
+
+ def test_from_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DecimalField()
+ result = f.from_native(None)
+
+ self.assertEqual(result, None)
+
+ def test_to_native(self):
+ """
+ Make sure to_native() returns Decimal as string.
+ """
+ f = serializers.DecimalField()
+
+ result_1 = f.to_native(Decimal('9000'))
+ result_2 = f.to_native(Decimal('1.00000001'))
+
+ self.assertEqual(Decimal('9000'), result_1)
+ self.assertEqual(Decimal('1.00000001'), result_2)
+
+ def test_to_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DecimalField(required=False)
+ self.assertEqual(None, f.to_native(None))
+
+ def test_valid_serialization(self):
+ """
+ Make sure the serializer works correctly
+ """
+ class DecimalSerializer(serializers.Serializer):
+ decimal_field = serializers.DecimalField(max_value=9010,
+ min_value=9000,
+ max_digits=6,
+ decimal_places=2)
+
+ self.assertTrue(DecimalSerializer(data={'decimal_field': '9001'}).is_valid())
+ self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.2'}).is_valid())
+ self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.23'}).is_valid())
+
+ self.assertFalse(DecimalSerializer(data={'decimal_field': '8000'}).is_valid())
+ self.assertFalse(DecimalSerializer(data={'decimal_field': '9900'}).is_valid())
+ self.assertFalse(DecimalSerializer(data={'decimal_field': '9001.234'}).is_valid())
+
+ def test_raise_max_value(self):
+ """
+ Make sure max_value violations raises ValidationError
+ """
+ class DecimalSerializer(serializers.Serializer):
+ decimal_field = serializers.DecimalField(max_value=100)
+
+ s = DecimalSerializer(data={'decimal_field': '123'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is less than or equal to 100.']})
+
+ def test_raise_min_value(self):
+ """
+ Make sure min_value violations raises ValidationError
+ """
+ class DecimalSerializer(serializers.Serializer):
+ decimal_field = serializers.DecimalField(min_value=100)
+
+ s = DecimalSerializer(data={'decimal_field': '99'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is greater than or equal to 100.']})
+
+ def test_raise_max_digits(self):
+ """
+ Make sure max_digits violations raises ValidationError
+ """
+ class DecimalSerializer(serializers.Serializer):
+ decimal_field = serializers.DecimalField(max_digits=5)
+
+ s = DecimalSerializer(data={'decimal_field': '123.456'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 5 digits in total.']})
+
+ def test_raise_max_decimal_places(self):
+ """
+ Make sure max_decimal_places violations raises ValidationError
+ """
+ class DecimalSerializer(serializers.Serializer):
+ decimal_field = serializers.DecimalField(decimal_places=3)
+
+ s = DecimalSerializer(data={'decimal_field': '123.4567'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 3 decimal places.']})
+
+ def test_raise_max_whole_digits(self):
+ """
+ Make sure max_whole_digits violations raises ValidationError
+ """
+ class DecimalSerializer(serializers.Serializer):
+ decimal_field = serializers.DecimalField(max_digits=4, decimal_places=3)
+
+ s = DecimalSerializer(data={'decimal_field': '12345.6'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']})
+
+
+class ChoiceFieldTests(TestCase):
+ """
+ Tests for the ChoiceField options generator
+ """
+ def test_choices_required(self):
+ """
+ Make sure proper choices are rendered if field is required
+ """
+ f = serializers.ChoiceField(required=True, choices=SAMPLE_CHOICES)
+ self.assertEqual(f.choices, SAMPLE_CHOICES)
+
+ def test_choices_not_required(self):
+ """
+ Make sure proper choices (plus blank) are rendered if the field isn't required
+ """
+ f = serializers.ChoiceField(required=False, choices=SAMPLE_CHOICES)
+ self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + SAMPLE_CHOICES)
+
+ def test_invalid_choice_model(self):
+ s = ChoiceFieldModelSerializer(data={'choice': 'wrong_value'})
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'choice': ['Select a valid choice. wrong_value is not one of the available choices.']})
+ self.assertEqual(s.data['choice'], '')
+
+ def test_empty_choice_model(self):
+ """
+ Test that the 'empty' value is correctly passed and used depending on
+ the 'null' property on the model field.
+ """
+ s = ChoiceFieldModelSerializer(data={'choice': ''})
+ self.assertTrue(s.is_valid())
+ self.assertEqual(s.data['choice'], '')
+
+ s = ChoiceFieldModelWithNullSerializer(data={'choice': ''})
+ self.assertTrue(s.is_valid())
+ self.assertEqual(s.data['choice'], None)
+
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns an empty string on empty param by default.
+ """
+ f = serializers.ChoiceField(choices=SAMPLE_CHOICES)
+ self.assertEqual(f.from_native(''), '')
+ self.assertEqual(f.from_native(None), '')
+
+ def test_from_native_empty_override(self):
+ """
+ Make sure you can override from_native() behavior regarding empty values.
+ """
+ f = serializers.ChoiceField(choices=SAMPLE_CHOICES, empty=None)
+ self.assertEqual(f.from_native(''), None)
+ self.assertEqual(f.from_native(None), None)
+
+ def test_metadata_choices(self):
+ """
+ Make sure proper choices are included in the field's metadata.
+ """
+ choices = [{'value': v, 'display_name': n} for v, n in SAMPLE_CHOICES]
+ f = serializers.ChoiceField(choices=SAMPLE_CHOICES)
+ self.assertEqual(f.metadata()['choices'], choices)
+
+ def test_metadata_choices_not_required(self):
+ """
+ Make sure proper choices are included in the field's metadata.
+ """
+ choices = [{'value': v, 'display_name': n}
+ for v, n in models.fields.BLANK_CHOICE_DASH + SAMPLE_CHOICES]
+ f = serializers.ChoiceField(required=False, choices=SAMPLE_CHOICES)
+ self.assertEqual(f.metadata()['choices'], choices)
+
+
+class EmailFieldTests(TestCase):
+ """
+ Tests for EmailField attribute values
+ """
+
+ class EmailFieldModel(RESTFrameworkModel):
+ email_field = models.EmailField(blank=True)
+
+ class EmailFieldWithGivenMaxLengthModel(RESTFrameworkModel):
+ email_field = models.EmailField(max_length=150, blank=True)
+
+ def test_default_model_value(self):
+ class EmailFieldSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = self.EmailFieldModel
+
+ serializer = EmailFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 75)
+
+ def test_given_model_value(self):
+ class EmailFieldSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = self.EmailFieldWithGivenMaxLengthModel
+
+ serializer = EmailFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 150)
+
+ def test_given_serializer_value(self):
+ class EmailFieldSerializer(serializers.ModelSerializer):
+ email_field = serializers.EmailField(source='email_field', max_length=20, required=False)
+
+ class Meta:
+ model = self.EmailFieldModel
+
+ serializer = EmailFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 20)
+
+
+class SlugFieldTests(TestCase):
+ """
+ Tests for SlugField attribute values
+ """
+
+ class SlugFieldModel(RESTFrameworkModel):
+ slug_field = models.SlugField(blank=True)
+
+ class SlugFieldWithGivenMaxLengthModel(RESTFrameworkModel):
+ slug_field = models.SlugField(max_length=84, blank=True)
+
+ def test_default_model_value(self):
+ class SlugFieldSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = self.SlugFieldModel
+
+ serializer = SlugFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['slug_field'], 'max_length'), 50)
+
+ def test_given_model_value(self):
+ class SlugFieldSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = self.SlugFieldWithGivenMaxLengthModel
+
+ serializer = SlugFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['slug_field'], 'max_length'), 84)
+
+ def test_given_serializer_value(self):
+ class SlugFieldSerializer(serializers.ModelSerializer):
+ slug_field = serializers.SlugField(source='slug_field',
+ max_length=20, required=False)
+
+ class Meta:
+ model = self.SlugFieldModel
+
+ serializer = SlugFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['slug_field'],
+ 'max_length'), 20)
+
+ def test_invalid_slug(self):
+ """
+ Make sure an invalid slug raises ValidationError
+ """
+ class SlugFieldSerializer(serializers.ModelSerializer):
+ slug_field = serializers.SlugField(source='slug_field', max_length=20, required=True)
+
+ class Meta:
+ model = self.SlugFieldModel
+
+ s = SlugFieldSerializer(data={'slug_field': 'a b'})
+
+ self.assertEqual(s.is_valid(), False)
+ self.assertEqual(s.errors, {'slug_field': ["Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens."]})
+
+
+class URLFieldTests(TestCase):
+ """
+ Tests for URLField attribute values.
+
+ (Includes test for #1210, checking that validators can be overridden.)
+ """
+
+ class URLFieldModel(RESTFrameworkModel):
+ url_field = models.URLField(blank=True)
+
+ class URLFieldWithGivenMaxLengthModel(RESTFrameworkModel):
+ url_field = models.URLField(max_length=128, blank=True)
+
+ def test_default_model_value(self):
+ class URLFieldSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = self.URLFieldModel
+
+ serializer = URLFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['url_field'],
+ 'max_length'), 200)
+
+ def test_given_model_value(self):
+ class URLFieldSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = self.URLFieldWithGivenMaxLengthModel
+
+ serializer = URLFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['url_field'],
+ 'max_length'), 128)
+
+ def test_given_serializer_value(self):
+ class URLFieldSerializer(serializers.ModelSerializer):
+ url_field = serializers.URLField(source='url_field',
+ max_length=20, required=False)
+
+ class Meta:
+ model = self.URLFieldWithGivenMaxLengthModel
+
+ serializer = URLFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['url_field'],
+ 'max_length'), 20)
+
+ def test_validators_can_be_overridden(self):
+ url_field = serializers.URLField(validators=[])
+ validators = url_field.validators
+ self.assertEqual([], validators, 'Passing `validators` kwarg should have overridden default validators')
+
+
+class FieldMetadata(TestCase):
+ def setUp(self):
+ self.required_field = serializers.Field()
+ self.required_field.label = uuid4().hex
+ self.required_field.required = True
+
+ self.optional_field = serializers.Field()
+ self.optional_field.label = uuid4().hex
+ self.optional_field.required = False
+
+ def test_required(self):
+ self.assertEqual(self.required_field.metadata()['required'], True)
+
+ def test_optional(self):
+ self.assertEqual(self.optional_field.metadata()['required'], False)
+
+ def test_label(self):
+ for field in (self.required_field, self.optional_field):
+ self.assertEqual(field.metadata()['label'], field.label)
+
+
+class FieldCallableDefault(TestCase):
+ def setUp(self):
+ self.simple_callable = lambda: 'foo bar'
+
+ def test_default_can_be_simple_callable(self):
+ """
+ Ensure that the 'default' argument can also be a simple callable.
+ """
+ field = serializers.WritableField(default=self.simple_callable)
+ into = {}
+ field.field_from_native({}, {}, 'field', into)
+ self.assertEqual(into, {'field': 'foo bar'})
+
+
+class CustomIntegerField(TestCase):
+ """
+ Test that custom fields apply min_value and max_value constraints
+ """
+ def test_custom_fields_can_be_validated_for_value(self):
+
+ class MoneyField(models.PositiveIntegerField):
+ pass
+
+ class EntryModel(models.Model):
+ bank = MoneyField(validators=[validators.MaxValueValidator(100)])
+
+ class EntrySerializer(serializers.ModelSerializer):
+ class Meta:
+ model = EntryModel
+
+ entry = EntryModel(bank=1)
+
+ serializer = EntrySerializer(entry, data={"bank": 11})
+ self.assertTrue(serializer.is_valid())
+
+ serializer = EntrySerializer(entry, data={"bank": -1})
+ self.assertFalse(serializer.is_valid())
+
+ serializer = EntrySerializer(entry, data={"bank": 101})
+ self.assertFalse(serializer.is_valid())
+
+
+class BooleanField(TestCase):
+ """
+ Tests for BooleanField
+ """
+ def test_boolean_required(self):
+ class BooleanRequiredSerializer(serializers.Serializer):
+ bool_field = serializers.BooleanField(required=True)
+
+ self.assertFalse(BooleanRequiredSerializer(data={}).is_valid())
diff --git a/tests/test_files.py b/tests/test_files.py
new file mode 100644
index 00000000..78f4cf42
--- /dev/null
+++ b/tests/test_files.py
@@ -0,0 +1,95 @@
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework import serializers
+from rest_framework.compat import BytesIO
+from rest_framework.compat import six
+import datetime
+
+
+class UploadedFile(object):
+ def __init__(self, file=None, created=None):
+ self.file = file
+ self.created = created or datetime.datetime.now()
+
+
+class UploadedFileSerializer(serializers.Serializer):
+ file = serializers.FileField(required=False)
+ created = serializers.DateTimeField()
+
+ def restore_object(self, attrs, instance=None):
+ if instance:
+ instance.file = attrs['file']
+ instance.created = attrs['created']
+ return instance
+ return UploadedFile(**attrs)
+
+
+class FileSerializerTests(TestCase):
+ def test_create(self):
+ now = datetime.datetime.now()
+ file = BytesIO(six.b('stuff'))
+ file.name = 'stuff.txt'
+ file.size = len(file.getvalue())
+ serializer = UploadedFileSerializer(data={'created': now}, files={'file': file})
+ uploaded_file = UploadedFile(file=file, created=now)
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.object.created, uploaded_file.created)
+ self.assertEqual(serializer.object.file, uploaded_file.file)
+ self.assertFalse(serializer.object is uploaded_file)
+
+ def test_creation_failure(self):
+ """
+ Passing files=None should result in an ValidationError
+
+ Regression test for:
+ https://github.com/tomchristie/django-rest-framework/issues/542
+ """
+ now = datetime.datetime.now()
+
+ serializer = UploadedFileSerializer(data={'created': now})
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.object.created, now)
+ self.assertIsNone(serializer.object.file)
+
+ def test_remove_with_empty_string(self):
+ """
+ Passing empty string as data should cause file to be removed
+
+ Test for:
+ https://github.com/tomchristie/django-rest-framework/issues/937
+ """
+ now = datetime.datetime.now()
+ file = BytesIO(six.b('stuff'))
+ file.name = 'stuff.txt'
+ file.size = len(file.getvalue())
+
+ uploaded_file = UploadedFile(file=file, created=now)
+
+ serializer = UploadedFileSerializer(instance=uploaded_file, data={'created': now, 'file': ''})
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.object.created, uploaded_file.created)
+ self.assertIsNone(serializer.object.file)
+
+ def test_validation_error_with_non_file(self):
+ """
+ Passing non-files should raise a validation error.
+ """
+ now = datetime.datetime.now()
+ errmsg = 'No file was submitted. Check the encoding type on the form.'
+
+ serializer = UploadedFileSerializer(data={'created': now, 'file': 'abc'})
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'file': [errmsg]})
+
+ def test_validation_with_no_data(self):
+ """
+ Validation should still function when no data dictionary is provided.
+ """
+ now = datetime.datetime.now()
+ file = BytesIO(six.b('stuff'))
+ file.name = 'stuff.txt'
+ file.size = len(file.getvalue())
+ uploaded_file = UploadedFile(file=file, created=now)
+
+ serializer = UploadedFileSerializer(files={'file': file})
+ self.assertFalse(serializer.is_valid())
diff --git a/tests/test_filters.py b/tests/test_filters.py
new file mode 100644
index 00000000..d9d8042e
--- /dev/null
+++ b/tests/test_filters.py
@@ -0,0 +1,615 @@
+from __future__ import unicode_literals
+import datetime
+from decimal import Decimal
+from django.db import models
+from django.core.urlresolvers import reverse
+from django.test import TestCase
+from django.utils import unittest
+from rest_framework import generics, serializers, status, filters
+from rest_framework.compat import django_filters, patterns, url
+from rest_framework.test import APIRequestFactory
+from tests.models import BasicModel
+
+factory = APIRequestFactory()
+
+
+class FilterableItem(models.Model):
+ text = models.CharField(max_length=100)
+ decimal = models.DecimalField(max_digits=4, decimal_places=2)
+ date = models.DateField()
+
+
+if django_filters:
+ # Basic filter on a list view.
+ class FilterFieldsRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_fields = ['decimal', 'date']
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ # These class are used to test a filter class.
+ class SeveralFieldsFilter(django_filters.FilterSet):
+ text = django_filters.CharFilter(lookup_type='icontains')
+ decimal = django_filters.NumberFilter(lookup_type='lt')
+ date = django_filters.DateFilter(lookup_type='gt')
+
+ class Meta:
+ model = FilterableItem
+ fields = ['text', 'decimal', 'date']
+
+ class FilterClassRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_class = SeveralFieldsFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ # These classes are used to test a misconfigured filter class.
+ class MisconfiguredFilter(django_filters.FilterSet):
+ text = django_filters.CharFilter(lookup_type='icontains')
+
+ class Meta:
+ model = BasicModel
+ fields = ['text']
+
+ class IncorrectlyConfiguredRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_class = MisconfiguredFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ class FilterClassDetailView(generics.RetrieveAPIView):
+ model = FilterableItem
+ filter_class = SeveralFieldsFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ # Regression test for #814
+ class FilterableItemSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = FilterableItem
+
+ class FilterFieldsQuerysetView(generics.ListCreateAPIView):
+ queryset = FilterableItem.objects.all()
+ serializer_class = FilterableItemSerializer
+ filter_fields = ['decimal', 'date']
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ class GetQuerysetView(generics.ListCreateAPIView):
+ serializer_class = FilterableItemSerializer
+ filter_class = SeveralFieldsFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ def get_queryset(self):
+ return FilterableItem.objects.all()
+
+ urlpatterns = patterns('',
+ url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'),
+ url(r'^$', FilterClassRootView.as_view(), name='root-view'),
+ url(r'^get-queryset/$', GetQuerysetView.as_view(),
+ name='get-queryset-view'),
+ )
+
+
+class CommonFilteringTestCase(TestCase):
+ def _serialize_object(self, obj):
+ return {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
+
+ def setUp(self):
+ """
+ Create 10 FilterableItem instances.
+ """
+ base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
+ for i in range(10):
+ text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
+ decimal = base_data[1] + i
+ date = base_data[2] - datetime.timedelta(days=i * 2)
+ FilterableItem(text=text, decimal=decimal, date=date).save()
+
+ self.objects = FilterableItem.objects
+ self.data = [
+ self._serialize_object(obj)
+ for obj in self.objects.all()
+ ]
+
+
+class IntegrationTestFiltering(CommonFilteringTestCase):
+ """
+ Integration tests for filtered list views.
+ """
+
+ @unittest.skipUnless(django_filters, 'django-filter not installed')
+ def test_get_filtered_fields_root_view(self):
+ """
+ GET requests to paginated ListCreateAPIView should return paginated results.
+ """
+ view = FilterFieldsRootView.as_view()
+
+ # Basic test with no filter.
+ request = factory.get('/')
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
+
+ # Tests that the decimal filter works.
+ search_decimal = Decimal('2.25')
+ request = factory.get('/?decimal=%s' % search_decimal)
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['decimal'] == search_decimal]
+ self.assertEqual(response.data, expected_data)
+
+ # Tests that the date filter works.
+ search_date = datetime.date(2012, 9, 22)
+ request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22'
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['date'] == search_date]
+ self.assertEqual(response.data, expected_data)
+
+ @unittest.skipUnless(django_filters, 'django-filter not installed')
+ def test_filter_with_queryset(self):
+ """
+ Regression test for #814.
+ """
+ view = FilterFieldsQuerysetView.as_view()
+
+ # Tests that the decimal filter works.
+ search_decimal = Decimal('2.25')
+ request = factory.get('/?decimal=%s' % search_decimal)
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['decimal'] == search_decimal]
+ self.assertEqual(response.data, expected_data)
+
+ @unittest.skipUnless(django_filters, 'django-filter not installed')
+ def test_filter_with_get_queryset_only(self):
+ """
+ Regression test for #834.
+ """
+ view = GetQuerysetView.as_view()
+ request = factory.get('/get-queryset/')
+ view(request).render()
+ # Used to raise "issubclass() arg 2 must be a class or tuple of classes"
+ # here when neither `model' nor `queryset' was specified.
+
+ @unittest.skipUnless(django_filters, 'django-filter not installed')
+ def test_get_filtered_class_root_view(self):
+ """
+ GET requests to filtered ListCreateAPIView that have a filter_class set
+ should return filtered results.
+ """
+ view = FilterClassRootView.as_view()
+
+ # Basic test with no filter.
+ request = factory.get('/')
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
+
+ # Tests that the decimal filter set with 'lt' in the filter class works.
+ search_decimal = Decimal('4.25')
+ request = factory.get('/?decimal=%s' % search_decimal)
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['decimal'] < search_decimal]
+ self.assertEqual(response.data, expected_data)
+
+ # Tests that the date filter set with 'gt' in the filter class works.
+ search_date = datetime.date(2012, 10, 2)
+ request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02'
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['date'] > search_date]
+ self.assertEqual(response.data, expected_data)
+
+ # Tests that the text filter set with 'icontains' in the filter class works.
+ search_text = 'ff'
+ request = factory.get('/?text=%s' % search_text)
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if search_text in f['text'].lower()]
+ self.assertEqual(response.data, expected_data)
+
+ # Tests that multiple filters works.
+ search_decimal = Decimal('5.25')
+ search_date = datetime.date(2012, 10, 2)
+ request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date))
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['date'] > search_date and
+ f['decimal'] < search_decimal]
+ self.assertEqual(response.data, expected_data)
+
+ @unittest.skipUnless(django_filters, 'django-filter not installed')
+ def test_incorrectly_configured_filter(self):
+ """
+ An error should be displayed when the filter class is misconfigured.
+ """
+ view = IncorrectlyConfiguredRootView.as_view()
+
+ request = factory.get('/')
+ self.assertRaises(AssertionError, view, request)
+
+ @unittest.skipUnless(django_filters, 'django-filter not installed')
+ def test_unknown_filter(self):
+ """
+ GET requests with filters that aren't configured should return 200.
+ """
+ view = FilterFieldsRootView.as_view()
+
+ search_integer = 10
+ request = factory.get('/?integer=%s' % search_integer)
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+
+class IntegrationTestDetailFiltering(CommonFilteringTestCase):
+ """
+ Integration tests for filtered detail views.
+ """
+ urls = 'tests.test_filters'
+
+ def _get_url(self, item):
+ return reverse('detail-view', kwargs=dict(pk=item.pk))
+
+ @unittest.skipUnless(django_filters, 'django-filter not installed')
+ def test_get_filtered_detail_view(self):
+ """
+ GET requests to filtered RetrieveAPIView that have a filter_class set
+ should return filtered results.
+ """
+ item = self.objects.all()[0]
+ data = self._serialize_object(item)
+
+ # Basic test with no filter.
+ response = self.client.get(self._get_url(item))
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, data)
+
+ # Tests that the decimal filter set that should fail.
+ search_decimal = Decimal('4.25')
+ high_item = self.objects.filter(decimal__gt=search_decimal)[0]
+ response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(high_item), param=search_decimal))
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
+ # Tests that the decimal filter set that should succeed.
+ search_decimal = Decimal('4.25')
+ low_item = self.objects.filter(decimal__lt=search_decimal)[0]
+ low_item_data = self._serialize_object(low_item)
+ response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal))
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, low_item_data)
+
+ # Tests that multiple filters works.
+ search_decimal = Decimal('5.25')
+ search_date = datetime.date(2012, 10, 2)
+ valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0]
+ valid_item_data = self._serialize_object(valid_item)
+ response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date))
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, valid_item_data)
+
+
+class SearchFilterModel(models.Model):
+ title = models.CharField(max_length=20)
+ text = models.CharField(max_length=100)
+
+
+class SearchFilterTests(TestCase):
+ def setUp(self):
+ # Sequence of title/text is:
+ #
+ # z abc
+ # zz bcd
+ # zzz cde
+ # ...
+ for idx in range(10):
+ title = 'z' * (idx + 1)
+ text = (
+ chr(idx + ord('a')) +
+ chr(idx + ord('b')) +
+ chr(idx + ord('c'))
+ )
+ SearchFilterModel(title=title, text=text).save()
+
+ def test_search(self):
+ class SearchListView(generics.ListAPIView):
+ model = SearchFilterModel
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('title', 'text')
+
+ view = SearchListView.as_view()
+ request = factory.get('?search=b')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 1, 'title': 'z', 'text': 'abc'},
+ {'id': 2, 'title': 'zz', 'text': 'bcd'}
+ ]
+ )
+
+ def test_exact_search(self):
+ class SearchListView(generics.ListAPIView):
+ model = SearchFilterModel
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('=title', 'text')
+
+ view = SearchListView.as_view()
+ request = factory.get('?search=zzz')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'zzz', 'text': 'cde'}
+ ]
+ )
+
+ def test_startswith_search(self):
+ class SearchListView(generics.ListAPIView):
+ model = SearchFilterModel
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('title', '^text')
+
+ view = SearchListView.as_view()
+ request = factory.get('?search=b')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 2, 'title': 'zz', 'text': 'bcd'}
+ ]
+ )
+
+
+class OrdringFilterModel(models.Model):
+ title = models.CharField(max_length=20)
+ text = models.CharField(max_length=100)
+
+
+class OrderingFilterRelatedModel(models.Model):
+ related_object = models.ForeignKey(OrdringFilterModel,
+ related_name="relateds")
+
+
+class OrderingFilterTests(TestCase):
+ def setUp(self):
+ # Sequence of title/text is:
+ #
+ # zyx abc
+ # yxw bcd
+ # xwv cde
+ for idx in range(3):
+ title = (
+ chr(ord('z') - idx) +
+ chr(ord('y') - idx) +
+ chr(ord('x') - idx)
+ )
+ text = (
+ chr(idx + ord('a')) +
+ chr(idx + ord('b')) +
+ chr(idx + ord('c'))
+ )
+ OrdringFilterModel(title=title, text=text).save()
+
+ def test_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+ ordering_fields = ('text',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=text')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ ]
+ )
+
+ def test_reverse_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+ ordering_fields = ('text',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=-text')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
+
+ def test_incorrectfield_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+ ordering_fields = ('text',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=foobar')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
+
+ def test_default_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+ oredering_fields = ('text',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
+
+ def test_default_ordering_using_string(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = 'title'
+ ordering_fields = ('text',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
+
+ def test_ordering_by_aggregate_field(self):
+ # create some related models to aggregate order by
+ num_objs = [2, 5, 3]
+ for obj, num_relateds in zip(OrdringFilterModel.objects.all(),
+ num_objs):
+ for _ in range(num_relateds):
+ new_related = OrderingFilterRelatedModel(
+ related_object=obj
+ )
+ new_related.save()
+
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = 'title'
+ ordering_fields = '__all__'
+ queryset = OrdringFilterModel.objects.all().annotate(
+ models.Count("relateds"))
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=relateds__count')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ ]
+ )
+
+
+class SensitiveOrderingFilterModel(models.Model):
+ username = models.CharField(max_length=20)
+ password = models.CharField(max_length=100)
+
+
+# Three different styles of serializer.
+# All should allow ordering by username, but not by password.
+class SensitiveDataSerializer1(serializers.ModelSerializer):
+ username = serializers.CharField()
+
+ class Meta:
+ model = SensitiveOrderingFilterModel
+ fields = ('id', 'username')
+
+
+class SensitiveDataSerializer2(serializers.ModelSerializer):
+ username = serializers.CharField()
+ password = serializers.CharField(write_only=True)
+
+ class Meta:
+ model = SensitiveOrderingFilterModel
+ fields = ('id', 'username', 'password')
+
+
+class SensitiveDataSerializer3(serializers.ModelSerializer):
+ user = serializers.CharField(source='username')
+
+ class Meta:
+ model = SensitiveOrderingFilterModel
+ fields = ('id', 'user')
+
+
+class SensitiveOrderingFilterTests(TestCase):
+ def setUp(self):
+ for idx in range(3):
+ username = {0: 'userA', 1: 'userB', 2: 'userC'}[idx]
+ password = {0: 'passA', 1: 'passC', 2: 'passB'}[idx]
+ SensitiveOrderingFilterModel(username=username, password=password).save()
+
+ def test_order_by_serializer_fields(self):
+ for serializer_cls in [
+ SensitiveDataSerializer1,
+ SensitiveDataSerializer2,
+ SensitiveDataSerializer3
+ ]:
+ class OrderingListView(generics.ListAPIView):
+ queryset = SensitiveOrderingFilterModel.objects.all().order_by('username')
+ filter_backends = (filters.OrderingFilter,)
+ serializer_class = serializer_cls
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=-username')
+ response = view(request)
+
+ if serializer_cls == SensitiveDataSerializer3:
+ username_field = 'user'
+ else:
+ username_field = 'username'
+
+ # Note: Inverse username ordering correctly applied.
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, username_field: 'userC'},
+ {'id': 2, username_field: 'userB'},
+ {'id': 1, username_field: 'userA'},
+ ]
+ )
+
+ def test_cannot_order_by_non_serializer_fields(self):
+ for serializer_cls in [
+ SensitiveDataSerializer1,
+ SensitiveDataSerializer2,
+ SensitiveDataSerializer3
+ ]:
+ class OrderingListView(generics.ListAPIView):
+ queryset = SensitiveOrderingFilterModel.objects.all().order_by('username')
+ filter_backends = (filters.OrderingFilter,)
+ serializer_class = serializer_cls
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=password')
+ response = view(request)
+
+ if serializer_cls == SensitiveDataSerializer3:
+ username_field = 'user'
+ else:
+ username_field = 'username'
+
+ # Note: The passwords are not in order. Default ordering is used.
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 1, username_field: 'userA'}, # PassB
+ {'id': 2, username_field: 'userB'}, # PassC
+ {'id': 3, username_field: 'userC'}, # PassA
+ ]
+ )
diff --git a/tests/test_genericrelations.py b/tests/test_genericrelations.py
new file mode 100644
index 00000000..2d341344
--- /dev/null
+++ b/tests/test_genericrelations.py
@@ -0,0 +1,129 @@
+from __future__ import unicode_literals
+from django.contrib.contenttypes.models import ContentType
+from django.contrib.contenttypes.generic import GenericRelation, GenericForeignKey
+from django.db import models
+from django.test import TestCase
+from rest_framework import serializers
+
+
+class Tag(models.Model):
+ """
+ Tags have a descriptive slug, and are attached to an arbitrary object.
+ """
+ tag = models.SlugField()
+ content_type = models.ForeignKey(ContentType)
+ object_id = models.PositiveIntegerField()
+ tagged_item = GenericForeignKey('content_type', 'object_id')
+
+ def __unicode__(self):
+ return self.tag
+
+
+class Bookmark(models.Model):
+ """
+ A URL bookmark that may have multiple tags attached.
+ """
+ url = models.URLField()
+ tags = GenericRelation(Tag)
+
+ def __unicode__(self):
+ return 'Bookmark: %s' % self.url
+
+
+class Note(models.Model):
+ """
+ A textual note that may have multiple tags attached.
+ """
+ text = models.TextField()
+ tags = GenericRelation(Tag)
+
+ def __unicode__(self):
+ return 'Note: %s' % self.text
+
+
+class TestGenericRelations(TestCase):
+ def setUp(self):
+ self.bookmark = Bookmark.objects.create(url='https://www.djangoproject.com/')
+ Tag.objects.create(tagged_item=self.bookmark, tag='django')
+ Tag.objects.create(tagged_item=self.bookmark, tag='python')
+ self.note = Note.objects.create(text='Remember the milk')
+ Tag.objects.create(tagged_item=self.note, tag='reminder')
+
+ def test_generic_relation(self):
+ """
+ Test a relationship that spans a GenericRelation field.
+ IE. A reverse generic relationship.
+ """
+
+ class BookmarkSerializer(serializers.ModelSerializer):
+ tags = serializers.RelatedField(many=True)
+
+ class Meta:
+ model = Bookmark
+ exclude = ('id',)
+
+ serializer = BookmarkSerializer(self.bookmark)
+ expected = {
+ 'tags': ['django', 'python'],
+ 'url': 'https://www.djangoproject.com/'
+ }
+ self.assertEqual(serializer.data, expected)
+
+ def test_generic_nested_relation(self):
+ """
+ Test saving a GenericRelation field via a nested serializer.
+ """
+
+ class TagSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Tag
+ exclude = ('content_type', 'object_id')
+
+ class BookmarkSerializer(serializers.ModelSerializer):
+ tags = TagSerializer()
+
+ class Meta:
+ model = Bookmark
+ exclude = ('id',)
+
+ data = {
+ 'url': 'https://docs.djangoproject.com/',
+ 'tags': [
+ {'tag': 'contenttypes'},
+ {'tag': 'genericrelations'},
+ ]
+ }
+ serializer = BookmarkSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEqual(serializer.object.tags.count(), 2)
+
+ def test_generic_fk(self):
+ """
+ Test a relationship that spans a GenericForeignKey field.
+ IE. A forward generic relationship.
+ """
+
+ class TagSerializer(serializers.ModelSerializer):
+ tagged_item = serializers.RelatedField()
+
+ class Meta:
+ model = Tag
+ exclude = ('id', 'content_type', 'object_id')
+
+ serializer = TagSerializer(Tag.objects.all(), many=True)
+ expected = [
+ {
+ 'tag': 'django',
+ 'tagged_item': 'Bookmark: https://www.djangoproject.com/'
+ },
+ {
+ 'tag': 'python',
+ 'tagged_item': 'Bookmark: https://www.djangoproject.com/'
+ },
+ {
+ 'tag': 'reminder',
+ 'tagged_item': 'Note: Remember the milk'
+ }
+ ]
+ self.assertEqual(serializer.data, expected)
diff --git a/tests/test_generics.py b/tests/test_generics.py
new file mode 100644
index 00000000..4389994a
--- /dev/null
+++ b/tests/test_generics.py
@@ -0,0 +1,609 @@
+from __future__ import unicode_literals
+from django.db import models
+from django.shortcuts import get_object_or_404
+from django.test import TestCase
+from rest_framework import generics, renderers, serializers, status
+from rest_framework.test import APIRequestFactory
+from tests.models import BasicModel, Comment, SlugBasedModel
+from rest_framework.compat import six
+
+factory = APIRequestFactory()
+
+
+class RootView(generics.ListCreateAPIView):
+ """
+ Example description for OPTIONS.
+ """
+ model = BasicModel
+
+
+class InstanceView(generics.RetrieveUpdateDestroyAPIView):
+ """
+ Example description for OPTIONS.
+ """
+ model = BasicModel
+
+ def get_queryset(self):
+ queryset = super(InstanceView, self).get_queryset()
+ return queryset.exclude(text='filtered out')
+
+
+class SlugSerializer(serializers.ModelSerializer):
+ slug = serializers.Field() # read only
+
+ class Meta:
+ model = SlugBasedModel
+ exclude = ('id',)
+
+
+class SlugBasedInstanceView(InstanceView):
+ """
+ A model with a slug-field.
+ """
+ model = SlugBasedModel
+ serializer_class = SlugSerializer
+ lookup_field = 'slug'
+
+
+class TestRootView(TestCase):
+ def setUp(self):
+ """
+ Create 3 BasicModel instances.
+ """
+ items = ['foo', 'bar', 'baz']
+ for item in items:
+ BasicModel(text=item).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.view = RootView.as_view()
+
+ def test_get_root_view(self):
+ """
+ GET requests to ListCreateAPIView should return list of objects.
+ """
+ request = factory.get('/')
+ with self.assertNumQueries(1):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
+
+ def test_post_root_view(self):
+ """
+ POST requests to ListCreateAPIView should create a new object.
+ """
+ data = {'text': 'foobar'}
+ request = factory.post('/', data, format='json')
+ with self.assertNumQueries(1):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ self.assertEqual(response.data, {'id': 4, 'text': 'foobar'})
+ created = self.objects.get(id=4)
+ self.assertEqual(created.text, 'foobar')
+
+ def test_put_root_view(self):
+ """
+ PUT requests to ListCreateAPIView should not be allowed
+ """
+ data = {'text': 'foobar'}
+ request = factory.put('/', data, format='json')
+ with self.assertNumQueries(0):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
+ self.assertEqual(response.data, {"detail": "Method 'PUT' not allowed."})
+
+ def test_delete_root_view(self):
+ """
+ DELETE requests to ListCreateAPIView should not be allowed
+ """
+ request = factory.delete('/')
+ with self.assertNumQueries(0):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
+ self.assertEqual(response.data, {"detail": "Method 'DELETE' not allowed."})
+
+ def test_options_root_view(self):
+ """
+ OPTIONS requests to ListCreateAPIView should return metadata
+ """
+ request = factory.options('/')
+ with self.assertNumQueries(0):
+ response = self.view(request).render()
+ expected = {
+ 'parses': [
+ 'application/json',
+ 'application/x-www-form-urlencoded',
+ 'multipart/form-data'
+ ],
+ 'renders': [
+ 'application/json',
+ 'text/html'
+ ],
+ 'name': 'Root',
+ 'description': 'Example description for OPTIONS.',
+ 'actions': {
+ 'POST': {
+ 'text': {
+ 'max_length': 100,
+ 'read_only': False,
+ 'required': True,
+ 'type': 'string',
+ "label": "Text comes here",
+ "help_text": "Text description."
+ },
+ 'id': {
+ 'read_only': True,
+ 'required': False,
+ 'type': 'integer',
+ 'label': 'ID',
+ },
+ }
+ }
+ }
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, expected)
+
+ def test_post_cannot_set_id(self):
+ """
+ POST requests to create a new object should not be able to set the id.
+ """
+ data = {'id': 999, 'text': 'foobar'}
+ request = factory.post('/', data, format='json')
+ with self.assertNumQueries(1):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ self.assertEqual(response.data, {'id': 4, 'text': 'foobar'})
+ created = self.objects.get(id=4)
+ self.assertEqual(created.text, 'foobar')
+
+
+class TestInstanceView(TestCase):
+ def setUp(self):
+ """
+ Create 3 BasicModel intances.
+ """
+ items = ['foo', 'bar', 'baz', 'filtered out']
+ for item in items:
+ BasicModel(text=item).save()
+ self.objects = BasicModel.objects.exclude(text='filtered out')
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.view = InstanceView.as_view()
+ self.slug_based_view = SlugBasedInstanceView.as_view()
+
+ def test_get_instance_view(self):
+ """
+ GET requests to RetrieveUpdateDestroyAPIView should return a single object.
+ """
+ request = factory.get('/1')
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data[0])
+
+ def test_post_instance_view(self):
+ """
+ POST requests to RetrieveUpdateDestroyAPIView should not be allowed
+ """
+ data = {'text': 'foobar'}
+ request = factory.post('/', data, format='json')
+ with self.assertNumQueries(0):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
+ self.assertEqual(response.data, {"detail": "Method 'POST' not allowed."})
+
+ def test_put_instance_view(self):
+ """
+ PUT requests to RetrieveUpdateDestroyAPIView should update an object.
+ """
+ data = {'text': 'foobar'}
+ request = factory.put('/1', data, format='json')
+ with self.assertNumQueries(2):
+ response = self.view(request, pk='1').render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
+ updated = self.objects.get(id=1)
+ self.assertEqual(updated.text, 'foobar')
+
+ def test_patch_instance_view(self):
+ """
+ PATCH requests to RetrieveUpdateDestroyAPIView should update an object.
+ """
+ data = {'text': 'foobar'}
+ request = factory.patch('/1', data, format='json')
+
+ with self.assertNumQueries(2):
+ response = self.view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
+ updated = self.objects.get(id=1)
+ self.assertEqual(updated.text, 'foobar')
+
+ def test_delete_instance_view(self):
+ """
+ DELETE requests to RetrieveUpdateDestroyAPIView should delete an object.
+ """
+ request = factory.delete('/1')
+ with self.assertNumQueries(2):
+ response = self.view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
+ self.assertEqual(response.content, six.b(''))
+ ids = [obj.id for obj in self.objects.all()]
+ self.assertEqual(ids, [2, 3])
+
+ def test_options_instance_view(self):
+ """
+ OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata
+ """
+ request = factory.options('/1')
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=1).render()
+ expected = {
+ 'parses': [
+ 'application/json',
+ 'application/x-www-form-urlencoded',
+ 'multipart/form-data'
+ ],
+ 'renders': [
+ 'application/json',
+ 'text/html'
+ ],
+ 'name': 'Instance',
+ 'description': 'Example description for OPTIONS.',
+ 'actions': {
+ 'PUT': {
+ 'text': {
+ 'max_length': 100,
+ 'read_only': False,
+ 'required': True,
+ 'type': 'string',
+ 'label': 'Text comes here',
+ 'help_text': 'Text description.'
+ },
+ 'id': {
+ 'read_only': True,
+ 'required': False,
+ 'type': 'integer',
+ 'label': 'ID',
+ },
+ }
+ }
+ }
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, expected)
+
+ def test_options_before_instance_create(self):
+ """
+ OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata
+ before the instance has been created
+ """
+ request = factory.options('/999')
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=999).render()
+ expected = {
+ 'parses': [
+ 'application/json',
+ 'application/x-www-form-urlencoded',
+ 'multipart/form-data'
+ ],
+ 'renders': [
+ 'application/json',
+ 'text/html'
+ ],
+ 'name': 'Instance',
+ 'description': 'Example description for OPTIONS.',
+ 'actions': {
+ 'PUT': {
+ 'text': {
+ 'max_length': 100,
+ 'read_only': False,
+ 'required': True,
+ 'type': 'string',
+ 'label': 'Text comes here',
+ 'help_text': 'Text description.'
+ },
+ 'id': {
+ 'read_only': True,
+ 'required': False,
+ 'type': 'integer',
+ 'label': 'ID',
+ },
+ }
+ }
+ }
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, expected)
+
+ def test_get_instance_view_incorrect_arg(self):
+ """
+ GET requests with an incorrect pk type, should raise 404, not 500.
+ Regression test for #890.
+ """
+ request = factory.get('/a')
+ with self.assertNumQueries(0):
+ response = self.view(request, pk='a').render()
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
+ def test_put_cannot_set_id(self):
+ """
+ PUT requests to create a new object should not be able to set the id.
+ """
+ data = {'id': 999, 'text': 'foobar'}
+ request = factory.put('/1', data, format='json')
+ with self.assertNumQueries(2):
+ response = self.view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
+ updated = self.objects.get(id=1)
+ self.assertEqual(updated.text, 'foobar')
+
+ def test_put_to_deleted_instance(self):
+ """
+ PUT requests to RetrieveUpdateDestroyAPIView should create an object
+ if it does not currently exist.
+ """
+ self.objects.get(id=1).delete()
+ data = {'text': 'foobar'}
+ request = factory.put('/1', data, format='json')
+ with self.assertNumQueries(3):
+ response = self.view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
+ updated = self.objects.get(id=1)
+ self.assertEqual(updated.text, 'foobar')
+
+ def test_put_to_filtered_out_instance(self):
+ """
+ PUT requests to an URL of instance which is filtered out should not be
+ able to create new objects.
+ """
+ data = {'text': 'foo'}
+ filtered_out_pk = BasicModel.objects.filter(text='filtered out')[0].pk
+ request = factory.put('/{0}'.format(filtered_out_pk), data, format='json')
+ response = self.view(request, pk=filtered_out_pk).render()
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+
+ def test_put_as_create_on_id_based_url(self):
+ """
+ PUT requests to RetrieveUpdateDestroyAPIView should create an object
+ at the requested url if it doesn't exist.
+ """
+ data = {'text': 'foobar'}
+ # pk fields can not be created on demand, only the database can set the pk for a new object
+ request = factory.put('/5', data, format='json')
+ with self.assertNumQueries(3):
+ response = self.view(request, pk=5).render()
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ new_obj = self.objects.get(pk=5)
+ self.assertEqual(new_obj.text, 'foobar')
+
+ def test_put_as_create_on_slug_based_url(self):
+ """
+ PUT requests to RetrieveUpdateDestroyAPIView should create an object
+ at the requested url if possible, else return HTTP_403_FORBIDDEN error-response.
+ """
+ data = {'text': 'foobar'}
+ request = factory.put('/test_slug', data, format='json')
+ with self.assertNumQueries(2):
+ response = self.slug_based_view(request, slug='test_slug').render()
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ self.assertEqual(response.data, {'slug': 'test_slug', 'text': 'foobar'})
+ new_obj = SlugBasedModel.objects.get(slug='test_slug')
+ self.assertEqual(new_obj.text, 'foobar')
+
+ def test_patch_cannot_create_an_object(self):
+ """
+ PATCH requests should not be able to create objects.
+ """
+ data = {'text': 'foobar'}
+ request = factory.patch('/999', data, format='json')
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=999).render()
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+ self.assertFalse(self.objects.filter(id=999).exists())
+
+
+class TestOverriddenGetObject(TestCase):
+ """
+ Test cases for a RetrieveUpdateDestroyAPIView that does NOT use the
+ queryset/model mechanism but instead overrides get_object()
+ """
+ def setUp(self):
+ """
+ Create 3 BasicModel intances.
+ """
+ items = ['foo', 'bar', 'baz']
+ for item in items:
+ BasicModel(text=item).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+
+ class OverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView):
+ """
+ Example detail view for override of get_object().
+ """
+ model = BasicModel
+
+ def get_object(self):
+ pk = int(self.kwargs['pk'])
+ return get_object_or_404(BasicModel.objects.all(), id=pk)
+
+ self.view = OverriddenGetObjectView.as_view()
+
+ def test_overridden_get_object_view(self):
+ """
+ GET requests to RetrieveUpdateDestroyAPIView should return a single object.
+ """
+ request = factory.get('/1')
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data[0])
+
+
+# Regression test for #285
+
+class CommentSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Comment
+ exclude = ('created',)
+
+
+class CommentView(generics.ListCreateAPIView):
+ serializer_class = CommentSerializer
+ model = Comment
+
+
+class TestCreateModelWithAutoNowAddField(TestCase):
+ def setUp(self):
+ self.objects = Comment.objects
+ self.view = CommentView.as_view()
+
+ def test_create_model_with_auto_now_add_field(self):
+ """
+ Regression test for #285
+
+ https://github.com/tomchristie/django-rest-framework/issues/285
+ """
+ data = {'email': 'foobar@example.com', 'content': 'foobar'}
+ request = factory.post('/', data, format='json')
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ created = self.objects.get(id=1)
+ self.assertEqual(created.content, 'foobar')
+
+
+# Test for particularly ugly regression with m2m in browsable API
+class ClassB(models.Model):
+ name = models.CharField(max_length=255)
+
+
+class ClassA(models.Model):
+ name = models.CharField(max_length=255)
+ childs = models.ManyToManyField(ClassB, blank=True, null=True)
+
+
+class ClassASerializer(serializers.ModelSerializer):
+ childs = serializers.PrimaryKeyRelatedField(many=True, source='childs')
+
+ class Meta:
+ model = ClassA
+
+
+class ExampleView(generics.ListCreateAPIView):
+ serializer_class = ClassASerializer
+ model = ClassA
+
+
+class TestM2MBrowseableAPI(TestCase):
+ def test_m2m_in_browseable_api(self):
+ """
+ Test for particularly ugly regression with m2m in browsable API
+ """
+ request = factory.get('/', HTTP_ACCEPT='text/html')
+ view = ExampleView().as_view()
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+
+class InclusiveFilterBackend(object):
+ def filter_queryset(self, request, queryset, view):
+ return queryset.filter(text='foo')
+
+
+class ExclusiveFilterBackend(object):
+ def filter_queryset(self, request, queryset, view):
+ return queryset.filter(text='other')
+
+
+class TwoFieldModel(models.Model):
+ field_a = models.CharField(max_length=100)
+ field_b = models.CharField(max_length=100)
+
+
+class DynamicSerializerView(generics.ListCreateAPIView):
+ model = TwoFieldModel
+ renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer)
+
+ def get_serializer_class(self):
+ if self.request.method == 'POST':
+ class DynamicSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = TwoFieldModel
+ fields = ('field_b',)
+ return DynamicSerializer
+ return super(DynamicSerializerView, self).get_serializer_class()
+
+
+class TestFilterBackendAppliedToViews(TestCase):
+
+ def setUp(self):
+ """
+ Create 3 BasicModel instances to filter on.
+ """
+ items = ['foo', 'bar', 'baz']
+ for item in items:
+ BasicModel(text=item).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+
+ def test_get_root_view_filters_by_name_with_filter_backend(self):
+ """
+ GET requests to ListCreateAPIView should return filtered list.
+ """
+ root_view = RootView.as_view(filter_backends=(InclusiveFilterBackend,))
+ request = factory.get('/')
+ response = root_view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(len(response.data), 1)
+ self.assertEqual(response.data, [{'id': 1, 'text': 'foo'}])
+
+ def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(self):
+ """
+ GET requests to ListCreateAPIView should return empty list when all models are filtered out.
+ """
+ root_view = RootView.as_view(filter_backends=(ExclusiveFilterBackend,))
+ request = factory.get('/')
+ response = root_view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, [])
+
+ def test_get_instance_view_filters_out_name_with_filter_backend(self):
+ """
+ GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out.
+ """
+ instance_view = InstanceView.as_view(filter_backends=(ExclusiveFilterBackend,))
+ request = factory.get('/1')
+ response = instance_view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+ self.assertEqual(response.data, {'detail': 'Not found'})
+
+ def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self):
+ """
+ GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded
+ """
+ instance_view = InstanceView.as_view(filter_backends=(InclusiveFilterBackend,))
+ request = factory.get('/1')
+ response = instance_view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, {'id': 1, 'text': 'foo'})
+
+ def test_dynamic_serializer_form_in_browsable_api(self):
+ """
+ GET requests to ListCreateAPIView should return filtered list.
+ """
+ view = DynamicSerializerView.as_view()
+ request = factory.get('/')
+ response = view(request).render()
+ self.assertContains(response, 'field_b')
+ self.assertNotContains(response, 'field_a')
diff --git a/tests/test_htmlrenderer.py b/tests/test_htmlrenderer.py
new file mode 100644
index 00000000..c748fbdb
--- /dev/null
+++ b/tests/test_htmlrenderer.py
@@ -0,0 +1,118 @@
+from __future__ import unicode_literals
+from django.core.exceptions import PermissionDenied
+from django.http import Http404
+from django.test import TestCase
+from django.template import TemplateDoesNotExist, Template
+import django.template.loader
+from rest_framework import status
+from rest_framework.compat import patterns, url
+from rest_framework.decorators import api_view, renderer_classes
+from rest_framework.renderers import TemplateHTMLRenderer
+from rest_framework.response import Response
+from rest_framework.compat import six
+
+
+@api_view(('GET',))
+@renderer_classes((TemplateHTMLRenderer,))
+def example(request):
+ """
+ A view that can returns an HTML representation.
+ """
+ data = {'object': 'foobar'}
+ return Response(data, template_name='example.html')
+
+
+@api_view(('GET',))
+@renderer_classes((TemplateHTMLRenderer,))
+def permission_denied(request):
+ raise PermissionDenied()
+
+
+@api_view(('GET',))
+@renderer_classes((TemplateHTMLRenderer,))
+def not_found(request):
+ raise Http404()
+
+
+urlpatterns = patterns('',
+ url(r'^$', example),
+ url(r'^permission_denied$', permission_denied),
+ url(r'^not_found$', not_found),
+)
+
+
+class TemplateHTMLRendererTests(TestCase):
+ urls = 'tests.test_htmlrenderer'
+
+ def setUp(self):
+ """
+ Monkeypatch get_template
+ """
+ self.get_template = django.template.loader.get_template
+
+ def get_template(template_name):
+ if template_name == 'example.html':
+ return Template("example: {{ object }}")
+ raise TemplateDoesNotExist(template_name)
+
+ django.template.loader.get_template = get_template
+
+ def tearDown(self):
+ """
+ Revert monkeypatching
+ """
+ django.template.loader.get_template = self.get_template
+
+ def test_simple_html_view(self):
+ response = self.client.get('/')
+ self.assertContains(response, "example: foobar")
+ self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
+
+ def test_not_found_html_view(self):
+ response = self.client.get('/not_found')
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+ self.assertEqual(response.content, six.b("404 Not Found"))
+ self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
+
+ def test_permission_denied_html_view(self):
+ response = self.client.get('/permission_denied')
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+ self.assertEqual(response.content, six.b("403 Forbidden"))
+ self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
+
+
+class TemplateHTMLRendererExceptionTests(TestCase):
+ urls = 'tests.test_htmlrenderer'
+
+ def setUp(self):
+ """
+ Monkeypatch get_template
+ """
+ self.get_template = django.template.loader.get_template
+
+ def get_template(template_name):
+ if template_name == '404.html':
+ return Template("404: {{ detail }}")
+ if template_name == '403.html':
+ return Template("403: {{ detail }}")
+ raise TemplateDoesNotExist(template_name)
+
+ django.template.loader.get_template = get_template
+
+ def tearDown(self):
+ """
+ Revert monkeypatching
+ """
+ django.template.loader.get_template = self.get_template
+
+ def test_not_found_html_view_with_template(self):
+ response = self.client.get('/not_found')
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+ self.assertEqual(response.content, six.b("404: Not found"))
+ self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
+
+ def test_permission_denied_html_view_with_template(self):
+ response = self.client.get('/permission_denied')
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+ self.assertEqual(response.content, six.b("403: Permission denied"))
+ self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
diff --git a/tests/test_hyperlinkedserializers.py b/tests/test_hyperlinkedserializers.py
new file mode 100644
index 00000000..eee179ca
--- /dev/null
+++ b/tests/test_hyperlinkedserializers.py
@@ -0,0 +1,379 @@
+from __future__ import unicode_literals
+import json
+from django.test import TestCase
+from rest_framework import generics, status, serializers
+from rest_framework.compat import patterns, url
+from rest_framework.settings import api_settings
+from rest_framework.test import APIRequestFactory
+from tests.models import (
+ Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment,
+ Album, Photo, OptionalRelationModel
+)
+
+factory = APIRequestFactory()
+
+
+class BlogPostCommentSerializer(serializers.ModelSerializer):
+ url = serializers.HyperlinkedIdentityField(view_name='blogpostcomment-detail')
+ text = serializers.CharField()
+ blog_post_url = serializers.HyperlinkedRelatedField(source='blog_post', view_name='blogpost-detail')
+
+ class Meta:
+ model = BlogPostComment
+ fields = ('text', 'blog_post_url', 'url')
+
+
+class PhotoSerializer(serializers.Serializer):
+ description = serializers.CharField()
+ album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', queryset=Album.objects.all(), lookup_field='title', slug_url_kwarg='title')
+
+ def restore_object(self, attrs, instance=None):
+ return Photo(**attrs)
+
+
+class AlbumSerializer(serializers.ModelSerializer):
+ url = serializers.HyperlinkedIdentityField(view_name='album-detail', lookup_field='title')
+
+ class Meta:
+ model = Album
+ fields = ('title', 'url')
+
+
+class BasicList(generics.ListCreateAPIView):
+ model = BasicModel
+ model_serializer_class = serializers.HyperlinkedModelSerializer
+
+
+class BasicDetail(generics.RetrieveUpdateDestroyAPIView):
+ model = BasicModel
+ model_serializer_class = serializers.HyperlinkedModelSerializer
+
+
+class AnchorDetail(generics.RetrieveAPIView):
+ model = Anchor
+ model_serializer_class = serializers.HyperlinkedModelSerializer
+
+
+class ManyToManyList(generics.ListAPIView):
+ model = ManyToManyModel
+ model_serializer_class = serializers.HyperlinkedModelSerializer
+
+
+class ManyToManyDetail(generics.RetrieveAPIView):
+ model = ManyToManyModel
+ model_serializer_class = serializers.HyperlinkedModelSerializer
+
+
+class BlogPostCommentListCreate(generics.ListCreateAPIView):
+ model = BlogPostComment
+ serializer_class = BlogPostCommentSerializer
+
+
+class BlogPostCommentDetail(generics.RetrieveAPIView):
+ model = BlogPostComment
+ serializer_class = BlogPostCommentSerializer
+
+
+class BlogPostDetail(generics.RetrieveAPIView):
+ model = BlogPost
+
+
+class PhotoListCreate(generics.ListCreateAPIView):
+ model = Photo
+ model_serializer_class = PhotoSerializer
+
+
+class AlbumDetail(generics.RetrieveAPIView):
+ model = Album
+ serializer_class = AlbumSerializer
+ lookup_field = 'title'
+
+
+class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView):
+ model = OptionalRelationModel
+ model_serializer_class = serializers.HyperlinkedModelSerializer
+
+
+urlpatterns = patterns('',
+ url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'),
+ url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'),
+ url(r'^anchor/(?P<pk>\d+)/$', AnchorDetail.as_view(), name='anchor-detail'),
+ url(r'^manytomany/$', ManyToManyList.as_view(), name='manytomanymodel-list'),
+ url(r'^manytomany/(?P<pk>\d+)/$', ManyToManyDetail.as_view(), name='manytomanymodel-detail'),
+ url(r'^posts/(?P<pk>\d+)/$', BlogPostDetail.as_view(), name='blogpost-detail'),
+ url(r'^comments/$', BlogPostCommentListCreate.as_view(), name='blogpostcomment-list'),
+ url(r'^comments/(?P<pk>\d+)/$', BlogPostCommentDetail.as_view(), name='blogpostcomment-detail'),
+ url(r'^albums/(?P<title>\w[\w-]*)/$', AlbumDetail.as_view(), name='album-detail'),
+ url(r'^photos/$', PhotoListCreate.as_view(), name='photo-list'),
+ url(r'^optionalrelation/(?P<pk>\d+)/$', OptionalRelationDetail.as_view(), name='optionalrelationmodel-detail'),
+)
+
+
+class TestBasicHyperlinkedView(TestCase):
+ urls = 'tests.test_hyperlinkedserializers'
+
+ def setUp(self):
+ """
+ Create 3 BasicModel instances.
+ """
+ items = ['foo', 'bar', 'baz']
+ for item in items:
+ BasicModel(text=item).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'url': 'http://testserver/basic/%d/' % obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.list_view = BasicList.as_view()
+ self.detail_view = BasicDetail.as_view()
+
+ def test_get_list_view(self):
+ """
+ GET requests to ListCreateAPIView should return list of objects.
+ """
+ request = factory.get('/basic/')
+ response = self.list_view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
+
+ def test_get_detail_view(self):
+ """
+ GET requests to ListCreateAPIView should return list of objects.
+ """
+ request = factory.get('/basic/1')
+ response = self.detail_view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data[0])
+
+
+class TestManyToManyHyperlinkedView(TestCase):
+ urls = 'tests.test_hyperlinkedserializers'
+
+ def setUp(self):
+ """
+ Create 3 BasicModel instances.
+ """
+ items = ['foo', 'bar', 'baz']
+ anchors = []
+ for item in items:
+ anchor = Anchor(text=item)
+ anchor.save()
+ anchors.append(anchor)
+
+ manytomany = ManyToManyModel()
+ manytomany.save()
+ manytomany.rel.add(*anchors)
+
+ self.data = [{
+ 'url': 'http://testserver/manytomany/1/',
+ 'rel': [
+ 'http://testserver/anchor/1/',
+ 'http://testserver/anchor/2/',
+ 'http://testserver/anchor/3/',
+ ]
+ }]
+ self.list_view = ManyToManyList.as_view()
+ self.detail_view = ManyToManyDetail.as_view()
+
+ def test_get_list_view(self):
+ """
+ GET requests to ListCreateAPIView should return list of objects.
+ """
+ request = factory.get('/manytomany/')
+ response = self.list_view(request)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
+
+ def test_get_detail_view(self):
+ """
+ GET requests to ListCreateAPIView should return list of objects.
+ """
+ request = factory.get('/manytomany/1/')
+ response = self.detail_view(request, pk=1)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data[0])
+
+
+class TestHyperlinkedIdentityFieldLookup(TestCase):
+ urls = 'tests.test_hyperlinkedserializers'
+
+ def setUp(self):
+ """
+ Create 3 Album instances.
+ """
+ titles = ['foo', 'bar', 'baz']
+ for title in titles:
+ album = Album(title=title)
+ album.save()
+ self.detail_view = AlbumDetail.as_view()
+ self.data = {
+ 'foo': {'title': 'foo', 'url': 'http://testserver/albums/foo/'},
+ 'bar': {'title': 'bar', 'url': 'http://testserver/albums/bar/'},
+ 'baz': {'title': 'baz', 'url': 'http://testserver/albums/baz/'}
+ }
+
+ def test_lookup_field(self):
+ """
+ GET requests to AlbumDetail view should return serialized Albums
+ with a url field keyed by `title`.
+ """
+ for album in Album.objects.all():
+ request = factory.get('/albums/{0}/'.format(album.title))
+ response = self.detail_view(request, title=album.title)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data[album.title])
+
+
+class TestCreateWithForeignKeys(TestCase):
+ urls = 'tests.test_hyperlinkedserializers'
+
+ def setUp(self):
+ """
+ Create a blog post
+ """
+ self.post = BlogPost.objects.create(title="Test post")
+ self.create_view = BlogPostCommentListCreate.as_view()
+
+ def test_create_comment(self):
+
+ data = {
+ 'text': 'A test comment',
+ 'blog_post_url': 'http://testserver/posts/1/'
+ }
+
+ request = factory.post('/comments/', data=data)
+ response = self.create_view(request)
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ self.assertEqual(response['Location'], 'http://testserver/comments/1/')
+ self.assertEqual(self.post.blogpostcomment_set.count(), 1)
+ self.assertEqual(self.post.blogpostcomment_set.all()[0].text, 'A test comment')
+
+
+class TestCreateWithForeignKeysAndCustomSlug(TestCase):
+ urls = 'tests.test_hyperlinkedserializers'
+
+ def setUp(self):
+ """
+ Create an Album
+ """
+ self.post = Album.objects.create(title='test-album')
+ self.list_create_view = PhotoListCreate.as_view()
+
+ def test_create_photo(self):
+
+ data = {
+ 'description': 'A test photo',
+ 'album_url': 'http://testserver/albums/test-album/'
+ }
+
+ request = factory.post('/photos/', data=data)
+ response = self.list_create_view(request)
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ self.assertNotIn('Location', response, msg='Location should only be included if there is a "url" field on the serializer')
+ self.assertEqual(self.post.photo_set.count(), 1)
+ self.assertEqual(self.post.photo_set.all()[0].description, 'A test photo')
+
+
+class TestOptionalRelationHyperlinkedView(TestCase):
+ urls = 'tests.test_hyperlinkedserializers'
+
+ def setUp(self):
+ """
+ Create 1 OptionalRelationModel instances.
+ """
+ OptionalRelationModel().save()
+ self.objects = OptionalRelationModel.objects
+ self.detail_view = OptionalRelationDetail.as_view()
+ self.data = {"url": "http://testserver/optionalrelation/1/", "other": None}
+
+ def test_get_detail_view(self):
+ """
+ GET requests to RetrieveAPIView with optional relations should return None
+ for non existing relations.
+ """
+ request = factory.get('/optionalrelationmodel-detail/1')
+ response = self.detail_view(request, pk=1)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
+
+ def test_put_detail_view(self):
+ """
+ PUT requests to RetrieveUpdateDestroyAPIView with optional relations
+ should accept None for non existing relations.
+ """
+ response = self.client.put('/optionalrelation/1/',
+ data=json.dumps(self.data),
+ content_type='application/json')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+
+class TestOverriddenURLField(TestCase):
+ def setUp(self):
+ class OverriddenURLSerializer(serializers.HyperlinkedModelSerializer):
+ url = serializers.SerializerMethodField('get_url')
+
+ class Meta:
+ model = BlogPost
+ fields = ('title', 'url')
+
+ def get_url(self, obj):
+ return 'foo bar'
+
+ self.Serializer = OverriddenURLSerializer
+ self.obj = BlogPost.objects.create(title='New blog post')
+
+ def test_overridden_url_field(self):
+ """
+ The 'url' field should respect overriding.
+ Regression test for #936.
+ """
+ serializer = self.Serializer(self.obj)
+ self.assertEqual(
+ serializer.data,
+ {'title': 'New blog post', 'url': 'foo bar'}
+ )
+
+
+class TestURLFieldNameBySettings(TestCase):
+ urls = 'tests.test_hyperlinkedserializers'
+
+ def setUp(self):
+ self.saved_url_field_name = api_settings.URL_FIELD_NAME
+ api_settings.URL_FIELD_NAME = 'global_url_field'
+
+ class Serializer(serializers.HyperlinkedModelSerializer):
+
+ class Meta:
+ model = BlogPost
+ fields = ('title', api_settings.URL_FIELD_NAME)
+
+ self.Serializer = Serializer
+ self.obj = BlogPost.objects.create(title="New blog post")
+
+ def tearDown(self):
+ api_settings.URL_FIELD_NAME = self.saved_url_field_name
+
+ def test_overridden_url_field_name(self):
+ request = factory.get('/posts/')
+ serializer = self.Serializer(self.obj, context={'request': request})
+ self.assertIn(api_settings.URL_FIELD_NAME, serializer.data)
+
+
+class TestURLFieldNameByOptions(TestCase):
+ urls = 'tests.test_hyperlinkedserializers'
+
+ def setUp(self):
+ class Serializer(serializers.HyperlinkedModelSerializer):
+
+ class Meta:
+ model = BlogPost
+ fields = ('title', 'serializer_url_field')
+ url_field_name = 'serializer_url_field'
+
+ self.Serializer = Serializer
+ self.obj = BlogPost.objects.create(title="New blog post")
+
+ def test_overridden_url_field_name(self):
+ request = factory.get('/posts/')
+ serializer = self.Serializer(self.obj, context={'request': request})
+ self.assertIn(self.Serializer.Meta.url_field_name, serializer.data)
diff --git a/tests/test_multitable_inheritance.py b/tests/test_multitable_inheritance.py
new file mode 100644
index 00000000..ce1bf3ea
--- /dev/null
+++ b/tests/test_multitable_inheritance.py
@@ -0,0 +1,67 @@
+from __future__ import unicode_literals
+from django.db import models
+from django.test import TestCase
+from rest_framework import serializers
+from tests.models import RESTFrameworkModel
+
+
+# Models
+class ParentModel(RESTFrameworkModel):
+ name1 = models.CharField(max_length=100)
+
+
+class ChildModel(ParentModel):
+ name2 = models.CharField(max_length=100)
+
+
+class AssociatedModel(RESTFrameworkModel):
+ ref = models.OneToOneField(ParentModel, primary_key=True)
+ name = models.CharField(max_length=100)
+
+
+# Serializers
+class DerivedModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ChildModel
+
+
+class AssociatedModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = AssociatedModel
+
+
+# Tests
+class IneritedModelSerializationTests(TestCase):
+
+ def test_multitable_inherited_model_fields_as_expected(self):
+ """
+ Assert that the parent pointer field is not included in the fields
+ serialized fields
+ """
+ child = ChildModel(name1='parent name', name2='child name')
+ serializer = DerivedModelSerializer(child)
+ self.assertEqual(set(serializer.data.keys()),
+ set(['name1', 'name2', 'id']))
+
+ def test_onetoone_primary_key_model_fields_as_expected(self):
+ """
+ Assert that a model with a onetoone field that is the primary key is
+ not treated like a derived model
+ """
+ parent = ParentModel(name1='parent name')
+ associate = AssociatedModel(name='hello', ref=parent)
+ serializer = AssociatedModelSerializer(associate)
+ self.assertEqual(set(serializer.data.keys()),
+ set(['name', 'ref']))
+
+ def test_data_is_valid_without_parent_ptr(self):
+ """
+ Assert that the pointer to the parent table is not a required field
+ for input data
+ """
+ data = {
+ 'name1': 'parent name',
+ 'name2': 'child name',
+ }
+ serializer = DerivedModelSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), True)
diff --git a/tests/test_negotiation.py b/tests/test_negotiation.py
new file mode 100644
index 00000000..04b89eb6
--- /dev/null
+++ b/tests/test_negotiation.py
@@ -0,0 +1,45 @@
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework.negotiation import DefaultContentNegotiation
+from rest_framework.request import Request
+from rest_framework.renderers import BaseRenderer
+from rest_framework.test import APIRequestFactory
+
+
+factory = APIRequestFactory()
+
+
+class MockJSONRenderer(BaseRenderer):
+ media_type = 'application/json'
+
+
+class MockHTMLRenderer(BaseRenderer):
+ media_type = 'text/html'
+
+
+class NoCharsetSpecifiedRenderer(BaseRenderer):
+ media_type = 'my/media'
+
+
+class TestAcceptedMediaType(TestCase):
+ def setUp(self):
+ self.renderers = [MockJSONRenderer(), MockHTMLRenderer()]
+ self.negotiator = DefaultContentNegotiation()
+
+ def select_renderer(self, request):
+ return self.negotiator.select_renderer(request, self.renderers)
+
+ def test_client_without_accept_use_renderer(self):
+ request = Request(factory.get('/'))
+ accepted_renderer, accepted_media_type = self.select_renderer(request)
+ self.assertEqual(accepted_media_type, 'application/json')
+
+ def test_client_underspecifies_accept_use_renderer(self):
+ request = Request(factory.get('/', HTTP_ACCEPT='*/*'))
+ accepted_renderer, accepted_media_type = self.select_renderer(request)
+ self.assertEqual(accepted_media_type, 'application/json')
+
+ def test_client_overspecifies_accept_use_client(self):
+ request = Request(factory.get('/', HTTP_ACCEPT='application/json; indent=8'))
+ accepted_renderer, accepted_media_type = self.select_renderer(request)
+ self.assertEqual(accepted_media_type, 'application/json; indent=8')
diff --git a/tests/test_nullable_fields.py b/tests/test_nullable_fields.py
new file mode 100644
index 00000000..33a9685f
--- /dev/null
+++ b/tests/test_nullable_fields.py
@@ -0,0 +1,30 @@
+from django.core.urlresolvers import reverse
+
+from rest_framework.compat import patterns, url
+from rest_framework.test import APITestCase
+from tests.models import NullableForeignKeySource
+from tests.serializers import NullableFKSourceSerializer
+from tests.views import NullableFKSourceDetail
+
+
+urlpatterns = patterns(
+ '',
+ url(r'^objects/(?P<pk>\d+)/$', NullableFKSourceDetail.as_view(), name='object-detail'),
+)
+
+
+class NullableForeignKeyTests(APITestCase):
+ """
+ DRF should be able to handle nullable foreign keys when a test
+ Client POST/PUT request is made with its own serialized object.
+ """
+ urls = 'tests.test_nullable_fields'
+
+ def test_updating_object_with_null_fk(self):
+ obj = NullableForeignKeySource(name='example', target=None)
+ obj.save()
+ serialized_data = NullableFKSourceSerializer(obj).data
+
+ response = self.client.put(reverse('object-detail', args=[obj.pk]), serialized_data)
+
+ self.assertEqual(response.data, serialized_data)
diff --git a/tests/test_pagination.py b/tests/test_pagination.py
new file mode 100644
index 00000000..65fa9dcd
--- /dev/null
+++ b/tests/test_pagination.py
@@ -0,0 +1,517 @@
+from __future__ import unicode_literals
+import datetime
+from decimal import Decimal
+from django.db import models
+from django.core.paginator import Paginator
+from django.test import TestCase
+from django.utils import unittest
+from rest_framework import generics, status, pagination, filters, serializers
+from rest_framework.compat import django_filters
+from rest_framework.test import APIRequestFactory
+from tests.models import BasicModel
+
+factory = APIRequestFactory()
+
+
+class FilterableItem(models.Model):
+ text = models.CharField(max_length=100)
+ decimal = models.DecimalField(max_digits=4, decimal_places=2)
+ date = models.DateField()
+
+
+class RootView(generics.ListCreateAPIView):
+ """
+ Example description for OPTIONS.
+ """
+ model = BasicModel
+ paginate_by = 10
+
+
+class DefaultPageSizeKwargView(generics.ListAPIView):
+ """
+ View for testing default paginate_by_param usage
+ """
+ model = BasicModel
+
+
+class PaginateByParamView(generics.ListAPIView):
+ """
+ View for testing custom paginate_by_param usage
+ """
+ model = BasicModel
+ paginate_by_param = 'page_size'
+
+
+class MaxPaginateByView(generics.ListAPIView):
+ """
+ View for testing custom max_paginate_by usage
+ """
+ model = BasicModel
+ paginate_by = 3
+ max_paginate_by = 5
+ paginate_by_param = 'page_size'
+
+
+class IntegrationTestPagination(TestCase):
+ """
+ Integration tests for paginated list views.
+ """
+
+ def setUp(self):
+ """
+ Create 26 BasicModel instances.
+ """
+ for char in 'abcdefghijklmnopqrstuvwxyz':
+ BasicModel(text=char * 3).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.view = RootView.as_view()
+
+ def test_get_paginated_root_view(self):
+ """
+ GET requests to paginated ListCreateAPIView should return paginated results.
+ """
+ request = factory.get('/')
+ # Note: Database queries are a `SELECT COUNT`, and `SELECT <fields>`
+ with self.assertNumQueries(2):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 26)
+ self.assertEqual(response.data['results'], self.data[:10])
+ self.assertNotEqual(response.data['next'], None)
+ self.assertEqual(response.data['previous'], None)
+
+ request = factory.get(response.data['next'])
+ with self.assertNumQueries(2):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 26)
+ self.assertEqual(response.data['results'], self.data[10:20])
+ self.assertNotEqual(response.data['next'], None)
+ self.assertNotEqual(response.data['previous'], None)
+
+ request = factory.get(response.data['next'])
+ with self.assertNumQueries(2):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 26)
+ self.assertEqual(response.data['results'], self.data[20:])
+ self.assertEqual(response.data['next'], None)
+ self.assertNotEqual(response.data['previous'], None)
+
+
+class IntegrationTestPaginationAndFiltering(TestCase):
+
+ def setUp(self):
+ """
+ Create 50 FilterableItem instances.
+ """
+ base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
+ for i in range(26):
+ text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
+ decimal = base_data[1] + i
+ date = base_data[2] - datetime.timedelta(days=i * 2)
+ FilterableItem(text=text, decimal=decimal, date=date).save()
+
+ self.objects = FilterableItem.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
+ for obj in self.objects.all()
+ ]
+
+ @unittest.skipUnless(django_filters, 'django-filter not installed')
+ def test_get_django_filter_paginated_filtered_root_view(self):
+ """
+ GET requests to paginated filtered ListCreateAPIView should return
+ paginated results. The next and previous links should preserve the
+ filtered parameters.
+ """
+ class DecimalFilter(django_filters.FilterSet):
+ decimal = django_filters.NumberFilter(lookup_type='lt')
+
+ class Meta:
+ model = FilterableItem
+ fields = ['text', 'decimal', 'date']
+
+ class FilterFieldsRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ paginate_by = 10
+ filter_class = DecimalFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ view = FilterFieldsRootView.as_view()
+
+ EXPECTED_NUM_QUERIES = 2
+
+ request = factory.get('/?decimal=15.20')
+ with self.assertNumQueries(EXPECTED_NUM_QUERIES):
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 15)
+ self.assertEqual(response.data['results'], self.data[:10])
+ self.assertNotEqual(response.data['next'], None)
+ self.assertEqual(response.data['previous'], None)
+
+ request = factory.get(response.data['next'])
+ with self.assertNumQueries(EXPECTED_NUM_QUERIES):
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 15)
+ self.assertEqual(response.data['results'], self.data[10:15])
+ self.assertEqual(response.data['next'], None)
+ self.assertNotEqual(response.data['previous'], None)
+
+ request = factory.get(response.data['previous'])
+ with self.assertNumQueries(EXPECTED_NUM_QUERIES):
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 15)
+ self.assertEqual(response.data['results'], self.data[:10])
+ self.assertNotEqual(response.data['next'], None)
+ self.assertEqual(response.data['previous'], None)
+
+ def test_get_basic_paginated_filtered_root_view(self):
+ """
+ Same as `test_get_django_filter_paginated_filtered_root_view`,
+ except using a custom filter backend instead of the django-filter
+ backend,
+ """
+
+ class DecimalFilterBackend(filters.BaseFilterBackend):
+ def filter_queryset(self, request, queryset, view):
+ return queryset.filter(decimal__lt=Decimal(request.GET['decimal']))
+
+ class BasicFilterFieldsRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ paginate_by = 10
+ filter_backends = (DecimalFilterBackend,)
+
+ view = BasicFilterFieldsRootView.as_view()
+
+ request = factory.get('/?decimal=15.20')
+ with self.assertNumQueries(2):
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 15)
+ self.assertEqual(response.data['results'], self.data[:10])
+ self.assertNotEqual(response.data['next'], None)
+ self.assertEqual(response.data['previous'], None)
+
+ request = factory.get(response.data['next'])
+ with self.assertNumQueries(2):
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 15)
+ self.assertEqual(response.data['results'], self.data[10:15])
+ self.assertEqual(response.data['next'], None)
+ self.assertNotEqual(response.data['previous'], None)
+
+ request = factory.get(response.data['previous'])
+ with self.assertNumQueries(2):
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 15)
+ self.assertEqual(response.data['results'], self.data[:10])
+ self.assertNotEqual(response.data['next'], None)
+ self.assertEqual(response.data['previous'], None)
+
+
+class PassOnContextPaginationSerializer(pagination.PaginationSerializer):
+ class Meta:
+ object_serializer_class = serializers.Serializer
+
+
+class UnitTestPagination(TestCase):
+ """
+ Unit tests for pagination of primitive objects.
+ """
+
+ def setUp(self):
+ self.objects = [char * 3 for char in 'abcdefghijklmnopqrstuvwxyz']
+ paginator = Paginator(self.objects, 10)
+ self.first_page = paginator.page(1)
+ self.last_page = paginator.page(3)
+
+ def test_native_pagination(self):
+ serializer = pagination.PaginationSerializer(self.first_page)
+ self.assertEqual(serializer.data['count'], 26)
+ self.assertEqual(serializer.data['next'], '?page=2')
+ self.assertEqual(serializer.data['previous'], None)
+ self.assertEqual(serializer.data['results'], self.objects[:10])
+
+ serializer = pagination.PaginationSerializer(self.last_page)
+ self.assertEqual(serializer.data['count'], 26)
+ self.assertEqual(serializer.data['next'], None)
+ self.assertEqual(serializer.data['previous'], '?page=2')
+ self.assertEqual(serializer.data['results'], self.objects[20:])
+
+ def test_context_available_in_result(self):
+ """
+ Ensure context gets passed through to the object serializer.
+ """
+ serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'})
+ serializer.data
+ results = serializer.fields[serializer.results_field]
+ self.assertEqual(serializer.context, results.context)
+
+
+class TestUnpaginated(TestCase):
+ """
+ Tests for list views without pagination.
+ """
+
+ def setUp(self):
+ """
+ Create 13 BasicModel instances.
+ """
+ for i in range(13):
+ BasicModel(text=i).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.view = DefaultPageSizeKwargView.as_view()
+
+ def test_unpaginated(self):
+ """
+ Tests the default page size for this view.
+ no page size --> no limit --> no meta data
+ """
+ request = factory.get('/')
+ response = self.view(request)
+ self.assertEqual(response.data, self.data)
+
+
+class TestCustomPaginateByParam(TestCase):
+ """
+ Tests for list views with default page size kwarg
+ """
+
+ def setUp(self):
+ """
+ Create 13 BasicModel instances.
+ """
+ for i in range(13):
+ BasicModel(text=i).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.view = PaginateByParamView.as_view()
+
+ def test_default_page_size(self):
+ """
+ Tests the default page size for this view.
+ no page size --> no limit --> no meta data
+ """
+ request = factory.get('/')
+ response = self.view(request).render()
+ self.assertEqual(response.data, self.data)
+
+ def test_paginate_by_param(self):
+ """
+ If paginate_by_param is set, the new kwarg should limit per view requests.
+ """
+ request = factory.get('/?page_size=5')
+ response = self.view(request).render()
+ self.assertEqual(response.data['count'], 13)
+ self.assertEqual(response.data['results'], self.data[:5])
+
+
+class TestMaxPaginateByParam(TestCase):
+ """
+ Tests for list views with max_paginate_by kwarg
+ """
+
+ def setUp(self):
+ """
+ Create 13 BasicModel instances.
+ """
+ for i in range(13):
+ BasicModel(text=i).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.view = MaxPaginateByView.as_view()
+
+ def test_max_paginate_by(self):
+ """
+ If max_paginate_by is set, it should limit page size for the view.
+ """
+ request = factory.get('/?page_size=10')
+ response = self.view(request).render()
+ self.assertEqual(response.data['count'], 13)
+ self.assertEqual(response.data['results'], self.data[:5])
+
+ def test_max_paginate_by_without_page_size_param(self):
+ """
+ If max_paginate_by is set, but client does not specifiy page_size,
+ standard `paginate_by` behavior should be used.
+ """
+ request = factory.get('/')
+ response = self.view(request).render()
+ self.assertEqual(response.data['results'], self.data[:3])
+
+
+### Tests for context in pagination serializers
+
+class CustomField(serializers.Field):
+ def to_native(self, value):
+ if not 'view' in self.context:
+ raise RuntimeError("context isn't getting passed into custom field")
+ return "value"
+
+
+class BasicModelSerializer(serializers.Serializer):
+ text = CustomField()
+
+ def __init__(self, *args, **kwargs):
+ super(BasicModelSerializer, self).__init__(*args, **kwargs)
+ if not 'view' in self.context:
+ raise RuntimeError("context isn't getting passed into serializer init")
+
+
+class TestContextPassedToCustomField(TestCase):
+ def setUp(self):
+ BasicModel.objects.create(text='ala ma kota')
+
+ def test_with_pagination(self):
+ class ListView(generics.ListCreateAPIView):
+ model = BasicModel
+ serializer_class = BasicModelSerializer
+ paginate_by = 1
+
+ self.view = ListView.as_view()
+ request = factory.get('/')
+ response = self.view(request).render()
+
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+
+### Tests for custom pagination serializers
+
+class LinksSerializer(serializers.Serializer):
+ next = pagination.NextPageField(source='*')
+ prev = pagination.PreviousPageField(source='*')
+
+
+class CustomPaginationSerializer(pagination.BasePaginationSerializer):
+ links = LinksSerializer(source='*') # Takes the page object as the source
+ total_results = serializers.Field(source='paginator.count')
+
+ results_field = 'objects'
+
+
+class TestCustomPaginationSerializer(TestCase):
+ def setUp(self):
+ objects = ['john', 'paul', 'george', 'ringo']
+ paginator = Paginator(objects, 2)
+ self.page = paginator.page(1)
+
+ def test_custom_pagination_serializer(self):
+ request = APIRequestFactory().get('/foobar')
+ serializer = CustomPaginationSerializer(
+ instance=self.page,
+ context={'request': request}
+ )
+ expected = {
+ 'links': {
+ 'next': 'http://testserver/foobar?page=2',
+ 'prev': None
+ },
+ 'total_results': 4,
+ 'objects': ['john', 'paul']
+ }
+ self.assertEqual(serializer.data, expected)
+
+
+class NonIntegerPage(object):
+
+ def __init__(self, paginator, object_list, prev_token, token, next_token):
+ self.paginator = paginator
+ self.object_list = object_list
+ self.prev_token = prev_token
+ self.token = token
+ self.next_token = next_token
+
+ def has_next(self):
+ return not not self.next_token
+
+ def next_page_number(self):
+ return self.next_token
+
+ def has_previous(self):
+ return not not self.prev_token
+
+ def previous_page_number(self):
+ return self.prev_token
+
+
+class NonIntegerPaginator(object):
+
+ def __init__(self, object_list, per_page):
+ self.object_list = object_list
+ self.per_page = per_page
+
+ def count(self):
+ # pretend like we don't know how many pages we have
+ return None
+
+ def page(self, token=None):
+ if token:
+ try:
+ first = self.object_list.index(token)
+ except ValueError:
+ first = 0
+ else:
+ first = 0
+ n = len(self.object_list)
+ last = min(first + self.per_page, n)
+ prev_token = self.object_list[last - (2 * self.per_page)] if first else None
+ next_token = self.object_list[last] if last < n else None
+ return NonIntegerPage(self, self.object_list[first:last], prev_token, token, next_token)
+
+
+class TestNonIntegerPagination(TestCase):
+
+
+ def test_custom_pagination_serializer(self):
+ objects = ['john', 'paul', 'george', 'ringo']
+ paginator = NonIntegerPaginator(objects, 2)
+
+ request = APIRequestFactory().get('/foobar')
+ serializer = CustomPaginationSerializer(
+ instance=paginator.page(),
+ context={'request': request}
+ )
+ expected = {
+ 'links': {
+ 'next': 'http://testserver/foobar?page={0}'.format(objects[2]),
+ 'prev': None
+ },
+ 'total_results': None,
+ 'objects': objects[:2]
+ }
+ self.assertEqual(serializer.data, expected)
+
+ request = APIRequestFactory().get('/foobar')
+ serializer = CustomPaginationSerializer(
+ instance=paginator.page('george'),
+ context={'request': request}
+ )
+ expected = {
+ 'links': {
+ 'next': None,
+ 'prev': 'http://testserver/foobar?page={0}'.format(objects[0]),
+ },
+ 'total_results': None,
+ 'objects': objects[2:]
+ }
+ self.assertEqual(serializer.data, expected)
diff --git a/tests/test_parsers.py b/tests/test_parsers.py
new file mode 100644
index 00000000..7699e10c
--- /dev/null
+++ b/tests/test_parsers.py
@@ -0,0 +1,115 @@
+from __future__ import unicode_literals
+from rest_framework.compat import StringIO
+from django import forms
+from django.core.files.uploadhandler import MemoryFileUploadHandler
+from django.test import TestCase
+from django.utils import unittest
+from rest_framework.compat import etree
+from rest_framework.parsers import FormParser, FileUploadParser
+from rest_framework.parsers import XMLParser
+import datetime
+
+
+class Form(forms.Form):
+ field1 = forms.CharField(max_length=3)
+ field2 = forms.CharField()
+
+
+class TestFormParser(TestCase):
+ def setUp(self):
+ self.string = "field1=abc&field2=defghijk"
+
+ def test_parse(self):
+ """ Make sure the `QueryDict` works OK """
+ parser = FormParser()
+
+ stream = StringIO(self.string)
+ data = parser.parse(stream)
+
+ self.assertEqual(Form(data).is_valid(), True)
+
+
+class TestXMLParser(TestCase):
+ def setUp(self):
+ self._input = StringIO(
+ '<?xml version="1.0" encoding="utf-8"?>'
+ '<root>'
+ '<field_a>121.0</field_a>'
+ '<field_b>dasd</field_b>'
+ '<field_c></field_c>'
+ '<field_d>2011-12-25 12:45:00</field_d>'
+ '</root>'
+ )
+ self._data = {
+ 'field_a': 121,
+ 'field_b': 'dasd',
+ 'field_c': None,
+ 'field_d': datetime.datetime(2011, 12, 25, 12, 45, 00)
+ }
+ self._complex_data_input = StringIO(
+ '<?xml version="1.0" encoding="utf-8"?>'
+ '<root>'
+ '<creation_date>2011-12-25 12:45:00</creation_date>'
+ '<sub_data_list>'
+ '<list-item><sub_id>1</sub_id><sub_name>first</sub_name></list-item>'
+ '<list-item><sub_id>2</sub_id><sub_name>second</sub_name></list-item>'
+ '</sub_data_list>'
+ '<name>name</name>'
+ '</root>'
+ )
+ self._complex_data = {
+ "creation_date": datetime.datetime(2011, 12, 25, 12, 45, 00),
+ "name": "name",
+ "sub_data_list": [
+ {
+ "sub_id": 1,
+ "sub_name": "first"
+ },
+ {
+ "sub_id": 2,
+ "sub_name": "second"
+ }
+ ]
+ }
+
+ @unittest.skipUnless(etree, 'defusedxml not installed')
+ def test_parse(self):
+ parser = XMLParser()
+ data = parser.parse(self._input)
+ self.assertEqual(data, self._data)
+
+ @unittest.skipUnless(etree, 'defusedxml not installed')
+ def test_complex_data_parse(self):
+ parser = XMLParser()
+ data = parser.parse(self._complex_data_input)
+ self.assertEqual(data, self._complex_data)
+
+
+class TestFileUploadParser(TestCase):
+ def setUp(self):
+ class MockRequest(object):
+ pass
+ from io import BytesIO
+ self.stream = BytesIO(
+ "Test text file".encode('utf-8')
+ )
+ request = MockRequest()
+ request.upload_handlers = (MemoryFileUploadHandler(),)
+ request.META = {
+ 'HTTP_CONTENT_DISPOSITION': 'Content-Disposition: inline; filename=file.txt'.encode('utf-8'),
+ 'HTTP_CONTENT_LENGTH': 14,
+ }
+ self.parser_context = {'request': request, 'kwargs': {}}
+
+ def test_parse(self):
+ """ Make sure the `QueryDict` works OK """
+ parser = FileUploadParser()
+ self.stream.seek(0)
+ data_and_files = parser.parse(self.stream, None, self.parser_context)
+ file_obj = data_and_files.files['file']
+ self.assertEqual(file_obj._size, 14)
+
+ def test_get_filename(self):
+ parser = FileUploadParser()
+ filename = parser.get_filename(self.stream, None, self.parser_context)
+ self.assertEqual(filename, 'file.txt'.encode('utf-8'))
diff --git a/tests/test_permissions.py b/tests/test_permissions.py
new file mode 100644
index 00000000..a2cb0c36
--- /dev/null
+++ b/tests/test_permissions.py
@@ -0,0 +1,291 @@
+from __future__ import unicode_literals
+from django.contrib.auth.models import User, Permission, Group
+from django.db import models
+from django.test import TestCase
+from django.utils import unittest
+from rest_framework import generics, status, permissions, authentication, HTTP_HEADER_ENCODING
+from rest_framework.compat import guardian, get_model_name
+from rest_framework.filters import DjangoObjectPermissionsFilter
+from rest_framework.test import APIRequestFactory
+from tests.models import BasicModel
+import base64
+
+factory = APIRequestFactory()
+
+class RootView(generics.ListCreateAPIView):
+ model = BasicModel
+ authentication_classes = [authentication.BasicAuthentication]
+ permission_classes = [permissions.DjangoModelPermissions]
+
+
+class InstanceView(generics.RetrieveUpdateDestroyAPIView):
+ model = BasicModel
+ authentication_classes = [authentication.BasicAuthentication]
+ permission_classes = [permissions.DjangoModelPermissions]
+
+root_view = RootView.as_view()
+instance_view = InstanceView.as_view()
+
+
+def basic_auth_header(username, password):
+ credentials = ('%s:%s' % (username, password))
+ base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING)
+ return 'Basic %s' % base64_credentials
+
+
+class ModelPermissionsIntegrationTests(TestCase):
+ def setUp(self):
+ User.objects.create_user('disallowed', 'disallowed@example.com', 'password')
+ user = User.objects.create_user('permitted', 'permitted@example.com', 'password')
+ user.user_permissions = [
+ Permission.objects.get(codename='add_basicmodel'),
+ Permission.objects.get(codename='change_basicmodel'),
+ Permission.objects.get(codename='delete_basicmodel')
+ ]
+ user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password')
+ user.user_permissions = [
+ Permission.objects.get(codename='change_basicmodel'),
+ ]
+
+ self.permitted_credentials = basic_auth_header('permitted', 'password')
+ self.disallowed_credentials = basic_auth_header('disallowed', 'password')
+ self.updateonly_credentials = basic_auth_header('updateonly', 'password')
+
+ BasicModel(text='foo').save()
+
+ def test_has_create_permissions(self):
+ request = factory.post('/', {'text': 'foobar'}, format='json',
+ HTTP_AUTHORIZATION=self.permitted_credentials)
+ response = root_view(request, pk=1)
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+
+ def test_has_put_permissions(self):
+ request = factory.put('/1', {'text': 'foobar'}, format='json',
+ HTTP_AUTHORIZATION=self.permitted_credentials)
+ response = instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ def test_has_delete_permissions(self):
+ request = factory.delete('/1', HTTP_AUTHORIZATION=self.permitted_credentials)
+ response = instance_view(request, pk=1)
+ self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
+
+ def test_does_not_have_create_permissions(self):
+ request = factory.post('/', {'text': 'foobar'}, format='json',
+ HTTP_AUTHORIZATION=self.disallowed_credentials)
+ response = root_view(request, pk=1)
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+ def test_does_not_have_put_permissions(self):
+ request = factory.put('/1', {'text': 'foobar'}, format='json',
+ HTTP_AUTHORIZATION=self.disallowed_credentials)
+ response = instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+ def test_does_not_have_delete_permissions(self):
+ request = factory.delete('/1', HTTP_AUTHORIZATION=self.disallowed_credentials)
+ response = instance_view(request, pk=1)
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+ def test_has_put_as_create_permissions(self):
+ # User only has update permissions - should be able to update an entity.
+ request = factory.put('/1', {'text': 'foobar'}, format='json',
+ HTTP_AUTHORIZATION=self.updateonly_credentials)
+ response = instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ # But if PUTing to a new entity, permission should be denied.
+ request = factory.put('/2', {'text': 'foobar'}, format='json',
+ HTTP_AUTHORIZATION=self.updateonly_credentials)
+ response = instance_view(request, pk='2')
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+ def test_options_permitted(self):
+ request = factory.options('/',
+ HTTP_AUTHORIZATION=self.permitted_credentials)
+ response = root_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertIn('actions', response.data)
+ self.assertEqual(list(response.data['actions'].keys()), ['POST'])
+
+ request = factory.options('/1',
+ HTTP_AUTHORIZATION=self.permitted_credentials)
+ response = instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertIn('actions', response.data)
+ self.assertEqual(list(response.data['actions'].keys()), ['PUT'])
+
+ def test_options_disallowed(self):
+ request = factory.options('/',
+ HTTP_AUTHORIZATION=self.disallowed_credentials)
+ response = root_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertNotIn('actions', response.data)
+
+ request = factory.options('/1',
+ HTTP_AUTHORIZATION=self.disallowed_credentials)
+ response = instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertNotIn('actions', response.data)
+
+ def test_options_updateonly(self):
+ request = factory.options('/',
+ HTTP_AUTHORIZATION=self.updateonly_credentials)
+ response = root_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertNotIn('actions', response.data)
+
+ request = factory.options('/1',
+ HTTP_AUTHORIZATION=self.updateonly_credentials)
+ response = instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertIn('actions', response.data)
+ self.assertEqual(list(response.data['actions'].keys()), ['PUT'])
+
+
+class BasicPermModel(models.Model):
+ text = models.CharField(max_length=100)
+
+ class Meta:
+ app_label = 'tests'
+ permissions = (
+ ('view_basicpermmodel', 'Can view basic perm model'),
+ # add, change, delete built in to django
+ )
+
+# Custom object-level permission, that includes 'view' permissions
+class ViewObjectPermissions(permissions.DjangoObjectPermissions):
+ perms_map = {
+ 'GET': ['%(app_label)s.view_%(model_name)s'],
+ 'OPTIONS': ['%(app_label)s.view_%(model_name)s'],
+ 'HEAD': ['%(app_label)s.view_%(model_name)s'],
+ 'POST': ['%(app_label)s.add_%(model_name)s'],
+ 'PUT': ['%(app_label)s.change_%(model_name)s'],
+ 'PATCH': ['%(app_label)s.change_%(model_name)s'],
+ 'DELETE': ['%(app_label)s.delete_%(model_name)s'],
+ }
+
+
+class ObjectPermissionInstanceView(generics.RetrieveUpdateDestroyAPIView):
+ model = BasicPermModel
+ authentication_classes = [authentication.BasicAuthentication]
+ permission_classes = [ViewObjectPermissions]
+
+object_permissions_view = ObjectPermissionInstanceView.as_view()
+
+
+class ObjectPermissionListView(generics.ListAPIView):
+ model = BasicPermModel
+ authentication_classes = [authentication.BasicAuthentication]
+ permission_classes = [ViewObjectPermissions]
+
+object_permissions_list_view = ObjectPermissionListView.as_view()
+
+
+@unittest.skipUnless(guardian, 'django-guardian not installed')
+class ObjectPermissionsIntegrationTests(TestCase):
+ """
+ Integration tests for the object level permissions API.
+ """
+ def setUp(self):
+ from guardian.shortcuts import assign_perm
+
+ # create users
+ create = User.objects.create_user
+ users = {
+ 'fullaccess': create('fullaccess', 'fullaccess@example.com', 'password'),
+ 'readonly': create('readonly', 'readonly@example.com', 'password'),
+ 'writeonly': create('writeonly', 'writeonly@example.com', 'password'),
+ 'deleteonly': create('deleteonly', 'deleteonly@example.com', 'password'),
+ }
+
+ # give everyone model level permissions, as we are not testing those
+ everyone = Group.objects.create(name='everyone')
+ model_name = get_model_name(BasicPermModel)
+ app_label = BasicPermModel._meta.app_label
+ f = '{0}_{1}'.format
+ perms = {
+ 'view': f('view', model_name),
+ 'change': f('change', model_name),
+ 'delete': f('delete', model_name)
+ }
+ for perm in perms.values():
+ perm = '{0}.{1}'.format(app_label, perm)
+ assign_perm(perm, everyone)
+ everyone.user_set.add(*users.values())
+
+ # appropriate object level permissions
+ readers = Group.objects.create(name='readers')
+ writers = Group.objects.create(name='writers')
+ deleters = Group.objects.create(name='deleters')
+
+ model = BasicPermModel.objects.create(text='foo')
+
+ assign_perm(perms['view'], readers, model)
+ assign_perm(perms['change'], writers, model)
+ assign_perm(perms['delete'], deleters, model)
+
+ readers.user_set.add(users['fullaccess'], users['readonly'])
+ writers.user_set.add(users['fullaccess'], users['writeonly'])
+ deleters.user_set.add(users['fullaccess'], users['deleteonly'])
+
+ self.credentials = {}
+ for user in users.values():
+ self.credentials[user.username] = basic_auth_header(user.username, 'password')
+
+ # Delete
+ def test_can_delete_permissions(self):
+ request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['deleteonly'])
+ response = object_permissions_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
+
+ def test_cannot_delete_permissions(self):
+ request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
+ response = object_permissions_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+ # Update
+ def test_can_update_permissions(self):
+ request = factory.patch('/1', {'text': 'foobar'}, format='json',
+ HTTP_AUTHORIZATION=self.credentials['writeonly'])
+ response = object_permissions_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data.get('text'), 'foobar')
+
+ def test_cannot_update_permissions(self):
+ request = factory.patch('/1', {'text': 'foobar'}, format='json',
+ HTTP_AUTHORIZATION=self.credentials['deleteonly'])
+ response = object_permissions_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
+ def test_cannot_update_permissions_non_existing(self):
+ request = factory.patch('/999', {'text': 'foobar'}, format='json',
+ HTTP_AUTHORIZATION=self.credentials['deleteonly'])
+ response = object_permissions_view(request, pk='999')
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
+ # Read
+ def test_can_read_permissions(self):
+ request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
+ response = object_permissions_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ def test_cannot_read_permissions(self):
+ request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['writeonly'])
+ response = object_permissions_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
+ # Read list
+ def test_can_read_list_permissions(self):
+ request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['readonly'])
+ object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,)
+ response = object_permissions_list_view(request)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data[0].get('id'), 1)
+
+ def test_cannot_read_list_permissions(self):
+ request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['writeonly'])
+ object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,)
+ response = object_permissions_list_view(request)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertListEqual(response.data, [])
diff --git a/tests/test_relations.py b/tests/test_relations.py
new file mode 100644
index 00000000..bfc8d487
--- /dev/null
+++ b/tests/test_relations.py
@@ -0,0 +1,120 @@
+"""
+General tests for relational fields.
+"""
+from __future__ import unicode_literals
+from django.db import models
+from django.test import TestCase
+from rest_framework import serializers
+from tests.models import BlogPost
+
+
+class NullModel(models.Model):
+ pass
+
+
+class FieldTests(TestCase):
+ def test_pk_related_field_with_empty_string(self):
+ """
+ Regression test for #446
+
+ https://github.com/tomchristie/django-rest-framework/issues/446
+ """
+ field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all())
+ self.assertRaises(serializers.ValidationError, field.from_native, '')
+ self.assertRaises(serializers.ValidationError, field.from_native, [])
+
+ def test_hyperlinked_related_field_with_empty_string(self):
+ field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='')
+ self.assertRaises(serializers.ValidationError, field.from_native, '')
+ self.assertRaises(serializers.ValidationError, field.from_native, [])
+
+ def test_slug_related_field_with_empty_string(self):
+ field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk')
+ self.assertRaises(serializers.ValidationError, field.from_native, '')
+ self.assertRaises(serializers.ValidationError, field.from_native, [])
+
+
+class TestManyRelatedMixin(TestCase):
+ def test_missing_many_to_many_related_field(self):
+ '''
+ Regression test for #632
+
+ https://github.com/tomchristie/django-rest-framework/pull/632
+ '''
+ field = serializers.RelatedField(many=True, read_only=False)
+
+ into = {}
+ field.field_from_native({}, None, 'field_name', into)
+ self.assertEqual(into['field_name'], [])
+
+
+# Regression tests for #694 (`source` attribute on related fields)
+
+class RelatedFieldSourceTests(TestCase):
+ def test_related_manager_source(self):
+ """
+ Relational fields should be able to use manager-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.RelatedField(many=True, source='get_blogposts_manager')
+
+ class ClassWithManagerMethod(object):
+ def get_blogposts_manager(self):
+ return BlogPost.objects
+
+ obj = ClassWithManagerMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['BlogPost object'])
+
+ def test_related_queryset_source(self):
+ """
+ Relational fields should be able to use queryset-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.RelatedField(many=True, source='get_blogposts_queryset')
+
+ class ClassWithQuerysetMethod(object):
+ def get_blogposts_queryset(self):
+ return BlogPost.objects.all()
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['BlogPost object'])
+
+ def test_dotted_source(self):
+ """
+ Source argument should support dotted.source notation.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.RelatedField(many=True, source='a.b.c')
+
+ class ClassWithQuerysetMethod(object):
+ a = {
+ 'b': {
+ 'c': BlogPost.objects.all()
+ }
+ }
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['BlogPost object'])
+
+ # Regression for #1129
+ def test_exception_for_incorect_fk(self):
+ """
+ Check that the exception message are correct if the source field
+ doesn't exist.
+ """
+ from tests.models import ManyToManySource
+ class Meta:
+ model = ManyToManySource
+ attrs = {
+ 'name': serializers.SlugRelatedField(
+ slug_field='name', source='banzai'),
+ 'Meta': Meta,
+ }
+
+ TestSerializer = type(str('TestSerializer'),
+ (serializers.ModelSerializer,), attrs)
+ with self.assertRaises(AttributeError):
+ TestSerializer(data={'name': 'foo'})
diff --git a/tests/test_relations_hyperlink.py b/tests/test_relations_hyperlink.py
new file mode 100644
index 00000000..98f68d29
--- /dev/null
+++ b/tests/test_relations_hyperlink.py
@@ -0,0 +1,524 @@
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework import serializers
+from rest_framework.compat import patterns, url
+from rest_framework.test import APIRequestFactory
+from tests.models import (
+ BlogPost,
+ ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource,
+ NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
+)
+
+factory = APIRequestFactory()
+request = factory.get('/') # Just to ensure we have a request in the serializer context
+
+
+def dummy_view(request, pk):
+ pass
+
+urlpatterns = patterns('',
+ url(r'^dummyurl/(?P<pk>[0-9]+)/$', dummy_view, name='dummy-url'),
+ url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'),
+ url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'),
+ url(r'^foreignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeysource-detail'),
+ url(r'^foreignkeytarget/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeytarget-detail'),
+ url(r'^nullableforeignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableforeignkeysource-detail'),
+ url(r'^onetoonetarget/(?P<pk>[0-9]+)/$', dummy_view, name='onetoonetarget-detail'),
+ url(r'^nullableonetoonesource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableonetoonesource-detail'),
+)
+
+
+# ManyToMany
+class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = ManyToManyTarget
+ fields = ('url', 'name', 'sources')
+
+
+class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = ManyToManySource
+ fields = ('url', 'name', 'targets')
+
+
+# ForeignKey
+class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = ForeignKeyTarget
+ fields = ('url', 'name', 'sources')
+
+
+class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = ForeignKeySource
+ fields = ('url', 'name', 'target')
+
+
+# Nullable ForeignKey
+class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = NullableForeignKeySource
+ fields = ('url', 'name', 'target')
+
+
+# Nullable OneToOne
+class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = OneToOneTarget
+ fields = ('url', 'name', 'nullable_source')
+
+
+# TODO: Add test that .data cannot be accessed prior to .is_valid
+
+class HyperlinkedManyToManyTests(TestCase):
+ urls = 'tests.test_relations_hyperlink'
+
+ def setUp(self):
+ for idx in range(1, 4):
+ target = ManyToManyTarget(name='target-%d' % idx)
+ target.save()
+ source = ManyToManySource(name='source-%d' % idx)
+ source.save()
+ for target in ManyToManyTarget.objects.all():
+ source.targets.add(target)
+
+ def test_many_to_many_retrieve(self):
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']},
+ {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
+ {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_many_to_many_retrieve(self):
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
+ {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
+ {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_many_to_many_update(self):
+ data = {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
+ instance = ManyToManySource.objects.get(pk=1)
+ serializer = ManyToManySourceSerializer(instance, data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEqual(serializer.data, data)
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']},
+ {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
+ {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_many_to_many_update(self):
+ data = {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']}
+ instance = ManyToManyTarget.objects.get(pk=1)
+ serializer = ManyToManyTargetSerializer(instance, data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEqual(serializer.data, data)
+
+ # Ensure target 1 is updated, and everything else is as expected
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']},
+ {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
+ {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}
+
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_many_to_many_create(self):
+ data = {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']}
+ serializer = ManyToManySourceSerializer(data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is added, and everything else is as expected
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']},
+ {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
+ {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']},
+ {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_many_to_many_create(self):
+ data = {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']}
+ serializer = ManyToManyTargetSerializer(data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-4')
+
+ # Ensure target 4 is added, and everything else is as expected
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
+ {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
+ {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']},
+ {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+
+class HyperlinkedForeignKeyTests(TestCase):
+ urls = 'tests.test_relations_hyperlink'
+
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ new_target = ForeignKeyTarget(name='target-2')
+ new_target.save()
+ for idx in range(1, 4):
+ source = ForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve(self):
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_foreign_key_retrieve(self):
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},
+ {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update(self):
+ data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'},
+ {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_incorrect_type(self):
+ data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 2}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected url string, received int.']})
+
+ def test_reverse_foreign_key_update(self):
+ data = {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}
+ instance = ForeignKeyTarget.objects.get(pk=2)
+ serializer = ForeignKeyTargetSerializer(instance, data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ # We shouldn't have saved anything to the db yet since save
+ # hasn't been called.
+ queryset = ForeignKeyTarget.objects.all()
+ new_serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},
+ {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
+ ]
+ self.assertEqual(new_serializer.data, expected)
+
+ serializer.save()
+ self.assertEqual(serializer.data, data)
+
+ # Ensure target 2 is update, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']},
+ {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_create(self):
+ data = {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'}
+ serializer = ForeignKeySourceSerializer(data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_foreign_key_create(self):
+ data = {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}
+ serializer = ForeignKeyTargetSerializer(data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-3')
+
+ # Ensure target 4 is added, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']},
+ {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
+ {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_with_invalid_null(self):
+ data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': None}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'target': ['This field is required.']})
+
+
+class HyperlinkedNullableForeignKeyTests(TestCase):
+ urls = 'tests.test_relations_hyperlink'
+
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ for idx in range(1, 4):
+ if idx == 3:
+ target = None
+ source = NullableForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve_with_null(self):
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_create_with_valid_null(self):
+ data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is created, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
+ {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_create_with_valid_emptystring(self):
+ """
+ The emptystring should be interpreted as null in the context
+ of relationships.
+ """
+ data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': ''}
+ expected_data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, expected_data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is created, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
+ {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_with_valid_null(self):
+ data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None},
+ {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_with_valid_emptystring(self):
+ """
+ The emptystring should be interpreted as null in the context
+ of relationships.
+ """
+ data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': ''}
+ expected_data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, expected_data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None},
+ {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ # reverse foreign keys MUST be read_only
+ # In the general case they do not provide .remove() or .clear()
+ # and cannot be arbitrarily set.
+
+ # def test_reverse_foreign_key_update(self):
+ # data = {'id': 1, 'name': 'target-1', 'sources': [1]}
+ # instance = ForeignKeyTarget.objects.get(pk=1)
+ # serializer = ForeignKeyTargetSerializer(instance, data=data)
+ # self.assertTrue(serializer.is_valid())
+ # self.assertEqual(serializer.data, data)
+ # serializer.save()
+
+ # # Ensure target 1 is updated, and everything else is as expected
+ # queryset = ForeignKeyTarget.objects.all()
+ # serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ # expected = [
+ # {'id': 1, 'name': 'target-1', 'sources': [1]},
+ # {'id': 2, 'name': 'target-2', 'sources': []},
+ # ]
+ # self.assertEqual(serializer.data, expected)
+
+
+class HyperlinkedNullableOneToOneTests(TestCase):
+ urls = 'tests.test_relations_hyperlink'
+
+ def setUp(self):
+ target = OneToOneTarget(name='target-1')
+ target.save()
+ new_target = OneToOneTarget(name='target-2')
+ new_target.save()
+ source = NullableOneToOneSource(name='source-1', target=target)
+ source.save()
+
+ def test_reverse_foreign_key_retrieve_with_null(self):
+ queryset = OneToOneTarget.objects.all()
+ serializer = NullableOneToOneTargetSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/onetoonetarget/1/', 'name': 'target-1', 'nullable_source': 'http://testserver/nullableonetoonesource/1/'},
+ {'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+
+# Regression tests for #694 (`source` attribute on related fields)
+
+class HyperlinkedRelatedFieldSourceTests(TestCase):
+ urls = 'tests.test_relations_hyperlink'
+
+ def test_related_manager_source(self):
+ """
+ Relational fields should be able to use manager-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.HyperlinkedRelatedField(
+ many=True,
+ source='get_blogposts_manager',
+ view_name='dummy-url',
+ )
+ field.context = {'request': request}
+
+ class ClassWithManagerMethod(object):
+ def get_blogposts_manager(self):
+ return BlogPost.objects
+
+ obj = ClassWithManagerMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['http://testserver/dummyurl/1/'])
+
+ def test_related_queryset_source(self):
+ """
+ Relational fields should be able to use queryset-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.HyperlinkedRelatedField(
+ many=True,
+ source='get_blogposts_queryset',
+ view_name='dummy-url',
+ )
+ field.context = {'request': request}
+
+ class ClassWithQuerysetMethod(object):
+ def get_blogposts_queryset(self):
+ return BlogPost.objects.all()
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['http://testserver/dummyurl/1/'])
+
+ def test_dotted_source(self):
+ """
+ Source argument should support dotted.source notation.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.HyperlinkedRelatedField(
+ many=True,
+ source='a.b.c',
+ view_name='dummy-url',
+ )
+ field.context = {'request': request}
+
+ class ClassWithQuerysetMethod(object):
+ a = {
+ 'b': {
+ 'c': BlogPost.objects.all()
+ }
+ }
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['http://testserver/dummyurl/1/'])
diff --git a/tests/test_relations_nested.py b/tests/test_relations_nested.py
new file mode 100644
index 00000000..d393b0c3
--- /dev/null
+++ b/tests/test_relations_nested.py
@@ -0,0 +1,328 @@
+from __future__ import unicode_literals
+from django.db import models
+from django.test import TestCase
+from rest_framework import serializers
+
+
+class OneToOneTarget(models.Model):
+ name = models.CharField(max_length=100)
+
+
+class OneToOneSource(models.Model):
+ name = models.CharField(max_length=100)
+ target = models.OneToOneField(OneToOneTarget, related_name='source',
+ null=True, blank=True)
+
+
+class OneToManyTarget(models.Model):
+ name = models.CharField(max_length=100)
+
+
+class OneToManySource(models.Model):
+ name = models.CharField(max_length=100)
+ target = models.ForeignKey(OneToManyTarget, related_name='sources')
+
+
+class ReverseNestedOneToOneTests(TestCase):
+ def setUp(self):
+ class OneToOneSourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = OneToOneSource
+ fields = ('id', 'name')
+
+ class OneToOneTargetSerializer(serializers.ModelSerializer):
+ source = OneToOneSourceSerializer()
+
+ class Meta:
+ model = OneToOneTarget
+ fields = ('id', 'name', 'source')
+
+ self.Serializer = OneToOneTargetSerializer
+
+ for idx in range(1, 4):
+ target = OneToOneTarget(name='target-%d' % idx)
+ target.save()
+ source = OneToOneSource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_one_to_one_retrieve(self):
+ queryset = OneToOneTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
+ {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
+ {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_one_create(self):
+ data = {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}}
+ serializer = self.Serializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-4')
+
+ # Ensure (target 4, target_source 4, source 4) are added, and
+ # everything else is as expected.
+ queryset = OneToOneTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
+ {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
+ {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}},
+ {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_one_create_with_invalid_data(self):
+ data = {'id': 4, 'name': 'target-4', 'source': {'id': 4}}
+ serializer = self.Serializer(data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'source': [{'name': ['This field is required.']}]})
+
+ def test_one_to_one_update(self):
+ data = {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}
+ instance = OneToOneTarget.objects.get(pk=3)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-3-updated')
+
+ # Ensure (target 3, target_source 3, source 3) are updated,
+ # and everything else is as expected.
+ queryset = OneToOneTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
+ {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
+ {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+
+class ForwardNestedOneToOneTests(TestCase):
+ def setUp(self):
+ class OneToOneTargetSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = OneToOneTarget
+ fields = ('id', 'name')
+
+ class OneToOneSourceSerializer(serializers.ModelSerializer):
+ target = OneToOneTargetSerializer()
+
+ class Meta:
+ model = OneToOneSource
+ fields = ('id', 'name', 'target')
+
+ self.Serializer = OneToOneSourceSerializer
+
+ for idx in range(1, 4):
+ target = OneToOneTarget(name='target-%d' % idx)
+ target.save()
+ source = OneToOneSource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_one_to_one_retrieve(self):
+ queryset = OneToOneSource.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
+ {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
+ {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_one_create(self):
+ data = {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}}
+ serializer = self.Serializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure (target 4, target_source 4, source 4) are added, and
+ # everything else is as expected.
+ queryset = OneToOneSource.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
+ {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
+ {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}},
+ {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_one_create_with_invalid_data(self):
+ data = {'id': 4, 'name': 'source-4', 'target': {'id': 4}}
+ serializer = self.Serializer(data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'target': [{'name': ['This field is required.']}]})
+
+ def test_one_to_one_update(self):
+ data = {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}}
+ instance = OneToOneSource.objects.get(pk=3)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-3-updated')
+
+ # Ensure (target 3, target_source 3, source 3) are updated,
+ # and everything else is as expected.
+ queryset = OneToOneSource.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
+ {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
+ {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_one_update_to_null(self):
+ data = {'id': 3, 'name': 'source-3-updated', 'target': None}
+ instance = OneToOneSource.objects.get(pk=3)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-3-updated')
+ self.assertEqual(obj.target, None)
+
+ queryset = OneToOneSource.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
+ {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
+ {'id': 3, 'name': 'source-3-updated', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ # TODO: Nullable 1-1 tests
+ # def test_one_to_one_delete(self):
+ # data = {'id': 3, 'name': 'target-3', 'target_source': None}
+ # instance = OneToOneTarget.objects.get(pk=3)
+ # serializer = self.Serializer(instance, data=data)
+ # self.assertTrue(serializer.is_valid())
+ # serializer.save()
+
+ # # Ensure (target_source 3, source 3) are deleted,
+ # # and everything else is as expected.
+ # queryset = OneToOneTarget.objects.all()
+ # serializer = self.Serializer(queryset)
+ # expected = [
+ # {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
+ # {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
+ # {'id': 3, 'name': 'target-3', 'source': None}
+ # ]
+ # self.assertEqual(serializer.data, expected)
+
+
+class ReverseNestedOneToManyTests(TestCase):
+ def setUp(self):
+ class OneToManySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = OneToManySource
+ fields = ('id', 'name')
+
+ class OneToManyTargetSerializer(serializers.ModelSerializer):
+ sources = OneToManySourceSerializer(many=True, allow_add_remove=True)
+
+ class Meta:
+ model = OneToManyTarget
+ fields = ('id', 'name', 'sources')
+
+ self.Serializer = OneToManyTargetSerializer
+
+ target = OneToManyTarget(name='target-1')
+ target.save()
+ for idx in range(1, 4):
+ source = OneToManySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_one_to_many_retrieve(self):
+ queryset = OneToManyTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'}]},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_many_create(self):
+ data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'},
+ {'id': 4, 'name': 'source-4'}]}
+ instance = OneToManyTarget.objects.get(pk=1)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-1')
+
+ # Ensure source 4 is added, and everything else is as
+ # expected.
+ queryset = OneToManyTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'},
+ {'id': 4, 'name': 'source-4'}]}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_many_create_with_invalid_data(self):
+ data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'},
+ {'id': 4}]}
+ serializer = self.Serializer(data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'sources': [{}, {}, {}, {'name': ['This field is required.']}]})
+
+ def test_one_to_many_update(self):
+ data = {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'}]}
+ instance = OneToManyTarget.objects.get(pk=1)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-1-updated')
+
+ # Ensure (target 1, source 1) are updated,
+ # and everything else is as expected.
+ queryset = OneToManyTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'}]}
+
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_many_delete(self):
+ data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 3, 'name': 'source-3'}]}
+ instance = OneToManyTarget.objects.get(pk=1)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+
+ # Ensure source 2 is deleted, and everything else is as
+ # expected.
+ queryset = OneToManyTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 3, 'name': 'source-3'}]}
+
+ ]
+ self.assertEqual(serializer.data, expected)
diff --git a/tests/test_relations_pk.py b/tests/test_relations_pk.py
new file mode 100644
index 00000000..ff59b250
--- /dev/null
+++ b/tests/test_relations_pk.py
@@ -0,0 +1,551 @@
+from __future__ import unicode_literals
+from django.db import models
+from django.test import TestCase
+from rest_framework import serializers
+from tests.models import (
+ BlogPost, ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource,
+ NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource,
+)
+from rest_framework.compat import six
+
+
+# ManyToMany
+class ManyToManyTargetSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ManyToManyTarget
+ fields = ('id', 'name', 'sources')
+
+
+class ManyToManySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ManyToManySource
+ fields = ('id', 'name', 'targets')
+
+
+# ForeignKey
+class ForeignKeyTargetSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ForeignKeyTarget
+ fields = ('id', 'name', 'sources')
+
+
+class ForeignKeySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ForeignKeySource
+ fields = ('id', 'name', 'target')
+
+
+# Nullable ForeignKey
+class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = NullableForeignKeySource
+ fields = ('id', 'name', 'target')
+
+
+# Nullable OneToOne
+class NullableOneToOneTargetSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = OneToOneTarget
+ fields = ('id', 'name', 'nullable_source')
+
+
+# TODO: Add test that .data cannot be accessed prior to .is_valid
+
+class PKManyToManyTests(TestCase):
+ def setUp(self):
+ for idx in range(1, 4):
+ target = ManyToManyTarget(name='target-%d' % idx)
+ target.save()
+ source = ManyToManySource(name='source-%d' % idx)
+ source.save()
+ for target in ManyToManyTarget.objects.all():
+ source.targets.add(target)
+
+ def test_many_to_many_retrieve(self):
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'targets': [1]},
+ {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
+ {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_many_to_many_retrieve(self):
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
+ {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
+ {'id': 3, 'name': 'target-3', 'sources': [3]}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_many_to_many_update(self):
+ data = {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]}
+ instance = ManyToManySource.objects.get(pk=1)
+ serializer = ManyToManySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEqual(serializer.data, data)
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]},
+ {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
+ {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_many_to_many_update(self):
+ data = {'id': 1, 'name': 'target-1', 'sources': [1]}
+ instance = ManyToManyTarget.objects.get(pk=1)
+ serializer = ManyToManyTargetSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEqual(serializer.data, data)
+
+ # Ensure target 1 is updated, and everything else is as expected
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [1]},
+ {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
+ {'id': 3, 'name': 'target-3', 'sources': [3]}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_many_to_many_create(self):
+ data = {'id': 4, 'name': 'source-4', 'targets': [1, 3]}
+ serializer = ManyToManySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is added, and everything else is as expected
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset, many=True)
+ self.assertFalse(serializer.fields['targets'].read_only)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'targets': [1]},
+ {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
+ {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]},
+ {'id': 4, 'name': 'source-4', 'targets': [1, 3]},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_many_to_many_create(self):
+ data = {'id': 4, 'name': 'target-4', 'sources': [1, 3]}
+ serializer = ManyToManyTargetSerializer(data=data)
+ self.assertFalse(serializer.fields['sources'].read_only)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-4')
+
+ # Ensure target 4 is added, and everything else is as expected
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
+ {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
+ {'id': 3, 'name': 'target-3', 'sources': [3]},
+ {'id': 4, 'name': 'target-4', 'sources': [1, 3]}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+
+class PKForeignKeyTests(TestCase):
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ new_target = ForeignKeyTarget(name='target-2')
+ new_target.save()
+ for idx in range(1, 4):
+ source = ForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve(self):
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 1},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': 1}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_foreign_key_retrieve(self):
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
+ {'id': 2, 'name': 'target-2', 'sources': []},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update(self):
+ data = {'id': 1, 'name': 'source-1', 'target': 2}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 2},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': 1}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_incorrect_type(self):
+ data = {'id': 1, 'name': 'source-1', 'target': 'foo'}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected pk value, received %s.' % six.text_type.__name__]})
+
+ def test_reverse_foreign_key_update(self):
+ data = {'id': 2, 'name': 'target-2', 'sources': [1, 3]}
+ instance = ForeignKeyTarget.objects.get(pk=2)
+ serializer = ForeignKeyTargetSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ # We shouldn't have saved anything to the db yet since save
+ # hasn't been called.
+ queryset = ForeignKeyTarget.objects.all()
+ new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
+ {'id': 2, 'name': 'target-2', 'sources': []},
+ ]
+ self.assertEqual(new_serializer.data, expected)
+
+ serializer.save()
+ self.assertEqual(serializer.data, data)
+
+ # Ensure target 2 is update, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [2]},
+ {'id': 2, 'name': 'target-2', 'sources': [1, 3]},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_create(self):
+ data = {'id': 4, 'name': 'source-4', 'target': 2}
+ serializer = ForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is added, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 1},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': 1},
+ {'id': 4, 'name': 'source-4', 'target': 2},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_foreign_key_create(self):
+ data = {'id': 3, 'name': 'target-3', 'sources': [1, 3]}
+ serializer = ForeignKeyTargetSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-3')
+
+ # Ensure target 3 is added, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [2]},
+ {'id': 2, 'name': 'target-2', 'sources': []},
+ {'id': 3, 'name': 'target-3', 'sources': [1, 3]},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_with_invalid_null(self):
+ data = {'id': 1, 'name': 'source-1', 'target': None}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'target': ['This field is required.']})
+
+ def test_foreign_key_with_empty(self):
+ """
+ Regression test for #1072
+
+ https://github.com/tomchristie/django-rest-framework/issues/1072
+ """
+ serializer = NullableForeignKeySourceSerializer()
+ self.assertEqual(serializer.data['target'], None)
+
+
+class PKNullableForeignKeyTests(TestCase):
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ for idx in range(1, 4):
+ if idx == 3:
+ target = None
+ source = NullableForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve_with_null(self):
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 1},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': None},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_create_with_valid_null(self):
+ data = {'id': 4, 'name': 'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is created, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 1},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': None},
+ {'id': 4, 'name': 'source-4', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_create_with_valid_emptystring(self):
+ """
+ The emptystring should be interpreted as null in the context
+ of relationships.
+ """
+ data = {'id': 4, 'name': 'source-4', 'target': ''}
+ expected_data = {'id': 4, 'name': 'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, expected_data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is created, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 1},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': None},
+ {'id': 4, 'name': 'source-4', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_with_valid_null(self):
+ data = {'id': 1, 'name': 'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': None},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_with_valid_emptystring(self):
+ """
+ The emptystring should be interpreted as null in the context
+ of relationships.
+ """
+ data = {'id': 1, 'name': 'source-1', 'target': ''}
+ expected_data = {'id': 1, 'name': 'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, expected_data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': None},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ # reverse foreign keys MUST be read_only
+ # In the general case they do not provide .remove() or .clear()
+ # and cannot be arbitrarily set.
+
+ # def test_reverse_foreign_key_update(self):
+ # data = {'id': 1, 'name': 'target-1', 'sources': [1]}
+ # instance = ForeignKeyTarget.objects.get(pk=1)
+ # serializer = ForeignKeyTargetSerializer(instance, data=data)
+ # self.assertTrue(serializer.is_valid())
+ # self.assertEqual(serializer.data, data)
+ # serializer.save()
+
+ # # Ensure target 1 is updated, and everything else is as expected
+ # queryset = ForeignKeyTarget.objects.all()
+ # serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ # expected = [
+ # {'id': 1, 'name': 'target-1', 'sources': [1]},
+ # {'id': 2, 'name': 'target-2', 'sources': []},
+ # ]
+ # self.assertEqual(serializer.data, expected)
+
+
+class PKNullableOneToOneTests(TestCase):
+ def setUp(self):
+ target = OneToOneTarget(name='target-1')
+ target.save()
+ new_target = OneToOneTarget(name='target-2')
+ new_target.save()
+ source = NullableOneToOneSource(name='source-1', target=new_target)
+ source.save()
+
+ def test_reverse_foreign_key_retrieve_with_null(self):
+ queryset = OneToOneTarget.objects.all()
+ serializer = NullableOneToOneTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'nullable_source': None},
+ {'id': 2, 'name': 'target-2', 'nullable_source': 1},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+
+# The below models and tests ensure that serializer fields corresponding
+# to a ManyToManyField field with a user-specified ``through`` model are
+# set to read only
+
+
+class ManyToManyThroughTarget(models.Model):
+ name = models.CharField(max_length=100)
+
+
+class ManyToManyThrough(models.Model):
+ source = models.ForeignKey('ManyToManyThroughSource')
+ target = models.ForeignKey(ManyToManyThroughTarget)
+
+
+class ManyToManyThroughSource(models.Model):
+ name = models.CharField(max_length=100)
+ targets = models.ManyToManyField(ManyToManyThroughTarget,
+ related_name='sources',
+ through='ManyToManyThrough')
+
+
+class ManyToManyThroughTargetSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ManyToManyThroughTarget
+ fields = ('id', 'name', 'sources')
+
+
+class ManyToManyThroughSourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ManyToManyThroughSource
+ fields = ('id', 'name', 'targets')
+
+
+class PKManyToManyThroughTests(TestCase):
+ def setUp(self):
+ self.source = ManyToManyThroughSource.objects.create(
+ name='through-source-1')
+ self.target = ManyToManyThroughTarget.objects.create(
+ name='through-target-1')
+
+ def test_many_to_many_create(self):
+ data = {'id': 2, 'name': 'source-2', 'targets': [self.target.pk]}
+ serializer = ManyToManyThroughSourceSerializer(data=data)
+ self.assertTrue(serializer.fields['targets'].read_only)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(obj.name, 'source-2')
+ self.assertEqual(obj.targets.count(), 0)
+
+ def test_many_to_many_reverse_create(self):
+ data = {'id': 2, 'name': 'target-2', 'sources': [self.source.pk]}
+ serializer = ManyToManyThroughTargetSerializer(data=data)
+ self.assertTrue(serializer.fields['sources'].read_only)
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ obj = serializer.save()
+ self.assertEqual(obj.name, 'target-2')
+ self.assertEqual(obj.sources.count(), 0)
+
+
+# Regression tests for #694 (`source` attribute on related fields)
+
+
+class PrimaryKeyRelatedFieldSourceTests(TestCase):
+ def test_related_manager_source(self):
+ """
+ Relational fields should be able to use manager-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.PrimaryKeyRelatedField(many=True, source='get_blogposts_manager')
+
+ class ClassWithManagerMethod(object):
+ def get_blogposts_manager(self):
+ return BlogPost.objects
+
+ obj = ClassWithManagerMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, [1])
+
+ def test_related_queryset_source(self):
+ """
+ Relational fields should be able to use queryset-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.PrimaryKeyRelatedField(many=True, source='get_blogposts_queryset')
+
+ class ClassWithQuerysetMethod(object):
+ def get_blogposts_queryset(self):
+ return BlogPost.objects.all()
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, [1])
+
+ def test_dotted_source(self):
+ """
+ Source argument should support dotted.source notation.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.PrimaryKeyRelatedField(many=True, source='a.b.c')
+
+ class ClassWithQuerysetMethod(object):
+ a = {
+ 'b': {
+ 'c': BlogPost.objects.all()
+ }
+ }
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, [1])
diff --git a/tests/test_relations_slug.py b/tests/test_relations_slug.py
new file mode 100644
index 00000000..97ebf23a
--- /dev/null
+++ b/tests/test_relations_slug.py
@@ -0,0 +1,257 @@
+from django.test import TestCase
+from rest_framework import serializers
+from tests.models import NullableForeignKeySource, ForeignKeySource, ForeignKeyTarget
+
+
+class ForeignKeyTargetSerializer(serializers.ModelSerializer):
+ sources = serializers.SlugRelatedField(many=True, slug_field='name')
+
+ class Meta:
+ model = ForeignKeyTarget
+
+
+class ForeignKeySourceSerializer(serializers.ModelSerializer):
+ target = serializers.SlugRelatedField(slug_field='name')
+
+ class Meta:
+ model = ForeignKeySource
+
+
+class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
+ target = serializers.SlugRelatedField(slug_field='name', required=False)
+
+ class Meta:
+ model = NullableForeignKeySource
+
+
+# TODO: M2M Tests, FKTests (Non-nullable), One2One
+class SlugForeignKeyTests(TestCase):
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ new_target = ForeignKeyTarget(name='target-2')
+ new_target.save()
+ for idx in range(1, 4):
+ source = ForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve(self):
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 'target-1'},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': 'target-1'}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_foreign_key_retrieve(self):
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
+ {'id': 2, 'name': 'target-2', 'sources': []},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update(self):
+ data = {'id': 1, 'name': 'source-1', 'target': 'target-2'}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 'target-2'},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': 'target-1'}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_incorrect_type(self):
+ data = {'id': 1, 'name': 'source-1', 'target': 123}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'target': ['Object with name=123 does not exist.']})
+
+ def test_reverse_foreign_key_update(self):
+ data = {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']}
+ instance = ForeignKeyTarget.objects.get(pk=2)
+ serializer = ForeignKeyTargetSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ # We shouldn't have saved anything to the db yet since save
+ # hasn't been called.
+ queryset = ForeignKeyTarget.objects.all()
+ new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
+ {'id': 2, 'name': 'target-2', 'sources': []},
+ ]
+ self.assertEqual(new_serializer.data, expected)
+
+ serializer.save()
+ self.assertEqual(serializer.data, data)
+
+ # Ensure target 2 is update, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': ['source-2']},
+ {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_create(self):
+ data = {'id': 4, 'name': 'source-4', 'target': 'target-2'}
+ serializer = ForeignKeySourceSerializer(data=data)
+ serializer.is_valid()
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is added, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 'target-1'},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': 'target-1'},
+ {'id': 4, 'name': 'source-4', 'target': 'target-2'},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_foreign_key_create(self):
+ data = {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']}
+ serializer = ForeignKeyTargetSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-3')
+
+ # Ensure target 3 is added, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': ['source-2']},
+ {'id': 2, 'name': 'target-2', 'sources': []},
+ {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_with_invalid_null(self):
+ data = {'id': 1, 'name': 'source-1', 'target': None}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'target': ['This field is required.']})
+
+
+class SlugNullableForeignKeyTests(TestCase):
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ for idx in range(1, 4):
+ if idx == 3:
+ target = None
+ source = NullableForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve_with_null(self):
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 'target-1'},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': None},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_create_with_valid_null(self):
+ data = {'id': 4, 'name': 'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is created, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 'target-1'},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': None},
+ {'id': 4, 'name': 'source-4', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_create_with_valid_emptystring(self):
+ """
+ The emptystring should be interpreted as null in the context
+ of relationships.
+ """
+ data = {'id': 4, 'name': 'source-4', 'target': ''}
+ expected_data = {'id': 4, 'name': 'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, expected_data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is created, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 'target-1'},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': None},
+ {'id': 4, 'name': 'source-4', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_with_valid_null(self):
+ data = {'id': 1, 'name': 'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': None},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_with_valid_emptystring(self):
+ """
+ The emptystring should be interpreted as null in the context
+ of relationships.
+ """
+ data = {'id': 1, 'name': 'source-1', 'target': ''}
+ expected_data = {'id': 1, 'name': 'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, expected_data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': None},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
diff --git a/tests/test_renderers.py b/tests/test_renderers.py
new file mode 100644
index 00000000..b41cff39
--- /dev/null
+++ b/tests/test_renderers.py
@@ -0,0 +1,651 @@
+# -*- coding: utf-8 -*-
+from __future__ import unicode_literals
+
+from decimal import Decimal
+from django.core.cache import cache
+from django.db import models
+from django.test import TestCase
+from django.utils import unittest
+from django.utils.translation import ugettext_lazy as _
+from rest_framework import status, permissions
+from rest_framework.compat import yaml, etree, patterns, url, include, six, StringIO
+from rest_framework.response import Response
+from rest_framework.views import APIView
+from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \
+ XMLRenderer, JSONPRenderer, BrowsableAPIRenderer, UnicodeJSONRenderer
+from rest_framework.parsers import YAMLParser, XMLParser
+from rest_framework.settings import api_settings
+from rest_framework.test import APIRequestFactory
+from collections import MutableMapping
+import datetime
+import json
+import pickle
+import re
+
+
+DUMMYSTATUS = status.HTTP_200_OK
+DUMMYCONTENT = 'dummycontent'
+
+RENDERER_A_SERIALIZER = lambda x: ('Renderer A: %s' % x).encode('ascii')
+RENDERER_B_SERIALIZER = lambda x: ('Renderer B: %s' % x).encode('ascii')
+
+
+expected_results = [
+ ((elem for elem in [1, 2, 3]), JSONRenderer, b'[1, 2, 3]') # Generator
+]
+
+
+class DummyTestModel(models.Model):
+ name = models.CharField(max_length=42, default='')
+
+
+class BasicRendererTests(TestCase):
+ def test_expected_results(self):
+ for value, renderer_cls, expected in expected_results:
+ output = renderer_cls().render(value)
+ self.assertEqual(output, expected)
+
+
+class RendererA(BaseRenderer):
+ media_type = 'mock/renderera'
+ format = "formata"
+
+ def render(self, data, media_type=None, renderer_context=None):
+ return RENDERER_A_SERIALIZER(data)
+
+
+class RendererB(BaseRenderer):
+ media_type = 'mock/rendererb'
+ format = "formatb"
+
+ def render(self, data, media_type=None, renderer_context=None):
+ return RENDERER_B_SERIALIZER(data)
+
+
+class MockView(APIView):
+ renderer_classes = (RendererA, RendererB)
+
+ def get(self, request, **kwargs):
+ response = Response(DUMMYCONTENT, status=DUMMYSTATUS)
+ return response
+
+
+class MockGETView(APIView):
+ def get(self, request, **kwargs):
+ return Response({'foo': ['bar', 'baz']})
+
+
+
+class MockPOSTView(APIView):
+ def post(self, request, **kwargs):
+ return Response({'foo': request.DATA})
+
+
+class EmptyGETView(APIView):
+ renderer_classes = (JSONRenderer,)
+
+ def get(self, request, **kwargs):
+ return Response(status=status.HTTP_204_NO_CONTENT)
+
+
+class HTMLView(APIView):
+ renderer_classes = (BrowsableAPIRenderer, )
+
+ def get(self, request, **kwargs):
+ return Response('text')
+
+
+class HTMLView1(APIView):
+ renderer_classes = (BrowsableAPIRenderer, JSONRenderer)
+
+ def get(self, request, **kwargs):
+ return Response('text')
+
+urlpatterns = patterns('',
+ url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
+ url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
+ url(r'^cache$', MockGETView.as_view()),
+ url(r'^jsonp/jsonrenderer$', MockGETView.as_view(renderer_classes=[JSONRenderer, JSONPRenderer])),
+ url(r'^jsonp/nojsonrenderer$', MockGETView.as_view(renderer_classes=[JSONPRenderer])),
+ url(r'^parseerror$', MockPOSTView.as_view(renderer_classes=[JSONRenderer, BrowsableAPIRenderer])),
+ url(r'^html$', HTMLView.as_view()),
+ url(r'^html1$', HTMLView1.as_view()),
+ url(r'^empty$', EmptyGETView.as_view()),
+ url(r'^api', include('rest_framework.urls', namespace='rest_framework'))
+)
+
+
+class POSTDeniedPermission(permissions.BasePermission):
+ def has_permission(self, request, view):
+ return request.method != 'POST'
+
+
+class POSTDeniedView(APIView):
+ renderer_classes = (BrowsableAPIRenderer,)
+ permission_classes = (POSTDeniedPermission,)
+
+ def get(self, request):
+ return Response()
+
+ def post(self, request):
+ return Response()
+
+ def put(self, request):
+ return Response()
+
+ def patch(self, request):
+ return Response()
+
+
+class DocumentingRendererTests(TestCase):
+ def test_only_permitted_forms_are_displayed(self):
+ view = POSTDeniedView.as_view()
+ request = APIRequestFactory().get('/')
+ response = view(request).render()
+ self.assertNotContains(response, '>POST<')
+ self.assertContains(response, '>PUT<')
+ self.assertContains(response, '>PATCH<')
+
+
+class RendererEndToEndTests(TestCase):
+ """
+ End-to-end testing of renderers using an RendererMixin on a generic view.
+ """
+
+ urls = 'tests.test_renderers'
+
+ def test_default_renderer_serializes_content(self):
+ """If the Accept header is not set the default renderer should serialize the response."""
+ resp = self.client.get('/')
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_head_method_serializes_no_content(self):
+ """No response must be included in HEAD requests."""
+ resp = self.client.head('/')
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, six.b(''))
+
+ def test_default_renderer_serializes_content_on_accept_any(self):
+ """If the Accept header is set to */* the default renderer should serialize the response."""
+ resp = self.client.get('/', HTTP_ACCEPT='*/*')
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_default_case(self):
+ """If the Accept header is set the specified renderer should serialize the response.
+ (In this case we check that works for the default renderer)"""
+ resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_non_default_case(self):
+ """If the Accept header is set the specified renderer should serialize the response.
+ (In this case we check that works for a non-default renderer)"""
+ resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_on_accept_query(self):
+ """The '_accept' query string should behave in the same way as the Accept header."""
+ param = '?%s=%s' % (
+ api_settings.URL_ACCEPT_OVERRIDE,
+ RendererB.media_type
+ )
+ resp = self.client.get('/' + param)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_unsatisfiable_accept_header_on_request_returns_406_status(self):
+ """If the Accept header is unsatisfiable we should return a 406 Not Acceptable response."""
+ resp = self.client.get('/', HTTP_ACCEPT='foo/bar')
+ self.assertEqual(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE)
+
+ def test_specified_renderer_serializes_content_on_format_query(self):
+ """If a 'format' query is specified, the renderer with the matching
+ format attribute should serialize the response."""
+ param = '?%s=%s' % (
+ api_settings.URL_FORMAT_OVERRIDE,
+ RendererB.format
+ )
+ resp = self.client.get('/' + param)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_on_format_kwargs(self):
+ """If a 'format' keyword arg is specified, the renderer with the matching
+ format attribute should serialize the response."""
+ resp = self.client.get('/something.formatb')
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
+ """If both a 'format' query and a matching Accept header specified,
+ the renderer with the matching format attribute should serialize the response."""
+ param = '?%s=%s' % (
+ api_settings.URL_FORMAT_OVERRIDE,
+ RendererB.format
+ )
+ resp = self.client.get('/' + param,
+ HTTP_ACCEPT=RendererB.media_type)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_parse_error_renderers_browsable_api(self):
+ """Invalid data should still render the browsable API correctly."""
+ resp = self.client.post('/parseerror', data='foobar', content_type='application/json', HTTP_ACCEPT='text/html')
+ self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
+ self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST)
+
+ def test_204_no_content_responses_have_no_content_type_set(self):
+ """
+ Regression test for #1196
+
+ https://github.com/tomchristie/django-rest-framework/issues/1196
+ """
+ resp = self.client.get('/empty')
+ self.assertEqual(resp.get('Content-Type', None), None)
+ self.assertEqual(resp.status_code, status.HTTP_204_NO_CONTENT)
+
+ def test_contains_headers_of_api_response(self):
+ """
+ Issue #1437
+
+ Test we display the headers of the API response and not those from the
+ HTML response
+ """
+ resp = self.client.get('/html1')
+ self.assertContains(resp, '>GET, HEAD, OPTIONS<')
+ self.assertContains(resp, '>application/json<')
+ self.assertNotContains(resp, '>text/html; charset=utf-8<')
+
+
+_flat_repr = '{"foo": ["bar", "baz"]}'
+_indented_repr = '{\n "foo": [\n "bar",\n "baz"\n ]\n}'
+
+
+def strip_trailing_whitespace(content):
+ """
+ Seems to be some inconsistencies re. trailing whitespace with
+ different versions of the json lib.
+ """
+ return re.sub(' +\n', '\n', content)
+
+
+class JSONRendererTests(TestCase):
+ """
+ Tests specific to the JSON Renderer
+ """
+
+ def test_render_lazy_strings(self):
+ """
+ JSONRenderer should deal with lazy translated strings.
+ """
+ ret = JSONRenderer().render(_('test'))
+ self.assertEqual(ret, b'"test"')
+
+ def test_render_queryset_values(self):
+ o = DummyTestModel.objects.create(name='dummy')
+ qs = DummyTestModel.objects.values('id', 'name')
+ ret = JSONRenderer().render(qs)
+ data = json.loads(ret.decode('utf-8'))
+ self.assertEquals(data, [{'id': o.id, 'name': o.name}])
+
+ def test_render_queryset_values_list(self):
+ o = DummyTestModel.objects.create(name='dummy')
+ qs = DummyTestModel.objects.values_list('id', 'name')
+ ret = JSONRenderer().render(qs)
+ data = json.loads(ret.decode('utf-8'))
+ self.assertEquals(data, [[o.id, o.name]])
+
+ def test_render_dict_abc_obj(self):
+ class Dict(MutableMapping):
+ def __init__(self):
+ self._dict = dict()
+ def __getitem__(self, key):
+ return self._dict.__getitem__(key)
+ def __setitem__(self, key, value):
+ return self._dict.__setitem__(key, value)
+ def __delitem__(self, key):
+ return self._dict.__delitem__(key)
+ def __iter__(self):
+ return self._dict.__iter__()
+ def __len__(self):
+ return self._dict.__len__()
+ def keys(self):
+ return self._dict.keys()
+
+ x = Dict()
+ x['key'] = 'string value'
+ x[2] = 3
+ ret = JSONRenderer().render(x)
+ data = json.loads(ret.decode('utf-8'))
+ self.assertEquals(data, {'key': 'string value', '2': 3})
+
+ def test_render_obj_with_getitem(self):
+ class DictLike(object):
+ def __init__(self):
+ self._dict = {}
+ def set(self, value):
+ self._dict = dict(value)
+ def __getitem__(self, key):
+ return self._dict[key]
+
+ x = DictLike()
+ x.set({'a': 1, 'b': 'string'})
+ with self.assertRaises(TypeError):
+ JSONRenderer().render(x)
+
+ def test_without_content_type_args(self):
+ """
+ Test basic JSON rendering.
+ """
+ obj = {'foo': ['bar', 'baz']}
+ renderer = JSONRenderer()
+ content = renderer.render(obj, 'application/json')
+ # Fix failing test case which depends on version of JSON library.
+ self.assertEqual(content.decode('utf-8'), _flat_repr)
+
+ def test_with_content_type_args(self):
+ """
+ Test JSON rendering with additional content type arguments supplied.
+ """
+ obj = {'foo': ['bar', 'baz']}
+ renderer = JSONRenderer()
+ content = renderer.render(obj, 'application/json; indent=2')
+ self.assertEqual(strip_trailing_whitespace(content.decode('utf-8')), _indented_repr)
+
+ def test_check_ascii(self):
+ obj = {'countries': ['United Kingdom', 'France', 'España']}
+ renderer = JSONRenderer()
+ content = renderer.render(obj, 'application/json')
+ self.assertEqual(content, '{"countries": ["United Kingdom", "France", "Espa\\u00f1a"]}'.encode('utf-8'))
+
+
+class UnicodeJSONRendererTests(TestCase):
+ """
+ Tests specific for the Unicode JSON Renderer
+ """
+ def test_proper_encoding(self):
+ obj = {'countries': ['United Kingdom', 'France', 'España']}
+ renderer = UnicodeJSONRenderer()
+ content = renderer.render(obj, 'application/json')
+ self.assertEqual(content, '{"countries": ["United Kingdom", "France", "España"]}'.encode('utf-8'))
+
+
+class JSONPRendererTests(TestCase):
+ """
+ Tests specific to the JSONP Renderer
+ """
+
+ urls = 'tests.test_renderers'
+
+ def test_without_callback_with_json_renderer(self):
+ """
+ Test JSONP rendering with View JSON Renderer.
+ """
+ resp = self.client.get('/jsonp/jsonrenderer',
+ HTTP_ACCEPT='application/javascript')
+ self.assertEqual(resp.status_code, status.HTTP_200_OK)
+ self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8')
+ self.assertEqual(resp.content,
+ ('callback(%s);' % _flat_repr).encode('ascii'))
+
+ def test_without_callback_without_json_renderer(self):
+ """
+ Test JSONP rendering without View JSON Renderer.
+ """
+ resp = self.client.get('/jsonp/nojsonrenderer',
+ HTTP_ACCEPT='application/javascript')
+ self.assertEqual(resp.status_code, status.HTTP_200_OK)
+ self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8')
+ self.assertEqual(resp.content,
+ ('callback(%s);' % _flat_repr).encode('ascii'))
+
+ def test_with_callback(self):
+ """
+ Test JSONP rendering with callback function name.
+ """
+ callback_func = 'myjsonpcallback'
+ resp = self.client.get('/jsonp/nojsonrenderer?callback=' + callback_func,
+ HTTP_ACCEPT='application/javascript')
+ self.assertEqual(resp.status_code, status.HTTP_200_OK)
+ self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8')
+ self.assertEqual(resp.content,
+ ('%s(%s);' % (callback_func, _flat_repr)).encode('ascii'))
+
+
+if yaml:
+ _yaml_repr = 'foo: [bar, baz]\n'
+
+ class YAMLRendererTests(TestCase):
+ """
+ Tests specific to the YAML Renderer
+ """
+
+ def test_render(self):
+ """
+ Test basic YAML rendering.
+ """
+ obj = {'foo': ['bar', 'baz']}
+ renderer = YAMLRenderer()
+ content = renderer.render(obj, 'application/yaml')
+ self.assertEqual(content, _yaml_repr)
+
+ def test_render_and_parse(self):
+ """
+ Test rendering and then parsing returns the original object.
+ IE obj -> render -> parse -> obj.
+ """
+ obj = {'foo': ['bar', 'baz']}
+
+ renderer = YAMLRenderer()
+ parser = YAMLParser()
+
+ content = renderer.render(obj, 'application/yaml')
+ data = parser.parse(StringIO(content))
+ self.assertEqual(obj, data)
+
+ def test_render_decimal(self):
+ """
+ Test YAML decimal rendering.
+ """
+ renderer = YAMLRenderer()
+ content = renderer.render({'field': Decimal('111.2')}, 'application/yaml')
+ self.assertYAMLContains(content, "field: '111.2'")
+
+ def assertYAMLContains(self, content, string):
+ self.assertTrue(string in content, '%r not in %r' % (string, content))
+
+
+class XMLRendererTestCase(TestCase):
+ """
+ Tests specific to the XML Renderer
+ """
+
+ _complex_data = {
+ "creation_date": datetime.datetime(2011, 12, 25, 12, 45, 00),
+ "name": "name",
+ "sub_data_list": [
+ {
+ "sub_id": 1,
+ "sub_name": "first"
+ },
+ {
+ "sub_id": 2,
+ "sub_name": "second"
+ }
+ ]
+ }
+
+ def test_render_string(self):
+ """
+ Test XML rendering.
+ """
+ renderer = XMLRenderer()
+ content = renderer.render({'field': 'astring'}, 'application/xml')
+ self.assertXMLContains(content, '<field>astring</field>')
+
+ def test_render_integer(self):
+ """
+ Test XML rendering.
+ """
+ renderer = XMLRenderer()
+ content = renderer.render({'field': 111}, 'application/xml')
+ self.assertXMLContains(content, '<field>111</field>')
+
+ def test_render_datetime(self):
+ """
+ Test XML rendering.
+ """
+ renderer = XMLRenderer()
+ content = renderer.render({
+ 'field': datetime.datetime(2011, 12, 25, 12, 45, 00)
+ }, 'application/xml')
+ self.assertXMLContains(content, '<field>2011-12-25 12:45:00</field>')
+
+ def test_render_float(self):
+ """
+ Test XML rendering.
+ """
+ renderer = XMLRenderer()
+ content = renderer.render({'field': 123.4}, 'application/xml')
+ self.assertXMLContains(content, '<field>123.4</field>')
+
+ def test_render_decimal(self):
+ """
+ Test XML rendering.
+ """
+ renderer = XMLRenderer()
+ content = renderer.render({'field': Decimal('111.2')}, 'application/xml')
+ self.assertXMLContains(content, '<field>111.2</field>')
+
+ def test_render_none(self):
+ """
+ Test XML rendering.
+ """
+ renderer = XMLRenderer()
+ content = renderer.render({'field': None}, 'application/xml')
+ self.assertXMLContains(content, '<field></field>')
+
+ def test_render_complex_data(self):
+ """
+ Test XML rendering.
+ """
+ renderer = XMLRenderer()
+ content = renderer.render(self._complex_data, 'application/xml')
+ self.assertXMLContains(content, '<sub_name>first</sub_name>')
+ self.assertXMLContains(content, '<sub_name>second</sub_name>')
+
+ @unittest.skipUnless(etree, 'defusedxml not installed')
+ def test_render_and_parse_complex_data(self):
+ """
+ Test XML rendering.
+ """
+ renderer = XMLRenderer()
+ content = StringIO(renderer.render(self._complex_data, 'application/xml'))
+
+ parser = XMLParser()
+ complex_data_out = parser.parse(content)
+ error_msg = "complex data differs!IN:\n %s \n\n OUT:\n %s" % (repr(self._complex_data), repr(complex_data_out))
+ self.assertEqual(self._complex_data, complex_data_out, error_msg)
+
+ def assertXMLContains(self, xml, string):
+ self.assertTrue(xml.startswith('<?xml version="1.0" encoding="utf-8"?>\n<root>'))
+ self.assertTrue(xml.endswith('</root>'))
+ self.assertTrue(string in xml, '%r not in %r' % (string, xml))
+
+
+# Tests for caching issue, #346
+class CacheRenderTest(TestCase):
+ """
+ Tests specific to caching responses
+ """
+
+ urls = 'tests.test_renderers'
+
+ cache_key = 'just_a_cache_key'
+
+ @classmethod
+ def _get_pickling_errors(cls, obj, seen=None):
+ """ Return any errors that would be raised if `obj' is pickled
+ Courtesy of koffie @ http://stackoverflow.com/a/7218986/109897
+ """
+ if seen == None:
+ seen = []
+ try:
+ state = obj.__getstate__()
+ except AttributeError:
+ return
+ if state == None:
+ return
+ if isinstance(state, tuple):
+ if not isinstance(state[0], dict):
+ state = state[1]
+ else:
+ state = state[0].update(state[1])
+ result = {}
+ for i in state:
+ try:
+ pickle.dumps(state[i], protocol=2)
+ except pickle.PicklingError:
+ if not state[i] in seen:
+ seen.append(state[i])
+ result[i] = cls._get_pickling_errors(state[i], seen)
+ return result
+
+ def http_resp(self, http_method, url):
+ """
+ Simple wrapper for Client http requests
+ Removes the `client' and `request' attributes from as they are
+ added by django.test.client.Client and not part of caching
+ responses outside of tests.
+ """
+ method = getattr(self.client, http_method)
+ resp = method(url)
+ del resp.client, resp.request
+ return resp
+
+ def test_obj_pickling(self):
+ """
+ Test that responses are properly pickled
+ """
+ resp = self.http_resp('get', '/cache')
+
+ # Make sure that no pickling errors occurred
+ self.assertEqual(self._get_pickling_errors(resp), {})
+
+ # Unfortunately LocMem backend doesn't raise PickleErrors but returns
+ # None instead.
+ cache.set(self.cache_key, resp)
+ self.assertTrue(cache.get(self.cache_key) is not None)
+
+ def test_head_caching(self):
+ """
+ Test caching of HEAD requests
+ """
+ resp = self.http_resp('head', '/cache')
+ cache.set(self.cache_key, resp)
+
+ cached_resp = cache.get(self.cache_key)
+ self.assertIsInstance(cached_resp, Response)
+
+ def test_get_caching(self):
+ """
+ Test caching of GET requests
+ """
+ resp = self.http_resp('get', '/cache')
+ cache.set(self.cache_key, resp)
+
+ cached_resp = cache.get(self.cache_key)
+ self.assertIsInstance(cached_resp, Response)
+ self.assertEqual(cached_resp.content, resp.content)
diff --git a/tests/test_request.py b/tests/test_request.py
new file mode 100644
index 00000000..0a9355f0
--- /dev/null
+++ b/tests/test_request.py
@@ -0,0 +1,347 @@
+"""
+Tests for content parsing, and form-overloaded content parsing.
+"""
+from __future__ import unicode_literals
+from django.contrib.auth.models import User
+from django.contrib.auth import authenticate, login, logout
+from django.contrib.sessions.middleware import SessionMiddleware
+from django.core.handlers.wsgi import WSGIRequest
+from django.test import TestCase
+from rest_framework import status
+from rest_framework.authentication import SessionAuthentication
+from rest_framework.compat import patterns
+from rest_framework.parsers import (
+ BaseParser,
+ FormParser,
+ MultiPartParser,
+ JSONParser
+)
+from rest_framework.request import Request, Empty
+from rest_framework.response import Response
+from rest_framework.settings import api_settings
+from rest_framework.test import APIRequestFactory, APIClient
+from rest_framework.views import APIView
+from rest_framework.compat import six
+from io import BytesIO
+import json
+
+
+factory = APIRequestFactory()
+
+
+class PlainTextParser(BaseParser):
+ media_type = 'text/plain'
+
+ def parse(self, stream, media_type=None, parser_context=None):
+ """
+ Returns a 2-tuple of `(data, files)`.
+
+ `data` will simply be a string representing the body of the request.
+ `files` will always be `None`.
+ """
+ return stream.read()
+
+
+class TestMethodOverloading(TestCase):
+ def test_method(self):
+ """
+ Request methods should be same as underlying request.
+ """
+ request = Request(factory.get('/'))
+ self.assertEqual(request.method, 'GET')
+ request = Request(factory.post('/'))
+ self.assertEqual(request.method, 'POST')
+
+ def test_overloaded_method(self):
+ """
+ POST requests can be overloaded to another method by setting a
+ reserved form field
+ """
+ request = Request(factory.post('/', {api_settings.FORM_METHOD_OVERRIDE: 'DELETE'}))
+ self.assertEqual(request.method, 'DELETE')
+
+ def test_x_http_method_override_header(self):
+ """
+ POST requests can also be overloaded to another method by setting
+ the X-HTTP-Method-Override header.
+ """
+ request = Request(factory.post('/', {'foo': 'bar'}, HTTP_X_HTTP_METHOD_OVERRIDE='DELETE'))
+ self.assertEqual(request.method, 'DELETE')
+
+ request = Request(factory.get('/', {'foo': 'bar'}, HTTP_X_HTTP_METHOD_OVERRIDE='DELETE'))
+ self.assertEqual(request.method, 'DELETE')
+
+
+class TestContentParsing(TestCase):
+ def test_standard_behaviour_determines_no_content_GET(self):
+ """
+ Ensure request.DATA returns empty QueryDict for GET request.
+ """
+ request = Request(factory.get('/'))
+ self.assertEqual(request.DATA, {})
+
+ def test_standard_behaviour_determines_no_content_HEAD(self):
+ """
+ Ensure request.DATA returns empty QueryDict for HEAD request.
+ """
+ request = Request(factory.head('/'))
+ self.assertEqual(request.DATA, {})
+
+ def test_request_DATA_with_form_content(self):
+ """
+ Ensure request.DATA returns content for POST request with form content.
+ """
+ data = {'qwerty': 'uiop'}
+ request = Request(factory.post('/', data))
+ request.parsers = (FormParser(), MultiPartParser())
+ self.assertEqual(list(request.DATA.items()), list(data.items()))
+
+ def test_request_DATA_with_text_content(self):
+ """
+ Ensure request.DATA returns content for POST request with
+ non-form content.
+ """
+ content = six.b('qwerty')
+ content_type = 'text/plain'
+ request = Request(factory.post('/', content, content_type=content_type))
+ request.parsers = (PlainTextParser(),)
+ self.assertEqual(request.DATA, content)
+
+ def test_request_POST_with_form_content(self):
+ """
+ Ensure request.POST returns content for POST request with form content.
+ """
+ data = {'qwerty': 'uiop'}
+ request = Request(factory.post('/', data))
+ request.parsers = (FormParser(), MultiPartParser())
+ self.assertEqual(list(request.POST.items()), list(data.items()))
+
+ def test_standard_behaviour_determines_form_content_PUT(self):
+ """
+ Ensure request.DATA returns content for PUT request with form content.
+ """
+ data = {'qwerty': 'uiop'}
+ request = Request(factory.put('/', data))
+ request.parsers = (FormParser(), MultiPartParser())
+ self.assertEqual(list(request.DATA.items()), list(data.items()))
+
+ def test_standard_behaviour_determines_non_form_content_PUT(self):
+ """
+ Ensure request.DATA returns content for PUT request with
+ non-form content.
+ """
+ content = six.b('qwerty')
+ content_type = 'text/plain'
+ request = Request(factory.put('/', content, content_type=content_type))
+ request.parsers = (PlainTextParser(), )
+ self.assertEqual(request.DATA, content)
+
+ def test_overloaded_behaviour_allows_content_tunnelling(self):
+ """
+ Ensure request.DATA returns content for overloaded POST request.
+ """
+ json_data = {'foobar': 'qwerty'}
+ content = json.dumps(json_data)
+ content_type = 'application/json'
+ form_data = {
+ api_settings.FORM_CONTENT_OVERRIDE: content,
+ api_settings.FORM_CONTENTTYPE_OVERRIDE: content_type
+ }
+ request = Request(factory.post('/', form_data))
+ request.parsers = (JSONParser(), )
+ self.assertEqual(request.DATA, json_data)
+
+ def test_form_POST_unicode(self):
+ """
+ JSON POST via default web interface with unicode data
+ """
+ # Note: environ and other variables here have simplified content compared to real Request
+ CONTENT = b'_content_type=application%2Fjson&_content=%7B%22request%22%3A+4%2C+%22firm%22%3A+1%2C+%22text%22%3A+%22%D0%9F%D1%80%D0%B8%D0%B2%D0%B5%D1%82%21%22%7D'
+ environ = {
+ 'REQUEST_METHOD': 'POST',
+ 'CONTENT_TYPE': 'application/x-www-form-urlencoded',
+ 'CONTENT_LENGTH': len(CONTENT),
+ 'wsgi.input': BytesIO(CONTENT),
+ }
+ wsgi_request = WSGIRequest(environ=environ)
+ wsgi_request._load_post_and_files()
+ parsers = (JSONParser(), FormParser(), MultiPartParser())
+ parser_context = {
+ 'encoding': 'utf-8',
+ 'kwargs': {},
+ 'args': (),
+ }
+ request = Request(wsgi_request, parsers=parsers, parser_context=parser_context)
+ method = request.method
+ self.assertEqual(method, 'POST')
+ self.assertEqual(request._content_type, 'application/json')
+ self.assertEqual(request._stream.getvalue(), b'{"request": 4, "firm": 1, "text": "\xd0\x9f\xd1\x80\xd0\xb8\xd0\xb2\xd0\xb5\xd1\x82!"}')
+ self.assertEqual(request._data, Empty)
+ self.assertEqual(request._files, Empty)
+
+ # def test_accessing_post_after_data_form(self):
+ # """
+ # Ensures request.POST can be accessed after request.DATA in
+ # form request.
+ # """
+ # data = {'qwerty': 'uiop'}
+ # request = factory.post('/', data=data)
+ # self.assertEqual(request.DATA.items(), data.items())
+ # self.assertEqual(request.POST.items(), data.items())
+
+ # def test_accessing_post_after_data_for_json(self):
+ # """
+ # Ensures request.POST can be accessed after request.DATA in
+ # json request.
+ # """
+ # data = {'qwerty': 'uiop'}
+ # content = json.dumps(data)
+ # content_type = 'application/json'
+ # parsers = (JSONParser, )
+
+ # request = factory.post('/', content, content_type=content_type,
+ # parsers=parsers)
+ # self.assertEqual(request.DATA.items(), data.items())
+ # self.assertEqual(request.POST.items(), [])
+
+ # def test_accessing_post_after_data_for_overloaded_json(self):
+ # """
+ # Ensures request.POST can be accessed after request.DATA in overloaded
+ # json request.
+ # """
+ # data = {'qwerty': 'uiop'}
+ # content = json.dumps(data)
+ # content_type = 'application/json'
+ # parsers = (JSONParser, )
+ # form_data = {Request._CONTENT_PARAM: content,
+ # Request._CONTENTTYPE_PARAM: content_type}
+
+ # request = factory.post('/', form_data, parsers=parsers)
+ # self.assertEqual(request.DATA.items(), data.items())
+ # self.assertEqual(request.POST.items(), form_data.items())
+
+ # def test_accessing_data_after_post_form(self):
+ # """
+ # Ensures request.DATA can be accessed after request.POST in
+ # form request.
+ # """
+ # data = {'qwerty': 'uiop'}
+ # parsers = (FormParser, MultiPartParser)
+ # request = factory.post('/', data, parsers=parsers)
+
+ # self.assertEqual(request.POST.items(), data.items())
+ # self.assertEqual(request.DATA.items(), data.items())
+
+ # def test_accessing_data_after_post_for_json(self):
+ # """
+ # Ensures request.DATA can be accessed after request.POST in
+ # json request.
+ # """
+ # data = {'qwerty': 'uiop'}
+ # content = json.dumps(data)
+ # content_type = 'application/json'
+ # parsers = (JSONParser, )
+ # request = factory.post('/', content, content_type=content_type,
+ # parsers=parsers)
+ # self.assertEqual(request.POST.items(), [])
+ # self.assertEqual(request.DATA.items(), data.items())
+
+ # def test_accessing_data_after_post_for_overloaded_json(self):
+ # """
+ # Ensures request.DATA can be accessed after request.POST in overloaded
+ # json request
+ # """
+ # data = {'qwerty': 'uiop'}
+ # content = json.dumps(data)
+ # content_type = 'application/json'
+ # parsers = (JSONParser, )
+ # form_data = {Request._CONTENT_PARAM: content,
+ # Request._CONTENTTYPE_PARAM: content_type}
+
+ # request = factory.post('/', form_data, parsers=parsers)
+ # self.assertEqual(request.POST.items(), form_data.items())
+ # self.assertEqual(request.DATA.items(), data.items())
+
+
+class MockView(APIView):
+ authentication_classes = (SessionAuthentication,)
+
+ def post(self, request):
+ if request.POST.get('example') is not None:
+ return Response(status=status.HTTP_200_OK)
+
+ return Response(status=status.INTERNAL_SERVER_ERROR)
+
+urlpatterns = patterns('',
+ (r'^$', MockView.as_view()),
+)
+
+
+class TestContentParsingWithAuthentication(TestCase):
+ urls = 'tests.test_request'
+
+ def setUp(self):
+ self.csrf_client = APIClient(enforce_csrf_checks=True)
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ def test_user_logged_in_authentication_has_POST_when_not_logged_in(self):
+ """
+ Ensures request.POST exists after SessionAuthentication when user
+ doesn't log in.
+ """
+ content = {'example': 'example'}
+
+ response = self.client.post('/', content)
+ self.assertEqual(status.HTTP_200_OK, response.status_code)
+
+ response = self.csrf_client.post('/', content)
+ self.assertEqual(status.HTTP_200_OK, response.status_code)
+
+ # def test_user_logged_in_authentication_has_post_when_logged_in(self):
+ # """Ensures request.POST exists after UserLoggedInAuthentication when user does log in"""
+ # self.client.login(username='john', password='password')
+ # self.csrf_client.login(username='john', password='password')
+ # content = {'example': 'example'}
+
+ # response = self.client.post('/', content)
+ # self.assertEqual(status.OK, response.status_code, "POST data is malformed")
+
+ # response = self.csrf_client.post('/', content)
+ # self.assertEqual(status.OK, response.status_code, "POST data is malformed")
+
+
+class TestUserSetter(TestCase):
+
+ def setUp(self):
+ # Pass request object through session middleware so session is
+ # available to login and logout functions
+ self.request = Request(factory.get('/'))
+ SessionMiddleware().process_request(self.request)
+
+ User.objects.create_user('ringo', 'starr@thebeatles.com', 'yellow')
+ self.user = authenticate(username='ringo', password='yellow')
+
+ def test_user_can_be_set(self):
+ self.request.user = self.user
+ self.assertEqual(self.request.user, self.user)
+
+ def test_user_can_login(self):
+ login(self.request, self.user)
+ self.assertEqual(self.request.user, self.user)
+
+ def test_user_can_logout(self):
+ self.request.user = self.user
+ self.assertFalse(self.request.user.is_anonymous())
+ logout(self.request)
+ self.assertTrue(self.request.user.is_anonymous())
+
+
+class TestAuthSetter(TestCase):
+
+ def test_auth_can_be_set(self):
+ request = Request(factory.get('/'))
+ request.auth = 'DUMMY'
+ self.assertEqual(request.auth, 'DUMMY')
diff --git a/tests/test_response.py b/tests/test_response.py
new file mode 100644
index 00000000..41c0f49d
--- /dev/null
+++ b/tests/test_response.py
@@ -0,0 +1,278 @@
+from __future__ import unicode_literals
+from django.test import TestCase
+from tests.models import BasicModel, BasicModelSerializer
+from rest_framework.compat import patterns, url, include
+from rest_framework.response import Response
+from rest_framework.views import APIView
+from rest_framework import generics
+from rest_framework import routers
+from rest_framework import status
+from rest_framework.renderers import (
+ BaseRenderer,
+ JSONRenderer,
+ BrowsableAPIRenderer
+)
+from rest_framework import viewsets
+from rest_framework.settings import api_settings
+from rest_framework.compat import six
+
+
+class MockPickleRenderer(BaseRenderer):
+ media_type = 'application/pickle'
+
+
+class MockJsonRenderer(BaseRenderer):
+ media_type = 'application/json'
+
+
+class MockTextMediaRenderer(BaseRenderer):
+ media_type = 'text/html'
+
+DUMMYSTATUS = status.HTTP_200_OK
+DUMMYCONTENT = 'dummycontent'
+
+RENDERER_A_SERIALIZER = lambda x: ('Renderer A: %s' % x).encode('ascii')
+RENDERER_B_SERIALIZER = lambda x: ('Renderer B: %s' % x).encode('ascii')
+
+
+class RendererA(BaseRenderer):
+ media_type = 'mock/renderera'
+ format = "formata"
+
+ def render(self, data, media_type=None, renderer_context=None):
+ return RENDERER_A_SERIALIZER(data)
+
+
+class RendererB(BaseRenderer):
+ media_type = 'mock/rendererb'
+ format = "formatb"
+
+ def render(self, data, media_type=None, renderer_context=None):
+ return RENDERER_B_SERIALIZER(data)
+
+
+class RendererC(RendererB):
+ media_type = 'mock/rendererc'
+ format = 'formatc'
+ charset = "rendererc"
+
+
+class MockView(APIView):
+ renderer_classes = (RendererA, RendererB, RendererC)
+
+ def get(self, request, **kwargs):
+ return Response(DUMMYCONTENT, status=DUMMYSTATUS)
+
+
+class MockViewSettingContentType(APIView):
+ renderer_classes = (RendererA, RendererB, RendererC)
+
+ def get(self, request, **kwargs):
+ return Response(DUMMYCONTENT, status=DUMMYSTATUS, content_type='setbyview')
+
+
+class HTMLView(APIView):
+ renderer_classes = (BrowsableAPIRenderer, )
+
+ def get(self, request, **kwargs):
+ return Response('text')
+
+
+class HTMLView1(APIView):
+ renderer_classes = (BrowsableAPIRenderer, JSONRenderer)
+
+ def get(self, request, **kwargs):
+ return Response('text')
+
+
+class HTMLNewModelViewSet(viewsets.ModelViewSet):
+ model = BasicModel
+
+
+class HTMLNewModelView(generics.ListCreateAPIView):
+ renderer_classes = (BrowsableAPIRenderer,)
+ permission_classes = []
+ serializer_class = BasicModelSerializer
+ model = BasicModel
+
+
+new_model_viewset_router = routers.DefaultRouter()
+new_model_viewset_router.register(r'', HTMLNewModelViewSet)
+
+
+urlpatterns = patterns('',
+ url(r'^setbyview$', MockViewSettingContentType.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
+ url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
+ url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
+ url(r'^html$', HTMLView.as_view()),
+ url(r'^html1$', HTMLView1.as_view()),
+ url(r'^html_new_model$', HTMLNewModelView.as_view()),
+ url(r'^html_new_model_viewset', include(new_model_viewset_router.urls)),
+ url(r'^restframework', include('rest_framework.urls', namespace='rest_framework'))
+)
+
+
+# TODO: Clean tests bellow - remove duplicates with above, better unit testing, ...
+class RendererIntegrationTests(TestCase):
+ """
+ End-to-end testing of renderers using an ResponseMixin on a generic view.
+ """
+
+ urls = 'tests.test_response'
+
+ def test_default_renderer_serializes_content(self):
+ """If the Accept header is not set the default renderer should serialize the response."""
+ resp = self.client.get('/')
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_head_method_serializes_no_content(self):
+ """No response must be included in HEAD requests."""
+ resp = self.client.head('/')
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, six.b(''))
+
+ def test_default_renderer_serializes_content_on_accept_any(self):
+ """If the Accept header is set to */* the default renderer should serialize the response."""
+ resp = self.client.get('/', HTTP_ACCEPT='*/*')
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_default_case(self):
+ """If the Accept header is set the specified renderer should serialize the response.
+ (In this case we check that works for the default renderer)"""
+ resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_non_default_case(self):
+ """If the Accept header is set the specified renderer should serialize the response.
+ (In this case we check that works for a non-default renderer)"""
+ resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_on_accept_query(self):
+ """The '_accept' query string should behave in the same way as the Accept header."""
+ param = '?%s=%s' % (
+ api_settings.URL_ACCEPT_OVERRIDE,
+ RendererB.media_type
+ )
+ resp = self.client.get('/' + param)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_on_format_query(self):
+ """If a 'format' query is specified, the renderer with the matching
+ format attribute should serialize the response."""
+ resp = self.client.get('/?format=%s' % RendererB.format)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_on_format_kwargs(self):
+ """If a 'format' keyword arg is specified, the renderer with the matching
+ format attribute should serialize the response."""
+ resp = self.client.get('/something.formatb')
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
+ """If both a 'format' query and a matching Accept header specified,
+ the renderer with the matching format attribute should serialize the response."""
+ resp = self.client.get('/?format=%s' % RendererB.format,
+ HTTP_ACCEPT=RendererB.media_type)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+
+class Issue122Tests(TestCase):
+ """
+ Tests that covers #122.
+ """
+ urls = 'tests.test_response'
+
+ def test_only_html_renderer(self):
+ """
+ Test if no infinite recursion occurs.
+ """
+ self.client.get('/html')
+
+ def test_html_renderer_is_first(self):
+ """
+ Test if no infinite recursion occurs.
+ """
+ self.client.get('/html1')
+
+
+class Issue467Tests(TestCase):
+ """
+ Tests for #467
+ """
+
+ urls = 'tests.test_response'
+
+ def test_form_has_label_and_help_text(self):
+ resp = self.client.get('/html_new_model')
+ self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
+ self.assertContains(resp, 'Text comes here')
+ self.assertContains(resp, 'Text description.')
+
+
+class Issue807Tests(TestCase):
+ """
+ Covers #807
+ """
+
+ urls = 'tests.test_response'
+
+ def test_does_not_append_charset_by_default(self):
+ """
+ Renderers don't include a charset unless set explicitly.
+ """
+ headers = {"HTTP_ACCEPT": RendererA.media_type}
+ resp = self.client.get('/', **headers)
+ expected = "{0}; charset={1}".format(RendererA.media_type, 'utf-8')
+ self.assertEqual(expected, resp['Content-Type'])
+
+ def test_if_there_is_charset_specified_on_renderer_it_gets_appended(self):
+ """
+ If renderer class has charset attribute declared, it gets appended
+ to Response's Content-Type
+ """
+ headers = {"HTTP_ACCEPT": RendererC.media_type}
+ resp = self.client.get('/', **headers)
+ expected = "{0}; charset={1}".format(RendererC.media_type, RendererC.charset)
+ self.assertEqual(expected, resp['Content-Type'])
+
+ def test_content_type_set_explictly_on_response(self):
+ """
+ The content type may be set explictly on the response.
+ """
+ headers = {"HTTP_ACCEPT": RendererC.media_type}
+ resp = self.client.get('/setbyview', **headers)
+ self.assertEqual('setbyview', resp['Content-Type'])
+
+ def test_viewset_label_help_text(self):
+ param = '?%s=%s' % (
+ api_settings.URL_ACCEPT_OVERRIDE,
+ 'text/html'
+ )
+ resp = self.client.get('/html_new_model_viewset/' + param)
+ self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
+ self.assertContains(resp, 'Text comes here')
+ self.assertContains(resp, 'Text description.')
+
+ def test_form_has_label_and_help_text(self):
+ resp = self.client.get('/html_new_model')
+ self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
+ self.assertContains(resp, 'Text comes here')
+ self.assertContains(resp, 'Text description.')
diff --git a/tests/test_reverse.py b/tests/test_reverse.py
new file mode 100644
index 00000000..3d14a28f
--- /dev/null
+++ b/tests/test_reverse.py
@@ -0,0 +1,27 @@
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework.compat import patterns, url
+from rest_framework.reverse import reverse
+from rest_framework.test import APIRequestFactory
+
+factory = APIRequestFactory()
+
+
+def null_view(request):
+ pass
+
+urlpatterns = patterns('',
+ url(r'^view$', null_view, name='view'),
+)
+
+
+class ReverseTests(TestCase):
+ """
+ Tests for fully qualified URLs when using `reverse`.
+ """
+ urls = 'tests.test_reverse'
+
+ def test_reversed_urls_are_fully_qualified(self):
+ request = factory.get('/view')
+ url = reverse('view', request=request)
+ self.assertEqual(url, 'http://testserver/view')
diff --git a/tests/test_routers.py b/tests/test_routers.py
new file mode 100644
index 00000000..084c0e27
--- /dev/null
+++ b/tests/test_routers.py
@@ -0,0 +1,216 @@
+from __future__ import unicode_literals
+from django.db import models
+from django.test import TestCase
+from django.core.exceptions import ImproperlyConfigured
+from rest_framework import serializers, viewsets, permissions
+from rest_framework.compat import include, patterns, url
+from rest_framework.decorators import link, action
+from rest_framework.response import Response
+from rest_framework.routers import SimpleRouter, DefaultRouter
+from rest_framework.test import APIRequestFactory
+
+factory = APIRequestFactory()
+
+urlpatterns = patterns('',)
+
+
+class BasicViewSet(viewsets.ViewSet):
+ def list(self, request, *args, **kwargs):
+ return Response({'method': 'list'})
+
+ @action()
+ def action1(self, request, *args, **kwargs):
+ return Response({'method': 'action1'})
+
+ @action()
+ def action2(self, request, *args, **kwargs):
+ return Response({'method': 'action2'})
+
+ @action(methods=['post', 'delete'])
+ def action3(self, request, *args, **kwargs):
+ return Response({'method': 'action2'})
+
+ @link()
+ def link1(self, request, *args, **kwargs):
+ return Response({'method': 'link1'})
+
+ @link()
+ def link2(self, request, *args, **kwargs):
+ return Response({'method': 'link2'})
+
+
+class TestSimpleRouter(TestCase):
+ def setUp(self):
+ self.router = SimpleRouter()
+
+ def test_link_and_action_decorator(self):
+ routes = self.router.get_routes(BasicViewSet)
+ decorator_routes = routes[2:]
+ # Make sure all these endpoints exist and none have been clobbered
+ for i, endpoint in enumerate(['action1', 'action2', 'action3', 'link1', 'link2']):
+ route = decorator_routes[i]
+ # check url listing
+ self.assertEqual(route.url,
+ '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(endpoint))
+ # check method to function mapping
+ if endpoint == 'action3':
+ methods_map = ['post', 'delete']
+ elif endpoint.startswith('action'):
+ methods_map = ['post']
+ else:
+ methods_map = ['get']
+ for method in methods_map:
+ self.assertEqual(route.mapping[method], endpoint)
+
+
+class RouterTestModel(models.Model):
+ uuid = models.CharField(max_length=20)
+ text = models.CharField(max_length=200)
+
+
+class TestCustomLookupFields(TestCase):
+ """
+ Ensure that custom lookup fields are correctly routed.
+ """
+ urls = 'tests.test_routers'
+
+ def setUp(self):
+ class NoteSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = RouterTestModel
+ lookup_field = 'uuid'
+ fields = ('url', 'uuid', 'text')
+
+ class NoteViewSet(viewsets.ModelViewSet):
+ queryset = RouterTestModel.objects.all()
+ serializer_class = NoteSerializer
+ lookup_field = 'uuid'
+
+ RouterTestModel.objects.create(uuid='123', text='foo bar')
+
+ self.router = SimpleRouter()
+ self.router.register(r'notes', NoteViewSet)
+
+ from tests import test_routers
+ urls = getattr(test_routers, 'urlpatterns')
+ urls += patterns('',
+ url(r'^', include(self.router.urls)),
+ )
+
+ def test_custom_lookup_field_route(self):
+ detail_route = self.router.urls[-1]
+ detail_url_pattern = detail_route.regex.pattern
+ self.assertIn('<uuid>', detail_url_pattern)
+
+ def test_retrieve_lookup_field_list_view(self):
+ response = self.client.get('/notes/')
+ self.assertEqual(response.data,
+ [{
+ "url": "http://testserver/notes/123/",
+ "uuid": "123", "text": "foo bar"
+ }]
+ )
+
+ def test_retrieve_lookup_field_detail_view(self):
+ response = self.client.get('/notes/123/')
+ self.assertEqual(response.data,
+ {
+ "url": "http://testserver/notes/123/",
+ "uuid": "123", "text": "foo bar"
+ }
+ )
+
+
+class TestTrailingSlashIncluded(TestCase):
+ def setUp(self):
+ class NoteViewSet(viewsets.ModelViewSet):
+ model = RouterTestModel
+
+ self.router = SimpleRouter()
+ self.router.register(r'notes', NoteViewSet)
+ self.urls = self.router.urls
+
+ def test_urls_have_trailing_slash_by_default(self):
+ expected = ['^notes/$', '^notes/(?P<pk>[^/]+)/$']
+ for idx in range(len(expected)):
+ self.assertEqual(expected[idx], self.urls[idx].regex.pattern)
+
+
+class TestTrailingSlashRemoved(TestCase):
+ def setUp(self):
+ class NoteViewSet(viewsets.ModelViewSet):
+ model = RouterTestModel
+
+ self.router = SimpleRouter(trailing_slash=False)
+ self.router.register(r'notes', NoteViewSet)
+ self.urls = self.router.urls
+
+ def test_urls_can_have_trailing_slash_removed(self):
+ expected = ['^notes$', '^notes/(?P<pk>[^/.]+)$']
+ for idx in range(len(expected)):
+ self.assertEqual(expected[idx], self.urls[idx].regex.pattern)
+
+
+class TestNameableRoot(TestCase):
+ def setUp(self):
+ class NoteViewSet(viewsets.ModelViewSet):
+ model = RouterTestModel
+ self.router = DefaultRouter()
+ self.router.root_view_name = 'nameable-root'
+ self.router.register(r'notes', NoteViewSet)
+ self.urls = self.router.urls
+
+ def test_router_has_custom_name(self):
+ expected = 'nameable-root'
+ self.assertEqual(expected, self.urls[0].name)
+
+
+class TestActionKeywordArgs(TestCase):
+ """
+ Ensure keyword arguments passed in the `@action` decorator
+ are properly handled. Refs #940.
+ """
+
+ def setUp(self):
+ class TestViewSet(viewsets.ModelViewSet):
+ permission_classes = []
+
+ @action(permission_classes=[permissions.AllowAny])
+ def custom(self, request, *args, **kwargs):
+ return Response({
+ 'permission_classes': self.permission_classes
+ })
+
+ self.router = SimpleRouter()
+ self.router.register(r'test', TestViewSet, base_name='test')
+ self.view = self.router.urls[-1].callback
+
+ def test_action_kwargs(self):
+ request = factory.post('/test/0/custom/')
+ response = self.view(request)
+ self.assertEqual(
+ response.data,
+ {'permission_classes': [permissions.AllowAny]}
+ )
+
+
+class TestActionAppliedToExistingRoute(TestCase):
+ """
+ Ensure `@action` decorator raises an except when applied
+ to an existing route
+ """
+
+ def test_exception_raised_when_action_applied_to_existing_route(self):
+ class TestViewSet(viewsets.ModelViewSet):
+
+ @action()
+ def retrieve(self, request, *args, **kwargs):
+ return Response({
+ 'hello': 'world'
+ })
+
+ self.router = SimpleRouter()
+ self.router.register(r'test', TestViewSet, base_name='test')
+
+ with self.assertRaises(ImproperlyConfigured):
+ self.router.urls
diff --git a/tests/test_serializer.py b/tests/test_serializer.py
new file mode 100644
index 00000000..18484afe
--- /dev/null
+++ b/tests/test_serializer.py
@@ -0,0 +1,1857 @@
+# -*- coding: utf-8 -*-
+from __future__ import unicode_literals
+from django.db import models
+from django.db.models.fields import BLANK_CHOICE_DASH
+from django.test import TestCase
+from django.utils.datastructures import MultiValueDict
+from django.utils.translation import ugettext_lazy as _
+from rest_framework import serializers, fields, relations
+from tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel,
+ BlankFieldModel, BlogPost, BlogPostComment, Book, CallableDefaultValueModel, DefaultValueModel,
+ ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo, RESTFrameworkModel)
+from tests.models import BasicModelSerializer
+import datetime
+import pickle
+
+
+class SubComment(object):
+ def __init__(self, sub_comment):
+ self.sub_comment = sub_comment
+
+
+class Comment(object):
+ def __init__(self, email, content, created):
+ self.email = email
+ self.content = content
+ self.created = created or datetime.datetime.now()
+
+ def __eq__(self, other):
+ return all([getattr(self, attr) == getattr(other, attr)
+ for attr in ('email', 'content', 'created')])
+
+ def get_sub_comment(self):
+ sub_comment = SubComment('And Merry Christmas!')
+ return sub_comment
+
+
+class CommentSerializer(serializers.Serializer):
+ email = serializers.EmailField()
+ content = serializers.CharField(max_length=1000)
+ created = serializers.DateTimeField()
+ sub_comment = serializers.Field(source='get_sub_comment.sub_comment')
+
+ def restore_object(self, data, instance=None):
+ if instance is None:
+ return Comment(**data)
+ for key, val in data.items():
+ setattr(instance, key, val)
+ return instance
+
+
+class NamesSerializer(serializers.Serializer):
+ first = serializers.CharField()
+ last = serializers.CharField(required=False, default='')
+ initials = serializers.CharField(required=False, default='')
+
+
+class PersonIdentifierSerializer(serializers.Serializer):
+ ssn = serializers.CharField()
+ names = NamesSerializer(source='names', required=False)
+
+
+class BookSerializer(serializers.ModelSerializer):
+ isbn = serializers.RegexField(regex=r'^[0-9]{13}$', error_messages={'invalid': 'isbn has to be exact 13 numbers'})
+
+ class Meta:
+ model = Book
+
+
+class ActionItemSerializer(serializers.ModelSerializer):
+
+ class Meta:
+ model = ActionItem
+
+class ActionItemSerializerOptionalFields(serializers.ModelSerializer):
+ """
+ Intended to test that fields with `required=False` are excluded from validation.
+ """
+ title = serializers.CharField(required=False)
+
+ class Meta:
+ model = ActionItem
+ fields = ('title',)
+
+class ActionItemSerializerCustomRestore(serializers.ModelSerializer):
+
+ class Meta:
+ model = ActionItem
+
+ def restore_object(self, data, instance=None):
+ if instance is None:
+ return ActionItem(**data)
+ for key, val in data.items():
+ setattr(instance, key, val)
+ return instance
+
+
+class PersonSerializer(serializers.ModelSerializer):
+ info = serializers.Field(source='info')
+
+ class Meta:
+ model = Person
+ fields = ('name', 'age', 'info')
+ read_only_fields = ('age',)
+
+
+class NestedSerializer(serializers.Serializer):
+ info = serializers.Field()
+
+
+class ModelSerializerWithNestedSerializer(serializers.ModelSerializer):
+ nested = NestedSerializer(source='*')
+
+ class Meta:
+ model = Person
+
+
+class NestedSerializerWithRenamedField(serializers.Serializer):
+ renamed_info = serializers.Field(source='info')
+
+
+class ModelSerializerWithNestedSerializerWithRenamedField(serializers.ModelSerializer):
+ nested = NestedSerializerWithRenamedField(source='*')
+
+ class Meta:
+ model = Person
+
+
+class PersonSerializerInvalidReadOnly(serializers.ModelSerializer):
+ """
+ Testing for #652.
+ """
+ info = serializers.Field(source='info')
+
+ class Meta:
+ model = Person
+ fields = ('name', 'age', 'info')
+ read_only_fields = ('age', 'info')
+
+
+class AlbumsSerializer(serializers.ModelSerializer):
+
+ class Meta:
+ model = Album
+ fields = ['title'] # lists are also valid options
+
+
+class PositiveIntegerAsChoiceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = HasPositiveIntegerAsChoice
+ fields = ['some_integer']
+
+
+class BasicTests(TestCase):
+ def setUp(self):
+ self.comment = Comment(
+ 'tom@example.com',
+ 'Happy new year!',
+ datetime.datetime(2012, 1, 1)
+ )
+ self.actionitem = ActionItem(title='Some to do item',)
+ self.data = {
+ 'email': 'tom@example.com',
+ 'content': 'Happy new year!',
+ 'created': datetime.datetime(2012, 1, 1),
+ 'sub_comment': 'This wont change'
+ }
+ self.expected = {
+ 'email': 'tom@example.com',
+ 'content': 'Happy new year!',
+ 'created': datetime.datetime(2012, 1, 1),
+ 'sub_comment': 'And Merry Christmas!'
+ }
+ self.person_data = {'name': 'dwight', 'age': 35}
+ self.person = Person(**self.person_data)
+ self.person.save()
+
+ def test_empty(self):
+ serializer = CommentSerializer()
+ expected = {
+ 'email': '',
+ 'content': '',
+ 'created': None
+ }
+ self.assertEqual(serializer.data, expected)
+
+ def test_retrieve(self):
+ serializer = CommentSerializer(self.comment)
+ self.assertEqual(serializer.data, self.expected)
+
+ def test_create(self):
+ serializer = CommentSerializer(data=self.data)
+ expected = self.comment
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, expected)
+ self.assertFalse(serializer.object is expected)
+ self.assertEqual(serializer.data['sub_comment'], 'And Merry Christmas!')
+
+ def test_create_nested(self):
+ """Test a serializer with nested data."""
+ names = {'first': 'John', 'last': 'Doe', 'initials': 'jd'}
+ data = {'ssn': '1234567890', 'names': names}
+ serializer = PersonIdentifierSerializer(data=data)
+
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, data)
+ self.assertFalse(serializer.object is data)
+ self.assertEqual(serializer.data['names'], names)
+
+ def test_create_partial_nested(self):
+ """Test a serializer with nested data which has missing fields."""
+ names = {'first': 'John'}
+ data = {'ssn': '1234567890', 'names': names}
+ serializer = PersonIdentifierSerializer(data=data)
+
+ expected_names = {'first': 'John', 'last': '', 'initials': ''}
+ data['names'] = expected_names
+
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, data)
+ self.assertFalse(serializer.object is expected_names)
+ self.assertEqual(serializer.data['names'], expected_names)
+
+ def test_null_nested(self):
+ """Test a serializer with a nonexistent nested field"""
+ data = {'ssn': '1234567890'}
+ serializer = PersonIdentifierSerializer(data=data)
+
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, data)
+ self.assertFalse(serializer.object is data)
+ expected = {'ssn': '1234567890', 'names': None}
+ self.assertEqual(serializer.data, expected)
+
+ def test_update(self):
+ serializer = CommentSerializer(self.comment, data=self.data)
+ expected = self.comment
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, expected)
+ self.assertTrue(serializer.object is expected)
+ self.assertEqual(serializer.data['sub_comment'], 'And Merry Christmas!')
+
+ def test_partial_update(self):
+ msg = 'Merry New Year!'
+ partial_data = {'content': msg}
+ serializer = CommentSerializer(self.comment, data=partial_data)
+ self.assertEqual(serializer.is_valid(), False)
+ serializer = CommentSerializer(self.comment, data=partial_data, partial=True)
+ expected = self.comment
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, expected)
+ self.assertTrue(serializer.object is expected)
+ self.assertEqual(serializer.data['content'], msg)
+
+ def test_model_fields_as_expected(self):
+ """
+ Make sure that the fields returned are the same as defined
+ in the Meta data
+ """
+ serializer = PersonSerializer(self.person)
+ self.assertEqual(set(serializer.data.keys()),
+ set(['name', 'age', 'info']))
+
+ def test_field_with_dictionary(self):
+ """
+ Make sure that dictionaries from fields are left intact
+ """
+ serializer = PersonSerializer(self.person)
+ expected = self.person_data
+ self.assertEqual(serializer.data['info'], expected)
+
+ def test_read_only_fields(self):
+ """
+ Attempting to update fields set as read_only should have no effect.
+ """
+ serializer = PersonSerializer(self.person, data={'name': 'dwight', 'age': 99})
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEqual(serializer.errors, {})
+ # Assert age is unchanged (35)
+ self.assertEqual(instance.age, self.person_data['age'])
+
+ def test_invalid_read_only_fields(self):
+ """
+ Regression test for #652.
+ """
+ self.assertRaises(AssertionError, PersonSerializerInvalidReadOnly, [])
+
+ def test_serializer_data_is_cleared_on_save(self):
+ """
+ Check _data attribute is cleared on `save()`
+
+ Regression test for #1116
+ — id field is not populated if `data` is accessed prior to `save()`
+ """
+ serializer = ActionItemSerializer(self.actionitem)
+ self.assertIsNone(serializer.data.get('id',None), 'New instance. `id` should not be set.')
+ serializer.save()
+ self.assertIsNotNone(serializer.data.get('id',None), 'Model is saved. `id` should be set.')
+
+ def test_fields_marked_as_not_required_are_excluded_from_validation(self):
+ """
+ Check that fields with `required=False` are included in list of exclusions.
+ """
+ serializer = ActionItemSerializerOptionalFields(self.actionitem)
+ exclusions = serializer.get_validation_exclusions()
+ self.assertTrue('title' in exclusions, '`title` field was marked `required=False` and should be excluded')
+
+
+class DictStyleSerializer(serializers.Serializer):
+ """
+ Note that we don't have any `restore_object` method, so the default
+ case of simply returning a dict will apply.
+ """
+ email = serializers.EmailField()
+
+
+class DictStyleSerializerTests(TestCase):
+ def test_dict_style_deserialize(self):
+ """
+ Ensure serializers can deserialize into a dict.
+ """
+ data = {'email': 'foo@example.com'}
+ serializer = DictStyleSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, data)
+
+ def test_dict_style_serialize(self):
+ """
+ Ensure serializers can serialize dict objects.
+ """
+ data = {'email': 'foo@example.com'}
+ serializer = DictStyleSerializer(data)
+ self.assertEqual(serializer.data, data)
+
+
+class ValidationTests(TestCase):
+ def setUp(self):
+ self.comment = Comment(
+ 'tom@example.com',
+ 'Happy new year!',
+ datetime.datetime(2012, 1, 1)
+ )
+ self.data = {
+ 'email': 'tom@example.com',
+ 'content': 'x' * 1001,
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+ self.actionitem = ActionItem(title='Some to do item',)
+
+ def test_create(self):
+ serializer = CommentSerializer(data=self.data)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, {'content': ['Ensure this value has at most 1000 characters (it has 1001).']})
+
+ def test_update(self):
+ serializer = CommentSerializer(self.comment, data=self.data)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, {'content': ['Ensure this value has at most 1000 characters (it has 1001).']})
+
+ def test_update_missing_field(self):
+ data = {
+ 'content': 'xxx',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+ serializer = CommentSerializer(self.comment, data=data)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, {'email': ['This field is required.']})
+
+ def test_missing_bool_with_default(self):
+ """Make sure that a boolean value with a 'False' value is not
+ mistaken for not having a default."""
+ data = {
+ 'title': 'Some action item',
+ #No 'done' value.
+ }
+ serializer = ActionItemSerializer(self.actionitem, data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.errors, {})
+
+ def test_cross_field_validation(self):
+
+ class CommentSerializerWithCrossFieldValidator(CommentSerializer):
+
+ def validate(self, attrs):
+ if attrs["email"] not in attrs["content"]:
+ raise serializers.ValidationError("Email address not in content")
+ return attrs
+
+ data = {
+ 'email': 'tom@example.com',
+ 'content': 'A comment from tom@example.com',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+
+ serializer = CommentSerializerWithCrossFieldValidator(data=data)
+ self.assertTrue(serializer.is_valid())
+
+ data['content'] = 'A comment from foo@bar.com'
+
+ serializer = CommentSerializerWithCrossFieldValidator(data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'non_field_errors': ['Email address not in content']})
+
+ def test_null_is_true_fields(self):
+ """
+ Omitting a value for null-field should validate.
+ """
+ serializer = PersonSerializer(data={'name': 'marko'})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.errors, {})
+
+ def test_modelserializer_max_length_exceeded(self):
+ data = {
+ 'title': 'x' * 201,
+ }
+ serializer = ActionItemSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, {'title': ['Ensure this value has at most 200 characters (it has 201).']})
+
+ def test_modelserializer_max_length_exceeded_with_custom_restore(self):
+ """
+ When overriding ModelSerializer.restore_object, validation tests should still apply.
+ Regression test for #623.
+
+ https://github.com/tomchristie/django-rest-framework/pull/623
+ """
+ data = {
+ 'title': 'x' * 201,
+ }
+ serializer = ActionItemSerializerCustomRestore(data=data)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, {'title': ['Ensure this value has at most 200 characters (it has 201).']})
+
+ def test_default_modelfield_max_length_exceeded(self):
+ data = {
+ 'title': 'Testing "info" field...',
+ 'info': 'x' * 13,
+ }
+ serializer = ActionItemSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, {'info': ['Ensure this value has at most 12 characters (it has 13).']})
+
+ def test_datetime_validation_failure(self):
+ """
+ Test DateTimeField validation errors on non-str values.
+ Regression test for #669.
+
+ https://github.com/tomchristie/django-rest-framework/issues/669
+ """
+ data = self.data
+ data['created'] = 0
+
+ serializer = CommentSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), False)
+
+ self.assertIn('created', serializer.errors)
+
+ def test_missing_model_field_exception_msg(self):
+ """
+ Assert that a meaningful exception message is outputted when the model
+ field is missing (e.g. when mistyping ``model``).
+ """
+ class BrokenModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ fields = ['some_field']
+
+ try:
+ BrokenModelSerializer()
+ except AssertionError as e:
+ self.assertEqual(e.args[0], "Serializer class 'BrokenModelSerializer' is missing 'model' Meta option")
+ except:
+ self.fail('Wrong exception type thrown.')
+
+ def test_writable_star_source_on_nested_serializer(self):
+ """
+ Assert that a nested serializer instantiated with source='*' correctly
+ expands the data into the outer serializer.
+ """
+ serializer = ModelSerializerWithNestedSerializer(data={
+ 'name': 'marko',
+ 'nested': {'info': 'hi'}},
+ )
+ self.assertEqual(serializer.is_valid(), True)
+
+ def test_writable_star_source_with_inner_source_fields(self):
+ """
+ Tests that a serializer with source="*" correctly expands the
+ it's fields into the outer serializer even if they have their
+ own 'source' parameters.
+ """
+
+ serializer = ModelSerializerWithNestedSerializerWithRenamedField(data={
+ 'name': 'marko',
+ 'nested': {'renamed_info': 'hi'}},
+ )
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.errors, {})
+
+
+class CustomValidationTests(TestCase):
+ class CommentSerializerWithFieldValidator(CommentSerializer):
+
+ def validate_email(self, attrs, source):
+ attrs[source]
+ return attrs
+
+ def validate_content(self, attrs, source):
+ value = attrs[source]
+ if "test" not in value:
+ raise serializers.ValidationError("Test not in value")
+ return attrs
+
+ def test_field_validation(self):
+ data = {
+ 'email': 'tom@example.com',
+ 'content': 'A test comment',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+
+ serializer = self.CommentSerializerWithFieldValidator(data=data)
+ self.assertTrue(serializer.is_valid())
+
+ data['content'] = 'This should not validate'
+
+ serializer = self.CommentSerializerWithFieldValidator(data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'content': ['Test not in value']})
+
+ def test_missing_data(self):
+ """
+ Make sure that validate_content isn't called if the field is missing
+ """
+ incomplete_data = {
+ 'email': 'tom@example.com',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+ serializer = self.CommentSerializerWithFieldValidator(data=incomplete_data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'content': ['This field is required.']})
+
+ def test_wrong_data(self):
+ """
+ Make sure that validate_content isn't called if the field input is wrong
+ """
+ wrong_data = {
+ 'email': 'not an email',
+ 'content': 'A test comment',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+ serializer = self.CommentSerializerWithFieldValidator(data=wrong_data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'email': ['Enter a valid email address.']})
+
+ def test_partial_update(self):
+ """
+ Make sure that validate_email isn't called when partial=True and email
+ isn't found in data.
+ """
+ initial_data = {
+ 'email': 'tom@example.com',
+ 'content': 'A test comment',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+
+ serializer = self.CommentSerializerWithFieldValidator(data=initial_data)
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.object
+
+ new_content = 'An *updated* test comment'
+ partial_data = {
+ 'content': new_content
+ }
+
+ serializer = self.CommentSerializerWithFieldValidator(instance=instance,
+ data=partial_data,
+ partial=True)
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.object
+ self.assertEqual(instance.content, new_content)
+
+
+class PositiveIntegerAsChoiceTests(TestCase):
+ def test_positive_integer_in_json_is_correctly_parsed(self):
+ data = {'some_integer': 1}
+ serializer = PositiveIntegerAsChoiceSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), True)
+
+
+class ModelValidationTests(TestCase):
+ def test_validate_unique(self):
+ """
+ Just check if serializers.ModelSerializer handles unique checks via .full_clean()
+ """
+ serializer = AlbumsSerializer(data={'title': 'a'})
+ serializer.is_valid()
+ serializer.save()
+ second_serializer = AlbumsSerializer(data={'title': 'a'})
+ self.assertFalse(second_serializer.is_valid())
+ self.assertEqual(second_serializer.errors, {'title': ['Album with this Title already exists.']})
+
+ def test_foreign_key_is_null_with_partial(self):
+ """
+ Test ModelSerializer validation with partial=True
+
+ Specifically test that a null foreign key does not pass validation
+ """
+ album = Album(title='test')
+ album.save()
+
+ class PhotoSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Photo
+
+ photo_serializer = PhotoSerializer(data={'description': 'test', 'album': album.pk})
+ self.assertTrue(photo_serializer.is_valid())
+ photo = photo_serializer.save()
+
+ # Updating only the album (foreign key)
+ photo_serializer = PhotoSerializer(instance=photo, data={'album': ''}, partial=True)
+ self.assertFalse(photo_serializer.is_valid())
+ self.assertTrue('album' in photo_serializer.errors)
+ self.assertEqual(photo_serializer.errors['album'], photo_serializer.error_messages['required'])
+
+ def test_foreign_key_with_partial(self):
+ """
+ Test ModelSerializer validation with partial=True
+
+ Specifically test foreign key validation.
+ """
+
+ album = Album(title='test')
+ album.save()
+
+ class PhotoSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Photo
+
+ photo_serializer = PhotoSerializer(data={'description': 'test', 'album': album.pk})
+ self.assertTrue(photo_serializer.is_valid())
+ photo = photo_serializer.save()
+
+ # Updating only the album (foreign key)
+ photo_serializer = PhotoSerializer(instance=photo, data={'album': album.pk}, partial=True)
+ self.assertTrue(photo_serializer.is_valid())
+ self.assertTrue(photo_serializer.save())
+
+ # Updating only the description
+ photo_serializer = PhotoSerializer(instance=photo,
+ data={'description': 'new'},
+ partial=True)
+
+ self.assertTrue(photo_serializer.is_valid())
+ self.assertTrue(photo_serializer.save())
+
+
+class RegexValidationTest(TestCase):
+ def test_create_failed(self):
+ serializer = BookSerializer(data={'isbn': '1234567890'})
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']})
+
+ serializer = BookSerializer(data={'isbn': '12345678901234'})
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']})
+
+ serializer = BookSerializer(data={'isbn': 'abcdefghijklm'})
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']})
+
+ def test_create_success(self):
+ serializer = BookSerializer(data={'isbn': '1234567890123'})
+ self.assertTrue(serializer.is_valid())
+
+
+class MetadataTests(TestCase):
+ def test_empty(self):
+ serializer = CommentSerializer()
+ expected = {
+ 'email': serializers.CharField,
+ 'content': serializers.CharField,
+ 'created': serializers.DateTimeField
+ }
+ for field_name, field in expected.items():
+ self.assertTrue(isinstance(serializer.data.fields[field_name], field))
+
+
+class ManyToManyTests(TestCase):
+ def setUp(self):
+ class ManyToManySerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ManyToManyModel
+
+ self.serializer_class = ManyToManySerializer
+
+ # An anchor instance to use for the relationship
+ self.anchor = Anchor()
+ self.anchor.save()
+
+ # A model instance with a many to many relationship to the anchor
+ self.instance = ManyToManyModel()
+ self.instance.save()
+ self.instance.rel.add(self.anchor)
+
+ # A serialized representation of the model instance
+ self.data = {'id': 1, 'rel': [self.anchor.id]}
+
+ def test_retrieve(self):
+ """
+ Serialize an instance of a model with a ManyToMany relationship.
+ """
+ serializer = self.serializer_class(instance=self.instance)
+ expected = self.data
+ self.assertEqual(serializer.data, expected)
+
+ def test_create(self):
+ """
+ Create an instance of a model with a ManyToMany relationship.
+ """
+ data = {'rel': [self.anchor.id]}
+ serializer = self.serializer_class(data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEqual(len(ManyToManyModel.objects.all()), 2)
+ self.assertEqual(instance.pk, 2)
+ self.assertEqual(list(instance.rel.all()), [self.anchor])
+
+ def test_update(self):
+ """
+ Update an instance of a model with a ManyToMany relationship.
+ """
+ new_anchor = Anchor()
+ new_anchor.save()
+ data = {'rel': [self.anchor.id, new_anchor.id]}
+ serializer = self.serializer_class(self.instance, data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEqual(len(ManyToManyModel.objects.all()), 1)
+ self.assertEqual(instance.pk, 1)
+ self.assertEqual(list(instance.rel.all()), [self.anchor, new_anchor])
+
+ def test_create_empty_relationship(self):
+ """
+ Create an instance of a model with a ManyToMany relationship,
+ containing no items.
+ """
+ data = {'rel': []}
+ serializer = self.serializer_class(data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEqual(len(ManyToManyModel.objects.all()), 2)
+ self.assertEqual(instance.pk, 2)
+ self.assertEqual(list(instance.rel.all()), [])
+
+ def test_update_empty_relationship(self):
+ """
+ Update an instance of a model with a ManyToMany relationship,
+ containing no items.
+ """
+ new_anchor = Anchor()
+ new_anchor.save()
+ data = {'rel': []}
+ serializer = self.serializer_class(self.instance, data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEqual(len(ManyToManyModel.objects.all()), 1)
+ self.assertEqual(instance.pk, 1)
+ self.assertEqual(list(instance.rel.all()), [])
+
+ def test_create_empty_relationship_flat_data(self):
+ """
+ Create an instance of a model with a ManyToMany relationship,
+ containing no items, using a representation that does not support
+ lists (eg form data).
+ """
+ data = MultiValueDict()
+ data.setlist('rel', [''])
+ serializer = self.serializer_class(data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEqual(len(ManyToManyModel.objects.all()), 2)
+ self.assertEqual(instance.pk, 2)
+ self.assertEqual(list(instance.rel.all()), [])
+
+
+class ReadOnlyManyToManyTests(TestCase):
+ def setUp(self):
+ class ReadOnlyManyToManySerializer(serializers.ModelSerializer):
+ rel = serializers.RelatedField(many=True, read_only=True)
+
+ class Meta:
+ model = ReadOnlyManyToManyModel
+
+ self.serializer_class = ReadOnlyManyToManySerializer
+
+ # An anchor instance to use for the relationship
+ self.anchor = Anchor()
+ self.anchor.save()
+
+ # A model instance with a many to many relationship to the anchor
+ self.instance = ReadOnlyManyToManyModel()
+ self.instance.save()
+ self.instance.rel.add(self.anchor)
+
+ # A serialized representation of the model instance
+ self.data = {'rel': [self.anchor.id], 'id': 1, 'text': 'anchor'}
+
+ def test_update(self):
+ """
+ Attempt to update an instance of a model with a ManyToMany
+ relationship. Not updated due to read_only=True
+ """
+ new_anchor = Anchor()
+ new_anchor.save()
+ data = {'rel': [self.anchor.id, new_anchor.id]}
+ serializer = self.serializer_class(self.instance, data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEqual(len(ReadOnlyManyToManyModel.objects.all()), 1)
+ self.assertEqual(instance.pk, 1)
+ # rel is still as original (1 entry)
+ self.assertEqual(list(instance.rel.all()), [self.anchor])
+
+ def test_update_without_relationship(self):
+ """
+ Attempt to update an instance of a model where many to ManyToMany
+ relationship is not supplied. Not updated due to read_only=True
+ """
+ new_anchor = Anchor()
+ new_anchor.save()
+ data = {}
+ serializer = self.serializer_class(self.instance, data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEqual(len(ReadOnlyManyToManyModel.objects.all()), 1)
+ self.assertEqual(instance.pk, 1)
+ # rel is still as original (1 entry)
+ self.assertEqual(list(instance.rel.all()), [self.anchor])
+
+
+class DefaultValueTests(TestCase):
+ def setUp(self):
+ class DefaultValueSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = DefaultValueModel
+
+ self.serializer_class = DefaultValueSerializer
+ self.objects = DefaultValueModel.objects
+
+ def test_create_using_default(self):
+ data = {}
+ serializer = self.serializer_class(data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEqual(len(self.objects.all()), 1)
+ self.assertEqual(instance.pk, 1)
+ self.assertEqual(instance.text, 'foobar')
+
+ def test_create_overriding_default(self):
+ data = {'text': 'overridden'}
+ serializer = self.serializer_class(data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEqual(len(self.objects.all()), 1)
+ self.assertEqual(instance.pk, 1)
+ self.assertEqual(instance.text, 'overridden')
+
+ def test_partial_update_default(self):
+ """ Regression test for issue #532 """
+ data = {'text': 'overridden'}
+ serializer = self.serializer_class(data=data, partial=True)
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.save()
+
+ data = {'extra': 'extra_value'}
+ serializer = self.serializer_class(instance=instance, data=data, partial=True)
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.save()
+
+ self.assertEqual(instance.extra, 'extra_value')
+ self.assertEqual(instance.text, 'overridden')
+
+
+class CallableDefaultValueTests(TestCase):
+ def setUp(self):
+ class CallableDefaultValueSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = CallableDefaultValueModel
+
+ self.serializer_class = CallableDefaultValueSerializer
+ self.objects = CallableDefaultValueModel.objects
+
+ def test_create_using_default(self):
+ data = {}
+ serializer = self.serializer_class(data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEqual(len(self.objects.all()), 1)
+ self.assertEqual(instance.pk, 1)
+ self.assertEqual(instance.text, 'foobar')
+
+ def test_create_overriding_default(self):
+ data = {'text': 'overridden'}
+ serializer = self.serializer_class(data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEqual(len(self.objects.all()), 1)
+ self.assertEqual(instance.pk, 1)
+ self.assertEqual(instance.text, 'overridden')
+
+
+class ManyRelatedTests(TestCase):
+ def test_reverse_relations(self):
+ post = BlogPost.objects.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I hate this blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ class BlogPostCommentSerializer(serializers.Serializer):
+ text = serializers.CharField()
+
+ class BlogPostSerializer(serializers.Serializer):
+ title = serializers.CharField()
+ comments = BlogPostCommentSerializer(source='blogpostcomment_set')
+
+ serializer = BlogPostSerializer(instance=post)
+ expected = {
+ 'title': 'Test blog post',
+ 'comments': [
+ {'text': 'I hate this blog post'},
+ {'text': 'I love this blog post'}
+ ]
+ }
+
+ self.assertEqual(serializer.data, expected)
+
+ def test_include_reverse_relations(self):
+ post = BlogPost.objects.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I hate this blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ class BlogPostSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlogPost
+ fields = ('id', 'title', 'blogpostcomment_set')
+
+ serializer = BlogPostSerializer(instance=post)
+ expected = {
+ 'id': 1, 'title': 'Test blog post', 'blogpostcomment_set': [1, 2]
+ }
+ self.assertEqual(serializer.data, expected)
+
+ def test_depth_include_reverse_relations(self):
+ post = BlogPost.objects.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I hate this blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ class BlogPostSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlogPost
+ fields = ('id', 'title', 'blogpostcomment_set')
+ depth = 1
+
+ serializer = BlogPostSerializer(instance=post)
+ expected = {
+ 'id': 1, 'title': 'Test blog post',
+ 'blogpostcomment_set': [
+ {'id': 1, 'text': 'I hate this blog post', 'blog_post': 1},
+ {'id': 2, 'text': 'I love this blog post', 'blog_post': 1}
+ ]
+ }
+ self.assertEqual(serializer.data, expected)
+
+ def test_callable_source(self):
+ post = BlogPost.objects.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ class BlogPostCommentSerializer(serializers.Serializer):
+ text = serializers.CharField()
+
+ class BlogPostSerializer(serializers.Serializer):
+ title = serializers.CharField()
+ first_comment = BlogPostCommentSerializer(source='get_first_comment')
+
+ serializer = BlogPostSerializer(post)
+
+ expected = {
+ 'title': 'Test blog post',
+ 'first_comment': {'text': 'I love this blog post'}
+ }
+ self.assertEqual(serializer.data, expected)
+
+
+class RelatedTraversalTest(TestCase):
+ def test_nested_traversal(self):
+ """
+ Source argument should support dotted.source notation.
+ """
+ user = Person.objects.create(name="django")
+ post = BlogPost.objects.create(title="Test blog post", writer=user)
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ class PersonSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Person
+ fields = ("name", "age")
+
+ class BlogPostCommentSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlogPostComment
+ fields = ("text", "post_owner")
+
+ text = serializers.CharField()
+ post_owner = PersonSerializer(source='blog_post.writer')
+
+ class BlogPostSerializer(serializers.Serializer):
+ title = serializers.CharField()
+ comments = BlogPostCommentSerializer(source='blogpostcomment_set')
+
+ serializer = BlogPostSerializer(instance=post)
+
+ expected = {
+ 'title': 'Test blog post',
+ 'comments': [{
+ 'text': 'I love this blog post',
+ 'post_owner': {
+ "name": "django",
+ "age": None
+ }
+ }]
+ }
+
+ self.assertEqual(serializer.data, expected)
+
+ def test_nested_traversal_with_none(self):
+ """
+ If a component of the dotted.source is None, return None for the field.
+ """
+ from tests.models import NullableForeignKeySource
+ instance = NullableForeignKeySource.objects.create(name='Source with null FK')
+
+ class NullableSourceSerializer(serializers.Serializer):
+ target_name = serializers.Field(source='target.name')
+
+ serializer = NullableSourceSerializer(instance=instance)
+
+ expected = {
+ 'target_name': None,
+ }
+
+ self.assertEqual(serializer.data, expected)
+
+
+class SerializerMethodFieldTests(TestCase):
+ def setUp(self):
+
+ class BoopSerializer(serializers.Serializer):
+ beep = serializers.SerializerMethodField('get_beep')
+ boop = serializers.Field()
+ boop_count = serializers.SerializerMethodField('get_boop_count')
+
+ def get_beep(self, obj):
+ return 'hello!'
+
+ def get_boop_count(self, obj):
+ return len(obj.boop)
+
+ self.serializer_class = BoopSerializer
+
+ def test_serializer_method_field(self):
+
+ class MyModel(object):
+ boop = ['a', 'b', 'c']
+
+ source_data = MyModel()
+
+ serializer = self.serializer_class(source_data)
+
+ expected = {
+ 'beep': 'hello!',
+ 'boop': ['a', 'b', 'c'],
+ 'boop_count': 3,
+ }
+
+ self.assertEqual(serializer.data, expected)
+
+
+# Test for issue #324
+class BlankFieldTests(TestCase):
+ def setUp(self):
+
+ class BlankFieldModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlankFieldModel
+
+ class BlankFieldSerializer(serializers.Serializer):
+ title = serializers.CharField(required=False)
+
+ class NotBlankFieldModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BasicModel
+
+ class NotBlankFieldSerializer(serializers.Serializer):
+ title = serializers.CharField()
+
+ self.model_serializer_class = BlankFieldModelSerializer
+ self.serializer_class = BlankFieldSerializer
+ self.not_blank_model_serializer_class = NotBlankFieldModelSerializer
+ self.not_blank_serializer_class = NotBlankFieldSerializer
+ self.data = {'title': ''}
+
+ def test_create_blank_field(self):
+ serializer = self.serializer_class(data=self.data)
+ self.assertEqual(serializer.is_valid(), True)
+
+ def test_create_model_blank_field(self):
+ serializer = self.model_serializer_class(data=self.data)
+ self.assertEqual(serializer.is_valid(), True)
+
+ def test_create_model_null_field(self):
+ serializer = self.model_serializer_class(data={'title': None})
+ self.assertEqual(serializer.is_valid(), True)
+
+ def test_create_not_blank_field(self):
+ """
+ Test to ensure blank data in a field not marked as blank=True
+ is considered invalid in a non-model serializer
+ """
+ serializer = self.not_blank_serializer_class(data=self.data)
+ self.assertEqual(serializer.is_valid(), False)
+
+ def test_create_model_not_blank_field(self):
+ """
+ Test to ensure blank data in a field not marked as blank=True
+ is considered invalid in a model serializer
+ """
+ serializer = self.not_blank_model_serializer_class(data=self.data)
+ self.assertEqual(serializer.is_valid(), False)
+
+ def test_create_model_empty_field(self):
+ serializer = self.model_serializer_class(data={})
+ self.assertEqual(serializer.is_valid(), True)
+
+
+#test for issue #460
+class SerializerPickleTests(TestCase):
+ """
+ Test pickleability of the output of Serializers
+ """
+ def test_pickle_simple_model_serializer_data(self):
+ """
+ Test simple serializer
+ """
+ pickle.dumps(PersonSerializer(Person(name="Methusela", age=969)).data)
+
+ def test_pickle_inner_serializer(self):
+ """
+ Test pickling a serializer whose resulting .data (a SortedDictWithMetadata) will
+ have unpickleable meta data--in order to make sure metadata doesn't get pulled into the pickle.
+ See DictWithMetadata.__getstate__
+ """
+ class InnerPersonSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Person
+ fields = ('name', 'age')
+ pickle.dumps(InnerPersonSerializer(Person(name="Noah", age=950)).data, 0)
+
+ def test_getstate_method_should_not_return_none(self):
+ """
+ Regression test for #645.
+ """
+ data = serializers.DictWithMetadata({1: 1})
+ self.assertEqual(data.__getstate__(), serializers.SortedDict({1: 1}))
+
+ def test_serializer_data_is_pickleable(self):
+ """
+ Another regression test for #645.
+ """
+ data = serializers.SortedDictWithMetadata({1: 1})
+ repr(pickle.loads(pickle.dumps(data, 0)))
+
+
+# test for issue #725
+class SeveralChoicesModel(models.Model):
+ color = models.CharField(
+ max_length=10,
+ choices=[('red', 'Red'), ('green', 'Green'), ('blue', 'Blue')],
+ blank=False
+ )
+ drink = models.CharField(
+ max_length=10,
+ choices=[('beer', 'Beer'), ('wine', 'Wine'), ('cider', 'Cider')],
+ blank=False,
+ default='beer'
+ )
+ os = models.CharField(
+ max_length=10,
+ choices=[('linux', 'Linux'), ('osx', 'OSX'), ('windows', 'Windows')],
+ blank=True
+ )
+ music_genre = models.CharField(
+ max_length=10,
+ choices=[('rock', 'Rock'), ('metal', 'Metal'), ('grunge', 'Grunge')],
+ blank=True,
+ default='metal'
+ )
+
+
+class SerializerChoiceFields(TestCase):
+
+ def setUp(self):
+ super(SerializerChoiceFields, self).setUp()
+
+ class SeveralChoicesSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = SeveralChoicesModel
+ fields = ('color', 'drink', 'os', 'music_genre')
+
+ self.several_choices_serializer = SeveralChoicesSerializer
+
+ def test_choices_blank_false_not_default(self):
+ serializer = self.several_choices_serializer()
+ self.assertEqual(
+ serializer.fields['color'].choices,
+ [('red', 'Red'), ('green', 'Green'), ('blue', 'Blue')]
+ )
+
+ def test_choices_blank_false_with_default(self):
+ serializer = self.several_choices_serializer()
+ self.assertEqual(
+ serializer.fields['drink'].choices,
+ [('beer', 'Beer'), ('wine', 'Wine'), ('cider', 'Cider')]
+ )
+
+ def test_choices_blank_true_not_default(self):
+ serializer = self.several_choices_serializer()
+ self.assertEqual(
+ serializer.fields['os'].choices,
+ BLANK_CHOICE_DASH + [('linux', 'Linux'), ('osx', 'OSX'), ('windows', 'Windows')]
+ )
+
+ def test_choices_blank_true_with_default(self):
+ serializer = self.several_choices_serializer()
+ self.assertEqual(
+ serializer.fields['music_genre'].choices,
+ BLANK_CHOICE_DASH + [('rock', 'Rock'), ('metal', 'Metal'), ('grunge', 'Grunge')]
+ )
+
+
+# Regression tests for #675
+class Ticket(models.Model):
+ assigned = models.ForeignKey(
+ Person, related_name='assigned_tickets')
+ reviewer = models.ForeignKey(
+ Person, blank=True, null=True, related_name='reviewed_tickets')
+
+
+class SerializerRelatedChoicesTest(TestCase):
+
+ def setUp(self):
+ super(SerializerRelatedChoicesTest, self).setUp()
+
+ class RelatedChoicesSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Ticket
+ fields = ('assigned', 'reviewer')
+
+ self.related_fields_serializer = RelatedChoicesSerializer
+
+ def test_empty_queryset_required(self):
+ serializer = self.related_fields_serializer()
+ self.assertEqual(serializer.fields['assigned'].queryset.count(), 0)
+ self.assertEqual(
+ [x for x in serializer.fields['assigned'].widget.choices],
+ []
+ )
+
+ def test_empty_queryset_not_required(self):
+ serializer = self.related_fields_serializer()
+ self.assertEqual(serializer.fields['reviewer'].queryset.count(), 0)
+ self.assertEqual(
+ [x for x in serializer.fields['reviewer'].widget.choices],
+ [('', '---------')]
+ )
+
+ def test_with_some_persons_required(self):
+ Person.objects.create(name="Lionel Messi")
+ Person.objects.create(name="Xavi Hernandez")
+ serializer = self.related_fields_serializer()
+ self.assertEqual(serializer.fields['assigned'].queryset.count(), 2)
+ self.assertEqual(
+ [x for x in serializer.fields['assigned'].widget.choices],
+ [(1, 'Person object - 1'), (2, 'Person object - 2')]
+ )
+
+ def test_with_some_persons_not_required(self):
+ Person.objects.create(name="Lionel Messi")
+ Person.objects.create(name="Xavi Hernandez")
+ serializer = self.related_fields_serializer()
+ self.assertEqual(serializer.fields['reviewer'].queryset.count(), 2)
+ self.assertEqual(
+ [x for x in serializer.fields['reviewer'].widget.choices],
+ [('', '---------'), (1, 'Person object - 1'), (2, 'Person object - 2')]
+ )
+
+
+class DepthTest(TestCase):
+ def test_implicit_nesting(self):
+
+ writer = Person.objects.create(name="django", age=1)
+ post = BlogPost.objects.create(title="Test blog post", writer=writer)
+ comment = BlogPostComment.objects.create(text="Test blog post comment", blog_post=post)
+
+ class BlogPostCommentSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlogPostComment
+ depth = 2
+
+ serializer = BlogPostCommentSerializer(instance=comment)
+ expected = {'id': 1, 'text': 'Test blog post comment', 'blog_post': {'id': 1, 'title': 'Test blog post',
+ 'writer': {'id': 1, 'name': 'django', 'age': 1}}}
+
+ self.assertEqual(serializer.data, expected)
+
+ def test_explicit_nesting(self):
+ writer = Person.objects.create(name="django", age=1)
+ post = BlogPost.objects.create(title="Test blog post", writer=writer)
+ comment = BlogPostComment.objects.create(text="Test blog post comment", blog_post=post)
+
+ class PersonSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Person
+
+ class BlogPostSerializer(serializers.ModelSerializer):
+ writer = PersonSerializer()
+
+ class Meta:
+ model = BlogPost
+
+ class BlogPostCommentSerializer(serializers.ModelSerializer):
+ blog_post = BlogPostSerializer()
+
+ class Meta:
+ model = BlogPostComment
+
+ serializer = BlogPostCommentSerializer(instance=comment)
+ expected = {'id': 1, 'text': 'Test blog post comment', 'blog_post': {'id': 1, 'title': 'Test blog post',
+ 'writer': {'id': 1, 'name': 'django', 'age': 1}}}
+
+ self.assertEqual(serializer.data, expected)
+
+
+class NestedSerializerContextTests(TestCase):
+
+ def test_nested_serializer_context(self):
+ """
+ Regression for #497
+
+ https://github.com/tomchristie/django-rest-framework/issues/497
+ """
+ class PhotoSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Photo
+ fields = ("description", "callable")
+
+ callable = serializers.SerializerMethodField('_callable')
+
+ def _callable(self, instance):
+ if not 'context_item' in self.context:
+ raise RuntimeError("context isn't getting passed into 2nd level nested serializer")
+ return "success"
+
+ class AlbumSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Album
+ fields = ("photo_set", "callable")
+
+ photo_set = PhotoSerializer(source="photo_set")
+ callable = serializers.SerializerMethodField("_callable")
+
+ def _callable(self, instance):
+ if not 'context_item' in self.context:
+ raise RuntimeError("context isn't getting passed into 1st level nested serializer")
+ return "success"
+
+ class AlbumCollection(object):
+ albums = None
+
+ class AlbumCollectionSerializer(serializers.Serializer):
+ albums = AlbumSerializer(source="albums")
+
+ album1 = Album.objects.create(title="album 1")
+ album2 = Album.objects.create(title="album 2")
+ Photo.objects.create(description="Bigfoot", album=album1)
+ Photo.objects.create(description="Unicorn", album=album1)
+ Photo.objects.create(description="Yeti", album=album2)
+ Photo.objects.create(description="Sasquatch", album=album2)
+ album_collection = AlbumCollection()
+ album_collection.albums = [album1, album2]
+
+ # This will raise RuntimeError if context doesn't get passed correctly to the nested Serializers
+ AlbumCollectionSerializer(album_collection, context={'context_item': 'album context'}).data
+
+
+class DeserializeListTestCase(TestCase):
+
+ def setUp(self):
+ self.data = {
+ 'email': 'nobody@nowhere.com',
+ 'content': 'This is some test content',
+ 'created': datetime.datetime(2013, 3, 7),
+ }
+
+ def test_no_errors(self):
+ data = [self.data.copy() for x in range(0, 3)]
+ serializer = CommentSerializer(data=data, many=True)
+ self.assertTrue(serializer.is_valid())
+ self.assertTrue(isinstance(serializer.object, list))
+ self.assertTrue(
+ all((isinstance(item, Comment) for item in serializer.object))
+ )
+
+ def test_errors_return_as_list(self):
+ invalid_item = self.data.copy()
+ invalid_item['email'] = ''
+ data = [self.data.copy(), invalid_item, self.data.copy()]
+
+ serializer = CommentSerializer(data=data, many=True)
+ self.assertFalse(serializer.is_valid())
+ expected = [{}, {'email': ['This field is required.']}, {}]
+ self.assertEqual(serializer.errors, expected)
+
+
+# Test for issue 747
+
+class LazyStringModel(object):
+ def __init__(self, lazystring):
+ self.lazystring = lazystring
+
+
+class LazyStringSerializer(serializers.Serializer):
+ lazystring = serializers.Field()
+
+ def restore_object(self, attrs, instance=None):
+ if instance is not None:
+ instance.lazystring = attrs.get('lazystring', instance.lazystring)
+ return instance
+ return LazyStringModel(**attrs)
+
+
+class LazyStringsTestCase(TestCase):
+ def setUp(self):
+ self.model = LazyStringModel(lazystring=_('lazystring'))
+
+ def test_lazy_strings_are_translated(self):
+ serializer = LazyStringSerializer(self.model)
+ self.assertEqual(type(serializer.data['lazystring']),
+ type('lazystring'))
+
+
+# Test for issue #467
+
+class FieldLabelTest(TestCase):
+ def setUp(self):
+ self.serializer_class = BasicModelSerializer
+
+ def test_label_from_model(self):
+ """
+ Validates that label and help_text are correctly copied from the model class.
+ """
+ serializer = self.serializer_class()
+ text_field = serializer.fields['text']
+
+ self.assertEqual('Text comes here', text_field.label)
+ self.assertEqual('Text description.', text_field.help_text)
+
+ def test_field_ctor(self):
+ """
+ This is check that ctor supports both label and help_text.
+ """
+ self.assertEqual('Label', fields.Field(label='Label', help_text='Help').label)
+ self.assertEqual('Help', fields.CharField(label='Label', help_text='Help').help_text)
+ self.assertEqual('Label', relations.HyperlinkedRelatedField(view_name='fake', label='Label', help_text='Help', many=True).label)
+
+
+# Test for issue #961
+
+class ManyFieldHelpTextTest(TestCase):
+ def test_help_text_no_hold_down_control_msg(self):
+ """
+ Validate that help_text doesn't contain the 'Hold down "Control" ...'
+ message that Django appends to choice fields.
+ """
+ rel_field = fields.Field(help_text=ManyToManyModel._meta.get_field('rel').help_text)
+ self.assertEqual('Some help text.', rel_field.help_text)
+
+
+class AttributeMappingOnAutogeneratedFieldsTests(TestCase):
+
+ def setUp(self):
+ class AMOAFModel(RESTFrameworkModel):
+ char_field = models.CharField(max_length=1024, blank=True)
+ comma_separated_integer_field = models.CommaSeparatedIntegerField(max_length=1024, blank=True)
+ decimal_field = models.DecimalField(max_digits=64, decimal_places=32, blank=True)
+ email_field = models.EmailField(max_length=1024, blank=True)
+ file_field = models.FileField(max_length=1024, blank=True)
+ image_field = models.ImageField(max_length=1024, blank=True)
+ slug_field = models.SlugField(max_length=1024, blank=True)
+ url_field = models.URLField(max_length=1024, blank=True)
+
+ class AMOAFSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = AMOAFModel
+
+ self.serializer_class = AMOAFSerializer
+ self.fields_attributes = {
+ 'char_field': [
+ ('max_length', 1024),
+ ],
+ 'comma_separated_integer_field': [
+ ('max_length', 1024),
+ ],
+ 'decimal_field': [
+ ('max_digits', 64),
+ ('decimal_places', 32),
+ ],
+ 'email_field': [
+ ('max_length', 1024),
+ ],
+ 'file_field': [
+ ('max_length', 1024),
+ ],
+ 'image_field': [
+ ('max_length', 1024),
+ ],
+ 'slug_field': [
+ ('max_length', 1024),
+ ],
+ 'url_field': [
+ ('max_length', 1024),
+ ],
+ }
+
+ def field_test(self, field):
+ serializer = self.serializer_class(data={})
+ self.assertEqual(serializer.is_valid(), True)
+
+ for attribute in self.fields_attributes[field]:
+ self.assertEqual(
+ getattr(serializer.fields[field], attribute[0]),
+ attribute[1]
+ )
+
+ def test_char_field(self):
+ self.field_test('char_field')
+
+ def test_comma_separated_integer_field(self):
+ self.field_test('comma_separated_integer_field')
+
+ def test_decimal_field(self):
+ self.field_test('decimal_field')
+
+ def test_email_field(self):
+ self.field_test('email_field')
+
+ def test_file_field(self):
+ self.field_test('file_field')
+
+ def test_image_field(self):
+ self.field_test('image_field')
+
+ def test_slug_field(self):
+ self.field_test('slug_field')
+
+ def test_url_field(self):
+ self.field_test('url_field')
+
+
+class DefaultValuesOnAutogeneratedFieldsTests(TestCase):
+
+ def setUp(self):
+ class DVOAFModel(RESTFrameworkModel):
+ positive_integer_field = models.PositiveIntegerField(blank=True)
+ positive_small_integer_field = models.PositiveSmallIntegerField(blank=True)
+ email_field = models.EmailField(blank=True)
+ file_field = models.FileField(blank=True)
+ image_field = models.ImageField(blank=True)
+ slug_field = models.SlugField(blank=True)
+ url_field = models.URLField(blank=True)
+
+ class DVOAFSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = DVOAFModel
+
+ self.serializer_class = DVOAFSerializer
+ self.fields_attributes = {
+ 'positive_integer_field': [
+ ('min_value', 0),
+ ],
+ 'positive_small_integer_field': [
+ ('min_value', 0),
+ ],
+ 'email_field': [
+ ('max_length', 75),
+ ],
+ 'file_field': [
+ ('max_length', 100),
+ ],
+ 'image_field': [
+ ('max_length', 100),
+ ],
+ 'slug_field': [
+ ('max_length', 50),
+ ],
+ 'url_field': [
+ ('max_length', 200),
+ ],
+ }
+
+ def field_test(self, field):
+ serializer = self.serializer_class(data={})
+ self.assertEqual(serializer.is_valid(), True)
+
+ for attribute in self.fields_attributes[field]:
+ self.assertEqual(
+ getattr(serializer.fields[field], attribute[0]),
+ attribute[1]
+ )
+
+ def test_positive_integer_field(self):
+ self.field_test('positive_integer_field')
+
+ def test_positive_small_integer_field(self):
+ self.field_test('positive_small_integer_field')
+
+ def test_email_field(self):
+ self.field_test('email_field')
+
+ def test_file_field(self):
+ self.field_test('file_field')
+
+ def test_image_field(self):
+ self.field_test('image_field')
+
+ def test_slug_field(self):
+ self.field_test('slug_field')
+
+ def test_url_field(self):
+ self.field_test('url_field')
+
+
+class MetadataSerializer(serializers.Serializer):
+ field1 = serializers.CharField(3, required=True)
+ field2 = serializers.CharField(10, required=False)
+
+
+class MetadataSerializerTestCase(TestCase):
+ def setUp(self):
+ self.serializer = MetadataSerializer()
+
+ def test_serializer_metadata(self):
+ metadata = self.serializer.metadata()
+ expected = {
+ 'field1': {
+ 'required': True,
+ 'max_length': 3,
+ 'type': 'string',
+ 'read_only': False
+ },
+ 'field2': {
+ 'required': False,
+ 'max_length': 10,
+ 'type': 'string',
+ 'read_only': False
+ }
+ }
+ self.assertEqual(expected, metadata)
+
+
+### Regression test for #840
+
+class SimpleModel(models.Model):
+ text = models.CharField(max_length=100)
+
+
+class SimpleModelSerializer(serializers.ModelSerializer):
+ text = serializers.CharField()
+ other = serializers.CharField()
+
+ class Meta:
+ model = SimpleModel
+
+ def validate_other(self, attrs, source):
+ del attrs['other']
+ return attrs
+
+
+class FieldValidationRemovingAttr(TestCase):
+ def test_removing_non_model_field_in_validation(self):
+ """
+ Removing an attr during field valiation should ensure that it is not
+ passed through when restoring the object.
+
+ This allows additional non-model fields to be supported.
+
+ Regression test for #840.
+ """
+ serializer = SimpleModelSerializer(data={'text': 'foo', 'other': 'bar'})
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEqual(serializer.object.text, 'foo')
+
+
+### Regression test for #878
+
+class SimpleTargetModel(models.Model):
+ text = models.CharField(max_length=100)
+
+
+class SimplePKSourceModelSerializer(serializers.Serializer):
+ targets = serializers.PrimaryKeyRelatedField(queryset=SimpleTargetModel.objects.all(), many=True)
+ text = serializers.CharField()
+
+
+class SimpleSlugSourceModelSerializer(serializers.Serializer):
+ targets = serializers.SlugRelatedField(queryset=SimpleTargetModel.objects.all(), many=True, slug_field='pk')
+ text = serializers.CharField()
+
+
+class SerializerSupportsManyRelationships(TestCase):
+ def setUp(self):
+ SimpleTargetModel.objects.create(text='foo')
+ SimpleTargetModel.objects.create(text='bar')
+
+ def test_serializer_supports_pk_many_relationships(self):
+ """
+ Regression test for #878.
+
+ Note that pk behavior has a different code path to usual cases,
+ for performance reasons.
+ """
+ serializer = SimplePKSourceModelSerializer(data={'text': 'foo', 'targets': [1, 2]})
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, {'text': 'foo', 'targets': [1, 2]})
+
+ def test_serializer_supports_slug_many_relationships(self):
+ """
+ Regression test for #878.
+ """
+ serializer = SimpleSlugSourceModelSerializer(data={'text': 'foo', 'targets': [1, 2]})
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, {'text': 'foo', 'targets': [1, 2]})
+
+
+class TransformMethodsSerializer(serializers.Serializer):
+ a = serializers.CharField()
+ b_renamed = serializers.CharField(source='b')
+
+ def transform_a(self, obj, value):
+ return value.lower()
+
+ def transform_b_renamed(self, obj, value):
+ if value is not None:
+ return 'and ' + value
+
+
+class TestSerializerTransformMethods(TestCase):
+ def setUp(self):
+ self.s = TransformMethodsSerializer()
+
+ def test_transform_methods(self):
+ self.assertEqual(
+ self.s.to_native({'a': 'GREEN EGGS', 'b': 'HAM'}),
+ {
+ 'a': 'green eggs',
+ 'b_renamed': 'and HAM',
+ }
+ )
+
+ def test_missing_fields(self):
+ self.assertEqual(
+ self.s.to_native({'a': 'GREEN EGGS'}),
+ {
+ 'a': 'green eggs',
+ 'b_renamed': None,
+ }
+ )
+
+
+class DefaultTrueBooleanModel(models.Model):
+ cat = models.BooleanField(default=True)
+ dog = models.BooleanField(default=False)
+
+
+class SerializerDefaultTrueBoolean(TestCase):
+
+ def setUp(self):
+ super(SerializerDefaultTrueBoolean, self).setUp()
+
+ class DefaultTrueBooleanSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = DefaultTrueBooleanModel
+ fields = ('cat', 'dog')
+
+ self.default_true_boolean_serializer = DefaultTrueBooleanSerializer
+
+ def test_enabled_as_false(self):
+ serializer = self.default_true_boolean_serializer(data={'cat': False,
+ 'dog': False})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.data['cat'], False)
+ self.assertEqual(serializer.data['dog'], False)
+
+ def test_enabled_as_true(self):
+ serializer = self.default_true_boolean_serializer(data={'cat': True,
+ 'dog': True})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.data['cat'], True)
+ self.assertEqual(serializer.data['dog'], True)
+
+ def test_enabled_partial(self):
+ serializer = self.default_true_boolean_serializer(data={'cat': False},
+ partial=True)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.data['cat'], False)
+ self.assertEqual(serializer.data['dog'], False)
+
+
+class BoolenFieldTypeTest(TestCase):
+ '''
+ Ensure the various Boolean based model fields are rendered as the proper
+ field type
+
+ '''
+
+ def setUp(self):
+ '''
+ Setup an ActionItemSerializer for BooleanTesting
+ '''
+ data = {
+ 'title': 'b' * 201,
+ }
+ self.serializer = ActionItemSerializer(data=data)
+
+ def test_booleanfield_type(self):
+ '''
+ Test that BooleanField is infered from models.BooleanField
+ '''
+ bfield = self.serializer.get_fields()['done']
+ self.assertEqual(type(bfield), fields.BooleanField)
+
+ def test_nullbooleanfield_type(self):
+ '''
+ Test that BooleanField is infered from models.NullBooleanField
+
+ https://groups.google.com/forum/#!topic/django-rest-framework/D9mXEftpuQ8
+ '''
+ bfield = self.serializer.get_fields()['started']
+ self.assertEqual(type(bfield), fields.BooleanField)
diff --git a/tests/test_serializer_bulk_update.py b/tests/test_serializer_bulk_update.py
new file mode 100644
index 00000000..8b0ded1a
--- /dev/null
+++ b/tests/test_serializer_bulk_update.py
@@ -0,0 +1,278 @@
+"""
+Tests to cover bulk create and update using serializers.
+"""
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework import serializers
+
+
+class BulkCreateSerializerTests(TestCase):
+ """
+ Creating multiple instances using serializers.
+ """
+
+ def setUp(self):
+ class BookSerializer(serializers.Serializer):
+ id = serializers.IntegerField()
+ title = serializers.CharField(max_length=100)
+ author = serializers.CharField(max_length=100)
+
+ self.BookSerializer = BookSerializer
+
+ def test_bulk_create_success(self):
+ """
+ Correct bulk update serialization should return the input data.
+ """
+
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 1,
+ 'title': 'If this is a man',
+ 'author': 'Primo Levi'
+ }, {
+ 'id': 2,
+ 'title': 'The wind-up bird chronicle',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+
+ serializer = self.BookSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, data)
+
+ def test_bulk_create_errors(self):
+ """
+ Correct bulk update serialization should return the input data.
+ """
+
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 1,
+ 'title': 'If this is a man',
+ 'author': 'Primo Levi'
+ }, {
+ 'id': 'foo',
+ 'title': 'The wind-up bird chronicle',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+ expected_errors = [
+ {},
+ {},
+ {'id': ['Enter a whole number.']}
+ ]
+
+ serializer = self.BookSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, expected_errors)
+
+ def test_invalid_list_datatype(self):
+ """
+ Data containing list of incorrect data type should return errors.
+ """
+ data = ['foo', 'bar', 'baz']
+ serializer = self.BookSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+
+ expected_errors = [
+ {'non_field_errors': ['Invalid data']},
+ {'non_field_errors': ['Invalid data']},
+ {'non_field_errors': ['Invalid data']}
+ ]
+
+ self.assertEqual(serializer.errors, expected_errors)
+
+ def test_invalid_single_datatype(self):
+ """
+ Data containing a single incorrect data type should return errors.
+ """
+ data = 123
+ serializer = self.BookSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+
+ expected_errors = {'non_field_errors': ['Expected a list of items.']}
+
+ self.assertEqual(serializer.errors, expected_errors)
+
+ def test_invalid_single_object(self):
+ """
+ Data containing only a single object, instead of a list of objects
+ should return errors.
+ """
+ data = {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }
+ serializer = self.BookSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+
+ expected_errors = {'non_field_errors': ['Expected a list of items.']}
+
+ self.assertEqual(serializer.errors, expected_errors)
+
+
+class BulkUpdateSerializerTests(TestCase):
+ """
+ Updating multiple instances using serializers.
+ """
+
+ def setUp(self):
+ class Book(object):
+ """
+ A data type that can be persisted to a mock storage backend
+ with `.save()` and `.delete()`.
+ """
+ object_map = {}
+
+ def __init__(self, id, title, author):
+ self.id = id
+ self.title = title
+ self.author = author
+
+ def save(self):
+ Book.object_map[self.id] = self
+
+ def delete(self):
+ del Book.object_map[self.id]
+
+ class BookSerializer(serializers.Serializer):
+ id = serializers.IntegerField()
+ title = serializers.CharField(max_length=100)
+ author = serializers.CharField(max_length=100)
+
+ def restore_object(self, attrs, instance=None):
+ if instance:
+ instance.id = attrs['id']
+ instance.title = attrs['title']
+ instance.author = attrs['author']
+ return instance
+ return Book(**attrs)
+
+ self.Book = Book
+ self.BookSerializer = BookSerializer
+
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 1,
+ 'title': 'If this is a man',
+ 'author': 'Primo Levi'
+ }, {
+ 'id': 2,
+ 'title': 'The wind-up bird chronicle',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+
+ for item in data:
+ book = Book(item['id'], item['title'], item['author'])
+ book.save()
+
+ def books(self):
+ """
+ Return all the objects in the mock storage backend.
+ """
+ return self.Book.object_map.values()
+
+ def test_bulk_update_success(self):
+ """
+ Correct bulk update serialization should return the input data.
+ """
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 2,
+ 'title': 'Kafka on the shore',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+ serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.data, data)
+ serializer.save()
+ new_data = self.BookSerializer(self.books(), many=True).data
+
+ self.assertEqual(data, new_data)
+
+ def test_bulk_update_and_create(self):
+ """
+ Bulk update serialization may also include created items.
+ """
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 3,
+ 'title': 'Kafka on the shore',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+ serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.data, data)
+ serializer.save()
+ new_data = self.BookSerializer(self.books(), many=True).data
+ self.assertEqual(data, new_data)
+
+ def test_bulk_update_invalid_create(self):
+ """
+ Bulk update serialization without allow_add_remove may not create items.
+ """
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 3,
+ 'title': 'Kafka on the shore',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+ expected_errors = [
+ {},
+ {'non_field_errors': ['Cannot create a new item, only existing items may be updated.']}
+ ]
+ serializer = self.BookSerializer(self.books(), data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, expected_errors)
+
+ def test_bulk_update_error(self):
+ """
+ Incorrect bulk update serialization should return error data.
+ """
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 'foo',
+ 'title': 'Kafka on the shore',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+ expected_errors = [
+ {},
+ {'id': ['Enter a whole number.']}
+ ]
+ serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, expected_errors)
diff --git a/tests/test_serializer_empty.py b/tests/test_serializer_empty.py
new file mode 100644
index 00000000..30cff361
--- /dev/null
+++ b/tests/test_serializer_empty.py
@@ -0,0 +1,15 @@
+from django.test import TestCase
+from rest_framework import serializers
+
+
+class EmptySerializerTestCase(TestCase):
+ def test_empty_serializer(self):
+ class FooBarSerializer(serializers.Serializer):
+ foo = serializers.IntegerField()
+ bar = serializers.SerializerMethodField('get_bar')
+
+ def get_bar(self, obj):
+ return 'bar'
+
+ serializer = FooBarSerializer()
+ self.assertEquals(serializer.data, {'foo': 0})
diff --git a/tests/test_serializer_import.py b/tests/test_serializer_import.py
new file mode 100644
index 00000000..3b8ff4b3
--- /dev/null
+++ b/tests/test_serializer_import.py
@@ -0,0 +1,19 @@
+from django.test import TestCase
+
+from rest_framework import serializers
+from tests.accounts.serializers import AccountSerializer
+
+
+class ImportingModelSerializerTests(TestCase):
+ """
+ In some situations like, GH #1225, it is possible, especially in
+ testing, to import a serializer who's related models have not yet
+ been resolved by Django. `AccountSerializer` is an example of such
+ a serializer (imported at the top of this file).
+ """
+ def test_import_model_serializer(self):
+ """
+ The serializer at the top of this file should have been
+ imported successfully, and we should be able to instantiate it.
+ """
+ self.assertIsInstance(AccountSerializer(), serializers.ModelSerializer)
diff --git a/tests/test_serializer_nested.py b/tests/test_serializer_nested.py
new file mode 100644
index 00000000..6d69ffbd
--- /dev/null
+++ b/tests/test_serializer_nested.py
@@ -0,0 +1,347 @@
+"""
+Tests to cover nested serializers.
+
+Doesn't cover model serializers.
+"""
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework import serializers
+from . import models
+
+
+class WritableNestedSerializerBasicTests(TestCase):
+ """
+ Tests for deserializing nested entities.
+ Basic tests that use serializers that simply restore to dicts.
+ """
+
+ def setUp(self):
+ class TrackSerializer(serializers.Serializer):
+ order = serializers.IntegerField()
+ title = serializers.CharField(max_length=100)
+ duration = serializers.IntegerField()
+
+ class AlbumSerializer(serializers.Serializer):
+ album_name = serializers.CharField(max_length=100)
+ artist = serializers.CharField(max_length=100)
+ tracks = TrackSerializer(many=True)
+
+ self.AlbumSerializer = AlbumSerializer
+
+ def test_nested_validation_success(self):
+ """
+ Correct nested serialization should return the input data.
+ """
+
+ data = {
+ 'album_name': 'Discovery',
+ 'artist': 'Daft Punk',
+ 'tracks': [
+ {'order': 1, 'title': 'One More Time', 'duration': 235},
+ {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
+ {'order': 3, 'title': 'Digital Love', 'duration': 239}
+ ]
+ }
+
+ serializer = self.AlbumSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, data)
+
+ def test_nested_validation_error(self):
+ """
+ Incorrect nested serialization should return appropriate error data.
+ """
+
+ data = {
+ 'album_name': 'Discovery',
+ 'artist': 'Daft Punk',
+ 'tracks': [
+ {'order': 1, 'title': 'One More Time', 'duration': 235},
+ {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
+ {'order': 3, 'title': 'Digital Love', 'duration': 'foobar'}
+ ]
+ }
+ expected_errors = {
+ 'tracks': [
+ {},
+ {},
+ {'duration': ['Enter a whole number.']}
+ ]
+ }
+
+ serializer = self.AlbumSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, expected_errors)
+
+ def test_many_nested_validation_error(self):
+ """
+ Incorrect nested serialization should return appropriate error data
+ when multiple entities are being deserialized.
+ """
+
+ data = [
+ {
+ 'album_name': 'Russian Red',
+ 'artist': 'I Love Your Glasses',
+ 'tracks': [
+ {'order': 1, 'title': 'Cigarettes', 'duration': 121},
+ {'order': 2, 'title': 'No Past Land', 'duration': 198},
+ {'order': 3, 'title': 'They Don\'t Believe', 'duration': 191}
+ ]
+ },
+ {
+ 'album_name': 'Discovery',
+ 'artist': 'Daft Punk',
+ 'tracks': [
+ {'order': 1, 'title': 'One More Time', 'duration': 235},
+ {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
+ {'order': 3, 'title': 'Digital Love', 'duration': 'foobar'}
+ ]
+ }
+ ]
+ expected_errors = [
+ {},
+ {
+ 'tracks': [
+ {},
+ {},
+ {'duration': ['Enter a whole number.']}
+ ]
+ }
+ ]
+
+ serializer = self.AlbumSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, expected_errors)
+
+
+class WritableNestedSerializerObjectTests(TestCase):
+ """
+ Tests for deserializing nested entities.
+ These tests use serializers that restore to concrete objects.
+ """
+
+ def setUp(self):
+ # Couple of concrete objects that we're going to deserialize into
+ class Track(object):
+ def __init__(self, order, title, duration):
+ self.order, self.title, self.duration = order, title, duration
+
+ def __eq__(self, other):
+ return (
+ self.order == other.order and
+ self.title == other.title and
+ self.duration == other.duration
+ )
+
+ class Album(object):
+ def __init__(self, album_name, artist, tracks):
+ self.album_name, self.artist, self.tracks = album_name, artist, tracks
+
+ def __eq__(self, other):
+ return (
+ self.album_name == other.album_name and
+ self.artist == other.artist and
+ self.tracks == other.tracks
+ )
+
+ # And their corresponding serializers
+ class TrackSerializer(serializers.Serializer):
+ order = serializers.IntegerField()
+ title = serializers.CharField(max_length=100)
+ duration = serializers.IntegerField()
+
+ def restore_object(self, attrs, instance=None):
+ return Track(attrs['order'], attrs['title'], attrs['duration'])
+
+ class AlbumSerializer(serializers.Serializer):
+ album_name = serializers.CharField(max_length=100)
+ artist = serializers.CharField(max_length=100)
+ tracks = TrackSerializer(many=True)
+
+ def restore_object(self, attrs, instance=None):
+ return Album(attrs['album_name'], attrs['artist'], attrs['tracks'])
+
+ self.Album, self.Track = Album, Track
+ self.AlbumSerializer = AlbumSerializer
+
+ def test_nested_validation_success(self):
+ """
+ Correct nested serialization should return a restored object
+ that corresponds to the input data.
+ """
+
+ data = {
+ 'album_name': 'Discovery',
+ 'artist': 'Daft Punk',
+ 'tracks': [
+ {'order': 1, 'title': 'One More Time', 'duration': 235},
+ {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
+ {'order': 3, 'title': 'Digital Love', 'duration': 239}
+ ]
+ }
+ expected_object = self.Album(
+ album_name='Discovery',
+ artist='Daft Punk',
+ tracks=[
+ self.Track(order=1, title='One More Time', duration=235),
+ self.Track(order=2, title='Aerodynamic', duration=184),
+ self.Track(order=3, title='Digital Love', duration=239),
+ ]
+ )
+
+ serializer = self.AlbumSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, expected_object)
+
+ def test_many_nested_validation_success(self):
+ """
+ Correct nested serialization should return multiple restored objects
+ that corresponds to the input data when multiple objects are
+ being deserialized.
+ """
+
+ data = [
+ {
+ 'album_name': 'Russian Red',
+ 'artist': 'I Love Your Glasses',
+ 'tracks': [
+ {'order': 1, 'title': 'Cigarettes', 'duration': 121},
+ {'order': 2, 'title': 'No Past Land', 'duration': 198},
+ {'order': 3, 'title': 'They Don\'t Believe', 'duration': 191}
+ ]
+ },
+ {
+ 'album_name': 'Discovery',
+ 'artist': 'Daft Punk',
+ 'tracks': [
+ {'order': 1, 'title': 'One More Time', 'duration': 235},
+ {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
+ {'order': 3, 'title': 'Digital Love', 'duration': 239}
+ ]
+ }
+ ]
+ expected_object = [
+ self.Album(
+ album_name='Russian Red',
+ artist='I Love Your Glasses',
+ tracks=[
+ self.Track(order=1, title='Cigarettes', duration=121),
+ self.Track(order=2, title='No Past Land', duration=198),
+ self.Track(order=3, title='They Don\'t Believe', duration=191),
+ ]
+ ),
+ self.Album(
+ album_name='Discovery',
+ artist='Daft Punk',
+ tracks=[
+ self.Track(order=1, title='One More Time', duration=235),
+ self.Track(order=2, title='Aerodynamic', duration=184),
+ self.Track(order=3, title='Digital Love', duration=239),
+ ]
+ )
+ ]
+
+ serializer = self.AlbumSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, expected_object)
+
+
+class ForeignKeyNestedSerializerUpdateTests(TestCase):
+ def setUp(self):
+ class Artist(object):
+ def __init__(self, name):
+ self.name = name
+
+ def __eq__(self, other):
+ return self.name == other.name
+
+ class Album(object):
+ def __init__(self, name, artist):
+ self.name, self.artist = name, artist
+
+ def __eq__(self, other):
+ return self.name == other.name and self.artist == other.artist
+
+ class ArtistSerializer(serializers.Serializer):
+ name = serializers.CharField()
+
+ def restore_object(self, attrs, instance=None):
+ if instance:
+ instance.name = attrs['name']
+ else:
+ instance = Artist(attrs['name'])
+ return instance
+
+ class AlbumSerializer(serializers.Serializer):
+ name = serializers.CharField()
+ by = ArtistSerializer(source='artist')
+
+ def restore_object(self, attrs, instance=None):
+ if instance:
+ instance.name = attrs['name']
+ instance.artist = attrs['artist']
+ else:
+ instance = Album(attrs['name'], attrs['artist'])
+ return instance
+
+ self.Artist = Artist
+ self.Album = Album
+ self.AlbumSerializer = AlbumSerializer
+
+ def test_create_via_foreign_key_with_source(self):
+ """
+ Check that we can both *create* and *update* into objects across
+ ForeignKeys that have a `source` specified.
+ Regression test for #1170
+ """
+ data = {
+ 'name': 'Discovery',
+ 'by': {'name': 'Daft Punk'},
+ }
+
+ expected = self.Album(artist=self.Artist('Daft Punk'), name='Discovery')
+
+ # create
+ serializer = self.AlbumSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, expected)
+
+ # update
+ original = self.Album(artist=self.Artist('The Bats'), name='Free All the Monsters')
+ serializer = self.AlbumSerializer(instance=original, data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, expected)
+
+
+class NestedModelSerializerUpdateTests(TestCase):
+ def test_second_nested_level(self):
+ john = models.Person.objects.create(name="john")
+
+ post = john.blogpost_set.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I hate this blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ class BlogPostCommentSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = models.BlogPostComment
+
+ class BlogPostSerializer(serializers.ModelSerializer):
+ comments = BlogPostCommentSerializer(many=True, source='blogpostcomment_set')
+ class Meta:
+ model = models.BlogPost
+ fields = ('id', 'title', 'comments')
+
+ class PersonSerializer(serializers.ModelSerializer):
+ posts = BlogPostSerializer(many=True, source='blogpost_set')
+ class Meta:
+ model = models.Person
+ fields = ('id', 'name', 'age', 'posts')
+
+ serialize = PersonSerializer(instance=john)
+ deserialize = PersonSerializer(data=serialize.data, instance=john)
+ self.assertTrue(deserialize.is_valid())
+
+ result = deserialize.object
+ result.save()
+ self.assertEqual(result.id, john.id)
diff --git a/tests/test_serializers.py b/tests/test_serializers.py
new file mode 100644
index 00000000..67547783
--- /dev/null
+++ b/tests/test_serializers.py
@@ -0,0 +1,28 @@
+from django.db import models
+from django.test import TestCase
+
+from rest_framework.serializers import _resolve_model
+from tests.models import BasicModel
+
+
+class ResolveModelTests(TestCase):
+ """
+ `_resolve_model` should return a Django model class given the
+ provided argument is a Django model class itself, or a properly
+ formatted string representation of one.
+ """
+ def test_resolve_django_model(self):
+ resolved_model = _resolve_model(BasicModel)
+ self.assertEqual(resolved_model, BasicModel)
+
+ def test_resolve_string_representation(self):
+ resolved_model = _resolve_model('tests.BasicModel')
+ self.assertEqual(resolved_model, BasicModel)
+
+ def test_resolve_non_django_model(self):
+ with self.assertRaises(ValueError):
+ _resolve_model(TestCase)
+
+ def test_resolve_improper_string_representation(self):
+ with self.assertRaises(ValueError):
+ _resolve_model('BasicModel')
diff --git a/tests/test_settings.py b/tests/test_settings.py
new file mode 100644
index 00000000..e29fc34a
--- /dev/null
+++ b/tests/test_settings.py
@@ -0,0 +1,22 @@
+"""Tests for the settings module"""
+from __future__ import unicode_literals
+from django.test import TestCase
+
+from rest_framework.settings import APISettings, DEFAULTS, IMPORT_STRINGS
+
+
+class TestSettings(TestCase):
+ """Tests relating to the api settings"""
+
+ def test_non_import_errors(self):
+ """Make sure other errors aren't suppressed."""
+ settings = APISettings({'DEFAULT_MODEL_SERIALIZER_CLASS': 'tests.extras.bad_import.ModelSerializer'}, DEFAULTS, IMPORT_STRINGS)
+ with self.assertRaises(ValueError):
+ settings.DEFAULT_MODEL_SERIALIZER_CLASS
+
+ def test_import_error_message_maintained(self):
+ """Make sure real import errors are captured and raised sensibly."""
+ settings = APISettings({'DEFAULT_MODEL_SERIALIZER_CLASS': 'tests.extras.not_here.ModelSerializer'}, DEFAULTS, IMPORT_STRINGS)
+ with self.assertRaises(ImportError) as cm:
+ settings.DEFAULT_MODEL_SERIALIZER_CLASS
+ self.assertTrue('ImportError' in str(cm.exception))
diff --git a/tests/test_status.py b/tests/test_status.py
new file mode 100644
index 00000000..7b1bdae3
--- /dev/null
+++ b/tests/test_status.py
@@ -0,0 +1,33 @@
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework.status import (
+ is_informational, is_success, is_redirect, is_client_error, is_server_error
+)
+
+
+class TestStatus(TestCase):
+ def test_status_categories(self):
+ self.assertFalse(is_informational(99))
+ self.assertTrue(is_informational(100))
+ self.assertTrue(is_informational(199))
+ self.assertFalse(is_informational(200))
+
+ self.assertFalse(is_success(199))
+ self.assertTrue(is_success(200))
+ self.assertTrue(is_success(299))
+ self.assertFalse(is_success(300))
+
+ self.assertFalse(is_redirect(299))
+ self.assertTrue(is_redirect(300))
+ self.assertTrue(is_redirect(399))
+ self.assertFalse(is_redirect(400))
+
+ self.assertFalse(is_client_error(399))
+ self.assertTrue(is_client_error(400))
+ self.assertTrue(is_client_error(499))
+ self.assertFalse(is_client_error(500))
+
+ self.assertFalse(is_server_error(499))
+ self.assertTrue(is_server_error(500))
+ self.assertTrue(is_server_error(599))
+ self.assertFalse(is_server_error(600)) \ No newline at end of file
diff --git a/tests/test_templatetags.py b/tests/test_templatetags.py
new file mode 100644
index 00000000..d4da0c23
--- /dev/null
+++ b/tests/test_templatetags.py
@@ -0,0 +1,51 @@
+# encoding: utf-8
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework.test import APIRequestFactory
+from rest_framework.templatetags.rest_framework import add_query_param, urlize_quoted_links
+
+factory = APIRequestFactory()
+
+
+class TemplateTagTests(TestCase):
+
+ def test_add_query_param_with_non_latin_charactor(self):
+ # Ensure we don't double-escape non-latin characters
+ # that are present in the querystring.
+ # See #1314.
+ request = factory.get("/", {'q': '查询'})
+ json_url = add_query_param(request, "format", "json")
+ self.assertIn("q=%E6%9F%A5%E8%AF%A2", json_url)
+ self.assertIn("format=json", json_url)
+
+
+class Issue1386Tests(TestCase):
+ """
+ Covers #1386
+ """
+
+ def test_issue_1386(self):
+ """
+ Test function urlize_quoted_links with different args
+ """
+ correct_urls = [
+ "asdf.com",
+ "asdf.net",
+ "www.as_df.org",
+ "as.d8f.ghj8.gov",
+ ]
+ for i in correct_urls:
+ res = urlize_quoted_links(i)
+ self.assertNotEqual(res, i)
+ self.assertIn(i, res)
+
+ incorrect_urls = [
+ "mailto://asdf@fdf.com",
+ "asdf.netnet",
+ ]
+ for i in incorrect_urls:
+ res = urlize_quoted_links(i)
+ self.assertEqual(i, res)
+
+ # example from issue #1386, this shouldn't raise an exception
+ _ = urlize_quoted_links("asdf:[/p]zxcv.com")
diff --git a/tests/test_testing.py b/tests/test_testing.py
new file mode 100644
index 00000000..8c6086a2
--- /dev/null
+++ b/tests/test_testing.py
@@ -0,0 +1,154 @@
+# -- coding: utf-8 --
+
+from __future__ import unicode_literals
+from io import BytesIO
+
+from django.contrib.auth.models import User
+from django.test import TestCase
+from rest_framework.compat import patterns, url
+from rest_framework.decorators import api_view
+from rest_framework.response import Response
+from rest_framework.test import APIClient, APIRequestFactory, force_authenticate
+
+
+@api_view(['GET', 'POST'])
+def view(request):
+ return Response({
+ 'auth': request.META.get('HTTP_AUTHORIZATION', b''),
+ 'user': request.user.username
+ })
+
+
+@api_view(['GET', 'POST'])
+def session_view(request):
+ active_session = request.session.get('active_session', False)
+ request.session['active_session'] = True
+ return Response({
+ 'active_session': active_session
+ })
+
+
+urlpatterns = patterns('',
+ url(r'^view/$', view),
+ url(r'^session-view/$', session_view),
+)
+
+
+class TestAPITestClient(TestCase):
+ urls = 'tests.test_testing'
+
+ def setUp(self):
+ self.client = APIClient()
+
+ def test_credentials(self):
+ """
+ Setting `.credentials()` adds the required headers to each request.
+ """
+ self.client.credentials(HTTP_AUTHORIZATION='example')
+ for _ in range(0, 3):
+ response = self.client.get('/view/')
+ self.assertEqual(response.data['auth'], 'example')
+
+ def test_force_authenticate(self):
+ """
+ Setting `.force_authenticate()` forcibly authenticates each request.
+ """
+ user = User.objects.create_user('example', 'example@example.com')
+ self.client.force_authenticate(user)
+ response = self.client.get('/view/')
+ self.assertEqual(response.data['user'], 'example')
+
+ def test_force_authenticate_with_sessions(self):
+ """
+ Setting `.force_authenticate()` forcibly authenticates each request.
+ """
+ user = User.objects.create_user('example', 'example@example.com')
+ self.client.force_authenticate(user)
+
+ # First request does not yet have an active session
+ response = self.client.get('/session-view/')
+ self.assertEqual(response.data['active_session'], False)
+
+ # Subsequant requests have an active session
+ response = self.client.get('/session-view/')
+ self.assertEqual(response.data['active_session'], True)
+
+ # Force authenticating as `None` should also logout the user session.
+ self.client.force_authenticate(None)
+ response = self.client.get('/session-view/')
+ self.assertEqual(response.data['active_session'], False)
+
+ def test_csrf_exempt_by_default(self):
+ """
+ By default, the test client is CSRF exempt.
+ """
+ User.objects.create_user('example', 'example@example.com', 'password')
+ self.client.login(username='example', password='password')
+ response = self.client.post('/view/')
+ self.assertEqual(response.status_code, 200)
+
+ def test_explicitly_enforce_csrf_checks(self):
+ """
+ The test client can enforce CSRF checks.
+ """
+ client = APIClient(enforce_csrf_checks=True)
+ User.objects.create_user('example', 'example@example.com', 'password')
+ client.login(username='example', password='password')
+ response = client.post('/view/')
+ expected = {'detail': 'CSRF Failed: CSRF cookie not set.'}
+ self.assertEqual(response.status_code, 403)
+ self.assertEqual(response.data, expected)
+
+
+class TestAPIRequestFactory(TestCase):
+ def test_csrf_exempt_by_default(self):
+ """
+ By default, the test client is CSRF exempt.
+ """
+ user = User.objects.create_user('example', 'example@example.com', 'password')
+ factory = APIRequestFactory()
+ request = factory.post('/view/')
+ request.user = user
+ response = view(request)
+ self.assertEqual(response.status_code, 200)
+
+ def test_explicitly_enforce_csrf_checks(self):
+ """
+ The test client can enforce CSRF checks.
+ """
+ user = User.objects.create_user('example', 'example@example.com', 'password')
+ factory = APIRequestFactory(enforce_csrf_checks=True)
+ request = factory.post('/view/')
+ request.user = user
+ response = view(request)
+ expected = {'detail': 'CSRF Failed: CSRF cookie not set.'}
+ self.assertEqual(response.status_code, 403)
+ self.assertEqual(response.data, expected)
+
+ def test_invalid_format(self):
+ """
+ Attempting to use a format that is not configured will raise an
+ assertion error.
+ """
+ factory = APIRequestFactory()
+ self.assertRaises(AssertionError, factory.post,
+ path='/view/', data={'example': 1}, format='xml'
+ )
+
+ def test_force_authenticate(self):
+ """
+ Setting `force_authenticate()` forcibly authenticates the request.
+ """
+ user = User.objects.create_user('example', 'example@example.com')
+ factory = APIRequestFactory()
+ request = factory.get('/view')
+ force_authenticate(request, user=user)
+ response = view(request)
+ self.assertEqual(response.data['user'], 'example')
+
+ def test_upload_file(self):
+ # This is a 1x1 black png
+ simple_png = BytesIO(b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x06\x00\x00\x00\x1f\x15\xc4\x89\x00\x00\x00\rIDATx\x9cc````\x00\x00\x00\x05\x00\x01\xa5\xf6E@\x00\x00\x00\x00IEND\xaeB`\x82')
+ simple_png.name = 'test.png'
+ factory = APIRequestFactory()
+ factory.post('/', data={'image': simple_png})
diff --git a/tests/test_throttling.py b/tests/test_throttling.py
new file mode 100644
index 00000000..41bff692
--- /dev/null
+++ b/tests/test_throttling.py
@@ -0,0 +1,277 @@
+"""
+Tests for the throttling implementations in the permissions module.
+"""
+from __future__ import unicode_literals
+from django.test import TestCase
+from django.contrib.auth.models import User
+from django.core.cache import cache
+from rest_framework.test import APIRequestFactory
+from rest_framework.views import APIView
+from rest_framework.throttling import BaseThrottle, UserRateThrottle, ScopedRateThrottle
+from rest_framework.response import Response
+
+
+class User3SecRateThrottle(UserRateThrottle):
+ rate = '3/sec'
+ scope = 'seconds'
+
+
+class User3MinRateThrottle(UserRateThrottle):
+ rate = '3/min'
+ scope = 'minutes'
+
+
+class NonTimeThrottle(BaseThrottle):
+ def allow_request(self, request, view):
+ if not hasattr(self.__class__, 'called'):
+ self.__class__.called = True
+ return True
+ return False
+
+
+class MockView(APIView):
+ throttle_classes = (User3SecRateThrottle,)
+
+ def get(self, request):
+ return Response('foo')
+
+
+class MockView_MinuteThrottling(APIView):
+ throttle_classes = (User3MinRateThrottle,)
+
+ def get(self, request):
+ return Response('foo')
+
+
+class MockView_NonTimeThrottling(APIView):
+ throttle_classes = (NonTimeThrottle,)
+
+ def get(self, request):
+ return Response('foo')
+
+
+class ThrottlingTests(TestCase):
+ def setUp(self):
+ """
+ Reset the cache so that no throttles will be active
+ """
+ cache.clear()
+ self.factory = APIRequestFactory()
+
+ def test_requests_are_throttled(self):
+ """
+ Ensure request rate is limited
+ """
+ request = self.factory.get('/')
+ for dummy in range(4):
+ response = MockView.as_view()(request)
+ self.assertEqual(429, response.status_code)
+
+ def set_throttle_timer(self, view, value):
+ """
+ Explicitly set the timer, overriding time.time()
+ """
+ view.throttle_classes[0].timer = lambda self: value
+
+ def test_request_throttling_expires(self):
+ """
+ Ensure request rate is limited for a limited duration only
+ """
+ self.set_throttle_timer(MockView, 0)
+
+ request = self.factory.get('/')
+ for dummy in range(4):
+ response = MockView.as_view()(request)
+ self.assertEqual(429, response.status_code)
+
+ # Advance the timer by one second
+ self.set_throttle_timer(MockView, 1)
+
+ response = MockView.as_view()(request)
+ self.assertEqual(200, response.status_code)
+
+ def ensure_is_throttled(self, view, expect):
+ request = self.factory.get('/')
+ request.user = User.objects.create(username='a')
+ for dummy in range(3):
+ view.as_view()(request)
+ request.user = User.objects.create(username='b')
+ response = view.as_view()(request)
+ self.assertEqual(expect, response.status_code)
+
+ def test_request_throttling_is_per_user(self):
+ """
+ Ensure request rate is only limited per user, not globally for
+ PerUserThrottles
+ """
+ self.ensure_is_throttled(MockView, 200)
+
+ def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
+ """
+ Ensure the response returns an X-Throttle field with status and next attributes
+ set properly.
+ """
+ request = self.factory.get('/')
+ for timer, expect in expected_headers:
+ self.set_throttle_timer(view, timer)
+ response = view.as_view()(request)
+ if expect is not None:
+ self.assertEqual(response['X-Throttle-Wait-Seconds'], expect)
+ else:
+ self.assertFalse('X-Throttle-Wait-Seconds' in response)
+
+ def test_seconds_fields(self):
+ """
+ Ensure for second based throttles.
+ """
+ self.ensure_response_header_contains_proper_throttle_field(MockView,
+ ((0, None),
+ (0, None),
+ (0, None),
+ (0, '1')
+ ))
+
+ def test_minutes_fields(self):
+ """
+ Ensure for minute based throttles.
+ """
+ self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
+ ((0, None),
+ (0, None),
+ (0, None),
+ (0, '60')
+ ))
+
+ def test_next_rate_remains_constant_if_followed(self):
+ """
+ If a client follows the recommended next request rate,
+ the throttling rate should stay constant.
+ """
+ self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
+ ((0, None),
+ (20, None),
+ (40, None),
+ (60, None),
+ (80, None)
+ ))
+
+ def test_non_time_throttle(self):
+ """
+ Ensure for second based throttles.
+ """
+ request = self.factory.get('/')
+
+ self.assertFalse(hasattr(MockView_NonTimeThrottling.throttle_classes[0], 'called'))
+
+ response = MockView_NonTimeThrottling.as_view()(request)
+ self.assertFalse('X-Throttle-Wait-Seconds' in response)
+
+ self.assertTrue(MockView_NonTimeThrottling.throttle_classes[0].called)
+
+ response = MockView_NonTimeThrottling.as_view()(request)
+ self.assertFalse('X-Throttle-Wait-Seconds' in response)
+
+
+class ScopedRateThrottleTests(TestCase):
+ """
+ Tests for ScopedRateThrottle.
+ """
+
+ def setUp(self):
+ class XYScopedRateThrottle(ScopedRateThrottle):
+ TIMER_SECONDS = 0
+ THROTTLE_RATES = {'x': '3/min', 'y': '1/min'}
+ timer = lambda self: self.TIMER_SECONDS
+
+ class XView(APIView):
+ throttle_classes = (XYScopedRateThrottle,)
+ throttle_scope = 'x'
+
+ def get(self, request):
+ return Response('x')
+
+ class YView(APIView):
+ throttle_classes = (XYScopedRateThrottle,)
+ throttle_scope = 'y'
+
+ def get(self, request):
+ return Response('y')
+
+ class UnscopedView(APIView):
+ throttle_classes = (XYScopedRateThrottle,)
+
+ def get(self, request):
+ return Response('y')
+
+ self.throttle_class = XYScopedRateThrottle
+ self.factory = APIRequestFactory()
+ self.x_view = XView.as_view()
+ self.y_view = YView.as_view()
+ self.unscoped_view = UnscopedView.as_view()
+
+ def increment_timer(self, seconds=1):
+ self.throttle_class.TIMER_SECONDS += seconds
+
+ def test_scoped_rate_throttle(self):
+ request = self.factory.get('/')
+
+ # Should be able to hit x view 3 times per minute.
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(429, response.status_code)
+
+ # Should be able to hit y view 1 time per minute.
+ self.increment_timer()
+ response = self.y_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.y_view(request)
+ self.assertEqual(429, response.status_code)
+
+ # Ensure throttles properly reset by advancing the rest of the minute
+ self.increment_timer(55)
+
+ # Should still be able to hit x view 3 times per minute.
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(429, response.status_code)
+
+ # Should still be able to hit y view 1 time per minute.
+ self.increment_timer()
+ response = self.y_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.y_view(request)
+ self.assertEqual(429, response.status_code)
+
+ def test_unscoped_view_not_throttled(self):
+ request = self.factory.get('/')
+
+ for idx in range(10):
+ self.increment_timer()
+ response = self.unscoped_view(request)
+ self.assertEqual(200, response.status_code)
diff --git a/tests/test_urlpatterns.py b/tests/test_urlpatterns.py
new file mode 100644
index 00000000..8132ec4c
--- /dev/null
+++ b/tests/test_urlpatterns.py
@@ -0,0 +1,76 @@
+from __future__ import unicode_literals
+from collections import namedtuple
+from django.core import urlresolvers
+from django.test import TestCase
+from rest_framework.test import APIRequestFactory
+from rest_framework.compat import patterns, url, include
+from rest_framework.urlpatterns import format_suffix_patterns
+
+
+# A container class for test paths for the test case
+URLTestPath = namedtuple('URLTestPath', ['path', 'args', 'kwargs'])
+
+
+def dummy_view(request, *args, **kwargs):
+ pass
+
+
+class FormatSuffixTests(TestCase):
+ """
+ Tests `format_suffix_patterns` against different URLPatterns to ensure the URLs still resolve properly, including any captured parameters.
+ """
+ def _resolve_urlpatterns(self, urlpatterns, test_paths):
+ factory = APIRequestFactory()
+ try:
+ urlpatterns = format_suffix_patterns(urlpatterns)
+ except Exception:
+ self.fail("Failed to apply `format_suffix_patterns` on the supplied urlpatterns")
+ resolver = urlresolvers.RegexURLResolver(r'^/', urlpatterns)
+ for test_path in test_paths:
+ request = factory.get(test_path.path)
+ try:
+ callback, callback_args, callback_kwargs = resolver.resolve(request.path_info)
+ except Exception:
+ self.fail("Failed to resolve URL: %s" % request.path_info)
+ self.assertEqual(callback_args, test_path.args)
+ self.assertEqual(callback_kwargs, test_path.kwargs)
+
+ def test_format_suffix(self):
+ urlpatterns = patterns(
+ '',
+ url(r'^test$', dummy_view),
+ )
+ test_paths = [
+ URLTestPath('/test', (), {}),
+ URLTestPath('/test.api', (), {'format': 'api'}),
+ URLTestPath('/test.asdf', (), {'format': 'asdf'}),
+ ]
+ self._resolve_urlpatterns(urlpatterns, test_paths)
+
+ def test_default_args(self):
+ urlpatterns = patterns(
+ '',
+ url(r'^test$', dummy_view, {'foo': 'bar'}),
+ )
+ test_paths = [
+ URLTestPath('/test', (), {'foo': 'bar', }),
+ URLTestPath('/test.api', (), {'foo': 'bar', 'format': 'api'}),
+ URLTestPath('/test.asdf', (), {'foo': 'bar', 'format': 'asdf'}),
+ ]
+ self._resolve_urlpatterns(urlpatterns, test_paths)
+
+ def test_included_urls(self):
+ nested_patterns = patterns(
+ '',
+ url(r'^path$', dummy_view)
+ )
+ urlpatterns = patterns(
+ '',
+ url(r'^test/', include(nested_patterns), {'foo': 'bar'}),
+ )
+ test_paths = [
+ URLTestPath('/test/path', (), {'foo': 'bar', }),
+ URLTestPath('/test/path.api', (), {'foo': 'bar', 'format': 'api'}),
+ URLTestPath('/test/path.asdf', (), {'foo': 'bar', 'format': 'asdf'}),
+ ]
+ self._resolve_urlpatterns(urlpatterns, test_paths)
diff --git a/tests/test_validation.py b/tests/test_validation.py
new file mode 100644
index 00000000..124c874d
--- /dev/null
+++ b/tests/test_validation.py
@@ -0,0 +1,104 @@
+from __future__ import unicode_literals
+from django.db import models
+from django.test import TestCase
+from rest_framework import generics, serializers, status
+from rest_framework.test import APIRequestFactory
+
+factory = APIRequestFactory()
+
+
+# Regression for #666
+
+class ValidationModel(models.Model):
+ blank_validated_field = models.CharField(max_length=255)
+
+
+class ValidationModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ValidationModel
+ fields = ('blank_validated_field',)
+ read_only_fields = ('blank_validated_field',)
+
+
+class UpdateValidationModel(generics.RetrieveUpdateDestroyAPIView):
+ model = ValidationModel
+ serializer_class = ValidationModelSerializer
+
+
+class TestPreSaveValidationExclusions(TestCase):
+ def test_pre_save_validation_exclusions(self):
+ """
+ Somewhat weird test case to ensure that we don't perform model
+ validation on read only fields.
+ """
+ obj = ValidationModel.objects.create(blank_validated_field='')
+ request = factory.put('/', {}, format='json')
+ view = UpdateValidationModel().as_view()
+ response = view(request, pk=obj.pk).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+
+# Regression for #653
+
+class ShouldValidateModel(models.Model):
+ should_validate_field = models.CharField(max_length=255)
+
+
+class ShouldValidateModelSerializer(serializers.ModelSerializer):
+ renamed = serializers.CharField(source='should_validate_field', required=False)
+
+ def validate_renamed(self, attrs, source):
+ value = attrs[source]
+ if len(value) < 3:
+ raise serializers.ValidationError('Minimum 3 characters.')
+ return attrs
+
+ class Meta:
+ model = ShouldValidateModel
+ fields = ('renamed',)
+
+
+class TestPreSaveValidationExclusionsSerializer(TestCase):
+ def test_renamed_fields_are_model_validated(self):
+ """
+ Ensure fields with 'source' applied do get still get model validation.
+ """
+ # We've set `required=False` on the serializer, but the model
+ # does not have `blank=True`, so this serializer should not validate.
+ serializer = ShouldValidateModelSerializer(data={'renamed': ''})
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertIn('renamed', serializer.errors)
+ self.assertNotIn('should_validate_field', serializer.errors)
+
+
+class TestCustomValidationMethods(TestCase):
+ def test_custom_validation_method_is_executed(self):
+ serializer = ShouldValidateModelSerializer(data={'renamed': 'fo'})
+ self.assertFalse(serializer.is_valid())
+ self.assertIn('renamed', serializer.errors)
+
+ def test_custom_validation_method_passing(self):
+ serializer = ShouldValidateModelSerializer(data={'renamed': 'foo'})
+ self.assertTrue(serializer.is_valid())
+
+
+class ValidationSerializer(serializers.Serializer):
+ foo = serializers.CharField()
+
+ def validate_foo(self, attrs, source):
+ raise serializers.ValidationError("foo invalid")
+
+ def validate(self, attrs):
+ raise serializers.ValidationError("serializer invalid")
+
+
+class TestAvoidValidation(TestCase):
+ """
+ If serializer was initialized with invalid data (None or non dict-like), it
+ should avoid validation layer (validate_<field> and validate methods)
+ """
+ def test_serializer_errors_has_only_invalid_data_error(self):
+ serializer = ValidationSerializer(data='invalid data')
+ self.assertFalse(serializer.is_valid())
+ self.assertDictEqual(serializer.errors,
+ {'non_field_errors': ['Invalid data']})
diff --git a/tests/test_views.py b/tests/test_views.py
new file mode 100644
index 00000000..65c7e50e
--- /dev/null
+++ b/tests/test_views.py
@@ -0,0 +1,142 @@
+from __future__ import unicode_literals
+
+import copy
+from django.test import TestCase
+from rest_framework import status
+from rest_framework.decorators import api_view
+from rest_framework.response import Response
+from rest_framework.settings import api_settings
+from rest_framework.test import APIRequestFactory
+from rest_framework.views import APIView
+
+factory = APIRequestFactory()
+
+
+class BasicView(APIView):
+ def get(self, request, *args, **kwargs):
+ return Response({'method': 'GET'})
+
+ def post(self, request, *args, **kwargs):
+ return Response({'method': 'POST', 'data': request.DATA})
+
+
+@api_view(['GET', 'POST', 'PUT', 'PATCH'])
+def basic_view(request):
+ if request.method == 'GET':
+ return {'method': 'GET'}
+ elif request.method == 'POST':
+ return {'method': 'POST', 'data': request.DATA}
+ elif request.method == 'PUT':
+ return {'method': 'PUT', 'data': request.DATA}
+ elif request.method == 'PATCH':
+ return {'method': 'PATCH', 'data': request.DATA}
+
+
+class ErrorView(APIView):
+ def get(self, request, *args, **kwargs):
+ raise Exception
+
+
+@api_view(['GET'])
+def error_view(request):
+ raise Exception
+
+
+def sanitise_json_error(error_dict):
+ """
+ Exact contents of JSON error messages depend on the installed version
+ of json.
+ """
+ ret = copy.copy(error_dict)
+ chop = len('JSON parse error - No JSON object could be decoded')
+ ret['detail'] = ret['detail'][:chop]
+ return ret
+
+
+class ClassBasedViewIntegrationTests(TestCase):
+ def setUp(self):
+ self.view = BasicView.as_view()
+
+ def test_400_parse_error(self):
+ request = factory.post('/', 'f00bar', content_type='application/json')
+ response = self.view(request)
+ expected = {
+ 'detail': 'JSON parse error - No JSON object could be decoded'
+ }
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEqual(sanitise_json_error(response.data), expected)
+
+ def test_400_parse_error_tunneled_content(self):
+ content = 'f00bar'
+ content_type = 'application/json'
+ form_data = {
+ api_settings.FORM_CONTENT_OVERRIDE: content,
+ api_settings.FORM_CONTENTTYPE_OVERRIDE: content_type
+ }
+ request = factory.post('/', form_data)
+ response = self.view(request)
+ expected = {
+ 'detail': 'JSON parse error - No JSON object could be decoded'
+ }
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEqual(sanitise_json_error(response.data), expected)
+
+
+class FunctionBasedViewIntegrationTests(TestCase):
+ def setUp(self):
+ self.view = basic_view
+
+ def test_400_parse_error(self):
+ request = factory.post('/', 'f00bar', content_type='application/json')
+ response = self.view(request)
+ expected = {
+ 'detail': 'JSON parse error - No JSON object could be decoded'
+ }
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEqual(sanitise_json_error(response.data), expected)
+
+ def test_400_parse_error_tunneled_content(self):
+ content = 'f00bar'
+ content_type = 'application/json'
+ form_data = {
+ api_settings.FORM_CONTENT_OVERRIDE: content,
+ api_settings.FORM_CONTENTTYPE_OVERRIDE: content_type
+ }
+ request = factory.post('/', form_data)
+ response = self.view(request)
+ expected = {
+ 'detail': 'JSON parse error - No JSON object could be decoded'
+ }
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEqual(sanitise_json_error(response.data), expected)
+
+
+class TestCustomExceptionHandler(TestCase):
+ def setUp(self):
+ self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER
+
+ def exception_handler(exc):
+ return Response('Error!', status=status.HTTP_400_BAD_REQUEST)
+
+ api_settings.EXCEPTION_HANDLER = exception_handler
+
+ def tearDown(self):
+ api_settings.EXCEPTION_HANDLER = self.DEFAULT_HANDLER
+
+ def test_class_based_view_exception_handler(self):
+ view = ErrorView.as_view()
+
+ request = factory.get('/', content_type='application/json')
+ response = view(request)
+ expected = 'Error!'
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEqual(response.data, expected)
+
+ def test_function_based_view_exception_handler(self):
+ view = error_view
+
+ request = factory.get('/', content_type='application/json')
+ response = view(request)
+ expected = 'Error!'
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEqual(response.data, expected)
diff --git a/tests/test_write_only_fields.py b/tests/test_write_only_fields.py
new file mode 100644
index 00000000..aabb18d6
--- /dev/null
+++ b/tests/test_write_only_fields.py
@@ -0,0 +1,42 @@
+from django.db import models
+from django.test import TestCase
+from rest_framework import serializers
+
+
+class ExampleModel(models.Model):
+ email = models.EmailField(max_length=100)
+ password = models.CharField(max_length=100)
+
+
+class WriteOnlyFieldTests(TestCase):
+ def test_write_only_fields(self):
+ class ExampleSerializer(serializers.Serializer):
+ email = serializers.EmailField()
+ password = serializers.CharField(write_only=True)
+
+ data = {
+ 'email': 'foo@example.com',
+ 'password': '123'
+ }
+ serializer = ExampleSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEquals(serializer.object, data)
+ self.assertEquals(serializer.data, {'email': 'foo@example.com'})
+
+ def test_write_only_fields_meta(self):
+ class ExampleSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ExampleModel
+ fields = ('email', 'password')
+ write_only_fields = ('password',)
+
+ data = {
+ 'email': 'foo@example.com',
+ 'password': '123'
+ }
+ serializer = ExampleSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertTrue(isinstance(serializer.object, ExampleModel))
+ self.assertEquals(serializer.object.email, data['email'])
+ self.assertEquals(serializer.object.password, data['password'])
+ self.assertEquals(serializer.data, {'email': 'foo@example.com'})
diff --git a/tests/urls.py b/tests/urls.py
new file mode 100644
index 00000000..62cad339
--- /dev/null
+++ b/tests/urls.py
@@ -0,0 +1,6 @@
+"""
+Blank URLConf just to keep the test suite happy
+"""
+from rest_framework.compat import patterns
+
+urlpatterns = patterns('')
diff --git a/tests/users/__init__.py b/tests/users/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/tests/users/__init__.py
diff --git a/tests/users/models.py b/tests/users/models.py
new file mode 100644
index 00000000..128bac90
--- /dev/null
+++ b/tests/users/models.py
@@ -0,0 +1,6 @@
+from django.db import models
+
+
+class User(models.Model):
+ account = models.ForeignKey('accounts.Account', blank=True, null=True, related_name='users')
+ active_record = models.ForeignKey('records.Record', blank=True, null=True)
diff --git a/tests/users/serializers.py b/tests/users/serializers.py
new file mode 100644
index 00000000..4893ddb3
--- /dev/null
+++ b/tests/users/serializers.py
@@ -0,0 +1,8 @@
+from rest_framework import serializers
+
+from tests.users.models import User
+
+
+class UserSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = User
diff --git a/tests/views.py b/tests/views.py
new file mode 100644
index 00000000..55935e92
--- /dev/null
+++ b/tests/views.py
@@ -0,0 +1,8 @@
+from rest_framework import generics
+from .models import NullableForeignKeySource
+from .serializers import NullableFKSourceSerializer
+
+
+class NullableFKSourceDetail(generics.RetrieveUpdateDestroyAPIView):
+ model = NullableForeignKeySource
+ model_serializer_class = NullableFKSourceSerializer