aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/tests
diff options
context:
space:
mode:
authorCraig de Stigter2013-05-18 10:23:26 +0200
committerCraig de Stigter2013-05-18 10:23:26 +0200
commitc56d5f8f63673e2d0235015e67d78f8e0e1d7550 (patch)
treebf96abef843accae2439413de45095de2995f473 /rest_framework/tests
parent84be169353f0dd2ceb06fe459b72aa2452fcbeb5 (diff)
parent34776da9249a5d73f822b3562bc56a5674b10ac7 (diff)
downloaddjango-rest-framework-c56d5f8f63673e2d0235015e67d78f8e0e1d7550.tar.bz2
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'rest_framework/tests')
-rw-r--r--rest_framework/tests/authentication.py341
-rw-r--r--rest_framework/tests/description.py63
-rw-r--r--rest_framework/tests/fields.py583
-rw-r--r--rest_framework/tests/filters.py474
-rw-r--r--rest_framework/tests/filterset.py169
-rw-r--r--rest_framework/tests/generics.py164
-rw-r--r--rest_framework/tests/hyperlinkedserializers.py40
-rw-r--r--rest_framework/tests/models.py7
-rw-r--r--rest_framework/tests/multitable_inheritance.py67
-rw-r--r--rest_framework/tests/pagination.py112
-rw-r--r--rest_framework/tests/parsers.py33
-rw-r--r--rest_framework/tests/relations_hyperlink.py16
-rw-r--r--rest_framework/tests/relations_nested.py24
-rw-r--r--rest_framework/tests/relations_pk.py23
-rw-r--r--rest_framework/tests/request.py8
-rw-r--r--rest_framework/tests/routers.py55
-rw-r--r--rest_framework/tests/serializer.py135
-rw-r--r--rest_framework/tests/serializer_bulk_update.py278
-rw-r--r--rest_framework/tests/serializer_nested.py246
-rw-r--r--rest_framework/tests/status.py13
20 files changed, 2491 insertions, 360 deletions
diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py
index 7b754af5..8e6d3e51 100644
--- a/rest_framework/tests/authentication.py
+++ b/rest_framework/tests/authentication.py
@@ -2,23 +2,29 @@ from __future__ import unicode_literals
from django.contrib.auth.models import User
from django.http import HttpResponse
from django.test import Client, TestCase
+from django.utils import unittest
from rest_framework import HTTP_HEADER_ENCODING
from rest_framework import exceptions
from rest_framework import permissions
from rest_framework import status
-from rest_framework.authtoken.models import Token
from rest_framework.authentication import (
BaseAuthentication,
TokenAuthentication,
BasicAuthentication,
- SessionAuthentication
+ SessionAuthentication,
+ OAuthAuthentication,
+ OAuth2Authentication
)
-from rest_framework.compat import patterns
+from rest_framework.authtoken.models import Token
+from rest_framework.compat import patterns, url, include
+from rest_framework.compat import oauth2_provider, oauth2_provider_models, oauth2_provider_scope
+from rest_framework.compat import oauth, oauth_provider
from rest_framework.tests.utils import RequestFactory
from rest_framework.views import APIView
import json
import base64
-
+import time
+import datetime
factory = RequestFactory()
@@ -41,8 +47,19 @@ urlpatterns = patterns('',
(r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])),
(r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),
(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),
+ (r'^oauth/$', MockView.as_view(authentication_classes=[OAuthAuthentication])),
+ (r'^oauth-with-scope/$', MockView.as_view(authentication_classes=[OAuthAuthentication],
+ permission_classes=[permissions.TokenHasReadWriteScope]))
)
+if oauth2_provider is not None:
+ urlpatterns += patterns('',
+ url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')),
+ url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])),
+ url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication],
+ permission_classes=[permissions.TokenHasReadWriteScope])),
+ )
+
class BasicAuthTests(TestCase):
"""Basic authentication"""
@@ -146,7 +163,7 @@ class TokenAuthTests(TestCase):
def test_post_form_passing_token_auth(self):
"""Ensure POSTing json over token auth with correct credentials passes and does not require CSRF"""
- auth = "Token " + self.key
+ auth = 'Token ' + self.key
response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -222,3 +239,317 @@ class IncorrectCredentialsTests(TestCase):
response = view(request)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(response.data, {'detail': 'Bad credentials'})
+
+
+class OAuthTests(TestCase):
+ """OAuth 1.0a authentication"""
+ urls = 'rest_framework.tests.authentication'
+
+ def setUp(self):
+ # these imports are here because oauth is optional and hiding them in try..except block or compat
+ # could obscure problems if something breaks
+ from oauth_provider.models import Consumer, Resource
+ from oauth_provider.models import Token as OAuthToken
+ from oauth_provider import consts
+
+ self.consts = consts
+
+ self.csrf_client = Client(enforce_csrf_checks=True)
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ self.CONSUMER_KEY = 'consumer_key'
+ self.CONSUMER_SECRET = 'consumer_secret'
+ self.TOKEN_KEY = "token_key"
+ self.TOKEN_SECRET = "token_secret"
+
+ self.consumer = Consumer.objects.create(key=self.CONSUMER_KEY, secret=self.CONSUMER_SECRET,
+ name='example', user=self.user, status=self.consts.ACCEPTED)
+
+ self.resource = Resource.objects.create(name="resource name", url="api/")
+ self.token = OAuthToken.objects.create(user=self.user, consumer=self.consumer, resource=self.resource,
+ token_type=OAuthToken.ACCESS, key=self.TOKEN_KEY, secret=self.TOKEN_SECRET, is_approved=True
+ )
+
+ def _create_authorization_header(self):
+ params = {
+ 'oauth_version': "1.0",
+ 'oauth_nonce': oauth.generate_nonce(),
+ 'oauth_timestamp': int(time.time()),
+ 'oauth_token': self.token.key,
+ 'oauth_consumer_key': self.consumer.key
+ }
+
+ req = oauth.Request(method="GET", url="http://example.com", parameters=params)
+
+ signature_method = oauth.SignatureMethod_PLAINTEXT()
+ req.sign_request(signature_method, self.consumer, self.token)
+
+ return req.to_header()["Authorization"]
+
+ def _create_authorization_url_parameters(self):
+ params = {
+ 'oauth_version': "1.0",
+ 'oauth_nonce': oauth.generate_nonce(),
+ 'oauth_timestamp': int(time.time()),
+ 'oauth_token': self.token.key,
+ 'oauth_consumer_key': self.consumer.key
+ }
+
+ req = oauth.Request(method="GET", url="http://example.com", parameters=params)
+
+ signature_method = oauth.SignatureMethod_PLAINTEXT()
+ req.sign_request(signature_method, self.consumer, self.token)
+ return dict(req)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_passing_oauth(self):
+ """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_repeated_nonce_failing_oauth(self):
+ """Ensure POSTing form over OAuth with repeated auth (same nonces and timestamp) credentials fails"""
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ # simulate reply attack auth header containes already used (nonce, timestamp) pair
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_token_removed_failing_oauth(self):
+ """Ensure POSTing when there is no OAuth access token in db fails"""
+ self.token.delete()
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_consumer_status_not_accepted_failing_oauth(self):
+ """Ensure POSTing when consumer status is anything other than ACCEPTED fails"""
+ for consumer_status in (self.consts.CANCELED, self.consts.PENDING, self.consts.REJECTED):
+ self.consumer.status = consumer_status
+ self.consumer.save()
+
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_with_request_token_failing_oauth(self):
+ """Ensure POSTing with unauthorized request token instead of access token fails"""
+ self.token.token_type = self.token.REQUEST
+ self.token.save()
+
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_with_urlencoded_parameters(self):
+ """Ensure POSTing with x-www-form-urlencoded auth parameters passes"""
+ params = self._create_authorization_url_parameters()
+ response = self.csrf_client.post('/oauth/', params)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_get_form_with_url_parameters(self):
+ """Ensure GETing with auth in url parameters passes"""
+ params = self._create_authorization_url_parameters()
+ response = self.csrf_client.get('/oauth/', params)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_hmac_sha1_signature_passes(self):
+ """Ensure POSTing using HMAC_SHA1 signature method passes"""
+ params = {
+ 'oauth_version': "1.0",
+ 'oauth_nonce': oauth.generate_nonce(),
+ 'oauth_timestamp': int(time.time()),
+ 'oauth_token': self.token.key,
+ 'oauth_consumer_key': self.consumer.key
+ }
+
+ req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params)
+
+ signature_method = oauth.SignatureMethod_HMAC_SHA1()
+ req.sign_request(signature_method, self.consumer, self.token)
+ auth = req.to_header()["Authorization"]
+
+ response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_get_form_with_readonly_resource_passing_auth(self):
+ """Ensure POSTing with a readonly resource instead of a write scope fails"""
+ read_only_access_token = self.token
+ read_only_access_token.resource.is_readonly = True
+ read_only_access_token.resource.save()
+ params = self._create_authorization_url_parameters()
+ response = self.csrf_client.get('/oauth-with-scope/', params)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_with_readonly_resource_failing_auth(self):
+ """Ensure POSTing with a readonly resource instead of a write scope fails"""
+ read_only_access_token = self.token
+ read_only_access_token.resource.is_readonly = True
+ read_only_access_token.resource.save()
+ params = self._create_authorization_url_parameters()
+ response = self.csrf_client.post('/oauth-with-scope/', params)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_with_write_resource_passing_auth(self):
+ """Ensure POSTing with a write resource succeed"""
+ read_write_access_token = self.token
+ read_write_access_token.resource.is_readonly = False
+ read_write_access_token.resource.save()
+ params = self._create_authorization_url_parameters()
+ response = self.csrf_client.post('/oauth-with-scope/', params)
+ self.assertEqual(response.status_code, 200)
+
+
+class OAuth2Tests(TestCase):
+ """OAuth 2.0 authentication"""
+ urls = 'rest_framework.tests.authentication'
+
+ def setUp(self):
+ self.csrf_client = Client(enforce_csrf_checks=True)
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ self.CLIENT_ID = 'client_key'
+ self.CLIENT_SECRET = 'client_secret'
+ self.ACCESS_TOKEN = "access_token"
+ self.REFRESH_TOKEN = "refresh_token"
+
+ self.oauth2_client = oauth2_provider_models.Client.objects.create(
+ client_id=self.CLIENT_ID,
+ client_secret=self.CLIENT_SECRET,
+ redirect_uri='',
+ client_type=0,
+ name='example',
+ user=None,
+ )
+
+ self.access_token = oauth2_provider_models.AccessToken.objects.create(
+ token=self.ACCESS_TOKEN,
+ client=self.oauth2_client,
+ user=self.user,
+ )
+ self.refresh_token = oauth2_provider_models.RefreshToken.objects.create(
+ user=self.user,
+ access_token=self.access_token,
+ client=self.oauth2_client
+ )
+
+ def _create_authorization_header(self, token=None):
+ return "Bearer {0}".format(token or self.access_token.token)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_with_wrong_authorization_header_token_type_failing(self):
+ """Ensure that a wrong token type lead to the correct HTTP error status code"""
+ auth = "Wrong token-type-obsviously"
+ response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+ response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_with_wrong_authorization_header_token_format_failing(self):
+ """Ensure that a wrong token format lead to the correct HTTP error status code"""
+ auth = "Bearer wrong token format"
+ response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+ response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_with_wrong_authorization_header_token_failing(self):
+ """Ensure that a wrong token lead to the correct HTTP error status code"""
+ auth = "Bearer wrong-token"
+ response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+ response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_passing_auth(self):
+ """Ensure GETing form over OAuth with correct client credentials succeed"""
+ auth = self._create_authorization_header()
+ response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_passing_auth(self):
+ """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_token_removed_failing_auth(self):
+ """Ensure POSTing when there is no OAuth access token in db fails"""
+ self.access_token.delete()
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_with_refresh_token_failing_auth(self):
+ """Ensure POSTing with refresh token instead of access token fails"""
+ auth = self._create_authorization_header(token=self.refresh_token.token)
+ response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_with_expired_access_token_failing_auth(self):
+ """Ensure POSTing with expired access token fails with an 'Invalid token' error"""
+ self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10) # 10 seconds late
+ self.access_token.save()
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+ self.assertIn('Invalid token', response.content)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_with_invalid_scope_failing_auth(self):
+ """Ensure POSTing with a readonly scope instead of a write scope fails"""
+ read_only_access_token = self.access_token
+ read_only_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['read']
+ read_only_access_token.save()
+ auth = self._create_authorization_header(token=read_only_access_token.token)
+ response = self.csrf_client.get('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+ response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_with_valid_scope_passing_auth(self):
+ """Ensure POSTing with a write scope succeed"""
+ read_write_access_token = self.access_token
+ read_write_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['write']
+ read_write_access_token.save()
+ auth = self._create_authorization_header(token=read_write_access_token.token)
+ response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
diff --git a/rest_framework/tests/description.py b/rest_framework/tests/description.py
index 5b3315bc..52c1a34c 100644
--- a/rest_framework/tests/description.py
+++ b/rest_framework/tests/description.py
@@ -4,6 +4,7 @@ from __future__ import unicode_literals
from django.test import TestCase
from rest_framework.views import APIView
from rest_framework.compat import apply_markdown
+from rest_framework.utils.formatting import get_view_name, get_view_description
# We check that docstrings get nicely un-indented.
DESCRIPTION = """an example docstring
@@ -49,22 +50,16 @@ MARKED_DOWN_gte_21 = """<h2 id="an-example-docstring">an example docstring</h2>
class TestViewNamesAndDescriptions(TestCase):
- def test_resource_name_uses_classname_by_default(self):
- """Ensure Resource names are based on the classname by default."""
+ def test_view_name_uses_class_name(self):
+ """
+ Ensure view names are based on the class name.
+ """
class MockView(APIView):
pass
- self.assertEqual(MockView().get_name(), 'Mock')
+ self.assertEqual(get_view_name(MockView), 'Mock')
- def test_resource_name_can_be_set_explicitly(self):
- """Ensure Resource names can be set using the 'get_name' method."""
- example = 'Some Other Name'
- class MockView(APIView):
- def get_name(self):
- return example
- self.assertEqual(MockView().get_name(), example)
-
- def test_resource_description_uses_docstring_by_default(self):
- """Ensure Resource names are based on the docstring by default."""
+ def test_view_description_uses_docstring(self):
+ """Ensure view descriptions are based on the docstring."""
class MockView(APIView):
"""an example docstring
====================
@@ -81,44 +76,32 @@ class TestViewNamesAndDescriptions(TestCase):
# hash style header #"""
- self.assertEqual(MockView().get_description(), DESCRIPTION)
-
- def test_resource_description_can_be_set_explicitly(self):
- """Ensure Resource descriptions can be set using the 'get_description' method."""
- example = 'Some other description'
-
- class MockView(APIView):
- """docstring"""
- def get_description(self):
- return example
- self.assertEqual(MockView().get_description(), example)
+ self.assertEqual(get_view_description(MockView), DESCRIPTION)
- def test_resource_description_supports_unicode(self):
+ def test_view_description_supports_unicode(self):
+ """
+ Unicode in docstrings should be respected.
+ """
class MockView(APIView):
"""Проверка"""
pass
- self.assertEqual(MockView().get_description(), "Проверка")
-
-
- def test_resource_description_does_not_require_docstring(self):
- """Ensure that empty docstrings do not affect the Resource's description if it has been set using the 'get_description' method."""
- example = 'Some other description'
-
- class MockView(APIView):
- def get_description(self):
- return example
- self.assertEqual(MockView().get_description(), example)
+ self.assertEqual(get_view_description(MockView), "Проверка")
- def test_resource_description_can_be_empty(self):
- """Ensure that if a resource has no doctring or 'description' class attribute, then it's description is the empty string."""
+ def test_view_description_can_be_empty(self):
+ """
+ Ensure that if a view has no docstring,
+ then it's description is the empty string.
+ """
class MockView(APIView):
pass
- self.assertEqual(MockView().get_description(), '')
+ self.assertEqual(get_view_description(MockView), '')
def test_markdown(self):
- """Ensure markdown to HTML works as expected"""
+ """
+ Ensure markdown to HTML works as expected.
+ """
if apply_markdown:
gte_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_gte_21
lt_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_lt_21
diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py
index 840ed320..3cdfa0f6 100644
--- a/rest_framework/tests/fields.py
+++ b/rest_framework/tests/fields.py
@@ -3,10 +3,14 @@ General serializer field tests.
"""
from __future__ import unicode_literals
import datetime
+from decimal import Decimal
+
from django.db import models
from django.test import TestCase
from django.core import validators
+
from rest_framework import serializers
+from rest_framework.serializers import Serializer
class TimestampedModel(models.Model):
@@ -59,37 +63,586 @@ class BasicFieldTests(TestCase):
serializer = CharPrimaryKeyModelSerializer()
self.assertEqual(serializer.fields['id'].read_only, False)
- def test_TimeField_from_native(self):
+
+class DateFieldTest(TestCase):
+ """
+ Tests for the DateFieldTest from_native() and to_native() behavior
+ """
+
+ def test_from_native_string(self):
+ """
+ Make sure from_native() accepts default iso input formats.
+ """
+ f = serializers.DateField()
+ result_1 = f.from_native('1984-07-31')
+
+ self.assertEqual(datetime.date(1984, 7, 31), result_1)
+
+ def test_from_native_datetime_date(self):
+ """
+ Make sure from_native() accepts a datetime.date instance.
+ """
+ f = serializers.DateField()
+ result_1 = f.from_native(datetime.date(1984, 7, 31))
+
+ self.assertEqual(result_1, datetime.date(1984, 7, 31))
+
+ def test_from_native_custom_format(self):
+ """
+ Make sure from_native() accepts custom input formats.
+ """
+ f = serializers.DateField(input_formats=['%Y -- %d'])
+ result = f.from_native('1984 -- 31')
+
+ self.assertEqual(datetime.date(1984, 1, 31), result)
+
+ def test_from_native_invalid_default_on_custom_format(self):
+ """
+ Make sure from_native() don't accept default formats if custom format is preset
+ """
+ f = serializers.DateField(input_formats=['%Y -- %d'])
+
+ try:
+ f.from_native('1984-07-31')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY -- DD"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns None on empty param.
+ """
+ f = serializers.DateField()
+ result = f.from_native('')
+
+ self.assertEqual(result, None)
+
+ def test_from_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DateField()
+ result = f.from_native(None)
+
+ self.assertEqual(result, None)
+
+ def test_from_native_invalid_date(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid date.
+ """
+ f = serializers.DateField()
+
+ try:
+ f.from_native('1984-13-31')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_invalid_format(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid format.
+ """
+ f = serializers.DateField()
+
+ try:
+ f.from_native('1984 -- 31')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_to_native(self):
+ """
+ Make sure to_native() returns datetime as default.
+ """
+ f = serializers.DateField()
+
+ result_1 = f.to_native(datetime.date(1984, 7, 31))
+
+ self.assertEqual(datetime.date(1984, 7, 31), result_1)
+
+ def test_to_native_iso(self):
+ """
+ Make sure to_native() with 'iso-8601' returns iso formated date.
+ """
+ f = serializers.DateField(format='iso-8601')
+
+ result_1 = f.to_native(datetime.date(1984, 7, 31))
+
+ self.assertEqual('1984-07-31', result_1)
+
+ def test_to_native_custom_format(self):
+ """
+ Make sure to_native() returns correct custom format.
+ """
+ f = serializers.DateField(format="%Y - %m.%d")
+
+ result_1 = f.to_native(datetime.date(1984, 7, 31))
+
+ self.assertEqual('1984 - 07.31', result_1)
+
+ def test_to_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DateField(required=False)
+ self.assertEqual(None, f.to_native(None))
+
+
+class DateTimeFieldTest(TestCase):
+ """
+ Tests for the DateTimeField from_native() and to_native() behavior
+ """
+
+ def test_from_native_string(self):
+ """
+ Make sure from_native() accepts default iso input formats.
+ """
+ f = serializers.DateTimeField()
+ result_1 = f.from_native('1984-07-31 04:31')
+ result_2 = f.from_native('1984-07-31 04:31:59')
+ result_3 = f.from_native('1984-07-31 04:31:59.000200')
+
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_1)
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_2)
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_3)
+
+ def test_from_native_datetime_datetime(self):
+ """
+ Make sure from_native() accepts a datetime.datetime instance.
+ """
+ f = serializers.DateTimeField()
+ result_1 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31))
+ result_2 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
+ result_3 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+
+ self.assertEqual(result_1, datetime.datetime(1984, 7, 31, 4, 31))
+ self.assertEqual(result_2, datetime.datetime(1984, 7, 31, 4, 31, 59))
+ self.assertEqual(result_3, datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+
+ def test_from_native_custom_format(self):
+ """
+ Make sure from_native() accepts custom input formats.
+ """
+ f = serializers.DateTimeField(input_formats=['%Y -- %H:%M'])
+ result = f.from_native('1984 -- 04:59')
+
+ self.assertEqual(datetime.datetime(1984, 1, 1, 4, 59), result)
+
+ def test_from_native_invalid_default_on_custom_format(self):
+ """
+ Make sure from_native() don't accept default formats if custom format is preset
+ """
+ f = serializers.DateTimeField(input_formats=['%Y -- %H:%M'])
+
+ try:
+ f.from_native('1984-07-31 04:31:59')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: YYYY -- hh:mm"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns None on empty param.
+ """
+ f = serializers.DateTimeField()
+ result = f.from_native('')
+
+ self.assertEqual(result, None)
+
+ def test_from_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DateTimeField()
+ result = f.from_native(None)
+
+ self.assertEqual(result, None)
+
+ def test_from_native_invalid_datetime(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid datetime.
+ """
+ f = serializers.DateTimeField()
+
+ try:
+ f.from_native('04:61:59')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: "
+ "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_invalid_format(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid format.
+ """
+ f = serializers.DateTimeField()
+
+ try:
+ f.from_native('04 -- 31')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: "
+ "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_to_native(self):
+ """
+ Make sure to_native() returns isoformat as default.
+ """
+ f = serializers.DateTimeField()
+
+ result_1 = f.to_native(datetime.datetime(1984, 7, 31))
+ result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31))
+ result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
+ result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+
+ self.assertEqual(datetime.datetime(1984, 7, 31), result_1)
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_2)
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_3)
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_4)
+
+ def test_to_native_iso(self):
+ """
+ Make sure to_native() with format=iso-8601 returns iso formatted datetime.
+ """
+ f = serializers.DateTimeField(format='iso-8601')
+
+ result_1 = f.to_native(datetime.datetime(1984, 7, 31))
+ result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31))
+ result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
+ result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+
+ self.assertEqual('1984-07-31T00:00:00', result_1)
+ self.assertEqual('1984-07-31T04:31:00', result_2)
+ self.assertEqual('1984-07-31T04:31:59', result_3)
+ self.assertEqual('1984-07-31T04:31:59.000200', result_4)
+
+ def test_to_native_custom_format(self):
+ """
+ Make sure to_native() returns correct custom format.
+ """
+ f = serializers.DateTimeField(format="%Y - %H:%M")
+
+ result_1 = f.to_native(datetime.datetime(1984, 7, 31))
+ result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31))
+ result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
+ result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+
+ self.assertEqual('1984 - 00:00', result_1)
+ self.assertEqual('1984 - 04:31', result_2)
+ self.assertEqual('1984 - 04:31', result_3)
+ self.assertEqual('1984 - 04:31', result_4)
+
+ def test_to_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DateTimeField(required=False)
+ self.assertEqual(None, f.to_native(None))
+
+
+class TimeFieldTest(TestCase):
+ """
+ Tests for the TimeField from_native() and to_native() behavior
+ """
+
+ def test_from_native_string(self):
+ """
+ Make sure from_native() accepts default iso input formats.
+ """
f = serializers.TimeField()
- result = f.from_native('12:34:56.987654')
+ result_1 = f.from_native('04:31')
+ result_2 = f.from_native('04:31:59')
+ result_3 = f.from_native('04:31:59.000200')
- self.assertEqual(datetime.time(12, 34, 56, 987654), result)
+ self.assertEqual(datetime.time(4, 31), result_1)
+ self.assertEqual(datetime.time(4, 31, 59), result_2)
+ self.assertEqual(datetime.time(4, 31, 59, 200), result_3)
- def test_TimeField_from_native_datetime_time(self):
+ def test_from_native_datetime_time(self):
"""
Make sure from_native() accepts a datetime.time instance.
"""
f = serializers.TimeField()
- result = f.from_native(datetime.time(12, 34, 56))
- self.assertEqual(result, datetime.time(12, 34, 56))
+ result_1 = f.from_native(datetime.time(4, 31))
+ result_2 = f.from_native(datetime.time(4, 31, 59))
+ result_3 = f.from_native(datetime.time(4, 31, 59, 200))
+
+ self.assertEqual(result_1, datetime.time(4, 31))
+ self.assertEqual(result_2, datetime.time(4, 31, 59))
+ self.assertEqual(result_3, datetime.time(4, 31, 59, 200))
+
+ def test_from_native_custom_format(self):
+ """
+ Make sure from_native() accepts custom input formats.
+ """
+ f = serializers.TimeField(input_formats=['%H -- %M'])
+ result = f.from_native('04 -- 31')
- def test_TimeField_from_native_empty(self):
+ self.assertEqual(datetime.time(4, 31), result)
+
+ def test_from_native_invalid_default_on_custom_format(self):
+ """
+ Make sure from_native() don't accept default formats if custom format is preset
+ """
+ f = serializers.TimeField(input_formats=['%H -- %M'])
+
+ try:
+ f.from_native('04:31:59')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: hh -- mm"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns None on empty param.
+ """
f = serializers.TimeField()
result = f.from_native('')
+
+ self.assertEqual(result, None)
+
+ def test_from_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.TimeField()
+ result = f.from_native(None)
+
self.assertEqual(result, None)
- def test_TimeField_from_native_invalid_time(self):
+ def test_from_native_invalid_time(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid time.
+ """
+ f = serializers.TimeField()
+
+ try:
+ f.from_native('04:61:59')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: "
+ "hh:mm[:ss[.uuuuuu]]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_invalid_format(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid format.
+ """
+ f = serializers.TimeField()
+
+ try:
+ f.from_native('04 -- 31')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: "
+ "hh:mm[:ss[.uuuuuu]]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_to_native(self):
+ """
+ Make sure to_native() returns time object as default.
+ """
f = serializers.TimeField()
+ result_1 = f.to_native(datetime.time(4, 31))
+ result_2 = f.to_native(datetime.time(4, 31, 59))
+ result_3 = f.to_native(datetime.time(4, 31, 59, 200))
+
+ self.assertEqual(datetime.time(4, 31), result_1)
+ self.assertEqual(datetime.time(4, 31, 59), result_2)
+ self.assertEqual(datetime.time(4, 31, 59, 200), result_3)
+
+ def test_to_native_iso(self):
+ """
+ Make sure to_native() with format='iso-8601' returns iso formatted time.
+ """
+ f = serializers.TimeField(format='iso-8601')
+ result_1 = f.to_native(datetime.time(4, 31))
+ result_2 = f.to_native(datetime.time(4, 31, 59))
+ result_3 = f.to_native(datetime.time(4, 31, 59, 200))
+
+ self.assertEqual('04:31:00', result_1)
+ self.assertEqual('04:31:59', result_2)
+ self.assertEqual('04:31:59.000200', result_3)
+
+ def test_to_native_custom_format(self):
+ """
+ Make sure to_native() returns correct custom format.
+ """
+ f = serializers.TimeField(format="%H - %S [%f]")
+ result_1 = f.to_native(datetime.time(4, 31))
+ result_2 = f.to_native(datetime.time(4, 31, 59))
+ result_3 = f.to_native(datetime.time(4, 31, 59, 200))
+
+ self.assertEqual('04 - 00 [000000]', result_1)
+ self.assertEqual('04 - 59 [000000]', result_2)
+ self.assertEqual('04 - 59 [000200]', result_3)
+
+
+class DecimalFieldTest(TestCase):
+ """
+ Tests for the DecimalField from_native() and to_native() behavior
+ """
+
+ def test_from_native_string(self):
+ """
+ Make sure from_native() accepts string values
+ """
+ f = serializers.DecimalField()
+ result_1 = f.from_native('9000')
+ result_2 = f.from_native('1.00000001')
+
+ self.assertEqual(Decimal('9000'), result_1)
+ self.assertEqual(Decimal('1.00000001'), result_2)
+
+ def test_from_native_invalid_string(self):
+ """
+ Make sure from_native() raises ValidationError on passing invalid string
+ """
+ f = serializers.DecimalField()
try:
- f.from_native('12:69:12')
+ f.from_native('123.45.6')
except validators.ValidationError as e:
- self.assertEqual(e.messages, ["'12:69:12' value has an invalid "
- "format. It must be a valid time "
- "in the HH:MM[:ss[.uuuuuu]] format."])
+ self.assertEqual(e.messages, ["Enter a number."])
else:
self.fail("ValidationError was not properly raised")
- def test_TimeFieldModelSerializer(self):
- serializer = TimeFieldModelSerializer()
- self.assertTrue(isinstance(serializer.fields['clock'], serializers.TimeField))
+ def test_from_native_integer(self):
+ """
+ Make sure from_native() accepts integer values
+ """
+ f = serializers.DecimalField()
+ result = f.from_native(9000)
+
+ self.assertEqual(Decimal('9000'), result)
+
+ def test_from_native_float(self):
+ """
+ Make sure from_native() accepts float values
+ """
+ f = serializers.DecimalField()
+ result = f.from_native(1.00000001)
+
+ self.assertEqual(Decimal('1.00000001'), result)
+
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns None on empty param.
+ """
+ f = serializers.DecimalField()
+ result = f.from_native('')
+
+ self.assertEqual(result, None)
+
+ def test_from_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DecimalField()
+ result = f.from_native(None)
+
+ self.assertEqual(result, None)
+
+ def test_to_native(self):
+ """
+ Make sure to_native() returns Decimal as string.
+ """
+ f = serializers.DecimalField()
+
+ result_1 = f.to_native(Decimal('9000'))
+ result_2 = f.to_native(Decimal('1.00000001'))
+
+ self.assertEqual(Decimal('9000'), result_1)
+ self.assertEqual(Decimal('1.00000001'), result_2)
+
+ def test_to_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DecimalField(required=False)
+ self.assertEqual(None, f.to_native(None))
+
+ def test_valid_serialization(self):
+ """
+ Make sure the serializer works correctly
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(max_value=9010,
+ min_value=9000,
+ max_digits=6,
+ decimal_places=2)
+
+ self.assertTrue(DecimalSerializer(data={'decimal_field': '9001'}).is_valid())
+ self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.2'}).is_valid())
+ self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.23'}).is_valid())
+
+ self.assertFalse(DecimalSerializer(data={'decimal_field': '8000'}).is_valid())
+ self.assertFalse(DecimalSerializer(data={'decimal_field': '9900'}).is_valid())
+ self.assertFalse(DecimalSerializer(data={'decimal_field': '9001.234'}).is_valid())
+
+ def test_raise_max_value(self):
+ """
+ Make sure max_value violations raises ValidationError
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(max_value=100)
+
+ s = DecimalSerializer(data={'decimal_field': '123'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is less than or equal to 100.']})
+
+ def test_raise_min_value(self):
+ """
+ Make sure min_value violations raises ValidationError
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(min_value=100)
+
+ s = DecimalSerializer(data={'decimal_field': '99'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is greater than or equal to 100.']})
+
+ def test_raise_max_digits(self):
+ """
+ Make sure max_digits violations raises ValidationError
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(max_digits=5)
+
+ s = DecimalSerializer(data={'decimal_field': '123.456'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 5 digits in total.']})
+
+ def test_raise_max_decimal_places(self):
+ """
+ Make sure max_decimal_places violations raises ValidationError
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(decimal_places=3)
+
+ s = DecimalSerializer(data={'decimal_field': '123.4567'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 3 decimal places.']})
+
+ def test_raise_max_whole_digits(self):
+ """
+ Make sure max_whole_digits violations raises ValidationError
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(max_digits=4, decimal_places=3)
+
+ s = DecimalSerializer(data={'decimal_field': '12345.6'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']}) \ No newline at end of file
diff --git a/rest_framework/tests/filters.py b/rest_framework/tests/filters.py
new file mode 100644
index 00000000..8ae6d530
--- /dev/null
+++ b/rest_framework/tests/filters.py
@@ -0,0 +1,474 @@
+from __future__ import unicode_literals
+import datetime
+from decimal import Decimal
+from django.db import models
+from django.core.urlresolvers import reverse
+from django.test import TestCase
+from django.test.client import RequestFactory
+from django.utils import unittest
+from rest_framework import generics, serializers, status, filters
+from rest_framework.compat import django_filters, patterns, url
+from rest_framework.tests.models import BasicModel
+
+factory = RequestFactory()
+
+
+class FilterableItem(models.Model):
+ text = models.CharField(max_length=100)
+ decimal = models.DecimalField(max_digits=4, decimal_places=2)
+ date = models.DateField()
+
+
+if django_filters:
+ # Basic filter on a list view.
+ class FilterFieldsRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_fields = ['decimal', 'date']
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ # These class are used to test a filter class.
+ class SeveralFieldsFilter(django_filters.FilterSet):
+ text = django_filters.CharFilter(lookup_type='icontains')
+ decimal = django_filters.NumberFilter(lookup_type='lt')
+ date = django_filters.DateFilter(lookup_type='gt')
+
+ class Meta:
+ model = FilterableItem
+ fields = ['text', 'decimal', 'date']
+
+ class FilterClassRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_class = SeveralFieldsFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ # These classes are used to test a misconfigured filter class.
+ class MisconfiguredFilter(django_filters.FilterSet):
+ text = django_filters.CharFilter(lookup_type='icontains')
+
+ class Meta:
+ model = BasicModel
+ fields = ['text']
+
+ class IncorrectlyConfiguredRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_class = MisconfiguredFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ class FilterClassDetailView(generics.RetrieveAPIView):
+ model = FilterableItem
+ filter_class = SeveralFieldsFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ # Regression test for #814
+ class FilterableItemSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = FilterableItem
+
+ class FilterFieldsQuerysetView(generics.ListCreateAPIView):
+ queryset = FilterableItem.objects.all()
+ serializer_class = FilterableItemSerializer
+ filter_fields = ['decimal', 'date']
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ class GetQuerysetView(generics.ListCreateAPIView):
+ serializer_class = FilterableItemSerializer
+ filter_class = SeveralFieldsFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ def get_queryset(self):
+ return FilterableItem.objects.all()
+
+ urlpatterns = patterns('',
+ url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'),
+ url(r'^$', FilterClassRootView.as_view(), name='root-view'),
+ url(r'^get-queryset/$', GetQuerysetView.as_view(),
+ name='get-queryset-view'),
+ )
+
+
+class CommonFilteringTestCase(TestCase):
+ def _serialize_object(self, obj):
+ return {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
+
+ def setUp(self):
+ """
+ Create 10 FilterableItem instances.
+ """
+ base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
+ for i in range(10):
+ text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
+ decimal = base_data[1] + i
+ date = base_data[2] - datetime.timedelta(days=i * 2)
+ FilterableItem(text=text, decimal=decimal, date=date).save()
+
+ self.objects = FilterableItem.objects
+ self.data = [
+ self._serialize_object(obj)
+ for obj in self.objects.all()
+ ]
+
+
+class IntegrationTestFiltering(CommonFilteringTestCase):
+ """
+ Integration tests for filtered list views.
+ """
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_get_filtered_fields_root_view(self):
+ """
+ GET requests to paginated ListCreateAPIView should return paginated results.
+ """
+ view = FilterFieldsRootView.as_view()
+
+ # Basic test with no filter.
+ request = factory.get('/')
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
+
+ # Tests that the decimal filter works.
+ search_decimal = Decimal('2.25')
+ request = factory.get('/?decimal=%s' % search_decimal)
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['decimal'] == search_decimal]
+ self.assertEqual(response.data, expected_data)
+
+ # Tests that the date filter works.
+ search_date = datetime.date(2012, 9, 22)
+ request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22'
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['date'] == search_date]
+ self.assertEqual(response.data, expected_data)
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_filter_with_queryset(self):
+ """
+ Regression test for #814.
+ """
+ view = FilterFieldsQuerysetView.as_view()
+
+ # Tests that the decimal filter works.
+ search_decimal = Decimal('2.25')
+ request = factory.get('/?decimal=%s' % search_decimal)
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['decimal'] == search_decimal]
+ self.assertEqual(response.data, expected_data)
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_filter_with_get_queryset_only(self):
+ """
+ Regression test for #834.
+ """
+ view = GetQuerysetView.as_view()
+ request = factory.get('/get-queryset/')
+ view(request).render()
+ # Used to raise "issubclass() arg 2 must be a class or tuple of classes"
+ # here when neither `model' nor `queryset' was specified.
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_get_filtered_class_root_view(self):
+ """
+ GET requests to filtered ListCreateAPIView that have a filter_class set
+ should return filtered results.
+ """
+ view = FilterClassRootView.as_view()
+
+ # Basic test with no filter.
+ request = factory.get('/')
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
+
+ # Tests that the decimal filter set with 'lt' in the filter class works.
+ search_decimal = Decimal('4.25')
+ request = factory.get('/?decimal=%s' % search_decimal)
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['decimal'] < search_decimal]
+ self.assertEqual(response.data, expected_data)
+
+ # Tests that the date filter set with 'gt' in the filter class works.
+ search_date = datetime.date(2012, 10, 2)
+ request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02'
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['date'] > search_date]
+ self.assertEqual(response.data, expected_data)
+
+ # Tests that the text filter set with 'icontains' in the filter class works.
+ search_text = 'ff'
+ request = factory.get('/?text=%s' % search_text)
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if search_text in f['text'].lower()]
+ self.assertEqual(response.data, expected_data)
+
+ # Tests that multiple filters works.
+ search_decimal = Decimal('5.25')
+ search_date = datetime.date(2012, 10, 2)
+ request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date))
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['date'] > search_date and
+ f['decimal'] < search_decimal]
+ self.assertEqual(response.data, expected_data)
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_incorrectly_configured_filter(self):
+ """
+ An error should be displayed when the filter class is misconfigured.
+ """
+ view = IncorrectlyConfiguredRootView.as_view()
+
+ request = factory.get('/')
+ self.assertRaises(AssertionError, view, request)
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_unknown_filter(self):
+ """
+ GET requests with filters that aren't configured should return 200.
+ """
+ view = FilterFieldsRootView.as_view()
+
+ search_integer = 10
+ request = factory.get('/?integer=%s' % search_integer)
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+
+class IntegrationTestDetailFiltering(CommonFilteringTestCase):
+ """
+ Integration tests for filtered detail views.
+ """
+ urls = 'rest_framework.tests.filters'
+
+ def _get_url(self, item):
+ return reverse('detail-view', kwargs=dict(pk=item.pk))
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_get_filtered_detail_view(self):
+ """
+ GET requests to filtered RetrieveAPIView that have a filter_class set
+ should return filtered results.
+ """
+ item = self.objects.all()[0]
+ data = self._serialize_object(item)
+
+ # Basic test with no filter.
+ response = self.client.get(self._get_url(item))
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, data)
+
+ # Tests that the decimal filter set that should fail.
+ search_decimal = Decimal('4.25')
+ high_item = self.objects.filter(decimal__gt=search_decimal)[0]
+ response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(high_item), param=search_decimal))
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
+ # Tests that the decimal filter set that should succeed.
+ search_decimal = Decimal('4.25')
+ low_item = self.objects.filter(decimal__lt=search_decimal)[0]
+ low_item_data = self._serialize_object(low_item)
+ response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal))
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, low_item_data)
+
+ # Tests that multiple filters works.
+ search_decimal = Decimal('5.25')
+ search_date = datetime.date(2012, 10, 2)
+ valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0]
+ valid_item_data = self._serialize_object(valid_item)
+ response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date))
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, valid_item_data)
+
+
+class SearchFilterModel(models.Model):
+ title = models.CharField(max_length=20)
+ text = models.CharField(max_length=100)
+
+
+class SearchFilterTests(TestCase):
+ def setUp(self):
+ # Sequence of title/text is:
+ #
+ # z abc
+ # zz bcd
+ # zzz cde
+ # ...
+ for idx in range(10):
+ title = 'z' * (idx + 1)
+ text = (
+ chr(idx + ord('a')) +
+ chr(idx + ord('b')) +
+ chr(idx + ord('c'))
+ )
+ SearchFilterModel(title=title, text=text).save()
+
+ def test_search(self):
+ class SearchListView(generics.ListAPIView):
+ model = SearchFilterModel
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('title', 'text')
+
+ view = SearchListView.as_view()
+ request = factory.get('?search=b')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 1, 'title': 'z', 'text': 'abc'},
+ {'id': 2, 'title': 'zz', 'text': 'bcd'}
+ ]
+ )
+
+ def test_exact_search(self):
+ class SearchListView(generics.ListAPIView):
+ model = SearchFilterModel
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('=title', 'text')
+
+ view = SearchListView.as_view()
+ request = factory.get('?search=zzz')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'zzz', 'text': 'cde'}
+ ]
+ )
+
+ def test_startswith_search(self):
+ class SearchListView(generics.ListAPIView):
+ model = SearchFilterModel
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('title', '^text')
+
+ view = SearchListView.as_view()
+ request = factory.get('?search=b')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 2, 'title': 'zz', 'text': 'bcd'}
+ ]
+ )
+
+
+class OrdringFilterModel(models.Model):
+ title = models.CharField(max_length=20)
+ text = models.CharField(max_length=100)
+
+
+class OrderingFilterTests(TestCase):
+ def setUp(self):
+ # Sequence of title/text is:
+ #
+ # zyx abc
+ # yxw bcd
+ # xwv cde
+ for idx in range(3):
+ title = (
+ chr(ord('z') - idx) +
+ chr(ord('y') - idx) +
+ chr(ord('x') - idx)
+ )
+ text = (
+ chr(idx + ord('a')) +
+ chr(idx + ord('b')) +
+ chr(idx + ord('c'))
+ )
+ OrdringFilterModel(title=title, text=text).save()
+
+ def test_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=text')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ ]
+ )
+
+ def test_reverse_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=-text')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
+
+ def test_incorrectfield_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=foobar')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
+
+ def test_default_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
+
+ def test_default_ordering_using_string(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = 'title'
+
+ view = OrderingListView.as_view()
+ request = factory.get('')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py
deleted file mode 100644
index 8c13947c..00000000
--- a/rest_framework/tests/filterset.py
+++ /dev/null
@@ -1,169 +0,0 @@
-from __future__ import unicode_literals
-import datetime
-from decimal import Decimal
-from django.test import TestCase
-from django.test.client import RequestFactory
-from django.utils import unittest
-from rest_framework import generics, status, filters
-from rest_framework.compat import django_filters
-from rest_framework.tests.models import FilterableItem, BasicModel
-
-factory = RequestFactory()
-
-
-if django_filters:
- # Basic filter on a list view.
- class FilterFieldsRootView(generics.ListCreateAPIView):
- model = FilterableItem
- filter_fields = ['decimal', 'date']
- filter_backend = filters.DjangoFilterBackend
-
- # These class are used to test a filter class.
- class SeveralFieldsFilter(django_filters.FilterSet):
- text = django_filters.CharFilter(lookup_type='icontains')
- decimal = django_filters.NumberFilter(lookup_type='lt')
- date = django_filters.DateFilter(lookup_type='gt')
-
- class Meta:
- model = FilterableItem
- fields = ['text', 'decimal', 'date']
-
- class FilterClassRootView(generics.ListCreateAPIView):
- model = FilterableItem
- filter_class = SeveralFieldsFilter
- filter_backend = filters.DjangoFilterBackend
-
- # These classes are used to test a misconfigured filter class.
- class MisconfiguredFilter(django_filters.FilterSet):
- text = django_filters.CharFilter(lookup_type='icontains')
-
- class Meta:
- model = BasicModel
- fields = ['text']
-
- class IncorrectlyConfiguredRootView(generics.ListCreateAPIView):
- model = FilterableItem
- filter_class = MisconfiguredFilter
- filter_backend = filters.DjangoFilterBackend
-
-
-class IntegrationTestFiltering(TestCase):
- """
- Integration tests for filtered list views.
- """
-
- def setUp(self):
- """
- Create 10 FilterableItem instances.
- """
- base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
- for i in range(10):
- text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
- decimal = base_data[1] + i
- date = base_data[2] - datetime.timedelta(days=i * 2)
- FilterableItem(text=text, decimal=decimal, date=date).save()
-
- self.objects = FilterableItem.objects
- self.data = [
- {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
- for obj in self.objects.all()
- ]
-
- @unittest.skipUnless(django_filters, 'django-filters not installed')
- def test_get_filtered_fields_root_view(self):
- """
- GET requests to paginated ListCreateAPIView should return paginated results.
- """
- view = FilterFieldsRootView.as_view()
-
- # Basic test with no filter.
- request = factory.get('/')
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, self.data)
-
- # Tests that the decimal filter works.
- search_decimal = Decimal('2.25')
- request = factory.get('/?decimal=%s' % search_decimal)
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['decimal'] == search_decimal]
- self.assertEqual(response.data, expected_data)
-
- # Tests that the date filter works.
- search_date = datetime.date(2012, 9, 22)
- request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22'
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['date'] == search_date]
- self.assertEqual(response.data, expected_data)
-
- @unittest.skipUnless(django_filters, 'django-filters not installed')
- def test_get_filtered_class_root_view(self):
- """
- GET requests to filtered ListCreateAPIView that have a filter_class set
- should return filtered results.
- """
- view = FilterClassRootView.as_view()
-
- # Basic test with no filter.
- request = factory.get('/')
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, self.data)
-
- # Tests that the decimal filter set with 'lt' in the filter class works.
- search_decimal = Decimal('4.25')
- request = factory.get('/?decimal=%s' % search_decimal)
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['decimal'] < search_decimal]
- self.assertEqual(response.data, expected_data)
-
- # Tests that the date filter set with 'gt' in the filter class works.
- search_date = datetime.date(2012, 10, 2)
- request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02'
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['date'] > search_date]
- self.assertEqual(response.data, expected_data)
-
- # Tests that the text filter set with 'icontains' in the filter class works.
- search_text = 'ff'
- request = factory.get('/?text=%s' % search_text)
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if search_text in f['text'].lower()]
- self.assertEqual(response.data, expected_data)
-
- # Tests that multiple filters works.
- search_decimal = Decimal('5.25')
- search_date = datetime.date(2012, 10, 2)
- request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date))
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['date'] > search_date and
- f['decimal'] < search_decimal]
- self.assertEqual(response.data, expected_data)
-
- @unittest.skipUnless(django_filters, 'django-filters not installed')
- def test_incorrectly_configured_filter(self):
- """
- An error should be displayed when the filter class is misconfigured.
- """
- view = IncorrectlyConfiguredRootView.as_view()
-
- request = factory.get('/')
- self.assertRaises(AssertionError, view, request)
-
- @unittest.skipUnless(django_filters, 'django-filters not installed')
- def test_unknown_filter(self):
- """
- GET requests with filters that aren't configured should return 200.
- """
- view = FilterFieldsRootView.as_view()
-
- search_integer = 10
- request = factory.get('/?integer=%s' % search_integer)
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py
index f8f2ddaa..2799d143 100644
--- a/rest_framework/tests/generics.py
+++ b/rest_framework/tests/generics.py
@@ -1,5 +1,6 @@
from __future__ import unicode_literals
from django.db import models
+from django.shortcuts import get_object_or_404
from django.test import TestCase
from rest_framework import generics, serializers, status
from rest_framework.tests.utils import RequestFactory
@@ -38,6 +39,7 @@ class SlugBasedInstanceView(InstanceView):
"""
model = SlugBasedModel
serializer_class = SlugSerializer
+ lookup_field = 'slug'
class TestRootView(TestCase):
@@ -60,7 +62,8 @@ class TestRootView(TestCase):
GET requests to ListCreateAPIView should return list of objects.
"""
request = factory.get('/')
- response = self.view(request).render()
+ with self.assertNumQueries(1):
+ response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, self.data)
@@ -71,7 +74,8 @@ class TestRootView(TestCase):
content = {'text': 'foobar'}
request = factory.post('/', json.dumps(content),
content_type='application/json')
- response = self.view(request).render()
+ with self.assertNumQueries(1):
+ response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response.data, {'id': 4, 'text': 'foobar'})
created = self.objects.get(id=4)
@@ -84,7 +88,8 @@ class TestRootView(TestCase):
content = {'text': 'foobar'}
request = factory.put('/', json.dumps(content),
content_type='application/json')
- response = self.view(request).render()
+ with self.assertNumQueries(0):
+ response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
self.assertEqual(response.data, {"detail": "Method 'PUT' not allowed."})
@@ -93,7 +98,8 @@ class TestRootView(TestCase):
DELETE requests to ListCreateAPIView should not be allowed
"""
request = factory.delete('/')
- response = self.view(request).render()
+ with self.assertNumQueries(0):
+ response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
self.assertEqual(response.data, {"detail": "Method 'DELETE' not allowed."})
@@ -102,7 +108,8 @@ class TestRootView(TestCase):
OPTIONS requests to ListCreateAPIView should return metadata
"""
request = factory.options('/')
- response = self.view(request).render()
+ with self.assertNumQueries(0):
+ response = self.view(request).render()
expected = {
'parses': [
'application/json',
@@ -126,7 +133,8 @@ class TestRootView(TestCase):
content = {'id': 999, 'text': 'foobar'}
request = factory.post('/', json.dumps(content),
content_type='application/json')
- response = self.view(request).render()
+ with self.assertNumQueries(1):
+ response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response.data, {'id': 4, 'text': 'foobar'})
created = self.objects.get(id=4)
@@ -154,7 +162,8 @@ class TestInstanceView(TestCase):
GET requests to RetrieveUpdateDestroyAPIView should return a single object.
"""
request = factory.get('/1')
- response = self.view(request, pk=1).render()
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, self.data[0])
@@ -165,7 +174,8 @@ class TestInstanceView(TestCase):
content = {'text': 'foobar'}
request = factory.post('/', json.dumps(content),
content_type='application/json')
- response = self.view(request).render()
+ with self.assertNumQueries(0):
+ response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
self.assertEqual(response.data, {"detail": "Method 'POST' not allowed."})
@@ -176,7 +186,8 @@ class TestInstanceView(TestCase):
content = {'text': 'foobar'}
request = factory.put('/1', json.dumps(content),
content_type='application/json')
- response = self.view(request, pk='1').render()
+ with self.assertNumQueries(2):
+ response = self.view(request, pk='1').render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
updated = self.objects.get(id=1)
@@ -190,7 +201,8 @@ class TestInstanceView(TestCase):
request = factory.patch('/1', json.dumps(content),
content_type='application/json')
- response = self.view(request, pk=1).render()
+ with self.assertNumQueries(2):
+ response = self.view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
updated = self.objects.get(id=1)
@@ -201,7 +213,8 @@ class TestInstanceView(TestCase):
DELETE requests to RetrieveUpdateDestroyAPIView should delete an object.
"""
request = factory.delete('/1')
- response = self.view(request, pk=1).render()
+ with self.assertNumQueries(2):
+ response = self.view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
self.assertEqual(response.content, six.b(''))
ids = [obj.id for obj in self.objects.all()]
@@ -212,7 +225,8 @@ class TestInstanceView(TestCase):
OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata
"""
request = factory.options('/')
- response = self.view(request).render()
+ with self.assertNumQueries(0):
+ response = self.view(request).render()
expected = {
'parses': [
'application/json',
@@ -236,7 +250,8 @@ class TestInstanceView(TestCase):
content = {'id': 999, 'text': 'foobar'}
request = factory.put('/1', json.dumps(content),
content_type='application/json')
- response = self.view(request, pk=1).render()
+ with self.assertNumQueries(2):
+ response = self.view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
updated = self.objects.get(id=1)
@@ -251,7 +266,8 @@ class TestInstanceView(TestCase):
content = {'text': 'foobar'}
request = factory.put('/1', json.dumps(content),
content_type='application/json')
- response = self.view(request, pk=1).render()
+ with self.assertNumQueries(3):
+ response = self.view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
updated = self.objects.get(id=1)
@@ -263,10 +279,11 @@ class TestInstanceView(TestCase):
at the requested url if it doesn't exist.
"""
content = {'text': 'foobar'}
- # pk fields can not be created on demand, only the database can set th pk for a new object
+ # pk fields can not be created on demand, only the database can set the pk for a new object
request = factory.put('/5', json.dumps(content),
content_type='application/json')
- response = self.view(request, pk=5).render()
+ with self.assertNumQueries(3):
+ response = self.view(request, pk=5).render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
new_obj = self.objects.get(pk=5)
self.assertEqual(new_obj.text, 'foobar')
@@ -279,13 +296,55 @@ class TestInstanceView(TestCase):
content = {'text': 'foobar'}
request = factory.put('/test_slug', json.dumps(content),
content_type='application/json')
- response = self.slug_based_view(request, slug='test_slug').render()
+ with self.assertNumQueries(2):
+ response = self.slug_based_view(request, slug='test_slug').render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response.data, {'slug': 'test_slug', 'text': 'foobar'})
new_obj = SlugBasedModel.objects.get(slug='test_slug')
self.assertEqual(new_obj.text, 'foobar')
+class TestOverriddenGetObject(TestCase):
+ """
+ Test cases for a RetrieveUpdateDestroyAPIView that does NOT use the
+ queryset/model mechanism but instead overrides get_object()
+ """
+ def setUp(self):
+ """
+ Create 3 BasicModel intances.
+ """
+ items = ['foo', 'bar', 'baz']
+ for item in items:
+ BasicModel(text=item).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+
+ class OverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView):
+ """
+ Example detail view for override of get_object().
+ """
+ model = BasicModel
+
+ def get_object(self):
+ pk = int(self.kwargs['pk'])
+ return get_object_or_404(BasicModel.objects.all(), id=pk)
+
+ self.view = OverriddenGetObjectView.as_view()
+
+ def test_overridden_get_object_view(self):
+ """
+ GET requests to RetrieveUpdateDestroyAPIView should return a single object.
+ """
+ request = factory.get('/1')
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data[0])
+
+
# Regression test for #285
class CommentSerializer(serializers.ModelSerializer):
@@ -319,7 +378,7 @@ class TestCreateModelWithAutoNowAddField(TestCase):
self.assertEqual(created.content, 'foobar')
-# Test for particularly ugly regression with m2m in browseable API
+# Test for particularly ugly regression with m2m in browsable API
class ClassB(models.Model):
name = models.CharField(max_length=255)
@@ -344,9 +403,76 @@ class ExampleView(generics.ListCreateAPIView):
class TestM2MBrowseableAPI(TestCase):
def test_m2m_in_browseable_api(self):
"""
- Test for particularly ugly regression with m2m in browseable API
+ Test for particularly ugly regression with m2m in browsable API
"""
request = factory.get('/', HTTP_ACCEPT='text/html')
view = ExampleView().as_view()
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+
+class InclusiveFilterBackend(object):
+ def filter_queryset(self, request, queryset, view):
+ return queryset.filter(text='foo')
+
+
+class ExclusiveFilterBackend(object):
+ def filter_queryset(self, request, queryset, view):
+ return queryset.filter(text='other')
+
+
+class TestFilterBackendAppliedToViews(TestCase):
+
+ def setUp(self):
+ """
+ Create 3 BasicModel instances to filter on.
+ """
+ items = ['foo', 'bar', 'baz']
+ for item in items:
+ BasicModel(text=item).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+
+ def test_get_root_view_filters_by_name_with_filter_backend(self):
+ """
+ GET requests to ListCreateAPIView should return filtered list.
+ """
+ root_view = RootView.as_view(filter_backends=(InclusiveFilterBackend,))
+ request = factory.get('/')
+ response = root_view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(len(response.data), 1)
+ self.assertEqual(response.data, [{'id': 1, 'text': 'foo'}])
+
+ def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(self):
+ """
+ GET requests to ListCreateAPIView should return empty list when all models are filtered out.
+ """
+ root_view = RootView.as_view(filter_backends=(ExclusiveFilterBackend,))
+ request = factory.get('/')
+ response = root_view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, [])
+
+ def test_get_instance_view_filters_out_name_with_filter_backend(self):
+ """
+ GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out.
+ """
+ instance_view = InstanceView.as_view(filter_backends=(ExclusiveFilterBackend,))
+ request = factory.get('/1')
+ response = instance_view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+ self.assertEqual(response.data, {'detail': 'Not found'})
+
+ def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self):
+ """
+ GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded
+ """
+ instance_view = InstanceView.as_view(filter_backends=(InclusiveFilterBackend,))
+ request = factory.get('/1')
+ response = instance_view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, {'id': 1, 'text': 'foo'})
diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py
index 9a61f299..8fc6ba77 100644
--- a/rest_framework/tests/hyperlinkedserializers.py
+++ b/rest_framework/tests/hyperlinkedserializers.py
@@ -27,6 +27,14 @@ class PhotoSerializer(serializers.Serializer):
return Photo(**attrs)
+class AlbumSerializer(serializers.ModelSerializer):
+ url = serializers.HyperlinkedIdentityField(view_name='album-detail', lookup_field='title')
+
+ class Meta:
+ model = Album
+ fields = ('title', 'url')
+
+
class BasicList(generics.ListCreateAPIView):
model = BasicModel
model_serializer_class = serializers.HyperlinkedModelSerializer
@@ -73,6 +81,8 @@ class PhotoListCreate(generics.ListCreateAPIView):
class AlbumDetail(generics.RetrieveAPIView):
model = Album
+ serializer_class = AlbumSerializer
+ lookup_field = 'title'
class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView):
@@ -180,6 +190,36 @@ class TestManyToManyHyperlinkedView(TestCase):
self.assertEqual(response.data, self.data[0])
+class TestHyperlinkedIdentityFieldLookup(TestCase):
+ urls = 'rest_framework.tests.hyperlinkedserializers'
+
+ def setUp(self):
+ """
+ Create 3 Album instances.
+ """
+ titles = ['foo', 'bar', 'baz']
+ for title in titles:
+ album = Album(title=title)
+ album.save()
+ self.detail_view = AlbumDetail.as_view()
+ self.data = {
+ 'foo': {'title': 'foo', 'url': 'http://testserver/albums/foo/'},
+ 'bar': {'title': 'bar', 'url': 'http://testserver/albums/bar/'},
+ 'baz': {'title': 'baz', 'url': 'http://testserver/albums/baz/'}
+ }
+
+ def test_lookup_field(self):
+ """
+ GET requests to AlbumDetail view should return serialized Albums
+ with a url field keyed by `title`.
+ """
+ for album in Album.objects.all():
+ request = factory.get('/albums/{0}/'.format(album.title))
+ response = self.detail_view(request, title=album.title)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data[album.title])
+
+
class TestCreateWithForeignKeys(TestCase):
urls = 'rest_framework.tests.hyperlinkedserializers'
diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py
index f2117538..40e41a64 100644
--- a/rest_framework/tests/models.py
+++ b/rest_framework/tests/models.py
@@ -58,13 +58,6 @@ class ReadOnlyManyToManyModel(RESTFrameworkModel):
rel = models.ManyToManyField(Anchor)
-# Model to test filtering.
-class FilterableItem(RESTFrameworkModel):
- text = models.CharField(max_length=100)
- decimal = models.DecimalField(max_digits=4, decimal_places=2)
- date = models.DateField()
-
-
# Model for regression test for #285
class Comment(RESTFrameworkModel):
diff --git a/rest_framework/tests/multitable_inheritance.py b/rest_framework/tests/multitable_inheritance.py
new file mode 100644
index 00000000..00c15327
--- /dev/null
+++ b/rest_framework/tests/multitable_inheritance.py
@@ -0,0 +1,67 @@
+from __future__ import unicode_literals
+from django.db import models
+from django.test import TestCase
+from rest_framework import serializers
+from rest_framework.tests.models import RESTFrameworkModel
+
+
+# Models
+class ParentModel(RESTFrameworkModel):
+ name1 = models.CharField(max_length=100)
+
+
+class ChildModel(ParentModel):
+ name2 = models.CharField(max_length=100)
+
+
+class AssociatedModel(RESTFrameworkModel):
+ ref = models.OneToOneField(ParentModel, primary_key=True)
+ name = models.CharField(max_length=100)
+
+
+# Serializers
+class DerivedModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ChildModel
+
+
+class AssociatedModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = AssociatedModel
+
+
+# Tests
+class IneritedModelSerializationTests(TestCase):
+
+ def test_multitable_inherited_model_fields_as_expected(self):
+ """
+ Assert that the parent pointer field is not included in the fields
+ serialized fields
+ """
+ child = ChildModel(name1='parent name', name2='child name')
+ serializer = DerivedModelSerializer(child)
+ self.assertEqual(set(serializer.data.keys()),
+ set(['name1', 'name2', 'id']))
+
+ def test_onetoone_primary_key_model_fields_as_expected(self):
+ """
+ Assert that a model with a onetoone field that is the primary key is
+ not treated like a derived model
+ """
+ parent = ParentModel(name1='parent name')
+ associate = AssociatedModel(name='hello', ref=parent)
+ serializer = AssociatedModelSerializer(associate)
+ self.assertEqual(set(serializer.data.keys()),
+ set(['name', 'ref']))
+
+ def test_data_is_valid_without_parent_ptr(self):
+ """
+ Assert that the pointer to the parent table is not a required field
+ for input data
+ """
+ data = {
+ 'name1': 'parent name',
+ 'name2': 'child name',
+ }
+ serializer = DerivedModelSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), True)
diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py
index 6b9970a6..e538a78e 100644
--- a/rest_framework/tests/pagination.py
+++ b/rest_framework/tests/pagination.py
@@ -1,17 +1,24 @@
from __future__ import unicode_literals
import datetime
from decimal import Decimal
+from django.db import models
from django.core.paginator import Paginator
from django.test import TestCase
from django.test.client import RequestFactory
from django.utils import unittest
from rest_framework import generics, status, pagination, filters, serializers
from rest_framework.compat import django_filters
-from rest_framework.tests.models import BasicModel, FilterableItem
+from rest_framework.tests.models import BasicModel
factory = RequestFactory()
+class FilterableItem(models.Model):
+ text = models.CharField(max_length=100)
+ decimal = models.DecimalField(max_digits=4, decimal_places=2)
+ date = models.DateField()
+
+
class RootView(generics.ListCreateAPIView):
"""
Example description for OPTIONS.
@@ -20,21 +27,6 @@ class RootView(generics.ListCreateAPIView):
paginate_by = 10
-if django_filters:
- class DecimalFilter(django_filters.FilterSet):
- decimal = django_filters.NumberFilter(lookup_type='lt')
-
- class Meta:
- model = FilterableItem
- fields = ['text', 'decimal', 'date']
-
- class FilterFieldsRootView(generics.ListCreateAPIView):
- model = FilterableItem
- paginate_by = 10
- filter_class = DecimalFilter
- filter_backend = filters.DjangoFilterBackend
-
-
class DefaultPageSizeKwargView(generics.ListAPIView):
"""
View for testing default paginate_by_param usage
@@ -73,7 +65,9 @@ class IntegrationTestPagination(TestCase):
GET requests to paginated ListCreateAPIView should return paginated results.
"""
request = factory.get('/')
- response = self.view(request).render()
+ # Note: Database queries are a `SELECT COUNT`, and `SELECT <fields>`
+ with self.assertNumQueries(2):
+ response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['count'], 26)
self.assertEqual(response.data['results'], self.data[:10])
@@ -81,7 +75,8 @@ class IntegrationTestPagination(TestCase):
self.assertEqual(response.data['previous'], None)
request = factory.get(response.data['next'])
- response = self.view(request).render()
+ with self.assertNumQueries(2):
+ response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['count'], 26)
self.assertEqual(response.data['results'], self.data[10:20])
@@ -89,7 +84,8 @@ class IntegrationTestPagination(TestCase):
self.assertNotEqual(response.data['previous'], None)
request = factory.get(response.data['next'])
- response = self.view(request).render()
+ with self.assertNumQueries(2):
+ response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['count'], 26)
self.assertEqual(response.data['results'], self.data[20:])
@@ -112,20 +108,37 @@ class IntegrationTestPaginationAndFiltering(TestCase):
self.objects = FilterableItem.objects
self.data = [
- {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
- for obj in self.objects.all()
+ {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
+ for obj in self.objects.all()
]
- self.view = FilterFieldsRootView.as_view()
@unittest.skipUnless(django_filters, 'django-filters not installed')
- def test_get_paginated_filtered_root_view(self):
+ def test_get_django_filter_paginated_filtered_root_view(self):
"""
GET requests to paginated filtered ListCreateAPIView should return
paginated results. The next and previous links should preserve the
filtered parameters.
"""
+ class DecimalFilter(django_filters.FilterSet):
+ decimal = django_filters.NumberFilter(lookup_type='lt')
+
+ class Meta:
+ model = FilterableItem
+ fields = ['text', 'decimal', 'date']
+
+ class FilterFieldsRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ paginate_by = 10
+ filter_class = DecimalFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ view = FilterFieldsRootView.as_view()
+
+ EXPECTED_NUM_QUERIES = 2
+
request = factory.get('/?decimal=15.20')
- response = self.view(request).render()
+ with self.assertNumQueries(EXPECTED_NUM_QUERIES):
+ response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['count'], 15)
self.assertEqual(response.data['results'], self.data[:10])
@@ -133,7 +146,8 @@ class IntegrationTestPaginationAndFiltering(TestCase):
self.assertEqual(response.data['previous'], None)
request = factory.get(response.data['next'])
- response = self.view(request).render()
+ with self.assertNumQueries(EXPECTED_NUM_QUERIES):
+ response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['count'], 15)
self.assertEqual(response.data['results'], self.data[10:15])
@@ -141,7 +155,53 @@ class IntegrationTestPaginationAndFiltering(TestCase):
self.assertNotEqual(response.data['previous'], None)
request = factory.get(response.data['previous'])
- response = self.view(request).render()
+ with self.assertNumQueries(EXPECTED_NUM_QUERIES):
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 15)
+ self.assertEqual(response.data['results'], self.data[:10])
+ self.assertNotEqual(response.data['next'], None)
+ self.assertEqual(response.data['previous'], None)
+
+ def test_get_basic_paginated_filtered_root_view(self):
+ """
+ Same as `test_get_django_filter_paginated_filtered_root_view`,
+ except using a custom filter backend instead of the django-filter
+ backend,
+ """
+
+ class DecimalFilterBackend(filters.BaseFilterBackend):
+ def filter_queryset(self, request, queryset, view):
+ return queryset.filter(decimal__lt=Decimal(request.GET['decimal']))
+
+ class BasicFilterFieldsRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ paginate_by = 10
+ filter_backends = (DecimalFilterBackend,)
+
+ view = BasicFilterFieldsRootView.as_view()
+
+ request = factory.get('/?decimal=15.20')
+ with self.assertNumQueries(2):
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 15)
+ self.assertEqual(response.data['results'], self.data[:10])
+ self.assertNotEqual(response.data['next'], None)
+ self.assertEqual(response.data['previous'], None)
+
+ request = factory.get(response.data['next'])
+ with self.assertNumQueries(2):
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 15)
+ self.assertEqual(response.data['results'], self.data[10:15])
+ self.assertEqual(response.data['next'], None)
+ self.assertNotEqual(response.data['previous'], None)
+
+ request = factory.get(response.data['previous'])
+ with self.assertNumQueries(2):
+ response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['count'], 15)
self.assertEqual(response.data['results'], self.data[:10])
diff --git a/rest_framework/tests/parsers.py b/rest_framework/tests/parsers.py
index 539c5b44..7699e10c 100644
--- a/rest_framework/tests/parsers.py
+++ b/rest_framework/tests/parsers.py
@@ -1,10 +1,11 @@
from __future__ import unicode_literals
from rest_framework.compat import StringIO
from django import forms
+from django.core.files.uploadhandler import MemoryFileUploadHandler
from django.test import TestCase
from django.utils import unittest
from rest_framework.compat import etree
-from rest_framework.parsers import FormParser
+from rest_framework.parsers import FormParser, FileUploadParser
from rest_framework.parsers import XMLParser
import datetime
@@ -82,3 +83,33 @@ class TestXMLParser(TestCase):
parser = XMLParser()
data = parser.parse(self._complex_data_input)
self.assertEqual(data, self._complex_data)
+
+
+class TestFileUploadParser(TestCase):
+ def setUp(self):
+ class MockRequest(object):
+ pass
+ from io import BytesIO
+ self.stream = BytesIO(
+ "Test text file".encode('utf-8')
+ )
+ request = MockRequest()
+ request.upload_handlers = (MemoryFileUploadHandler(),)
+ request.META = {
+ 'HTTP_CONTENT_DISPOSITION': 'Content-Disposition: inline; filename=file.txt'.encode('utf-8'),
+ 'HTTP_CONTENT_LENGTH': 14,
+ }
+ self.parser_context = {'request': request, 'kwargs': {}}
+
+ def test_parse(self):
+ """ Make sure the `QueryDict` works OK """
+ parser = FileUploadParser()
+ self.stream.seek(0)
+ data_and_files = parser.parse(self.stream, None, self.parser_context)
+ file_obj = data_and_files.files['file']
+ self.assertEqual(file_obj._size, 14)
+
+ def test_get_filename(self):
+ parser = FileUploadParser()
+ filename = parser.get_filename(self.stream, None, self.parser_context)
+ self.assertEqual(filename, 'file.txt'.encode('utf-8'))
diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/relations_hyperlink.py
index b5702a48..b1eed9a7 100644
--- a/rest_framework/tests/relations_hyperlink.py
+++ b/rest_framework/tests/relations_hyperlink.py
@@ -26,42 +26,44 @@ urlpatterns = patterns('',
)
+# ManyToMany
class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer):
- sources = serializers.HyperlinkedRelatedField(many=True, view_name='manytomanysource-detail')
-
class Meta:
model = ManyToManyTarget
+ fields = ('url', 'name', 'sources')
class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = ManyToManySource
+ fields = ('url', 'name', 'targets')
+# ForeignKey
class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer):
- sources = serializers.HyperlinkedRelatedField(many=True, view_name='foreignkeysource-detail')
-
class Meta:
model = ForeignKeyTarget
+ fields = ('url', 'name', 'sources')
class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = ForeignKeySource
+ fields = ('url', 'name', 'target')
# Nullable ForeignKey
class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = NullableForeignKeySource
+ fields = ('url', 'name', 'target')
-# OneToOne
+# Nullable OneToOne
class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer):
- nullable_source = serializers.HyperlinkedRelatedField(view_name='nullableonetoonesource-detail')
-
class Meta:
model = OneToOneTarget
+ fields = ('url', 'name', 'nullable_source')
# TODO: Add test that .data cannot be accessed prior to .is_valid
diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py
index a125ba65..f6d006b3 100644
--- a/rest_framework/tests/relations_nested.py
+++ b/rest_framework/tests/relations_nested.py
@@ -6,38 +6,30 @@ from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, Null
class ForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
- depth = 1
- model = ForeignKeySource
-
-
-class FlatForeignKeySourceSerializer(serializers.ModelSerializer):
- class Meta:
model = ForeignKeySource
+ fields = ('id', 'name', 'target')
+ depth = 1
class ForeignKeyTargetSerializer(serializers.ModelSerializer):
- sources = FlatForeignKeySourceSerializer(many=True)
-
class Meta:
model = ForeignKeyTarget
+ fields = ('id', 'name', 'sources')
+ depth = 1
class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
- depth = 1
model = NullableForeignKeySource
-
-
-class NullableOneToOneSourceSerializer(serializers.ModelSerializer):
- class Meta:
- model = NullableOneToOneSource
+ fields = ('id', 'name', 'target')
+ depth = 1
class NullableOneToOneTargetSerializer(serializers.ModelSerializer):
- nullable_source = NullableOneToOneSourceSerializer()
-
class Meta:
model = OneToOneTarget
+ fields = ('id', 'name', 'nullable_source')
+ depth = 1
class ReverseForeignKeyTests(TestCase):
diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/relations_pk.py
index d6ae3176..5ce8b567 100644
--- a/rest_framework/tests/relations_pk.py
+++ b/rest_framework/tests/relations_pk.py
@@ -5,41 +5,44 @@ from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, Fore
from rest_framework.compat import six
+# ManyToMany
class ManyToManyTargetSerializer(serializers.ModelSerializer):
- sources = serializers.PrimaryKeyRelatedField(many=True)
-
class Meta:
model = ManyToManyTarget
+ fields = ('id', 'name', 'sources')
class ManyToManySourceSerializer(serializers.ModelSerializer):
class Meta:
model = ManyToManySource
+ fields = ('id', 'name', 'targets')
+# ForeignKey
class ForeignKeyTargetSerializer(serializers.ModelSerializer):
- sources = serializers.PrimaryKeyRelatedField(many=True)
-
class Meta:
model = ForeignKeyTarget
+ fields = ('id', 'name', 'sources')
class ForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
model = ForeignKeySource
+ fields = ('id', 'name', 'target')
+# Nullable ForeignKey
class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
model = NullableForeignKeySource
+ fields = ('id', 'name', 'target')
-# OneToOne
+# Nullable OneToOne
class NullableOneToOneTargetSerializer(serializers.ModelSerializer):
- nullable_source = serializers.PrimaryKeyRelatedField()
-
class Meta:
model = OneToOneTarget
+ fields = ('id', 'name', 'nullable_source')
# TODO: Add test that .data cannot be accessed prior to .is_valid
@@ -407,14 +410,14 @@ class PKNullableOneToOneTests(TestCase):
target.save()
new_target = OneToOneTarget(name='target-2')
new_target.save()
- source = NullableOneToOneSource(name='source-1', target=target)
+ source = NullableOneToOneSource(name='source-1', target=new_target)
source.save()
def test_reverse_foreign_key_retrieve_with_null(self):
queryset = OneToOneTarget.objects.all()
serializer = NullableOneToOneTargetSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'target-1', 'nullable_source': 1},
- {'id': 2, 'name': 'target-2', 'nullable_source': None},
+ {'id': 1, 'name': 'target-1', 'nullable_source': None},
+ {'id': 2, 'name': 'target-2', 'nullable_source': 1},
]
self.assertEqual(serializer.data, expected)
diff --git a/rest_framework/tests/request.py b/rest_framework/tests/request.py
index 4892f7a6..97e5af20 100644
--- a/rest_framework/tests/request.py
+++ b/rest_framework/tests/request.py
@@ -58,6 +58,14 @@ class TestMethodOverloading(TestCase):
request = Request(factory.post('/', {api_settings.FORM_METHOD_OVERRIDE: 'DELETE'}))
self.assertEqual(request.method, 'DELETE')
+ def test_x_http_method_override_header(self):
+ """
+ POST requests can also be overloaded to another method by setting
+ the X-HTTP-Method-Override header.
+ """
+ request = Request(factory.post('/', {'foo': 'bar'}, HTTP_X_HTTP_METHOD_OVERRIDE='DELETE'))
+ self.assertEqual(request.method, 'DELETE')
+
class TestContentParsing(TestCase):
def test_standard_behaviour_determines_no_content_GET(self):
diff --git a/rest_framework/tests/routers.py b/rest_framework/tests/routers.py
new file mode 100644
index 00000000..4e4765cb
--- /dev/null
+++ b/rest_framework/tests/routers.py
@@ -0,0 +1,55 @@
+from __future__ import unicode_literals
+from django.test import TestCase
+from django.test.client import RequestFactory
+from rest_framework import status
+from rest_framework.response import Response
+from rest_framework import viewsets
+from rest_framework.decorators import link, action
+from rest_framework.routers import SimpleRouter
+import copy
+
+factory = RequestFactory()
+
+
+class BasicViewSet(viewsets.ViewSet):
+ def list(self, request, *args, **kwargs):
+ return Response({'method': 'list'})
+
+ @action()
+ def action1(self, request, *args, **kwargs):
+ return Response({'method': 'action1'})
+
+ @action()
+ def action2(self, request, *args, **kwargs):
+ return Response({'method': 'action2'})
+
+ @link()
+ def link1(self, request, *args, **kwargs):
+ return Response({'method': 'link1'})
+
+ @link()
+ def link2(self, request, *args, **kwargs):
+ return Response({'method': 'link2'})
+
+
+class TestSimpleRouter(TestCase):
+ def setUp(self):
+ self.router = SimpleRouter()
+
+ def test_link_and_action_decorator(self):
+ routes = self.router.get_routes(BasicViewSet)
+ decorator_routes = routes[2:]
+ # Make sure all these endpoints exist and none have been clobbered
+ for i, endpoint in enumerate(['action1', 'action2', 'link1', 'link2']):
+ route = decorator_routes[i]
+ # check url listing
+ self.assertEqual(route.url,
+ '^{{prefix}}/{{lookup}}/{0}/$'.format(endpoint))
+ # check method to function mapping
+ if endpoint.startswith('action'):
+ method_map = 'post'
+ else:
+ method_map = 'get'
+ self.assertEqual(route.mapping[method_map], endpoint)
+
+
diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py
index d0300f9e..db3881f9 100644
--- a/rest_framework/tests/serializer.py
+++ b/rest_framework/tests/serializer.py
@@ -3,7 +3,7 @@ from django.utils.datastructures import MultiValueDict
from django.test import TestCase
from rest_framework import serializers
from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel,
- BlankFieldModel, BlogPost, Book, CallableDefaultValueModel, DefaultValueModel,
+ BlankFieldModel, BlogPost, BlogPostComment, Book, CallableDefaultValueModel, DefaultValueModel,
ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo)
import datetime
import pickle
@@ -78,6 +78,18 @@ class PersonSerializer(serializers.ModelSerializer):
read_only_fields = ('age',)
+class PersonSerializerInvalidReadOnly(serializers.ModelSerializer):
+ """
+ Testing for #652.
+ """
+ info = serializers.Field(source='info')
+
+ class Meta:
+ model = Person
+ fields = ('name', 'age', 'info')
+ read_only_fields = ('age', 'info')
+
+
class AlbumsSerializer(serializers.ModelSerializer):
class Meta:
@@ -189,6 +201,12 @@ class BasicTests(TestCase):
# Assert age is unchanged (35)
self.assertEqual(instance.age, self.person_data['age'])
+ def test_invalid_read_only_fields(self):
+ """
+ Regression test for #652.
+ """
+ self.assertRaises(AssertionError, PersonSerializerInvalidReadOnly, [])
+
class DictStyleSerializer(serializers.Serializer):
"""
@@ -261,25 +279,6 @@ class ValidationTests(TestCase):
self.assertEqual(serializer.is_valid(), True)
self.assertEqual(serializer.errors, {})
- def test_bad_type_data_is_false(self):
- """
- Data of the wrong type is not valid.
- """
- data = ['i am', 'a', 'list']
- serializer = CommentSerializer(self.comment, data=data, many=True)
- self.assertEqual(serializer.is_valid(), False)
- self.assertEqual(serializer.errors, {'non_field_errors': ['Invalid data']})
-
- data = 'and i am a string'
- serializer = CommentSerializer(self.comment, data=data)
- self.assertEqual(serializer.is_valid(), False)
- self.assertEqual(serializer.errors, {'non_field_errors': ['Invalid data']})
-
- data = 42
- serializer = CommentSerializer(self.comment, data=data)
- self.assertEqual(serializer.is_valid(), False)
- self.assertEqual(serializer.errors, {'non_field_errors': ['Invalid data']})
-
def test_cross_field_validation(self):
class CommentSerializerWithCrossFieldValidator(CommentSerializer):
@@ -376,7 +375,6 @@ class CustomValidationTests(TestCase):
def validate_email(self, attrs, source):
value = attrs[source]
-
return attrs
def validate_content(self, attrs, source):
@@ -757,6 +755,43 @@ class ManyRelatedTests(TestCase):
self.assertEqual(serializer.data, expected)
+ def test_include_reverse_relations(self):
+ post = BlogPost.objects.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I hate this blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ class BlogPostSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlogPost
+ fields = ('id', 'title', 'blogpostcomment_set')
+
+ serializer = BlogPostSerializer(instance=post)
+ expected = {
+ 'id': 1, 'title': 'Test blog post', 'blogpostcomment_set': [1, 2]
+ }
+ self.assertEqual(serializer.data, expected)
+
+ def test_depth_include_reverse_relations(self):
+ post = BlogPost.objects.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I hate this blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ class BlogPostSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlogPost
+ fields = ('id', 'title', 'blogpostcomment_set')
+ depth = 1
+
+ serializer = BlogPostSerializer(instance=post)
+ expected = {
+ 'id': 1, 'title': 'Test blog post',
+ 'blogpostcomment_set': [
+ {'id': 1, 'text': 'I hate this blog post', 'blog_post': 1},
+ {'id': 2, 'text': 'I love this blog post', 'blog_post': 1}
+ ]
+ }
+ self.assertEqual(serializer.data, expected)
+
def test_callable_source(self):
post = BlogPost.objects.create(title="Test blog post")
post.blogpostcomment_set.create(text="I love this blog post")
@@ -786,8 +821,6 @@ class RelatedTraversalTest(TestCase):
post = BlogPost.objects.create(title="Test blog post", writer=user)
post.blogpostcomment_set.create(text="I love this blog post")
- from rest_framework.tests.models import BlogPostComment
-
class PersonSerializer(serializers.ModelSerializer):
class Meta:
model = Person
@@ -987,23 +1020,26 @@ class SerializerPickleTests(TestCase):
class DepthTest(TestCase):
def test_implicit_nesting(self):
+
writer = Person.objects.create(name="django", age=1)
post = BlogPost.objects.create(title="Test blog post", writer=writer)
+ comment = BlogPostComment.objects.create(text="Test blog post comment", blog_post=post)
- class BlogPostSerializer(serializers.ModelSerializer):
+ class BlogPostCommentSerializer(serializers.ModelSerializer):
class Meta:
- model = BlogPost
- depth = 1
+ model = BlogPostComment
+ depth = 2
- serializer = BlogPostSerializer(instance=post)
- expected = {'id': 1, 'title': 'Test blog post',
- 'writer': {'id': 1, 'name': 'django', 'age': 1}}
+ serializer = BlogPostCommentSerializer(instance=comment)
+ expected = {'id': 1, 'text': 'Test blog post comment', 'blog_post': {'id': 1, 'title': 'Test blog post',
+ 'writer': {'id': 1, 'name': 'django', 'age': 1}}}
self.assertEqual(serializer.data, expected)
def test_explicit_nesting(self):
writer = Person.objects.create(name="django", age=1)
post = BlogPost.objects.create(title="Test blog post", writer=writer)
+ comment = BlogPostComment.objects.create(text="Test blog post comment", blog_post=post)
class PersonSerializer(serializers.ModelSerializer):
class Meta:
@@ -1015,9 +1051,15 @@ class DepthTest(TestCase):
class Meta:
model = BlogPost
- serializer = BlogPostSerializer(instance=post)
- expected = {'id': 1, 'title': 'Test blog post',
- 'writer': {'id': 1, 'name': 'django', 'age': 1}}
+ class BlogPostCommentSerializer(serializers.ModelSerializer):
+ blog_post = BlogPostSerializer()
+
+ class Meta:
+ model = BlogPostComment
+
+ serializer = BlogPostCommentSerializer(instance=comment)
+ expected = {'id': 1, 'text': 'Test blog post comment', 'blog_post': {'id': 1, 'title': 'Test blog post',
+ 'writer': {'id': 1, 'name': 'django', 'age': 1}}}
self.assertEqual(serializer.data, expected)
@@ -1072,3 +1114,32 @@ class NestedSerializerContextTests(TestCase):
# This will raise RuntimeError if context doesn't get passed correctly to the nested Serializers
AlbumCollectionSerializer(album_collection, context={'context_item': 'album context'}).data
+
+
+class DeserializeListTestCase(TestCase):
+
+ def setUp(self):
+ self.data = {
+ 'email': 'nobody@nowhere.com',
+ 'content': 'This is some test content',
+ 'created': datetime.datetime(2013, 3, 7),
+ }
+
+ def test_no_errors(self):
+ data = [self.data.copy() for x in range(0, 3)]
+ serializer = CommentSerializer(data=data, many=True)
+ self.assertTrue(serializer.is_valid())
+ self.assertTrue(isinstance(serializer.object, list))
+ self.assertTrue(
+ all((isinstance(item, Comment) for item in serializer.object))
+ )
+
+ def test_errors_return_as_list(self):
+ invalid_item = self.data.copy()
+ invalid_item['email'] = ''
+ data = [self.data.copy(), invalid_item, self.data.copy()]
+
+ serializer = CommentSerializer(data=data, many=True)
+ self.assertFalse(serializer.is_valid())
+ expected = [{}, {'email': ['This field is required.']}, {}]
+ self.assertEqual(serializer.errors, expected)
diff --git a/rest_framework/tests/serializer_bulk_update.py b/rest_framework/tests/serializer_bulk_update.py
new file mode 100644
index 00000000..8b0ded1a
--- /dev/null
+++ b/rest_framework/tests/serializer_bulk_update.py
@@ -0,0 +1,278 @@
+"""
+Tests to cover bulk create and update using serializers.
+"""
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework import serializers
+
+
+class BulkCreateSerializerTests(TestCase):
+ """
+ Creating multiple instances using serializers.
+ """
+
+ def setUp(self):
+ class BookSerializer(serializers.Serializer):
+ id = serializers.IntegerField()
+ title = serializers.CharField(max_length=100)
+ author = serializers.CharField(max_length=100)
+
+ self.BookSerializer = BookSerializer
+
+ def test_bulk_create_success(self):
+ """
+ Correct bulk update serialization should return the input data.
+ """
+
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 1,
+ 'title': 'If this is a man',
+ 'author': 'Primo Levi'
+ }, {
+ 'id': 2,
+ 'title': 'The wind-up bird chronicle',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+
+ serializer = self.BookSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, data)
+
+ def test_bulk_create_errors(self):
+ """
+ Correct bulk update serialization should return the input data.
+ """
+
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 1,
+ 'title': 'If this is a man',
+ 'author': 'Primo Levi'
+ }, {
+ 'id': 'foo',
+ 'title': 'The wind-up bird chronicle',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+ expected_errors = [
+ {},
+ {},
+ {'id': ['Enter a whole number.']}
+ ]
+
+ serializer = self.BookSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, expected_errors)
+
+ def test_invalid_list_datatype(self):
+ """
+ Data containing list of incorrect data type should return errors.
+ """
+ data = ['foo', 'bar', 'baz']
+ serializer = self.BookSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+
+ expected_errors = [
+ {'non_field_errors': ['Invalid data']},
+ {'non_field_errors': ['Invalid data']},
+ {'non_field_errors': ['Invalid data']}
+ ]
+
+ self.assertEqual(serializer.errors, expected_errors)
+
+ def test_invalid_single_datatype(self):
+ """
+ Data containing a single incorrect data type should return errors.
+ """
+ data = 123
+ serializer = self.BookSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+
+ expected_errors = {'non_field_errors': ['Expected a list of items.']}
+
+ self.assertEqual(serializer.errors, expected_errors)
+
+ def test_invalid_single_object(self):
+ """
+ Data containing only a single object, instead of a list of objects
+ should return errors.
+ """
+ data = {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }
+ serializer = self.BookSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+
+ expected_errors = {'non_field_errors': ['Expected a list of items.']}
+
+ self.assertEqual(serializer.errors, expected_errors)
+
+
+class BulkUpdateSerializerTests(TestCase):
+ """
+ Updating multiple instances using serializers.
+ """
+
+ def setUp(self):
+ class Book(object):
+ """
+ A data type that can be persisted to a mock storage backend
+ with `.save()` and `.delete()`.
+ """
+ object_map = {}
+
+ def __init__(self, id, title, author):
+ self.id = id
+ self.title = title
+ self.author = author
+
+ def save(self):
+ Book.object_map[self.id] = self
+
+ def delete(self):
+ del Book.object_map[self.id]
+
+ class BookSerializer(serializers.Serializer):
+ id = serializers.IntegerField()
+ title = serializers.CharField(max_length=100)
+ author = serializers.CharField(max_length=100)
+
+ def restore_object(self, attrs, instance=None):
+ if instance:
+ instance.id = attrs['id']
+ instance.title = attrs['title']
+ instance.author = attrs['author']
+ return instance
+ return Book(**attrs)
+
+ self.Book = Book
+ self.BookSerializer = BookSerializer
+
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 1,
+ 'title': 'If this is a man',
+ 'author': 'Primo Levi'
+ }, {
+ 'id': 2,
+ 'title': 'The wind-up bird chronicle',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+
+ for item in data:
+ book = Book(item['id'], item['title'], item['author'])
+ book.save()
+
+ def books(self):
+ """
+ Return all the objects in the mock storage backend.
+ """
+ return self.Book.object_map.values()
+
+ def test_bulk_update_success(self):
+ """
+ Correct bulk update serialization should return the input data.
+ """
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 2,
+ 'title': 'Kafka on the shore',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+ serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.data, data)
+ serializer.save()
+ new_data = self.BookSerializer(self.books(), many=True).data
+
+ self.assertEqual(data, new_data)
+
+ def test_bulk_update_and_create(self):
+ """
+ Bulk update serialization may also include created items.
+ """
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 3,
+ 'title': 'Kafka on the shore',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+ serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.data, data)
+ serializer.save()
+ new_data = self.BookSerializer(self.books(), many=True).data
+ self.assertEqual(data, new_data)
+
+ def test_bulk_update_invalid_create(self):
+ """
+ Bulk update serialization without allow_add_remove may not create items.
+ """
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 3,
+ 'title': 'Kafka on the shore',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+ expected_errors = [
+ {},
+ {'non_field_errors': ['Cannot create a new item, only existing items may be updated.']}
+ ]
+ serializer = self.BookSerializer(self.books(), data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, expected_errors)
+
+ def test_bulk_update_error(self):
+ """
+ Incorrect bulk update serialization should return error data.
+ """
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 'foo',
+ 'title': 'Kafka on the shore',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+ expected_errors = [
+ {},
+ {'id': ['Enter a whole number.']}
+ ]
+ serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, expected_errors)
diff --git a/rest_framework/tests/serializer_nested.py b/rest_framework/tests/serializer_nested.py
new file mode 100644
index 00000000..71d0e24b
--- /dev/null
+++ b/rest_framework/tests/serializer_nested.py
@@ -0,0 +1,246 @@
+"""
+Tests to cover nested serializers.
+
+Doesn't cover model serializers.
+"""
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework import serializers
+
+
+class WritableNestedSerializerBasicTests(TestCase):
+ """
+ Tests for deserializing nested entities.
+ Basic tests that use serializers that simply restore to dicts.
+ """
+
+ def setUp(self):
+ class TrackSerializer(serializers.Serializer):
+ order = serializers.IntegerField()
+ title = serializers.CharField(max_length=100)
+ duration = serializers.IntegerField()
+
+ class AlbumSerializer(serializers.Serializer):
+ album_name = serializers.CharField(max_length=100)
+ artist = serializers.CharField(max_length=100)
+ tracks = TrackSerializer(many=True)
+
+ self.AlbumSerializer = AlbumSerializer
+
+ def test_nested_validation_success(self):
+ """
+ Correct nested serialization should return the input data.
+ """
+
+ data = {
+ 'album_name': 'Discovery',
+ 'artist': 'Daft Punk',
+ 'tracks': [
+ {'order': 1, 'title': 'One More Time', 'duration': 235},
+ {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
+ {'order': 3, 'title': 'Digital Love', 'duration': 239}
+ ]
+ }
+
+ serializer = self.AlbumSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, data)
+
+ def test_nested_validation_error(self):
+ """
+ Incorrect nested serialization should return appropriate error data.
+ """
+
+ data = {
+ 'album_name': 'Discovery',
+ 'artist': 'Daft Punk',
+ 'tracks': [
+ {'order': 1, 'title': 'One More Time', 'duration': 235},
+ {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
+ {'order': 3, 'title': 'Digital Love', 'duration': 'foobar'}
+ ]
+ }
+ expected_errors = {
+ 'tracks': [
+ {},
+ {},
+ {'duration': ['Enter a whole number.']}
+ ]
+ }
+
+ serializer = self.AlbumSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, expected_errors)
+
+ def test_many_nested_validation_error(self):
+ """
+ Incorrect nested serialization should return appropriate error data
+ when multiple entities are being deserialized.
+ """
+
+ data = [
+ {
+ 'album_name': 'Russian Red',
+ 'artist': 'I Love Your Glasses',
+ 'tracks': [
+ {'order': 1, 'title': 'Cigarettes', 'duration': 121},
+ {'order': 2, 'title': 'No Past Land', 'duration': 198},
+ {'order': 3, 'title': 'They Don\'t Believe', 'duration': 191}
+ ]
+ },
+ {
+ 'album_name': 'Discovery',
+ 'artist': 'Daft Punk',
+ 'tracks': [
+ {'order': 1, 'title': 'One More Time', 'duration': 235},
+ {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
+ {'order': 3, 'title': 'Digital Love', 'duration': 'foobar'}
+ ]
+ }
+ ]
+ expected_errors = [
+ {},
+ {
+ 'tracks': [
+ {},
+ {},
+ {'duration': ['Enter a whole number.']}
+ ]
+ }
+ ]
+
+ serializer = self.AlbumSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, expected_errors)
+
+
+class WritableNestedSerializerObjectTests(TestCase):
+ """
+ Tests for deserializing nested entities.
+ These tests use serializers that restore to concrete objects.
+ """
+
+ def setUp(self):
+ # Couple of concrete objects that we're going to deserialize into
+ class Track(object):
+ def __init__(self, order, title, duration):
+ self.order, self.title, self.duration = order, title, duration
+
+ def __eq__(self, other):
+ return (
+ self.order == other.order and
+ self.title == other.title and
+ self.duration == other.duration
+ )
+
+ class Album(object):
+ def __init__(self, album_name, artist, tracks):
+ self.album_name, self.artist, self.tracks = album_name, artist, tracks
+
+ def __eq__(self, other):
+ return (
+ self.album_name == other.album_name and
+ self.artist == other.artist and
+ self.tracks == other.tracks
+ )
+
+ # And their corresponding serializers
+ class TrackSerializer(serializers.Serializer):
+ order = serializers.IntegerField()
+ title = serializers.CharField(max_length=100)
+ duration = serializers.IntegerField()
+
+ def restore_object(self, attrs, instance=None):
+ return Track(attrs['order'], attrs['title'], attrs['duration'])
+
+ class AlbumSerializer(serializers.Serializer):
+ album_name = serializers.CharField(max_length=100)
+ artist = serializers.CharField(max_length=100)
+ tracks = TrackSerializer(many=True)
+
+ def restore_object(self, attrs, instance=None):
+ return Album(attrs['album_name'], attrs['artist'], attrs['tracks'])
+
+ self.Album, self.Track = Album, Track
+ self.AlbumSerializer = AlbumSerializer
+
+ def test_nested_validation_success(self):
+ """
+ Correct nested serialization should return a restored object
+ that corresponds to the input data.
+ """
+
+ data = {
+ 'album_name': 'Discovery',
+ 'artist': 'Daft Punk',
+ 'tracks': [
+ {'order': 1, 'title': 'One More Time', 'duration': 235},
+ {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
+ {'order': 3, 'title': 'Digital Love', 'duration': 239}
+ ]
+ }
+ expected_object = self.Album(
+ album_name='Discovery',
+ artist='Daft Punk',
+ tracks=[
+ self.Track(order=1, title='One More Time', duration=235),
+ self.Track(order=2, title='Aerodynamic', duration=184),
+ self.Track(order=3, title='Digital Love', duration=239),
+ ]
+ )
+
+ serializer = self.AlbumSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, expected_object)
+
+ def test_many_nested_validation_success(self):
+ """
+ Correct nested serialization should return multiple restored objects
+ that corresponds to the input data when multiple objects are
+ being deserialized.
+ """
+
+ data = [
+ {
+ 'album_name': 'Russian Red',
+ 'artist': 'I Love Your Glasses',
+ 'tracks': [
+ {'order': 1, 'title': 'Cigarettes', 'duration': 121},
+ {'order': 2, 'title': 'No Past Land', 'duration': 198},
+ {'order': 3, 'title': 'They Don\'t Believe', 'duration': 191}
+ ]
+ },
+ {
+ 'album_name': 'Discovery',
+ 'artist': 'Daft Punk',
+ 'tracks': [
+ {'order': 1, 'title': 'One More Time', 'duration': 235},
+ {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
+ {'order': 3, 'title': 'Digital Love', 'duration': 239}
+ ]
+ }
+ ]
+ expected_object = [
+ self.Album(
+ album_name='Russian Red',
+ artist='I Love Your Glasses',
+ tracks=[
+ self.Track(order=1, title='Cigarettes', duration=121),
+ self.Track(order=2, title='No Past Land', duration=198),
+ self.Track(order=3, title='They Don\'t Believe', duration=191),
+ ]
+ ),
+ self.Album(
+ album_name='Discovery',
+ artist='Daft Punk',
+ tracks=[
+ self.Track(order=1, title='One More Time', duration=235),
+ self.Track(order=2, title='Aerodynamic', duration=184),
+ self.Track(order=3, title='Digital Love', duration=239),
+ ]
+ )
+ ]
+
+ serializer = self.AlbumSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, expected_object)
diff --git a/rest_framework/tests/status.py b/rest_framework/tests/status.py
deleted file mode 100644
index e1644a6b..00000000
--- a/rest_framework/tests/status.py
+++ /dev/null
@@ -1,13 +0,0 @@
-"""Tests for the status module"""
-from __future__ import unicode_literals
-from django.test import TestCase
-from rest_framework import status
-
-
-class TestStatus(TestCase):
- """Simple sanity test to check the status module"""
-
- def test_status(self):
- """Ensure the status module is present and correct."""
- self.assertEqual(200, status.HTTP_200_OK)
- self.assertEqual(404, status.HTTP_404_NOT_FOUND)