aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/tests
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework/tests')
-rw-r--r--rest_framework/tests/authentication.py466
-rw-r--r--rest_framework/tests/breadcrumbs.py1
-rw-r--r--rest_framework/tests/decorators.py41
-rw-r--r--rest_framework/tests/description.py24
-rw-r--r--rest_framework/tests/fields.py403
-rw-r--r--rest_framework/tests/files.py16
-rw-r--r--rest_framework/tests/filterset.py48
-rw-r--r--rest_framework/tests/genericrelations.py103
-rw-r--r--rest_framework/tests/generics.py224
-rw-r--r--rest_framework/tests/htmlrenderer.py29
-rw-r--r--rest_framework/tests/hyperlinkedserializers.py27
-rw-r--r--rest_framework/tests/models.py52
-rw-r--r--rest_framework/tests/modelviews.py90
-rw-r--r--rest_framework/tests/multitable_inheritance.py67
-rw-r--r--rest_framework/tests/negotiation.py15
-rw-r--r--rest_framework/tests/pagination.py246
-rw-r--r--rest_framework/tests/parsers.py140
-rw-r--r--rest_framework/tests/permissions.py153
-rw-r--r--rest_framework/tests/relations.py16
-rw-r--r--rest_framework/tests/relations_hyperlink.py323
-rw-r--r--rest_framework/tests/relations_nested.py45
-rw-r--r--rest_framework/tests/relations_pk.py289
-rw-r--r--rest_framework/tests/relations_slug.py257
-rw-r--r--rest_framework/tests/renderers.py113
-rw-r--r--rest_framework/tests/request.py30
-rw-r--r--rest_framework/tests/response.py60
-rw-r--r--rest_framework/tests/reverse.py3
-rw-r--r--rest_framework/tests/serializer.py497
-rw-r--r--rest_framework/tests/settings.py1
-rw-r--r--rest_framework/tests/status.py5
-rw-r--r--rest_framework/tests/testcases.py1
-rw-r--r--rest_framework/tests/tests.py1
-rw-r--r--rest_framework/tests/throttling.py5
-rw-r--r--rest_framework/tests/urlpatterns.py76
-rw-r--r--rest_framework/tests/utils.py21
-rw-r--r--rest_framework/tests/validation.py65
-rw-r--r--rest_framework/tests/validators.py329
-rw-r--r--rest_framework/tests/views.py27
38 files changed, 2855 insertions, 1454 deletions
diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py
index e86041bc..b663ca48 100644
--- a/rest_framework/tests/authentication.py
+++ b/rest_framework/tests/authentication.py
@@ -1,33 +1,65 @@
+from __future__ import unicode_literals
from django.contrib.auth.models import User
from django.http import HttpResponse
from django.test import Client, TestCase
-
+from django.utils import unittest
+from rest_framework import HTTP_HEADER_ENCODING
+from rest_framework import exceptions
from rest_framework import permissions
+from rest_framework import status
+from rest_framework.authentication import (
+ BaseAuthentication,
+ TokenAuthentication,
+ BasicAuthentication,
+ SessionAuthentication,
+ OAuthAuthentication,
+ OAuth2Authentication
+)
from rest_framework.authtoken.models import Token
-from rest_framework.authentication import TokenAuthentication
-from rest_framework.compat import patterns
+from rest_framework.compat import patterns, url, include
+from rest_framework.compat import oauth2_provider, oauth2_provider_models, oauth2_provider_scope
+from rest_framework.compat import oauth, oauth_provider
+from rest_framework.tests.utils import RequestFactory
from rest_framework.views import APIView
-
import json
import base64
+import time
+import datetime
+
+factory = RequestFactory()
class MockView(APIView):
permission_classes = (permissions.IsAuthenticated,)
+ def get(self, request):
+ return HttpResponse({'a': 1, 'b': 2, 'c': 3})
+
def post(self, request):
return HttpResponse({'a': 1, 'b': 2, 'c': 3})
def put(self, request):
return HttpResponse({'a': 1, 'b': 2, 'c': 3})
-MockView.authentication_classes += (TokenAuthentication,)
urlpatterns = patterns('',
- (r'^$', MockView.as_view()),
+ (r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])),
+ (r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])),
+ (r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),
(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),
+ (r'^oauth/$', MockView.as_view(authentication_classes=[OAuthAuthentication])),
+ (r'^oauth-with-scope/$', MockView.as_view(authentication_classes=[OAuthAuthentication],
+ permission_classes=[permissions.TokenHasReadWriteScope]))
)
+if oauth2_provider is not None:
+ urlpatterns += patterns('',
+ url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')),
+ url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])),
+ url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication],
+ permission_classes=[permissions.TokenHasReadWriteScope])),
+ )
+
class BasicAuthTests(TestCase):
"""Basic authentication"""
@@ -42,25 +74,30 @@ class BasicAuthTests(TestCase):
def test_post_form_passing_basic_auth(self):
"""Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF"""
- auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip()
- response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
+ credentials = ('%s:%s' % (self.username, self.password))
+ base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING)
+ auth = 'Basic %s' % base64_credentials
+ response = self.csrf_client.post('/basic/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_post_json_passing_basic_auth(self):
"""Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF"""
- auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip()
- response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
+ credentials = ('%s:%s' % (self.username, self.password))
+ base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING)
+ auth = 'Basic %s' % base64_credentials
+ response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_post_form_failing_basic_auth(self):
"""Ensure POSTing form over basic auth without correct credentials fails"""
- response = self.csrf_client.post('/', {'example': 'example'})
- self.assertEqual(response.status_code, 403)
+ response = self.csrf_client.post('/basic/', {'example': 'example'})
+ self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
def test_post_json_failing_basic_auth(self):
"""Ensure POSTing json over basic auth without correct credentials fails"""
- response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json')
- self.assertEqual(response.status_code, 403)
+ response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json')
+ self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
+ self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"')
class SessionAuthTests(TestCase):
@@ -83,31 +120,31 @@ class SessionAuthTests(TestCase):
Ensure POSTing form over session authentication without CSRF token fails.
"""
self.csrf_client.login(username=self.username, password=self.password)
- response = self.csrf_client.post('/', {'example': 'example'})
- self.assertEqual(response.status_code, 403)
+ response = self.csrf_client.post('/session/', {'example': 'example'})
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_post_form_session_auth_passing(self):
"""
Ensure POSTing form over session authentication with logged in user and CSRF token passes.
"""
self.non_csrf_client.login(username=self.username, password=self.password)
- response = self.non_csrf_client.post('/', {'example': 'example'})
- self.assertEqual(response.status_code, 200)
+ response = self.non_csrf_client.post('/session/', {'example': 'example'})
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_put_form_session_auth_passing(self):
"""
Ensure PUTting form over session authentication with logged in user and CSRF token passes.
"""
self.non_csrf_client.login(username=self.username, password=self.password)
- response = self.non_csrf_client.put('/', {'example': 'example'})
- self.assertEqual(response.status_code, 200)
+ response = self.non_csrf_client.put('/session/', {'example': 'example'})
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_post_form_session_auth_failing(self):
"""
Ensure POSTing form over session authentication without logged in user fails.
"""
- response = self.csrf_client.post('/', {'example': 'example'})
- self.assertEqual(response.status_code, 403)
+ response = self.csrf_client.post('/session/', {'example': 'example'})
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
class TokenAuthTests(TestCase):
@@ -126,25 +163,25 @@ class TokenAuthTests(TestCase):
def test_post_form_passing_token_auth(self):
"""Ensure POSTing json over token auth with correct credentials passes and does not require CSRF"""
- auth = "Token " + self.key
- response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
+ auth = 'Token ' + self.key
+ response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_post_json_passing_token_auth(self):
"""Ensure POSTing form over token auth with correct credentials passes and does not require CSRF"""
auth = "Token " + self.key
- response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
+ response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_post_form_failing_token_auth(self):
"""Ensure POSTing form over token auth without correct credentials fails"""
- response = self.csrf_client.post('/', {'example': 'example'})
- self.assertEqual(response.status_code, 403)
+ response = self.csrf_client.post('/token/', {'example': 'example'})
+ self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
def test_post_json_failing_token_auth(self):
"""Ensure POSTing json over token auth without correct credentials fails"""
- response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json')
- self.assertEqual(response.status_code, 403)
+ response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json')
+ self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
def test_token_has_auto_assigned_key_if_none_provided(self):
"""Ensure creating a token with no key will auto-assign a key"""
@@ -157,8 +194,8 @@ class TokenAuthTests(TestCase):
client = Client(enforce_csrf_checks=True)
response = client.post('/auth-token/',
json.dumps({'username': self.username, 'password': self.password}), 'application/json')
- self.assertEqual(response.status_code, 200)
- self.assertEqual(json.loads(response.content)['token'], self.key)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(json.loads(response.content.decode('ascii'))['token'], self.key)
def test_token_login_json_bad_creds(self):
"""Ensure token login view using JSON POST fails if bad credentials are used."""
@@ -179,5 +216,362 @@ class TokenAuthTests(TestCase):
client = Client(enforce_csrf_checks=True)
response = client.post('/auth-token/',
{'username': self.username, 'password': self.password})
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(json.loads(response.content.decode('ascii'))['token'], self.key)
+
+
+class IncorrectCredentialsTests(TestCase):
+ def test_incorrect_credentials(self):
+ """
+ If a request contains bad authentication credentials, then
+ authentication should run and error, even if no permissions
+ are set on the view.
+ """
+ class IncorrectCredentialsAuth(BaseAuthentication):
+ def authenticate(self, request):
+ raise exceptions.AuthenticationFailed('Bad credentials')
+
+ request = factory.get('/')
+ view = MockView.as_view(
+ authentication_classes=(IncorrectCredentialsAuth,),
+ permission_classes=()
+ )
+ response = view(request)
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+ self.assertEqual(response.data, {'detail': 'Bad credentials'})
+
+
+class OAuthTests(TestCase):
+ """OAuth 1.0a authentication"""
+ urls = 'rest_framework.tests.authentication'
+
+ def setUp(self):
+ # these imports are here because oauth is optional and hiding them in try..except block or compat
+ # could obscure problems if something breaks
+ from oauth_provider.models import Consumer, Resource
+ from oauth_provider.models import Token as OAuthToken
+ from oauth_provider import consts
+
+ self.consts = consts
+
+ self.csrf_client = Client(enforce_csrf_checks=True)
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ self.CONSUMER_KEY = 'consumer_key'
+ self.CONSUMER_SECRET = 'consumer_secret'
+ self.TOKEN_KEY = "token_key"
+ self.TOKEN_SECRET = "token_secret"
+
+ self.consumer = Consumer.objects.create(key=self.CONSUMER_KEY, secret=self.CONSUMER_SECRET,
+ name='example', user=self.user, status=self.consts.ACCEPTED)
+
+ self.resource = Resource.objects.create(name="resource name", url="api/")
+ self.token = OAuthToken.objects.create(user=self.user, consumer=self.consumer, resource=self.resource,
+ token_type=OAuthToken.ACCESS, key=self.TOKEN_KEY, secret=self.TOKEN_SECRET, is_approved=True
+ )
+
+ def _create_authorization_header(self):
+ params = {
+ 'oauth_version': "1.0",
+ 'oauth_nonce': oauth.generate_nonce(),
+ 'oauth_timestamp': int(time.time()),
+ 'oauth_token': self.token.key,
+ 'oauth_consumer_key': self.consumer.key
+ }
+
+ req = oauth.Request(method="GET", url="http://example.com", parameters=params)
+
+ signature_method = oauth.SignatureMethod_PLAINTEXT()
+ req.sign_request(signature_method, self.consumer, self.token)
+
+ return req.to_header()["Authorization"]
+
+ def _create_authorization_url_parameters(self):
+ params = {
+ 'oauth_version': "1.0",
+ 'oauth_nonce': oauth.generate_nonce(),
+ 'oauth_timestamp': int(time.time()),
+ 'oauth_token': self.token.key,
+ 'oauth_consumer_key': self.consumer.key
+ }
+
+ req = oauth.Request(method="GET", url="http://example.com", parameters=params)
+
+ signature_method = oauth.SignatureMethod_PLAINTEXT()
+ req.sign_request(signature_method, self.consumer, self.token)
+ return dict(req)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_passing_oauth(self):
+ """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_repeated_nonce_failing_oauth(self):
+ """Ensure POSTing form over OAuth with repeated auth (same nonces and timestamp) credentials fails"""
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ # simulate reply attack auth header containes already used (nonce, timestamp) pair
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_token_removed_failing_oauth(self):
+ """Ensure POSTing when there is no OAuth access token in db fails"""
+ self.token.delete()
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_consumer_status_not_accepted_failing_oauth(self):
+ """Ensure POSTing when consumer status is anything other than ACCEPTED fails"""
+ for consumer_status in (self.consts.CANCELED, self.consts.PENDING, self.consts.REJECTED):
+ self.consumer.status = consumer_status
+ self.consumer.save()
+
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_with_request_token_failing_oauth(self):
+ """Ensure POSTing with unauthorized request token instead of access token fails"""
+ self.token.token_type = self.token.REQUEST
+ self.token.save()
+
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_with_urlencoded_parameters(self):
+ """Ensure POSTing with x-www-form-urlencoded auth parameters passes"""
+ params = self._create_authorization_url_parameters()
+ response = self.csrf_client.post('/oauth/', params)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_get_form_with_url_parameters(self):
+ """Ensure GETing with auth in url parameters passes"""
+ params = self._create_authorization_url_parameters()
+ response = self.csrf_client.get('/oauth/', params)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_hmac_sha1_signature_passes(self):
+ """Ensure POSTing using HMAC_SHA1 signature method passes"""
+ params = {
+ 'oauth_version': "1.0",
+ 'oauth_nonce': oauth.generate_nonce(),
+ 'oauth_timestamp': int(time.time()),
+ 'oauth_token': self.token.key,
+ 'oauth_consumer_key': self.consumer.key
+ }
+
+ req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params)
+
+ signature_method = oauth.SignatureMethod_HMAC_SHA1()
+ req.sign_request(signature_method, self.consumer, self.token)
+ auth = req.to_header()["Authorization"]
+
+ response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_get_form_with_readonly_resource_passing_auth(self):
+ """Ensure POSTing with a readonly resource instead of a write scope fails"""
+ read_only_access_token = self.token
+ read_only_access_token.resource.is_readonly = True
+ read_only_access_token.resource.save()
+ params = self._create_authorization_url_parameters()
+ response = self.csrf_client.get('/oauth-with-scope/', params)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_with_readonly_resource_failing_auth(self):
+ """Ensure POSTing with a readonly resource instead of a write scope fails"""
+ read_only_access_token = self.token
+ read_only_access_token.resource.is_readonly = True
+ read_only_access_token.resource.save()
+ params = self._create_authorization_url_parameters()
+ response = self.csrf_client.post('/oauth-with-scope/', params)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_with_write_resource_passing_auth(self):
+ """Ensure POSTing with a write resource succeed"""
+ read_write_access_token = self.token
+ read_write_access_token.resource.is_readonly = False
+ read_write_access_token.resource.save()
+ params = self._create_authorization_url_parameters()
+ response = self.csrf_client.post('/oauth-with-scope/', params)
+ self.assertEqual(response.status_code, 200)
+
+
+class OAuth2Tests(TestCase):
+ """OAuth 2.0 authentication"""
+ urls = 'rest_framework.tests.authentication'
+
+ def setUp(self):
+ self.csrf_client = Client(enforce_csrf_checks=True)
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ self.CLIENT_ID = 'client_key'
+ self.CLIENT_SECRET = 'client_secret'
+ self.ACCESS_TOKEN = "access_token"
+ self.REFRESH_TOKEN = "refresh_token"
+
+ self.oauth2_client = oauth2_provider_models.Client.objects.create(
+ client_id=self.CLIENT_ID,
+ client_secret=self.CLIENT_SECRET,
+ redirect_uri='',
+ client_type=0,
+ name='example',
+ user=None,
+ )
+
+ self.access_token = oauth2_provider_models.AccessToken.objects.create(
+ token=self.ACCESS_TOKEN,
+ client=self.oauth2_client,
+ user=self.user,
+ )
+ self.refresh_token = oauth2_provider_models.RefreshToken.objects.create(
+ user=self.user,
+ access_token=self.access_token,
+ client=self.oauth2_client
+ )
+
+ def _create_authorization_header(self, token=None):
+ return "Bearer {0}".format(token or self.access_token.token)
+
+ def _client_credentials_params(self):
+ return {'client_id': self.CLIENT_ID, 'client_secret': self.CLIENT_SECRET}
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_with_wrong_authorization_header_token_type_failing(self):
+ """Ensure that a wrong token type lead to the correct HTTP error status code"""
+ auth = "Wrong token-type-obsviously"
+ response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+ params = self._client_credentials_params()
+ response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_with_wrong_authorization_header_token_format_failing(self):
+ """Ensure that a wrong token format lead to the correct HTTP error status code"""
+ auth = "Bearer wrong token format"
+ response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+ params = self._client_credentials_params()
+ response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_with_wrong_authorization_header_token_failing(self):
+ """Ensure that a wrong token lead to the correct HTTP error status code"""
+ auth = "Bearer wrong-token"
+ response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+ params = self._client_credentials_params()
+ response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_with_wrong_client_data_failing_auth(self):
+ """Ensure GETing form over OAuth with incorrect client credentials fails"""
+ auth = self._create_authorization_header()
+ params = self._client_credentials_params()
+ params['client_id'] += 'a'
+ response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_passing_auth(self):
+ """Ensure GETing form over OAuth with correct client credentials succeed"""
+ auth = self._create_authorization_header()
+ params = self._client_credentials_params()
+ response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_passing_auth(self):
+ """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
+ auth = self._create_authorization_header()
+ params = self._client_credentials_params()
+ response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_token_removed_failing_auth(self):
+ """Ensure POSTing when there is no OAuth access token in db fails"""
+ self.access_token.delete()
+ auth = self._create_authorization_header()
+ params = self._client_credentials_params()
+ response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_with_refresh_token_failing_auth(self):
+ """Ensure POSTing with refresh token instead of access token fails"""
+ auth = self._create_authorization_header(token=self.refresh_token.token)
+ params = self._client_credentials_params()
+ response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_with_expired_access_token_failing_auth(self):
+ """Ensure POSTing with expired access token fails with an 'Invalid token' error"""
+ self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10) # 10 seconds late
+ self.access_token.save()
+ auth = self._create_authorization_header()
+ params = self._client_credentials_params()
+ response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+ self.assertIn('Invalid token', response.content)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_with_invalid_scope_failing_auth(self):
+ """Ensure POSTing with a readonly scope instead of a write scope fails"""
+ read_only_access_token = self.access_token
+ read_only_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['read']
+ read_only_access_token.save()
+ auth = self._create_authorization_header(token=read_only_access_token.token)
+ params = self._client_credentials_params()
+ response = self.csrf_client.get('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+ response = self.csrf_client.post('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_with_valid_scope_passing_auth(self):
+ """Ensure POSTing with a write scope succeed"""
+ read_write_access_token = self.access_token
+ read_write_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['write']
+ read_write_access_token.save()
+ auth = self._create_authorization_header(token=read_write_access_token.token)
+ params = self._client_credentials_params()
+ response = self.csrf_client.post('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
- self.assertEqual(json.loads(response.content)['token'], self.key)
diff --git a/rest_framework/tests/breadcrumbs.py b/rest_framework/tests/breadcrumbs.py
index df891683..d9ed647e 100644
--- a/rest_framework/tests/breadcrumbs.py
+++ b/rest_framework/tests/breadcrumbs.py
@@ -1,3 +1,4 @@
+from __future__ import unicode_literals
from django.test import TestCase
from rest_framework.compat import patterns, url
from rest_framework.utils.breadcrumbs import get_breadcrumbs
diff --git a/rest_framework/tests/decorators.py b/rest_framework/tests/decorators.py
index 5e6bce4e..1016fed3 100644
--- a/rest_framework/tests/decorators.py
+++ b/rest_framework/tests/decorators.py
@@ -1,3 +1,4 @@
+from __future__ import unicode_literals
from django.test import TestCase
from rest_framework import status
from rest_framework.response import Response
@@ -28,13 +29,27 @@ class DecoratorTestCase(TestCase):
response.request = request
return APIView.finalize_response(self, request, response, *args, **kwargs)
- def test_wrap_view(self):
+ def test_api_view_incorrect(self):
+ """
+ If @api_view is not applied correct, we should raise an assertion.
+ """
- @api_view(['GET'])
+ @api_view
def view(request):
- return Response({})
+ return Response()
+
+ request = self.factory.get('/')
+ self.assertRaises(AssertionError, view, request)
+
+ def test_api_view_incorrect_arguments(self):
+ """
+ If @api_view is missing arguments, we should raise an assertion.
+ """
- self.assertTrue(isinstance(view.cls_instance, APIView))
+ with self.assertRaises(AssertionError):
+ @api_view('GET')
+ def view(request):
+ return Response()
def test_calling_method(self):
@@ -44,11 +59,11 @@ class DecoratorTestCase(TestCase):
request = self.factory.get('/')
response = view(request)
- self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
request = self.factory.post('/')
response = view(request)
- self.assertEqual(response.status_code, 405)
+ self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
def test_calling_put_method(self):
@@ -58,11 +73,11 @@ class DecoratorTestCase(TestCase):
request = self.factory.put('/')
response = view(request)
- self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
request = self.factory.post('/')
response = view(request)
- self.assertEqual(response.status_code, 405)
+ self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
def test_calling_patch_method(self):
@@ -72,11 +87,11 @@ class DecoratorTestCase(TestCase):
request = self.factory.patch('/')
response = view(request)
- self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
request = self.factory.post('/')
response = view(request)
- self.assertEqual(response.status_code, 405)
+ self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
def test_renderer_classes(self):
@@ -124,7 +139,7 @@ class DecoratorTestCase(TestCase):
request = self.factory.get('/')
response = view(request)
- self.assertEquals(response.status_code, status.HTTP_403_FORBIDDEN)
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_throttle_classes(self):
class OncePerDayUserThrottle(UserRateThrottle):
@@ -137,7 +152,7 @@ class DecoratorTestCase(TestCase):
request = self.factory.get('/')
response = view(request)
- self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
response = view(request)
- self.assertEquals(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS)
+ self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS)
diff --git a/rest_framework/tests/description.py b/rest_framework/tests/description.py
index d958b840..5b3315bc 100644
--- a/rest_framework/tests/description.py
+++ b/rest_framework/tests/description.py
@@ -1,3 +1,6 @@
+# -- coding: utf-8 --
+
+from __future__ import unicode_literals
from django.test import TestCase
from rest_framework.views import APIView
from rest_framework.compat import apply_markdown
@@ -50,7 +53,7 @@ class TestViewNamesAndDescriptions(TestCase):
"""Ensure Resource names are based on the classname by default."""
class MockView(APIView):
pass
- self.assertEquals(MockView().get_name(), 'Mock')
+ self.assertEqual(MockView().get_name(), 'Mock')
def test_resource_name_can_be_set_explicitly(self):
"""Ensure Resource names can be set using the 'get_name' method."""
@@ -58,7 +61,7 @@ class TestViewNamesAndDescriptions(TestCase):
class MockView(APIView):
def get_name(self):
return example
- self.assertEquals(MockView().get_name(), example)
+ self.assertEqual(MockView().get_name(), example)
def test_resource_description_uses_docstring_by_default(self):
"""Ensure Resource names are based on the docstring by default."""
@@ -78,7 +81,7 @@ class TestViewNamesAndDescriptions(TestCase):
# hash style header #"""
- self.assertEquals(MockView().get_description(), DESCRIPTION)
+ self.assertEqual(MockView().get_description(), DESCRIPTION)
def test_resource_description_can_be_set_explicitly(self):
"""Ensure Resource descriptions can be set using the 'get_description' method."""
@@ -88,7 +91,16 @@ class TestViewNamesAndDescriptions(TestCase):
"""docstring"""
def get_description(self):
return example
- self.assertEquals(MockView().get_description(), example)
+ self.assertEqual(MockView().get_description(), example)
+
+ def test_resource_description_supports_unicode(self):
+
+ class MockView(APIView):
+ """Проверка"""
+ pass
+
+ self.assertEqual(MockView().get_description(), "Проверка")
+
def test_resource_description_does_not_require_docstring(self):
"""Ensure that empty docstrings do not affect the Resource's description if it has been set using the 'get_description' method."""
@@ -97,13 +109,13 @@ class TestViewNamesAndDescriptions(TestCase):
class MockView(APIView):
def get_description(self):
return example
- self.assertEquals(MockView().get_description(), example)
+ self.assertEqual(MockView().get_description(), example)
def test_resource_description_can_be_empty(self):
"""Ensure that if a resource has no doctring or 'description' class attribute, then it's description is the empty string."""
class MockView(APIView):
pass
- self.assertEquals(MockView().get_description(), '')
+ self.assertEqual(MockView().get_description(), '')
def test_markdown(self):
"""Ensure markdown to HTML works as expected"""
diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py
index 8068272d..fd6de779 100644
--- a/rest_framework/tests/fields.py
+++ b/rest_framework/tests/fields.py
@@ -1,9 +1,13 @@
"""
General serializer field tests.
"""
+from __future__ import unicode_literals
+import datetime
from django.db import models
from django.test import TestCase
+from django.core import validators
+
from rest_framework import serializers
@@ -26,24 +30,415 @@ class CharPrimaryKeyModelSerializer(serializers.ModelSerializer):
model = CharPrimaryKeyModel
-class ReadOnlyFieldTests(TestCase):
+class TimeFieldModel(models.Model):
+ clock = models.TimeField()
+
+
+class TimeFieldModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = TimeFieldModel
+
+
+class BasicFieldTests(TestCase):
def test_auto_now_fields_read_only(self):
"""
auto_now and auto_now_add fields should be read_only by default.
"""
serializer = TimestampedModelSerializer()
- self.assertEquals(serializer.fields['added'].read_only, True)
+ self.assertEqual(serializer.fields['added'].read_only, True)
def test_auto_pk_fields_read_only(self):
"""
AutoField fields should be read_only by default.
"""
serializer = TimestampedModelSerializer()
- self.assertEquals(serializer.fields['id'].read_only, True)
+ self.assertEqual(serializer.fields['id'].read_only, True)
def test_non_auto_pk_fields_not_read_only(self):
"""
PK fields other than AutoField fields should not be read_only by default.
"""
serializer = CharPrimaryKeyModelSerializer()
- self.assertEquals(serializer.fields['id'].read_only, False)
+ self.assertEqual(serializer.fields['id'].read_only, False)
+
+
+class DateFieldTest(TestCase):
+ """
+ Tests for the DateFieldTest from_native() and to_native() behavior
+ """
+
+ def test_from_native_string(self):
+ """
+ Make sure from_native() accepts default iso input formats.
+ """
+ f = serializers.DateField()
+ result_1 = f.from_native('1984-07-31')
+
+ self.assertEqual(datetime.date(1984, 7, 31), result_1)
+
+ def test_from_native_datetime_date(self):
+ """
+ Make sure from_native() accepts a datetime.date instance.
+ """
+ f = serializers.DateField()
+ result_1 = f.from_native(datetime.date(1984, 7, 31))
+
+ self.assertEqual(result_1, datetime.date(1984, 7, 31))
+
+ def test_from_native_custom_format(self):
+ """
+ Make sure from_native() accepts custom input formats.
+ """
+ f = serializers.DateField(input_formats=['%Y -- %d'])
+ result = f.from_native('1984 -- 31')
+
+ self.assertEqual(datetime.date(1984, 1, 31), result)
+
+ def test_from_native_invalid_default_on_custom_format(self):
+ """
+ Make sure from_native() don't accept default formats if custom format is preset
+ """
+ f = serializers.DateField(input_formats=['%Y -- %d'])
+
+ try:
+ f.from_native('1984-07-31')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY -- DD"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns None on empty param.
+ """
+ f = serializers.DateField()
+ result = f.from_native('')
+
+ self.assertEqual(result, None)
+
+ def test_from_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DateField()
+ result = f.from_native(None)
+
+ self.assertEqual(result, None)
+
+ def test_from_native_invalid_date(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid date.
+ """
+ f = serializers.DateField()
+
+ try:
+ f.from_native('1984-13-31')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_invalid_format(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid format.
+ """
+ f = serializers.DateField()
+
+ try:
+ f.from_native('1984 -- 31')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_to_native(self):
+ """
+ Make sure to_native() returns isoformat as default.
+ """
+ f = serializers.DateField()
+
+ result_1 = f.to_native(datetime.date(1984, 7, 31))
+
+ self.assertEqual('1984-07-31', result_1)
+
+ def test_to_native_custom_format(self):
+ """
+ Make sure to_native() returns correct custom format.
+ """
+ f = serializers.DateField(format="%Y - %m.%d")
+
+ result_1 = f.to_native(datetime.date(1984, 7, 31))
+
+ self.assertEqual('1984 - 07.31', result_1)
+
+ def test_to_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DateField(required=False)
+ self.assertEqual(None, f.to_native(None))
+
+
+class DateTimeFieldTest(TestCase):
+ """
+ Tests for the DateTimeField from_native() and to_native() behavior
+ """
+
+ def test_from_native_string(self):
+ """
+ Make sure from_native() accepts default iso input formats.
+ """
+ f = serializers.DateTimeField()
+ result_1 = f.from_native('1984-07-31 04:31')
+ result_2 = f.from_native('1984-07-31 04:31:59')
+ result_3 = f.from_native('1984-07-31 04:31:59.000200')
+
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_1)
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_2)
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_3)
+
+ def test_from_native_datetime_datetime(self):
+ """
+ Make sure from_native() accepts a datetime.datetime instance.
+ """
+ f = serializers.DateTimeField()
+ result_1 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31))
+ result_2 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
+ result_3 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+
+ self.assertEqual(result_1, datetime.datetime(1984, 7, 31, 4, 31))
+ self.assertEqual(result_2, datetime.datetime(1984, 7, 31, 4, 31, 59))
+ self.assertEqual(result_3, datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+
+ def test_from_native_custom_format(self):
+ """
+ Make sure from_native() accepts custom input formats.
+ """
+ f = serializers.DateTimeField(input_formats=['%Y -- %H:%M'])
+ result = f.from_native('1984 -- 04:59')
+
+ self.assertEqual(datetime.datetime(1984, 1, 1, 4, 59), result)
+
+ def test_from_native_invalid_default_on_custom_format(self):
+ """
+ Make sure from_native() don't accept default formats if custom format is preset
+ """
+ f = serializers.DateTimeField(input_formats=['%Y -- %H:%M'])
+
+ try:
+ f.from_native('1984-07-31 04:31:59')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: YYYY -- hh:mm"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns None on empty param.
+ """
+ f = serializers.DateTimeField()
+ result = f.from_native('')
+
+ self.assertEqual(result, None)
+
+ def test_from_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DateTimeField()
+ result = f.from_native(None)
+
+ self.assertEqual(result, None)
+
+ def test_from_native_invalid_datetime(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid datetime.
+ """
+ f = serializers.DateTimeField()
+
+ try:
+ f.from_native('04:61:59')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: "
+ "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_invalid_format(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid format.
+ """
+ f = serializers.DateTimeField()
+
+ try:
+ f.from_native('04 -- 31')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: "
+ "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_to_native(self):
+ """
+ Make sure to_native() returns isoformat as default.
+ """
+ f = serializers.DateTimeField()
+
+ result_1 = f.to_native(datetime.datetime(1984, 7, 31))
+ result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31))
+ result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
+ result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+
+ self.assertEqual('1984-07-31T00:00:00', result_1)
+ self.assertEqual('1984-07-31T04:31:00', result_2)
+ self.assertEqual('1984-07-31T04:31:59', result_3)
+ self.assertEqual('1984-07-31T04:31:59.000200', result_4)
+
+ def test_to_native_custom_format(self):
+ """
+ Make sure to_native() returns correct custom format.
+ """
+ f = serializers.DateTimeField(format="%Y - %H:%M")
+
+ result_1 = f.to_native(datetime.datetime(1984, 7, 31))
+ result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31))
+ result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
+ result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+
+ self.assertEqual('1984 - 00:00', result_1)
+ self.assertEqual('1984 - 04:31', result_2)
+ self.assertEqual('1984 - 04:31', result_3)
+ self.assertEqual('1984 - 04:31', result_4)
+
+ def test_to_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DateTimeField(required=False)
+ self.assertEqual(None, f.to_native(None))
+
+
+class TimeFieldTest(TestCase):
+ """
+ Tests for the TimeField from_native() and to_native() behavior
+ """
+
+ def test_from_native_string(self):
+ """
+ Make sure from_native() accepts default iso input formats.
+ """
+ f = serializers.TimeField()
+ result_1 = f.from_native('04:31')
+ result_2 = f.from_native('04:31:59')
+ result_3 = f.from_native('04:31:59.000200')
+
+ self.assertEqual(datetime.time(4, 31), result_1)
+ self.assertEqual(datetime.time(4, 31, 59), result_2)
+ self.assertEqual(datetime.time(4, 31, 59, 200), result_3)
+
+ def test_from_native_datetime_time(self):
+ """
+ Make sure from_native() accepts a datetime.time instance.
+ """
+ f = serializers.TimeField()
+ result_1 = f.from_native(datetime.time(4, 31))
+ result_2 = f.from_native(datetime.time(4, 31, 59))
+ result_3 = f.from_native(datetime.time(4, 31, 59, 200))
+
+ self.assertEqual(result_1, datetime.time(4, 31))
+ self.assertEqual(result_2, datetime.time(4, 31, 59))
+ self.assertEqual(result_3, datetime.time(4, 31, 59, 200))
+
+ def test_from_native_custom_format(self):
+ """
+ Make sure from_native() accepts custom input formats.
+ """
+ f = serializers.TimeField(input_formats=['%H -- %M'])
+ result = f.from_native('04 -- 31')
+
+ self.assertEqual(datetime.time(4, 31), result)
+
+ def test_from_native_invalid_default_on_custom_format(self):
+ """
+ Make sure from_native() don't accept default formats if custom format is preset
+ """
+ f = serializers.TimeField(input_formats=['%H -- %M'])
+
+ try:
+ f.from_native('04:31:59')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: hh -- mm"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns None on empty param.
+ """
+ f = serializers.TimeField()
+ result = f.from_native('')
+
+ self.assertEqual(result, None)
+
+ def test_from_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.TimeField()
+ result = f.from_native(None)
+
+ self.assertEqual(result, None)
+
+ def test_from_native_invalid_time(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid time.
+ """
+ f = serializers.TimeField()
+
+ try:
+ f.from_native('04:61:59')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: "
+ "hh:mm[:ss[.uuuuuu]]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_invalid_format(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid format.
+ """
+ f = serializers.TimeField()
+
+ try:
+ f.from_native('04 -- 31')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: "
+ "hh:mm[:ss[.uuuuuu]]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_to_native(self):
+ """
+ Make sure to_native() returns isoformat as default.
+ """
+ f = serializers.TimeField()
+ result_1 = f.to_native(datetime.time(4, 31))
+ result_2 = f.to_native(datetime.time(4, 31, 59))
+ result_3 = f.to_native(datetime.time(4, 31, 59, 200))
+
+ self.assertEqual('04:31:00', result_1)
+ self.assertEqual('04:31:59', result_2)
+ self.assertEqual('04:31:59.000200', result_3)
+
+ def test_to_native_custom_format(self):
+ """
+ Make sure to_native() returns correct custom format.
+ """
+ f = serializers.TimeField(format="%H - %S [%f]")
+ result_1 = f.to_native(datetime.time(4, 31))
+ result_2 = f.to_native(datetime.time(4, 31, 59))
+ result_3 = f.to_native(datetime.time(4, 31, 59, 200))
+
+ self.assertEqual('04 - 00 [000000]', result_1)
+ self.assertEqual('04 - 59 [000000]', result_2)
+ self.assertEqual('04 - 59 [000200]', result_3)
diff --git a/rest_framework/tests/files.py b/rest_framework/tests/files.py
index 446e23c0..487046ac 100644
--- a/rest_framework/tests/files.py
+++ b/rest_framework/tests/files.py
@@ -1,9 +1,9 @@
-import StringIO
-import datetime
-
+from __future__ import unicode_literals
from django.test import TestCase
-
from rest_framework import serializers
+from rest_framework.compat import BytesIO
+from rest_framework.compat import six
+import datetime
class UploadedFile(object):
@@ -27,14 +27,14 @@ class UploadedFileSerializer(serializers.Serializer):
class FileSerializerTests(TestCase):
def test_create(self):
now = datetime.datetime.now()
- file = StringIO.StringIO('stuff')
+ file = BytesIO(six.b('stuff'))
file.name = 'stuff.txt'
- file.size = file.len
+ file.size = len(file.getvalue())
serializer = UploadedFileSerializer(data={'created': now}, files={'file': file})
uploaded_file = UploadedFile(file=file, created=now)
self.assertTrue(serializer.is_valid())
- self.assertEquals(serializer.object.created, uploaded_file.created)
- self.assertEquals(serializer.object.file, uploaded_file.file)
+ self.assertEqual(serializer.object.created, uploaded_file.created)
+ self.assertEqual(serializer.object.file, uploaded_file.file)
self.assertFalse(serializer.object is uploaded_file)
def test_creation_failure(self):
diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py
index af2e6c2e..fe92e0bc 100644
--- a/rest_framework/tests/filterset.py
+++ b/rest_framework/tests/filterset.py
@@ -1,3 +1,4 @@
+from __future__ import unicode_literals
import datetime
from decimal import Decimal
from django.test import TestCase
@@ -64,8 +65,8 @@ class IntegrationTestFiltering(TestCase):
self.objects = FilterableItem.objects
self.data = [
- {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
- for obj in self.objects.all()
+ {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date.isoformat()}
+ for obj in self.objects.all()
]
@unittest.skipUnless(django_filters, 'django-filters not installed')
@@ -78,24 +79,24 @@ class IntegrationTestFiltering(TestCase):
# Basic test with no filter.
request = factory.get('/')
response = view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data, self.data)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
# Tests that the decimal filter works.
search_decimal = Decimal('2.25')
request = factory.get('/?decimal=%s' % search_decimal)
response = view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['decimal'] == search_decimal]
- self.assertEquals(response.data, expected_data)
+ self.assertEqual(response.data, expected_data)
# Tests that the date filter works.
search_date = datetime.date(2012, 9, 22)
request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22'
response = view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['date'] == search_date]
- self.assertEquals(response.data, expected_data)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if datetime.datetime.strptime(f['date'], '%Y-%m-%d').date() == search_date]
+ self.assertEqual(response.data, expected_data)
@unittest.skipUnless(django_filters, 'django-filters not installed')
def test_get_filtered_class_root_view(self):
@@ -108,42 +109,43 @@ class IntegrationTestFiltering(TestCase):
# Basic test with no filter.
request = factory.get('/')
response = view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data, self.data)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
# Tests that the decimal filter set with 'lt' in the filter class works.
search_decimal = Decimal('4.25')
request = factory.get('/?decimal=%s' % search_decimal)
response = view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['decimal'] < search_decimal]
- self.assertEquals(response.data, expected_data)
+ self.assertEqual(response.data, expected_data)
# Tests that the date filter set with 'gt' in the filter class works.
search_date = datetime.date(2012, 10, 2)
request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02'
response = view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['date'] > search_date]
- self.assertEquals(response.data, expected_data)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if datetime.datetime.strptime(f['date'], '%Y-%m-%d').date() > search_date]
+ self.assertEqual(response.data, expected_data)
# Tests that the text filter set with 'icontains' in the filter class works.
search_text = 'ff'
request = factory.get('/?text=%s' % search_text)
response = view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if search_text in f['text'].lower()]
- self.assertEquals(response.data, expected_data)
+ self.assertEqual(response.data, expected_data)
# Tests that multiple filters works.
search_decimal = Decimal('5.25')
search_date = datetime.date(2012, 10, 2)
request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date))
response = view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['date'] > search_date and
- f['decimal'] < search_decimal]
- self.assertEquals(response.data, expected_data)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if
+ datetime.datetime.strptime(f['date'], '%Y-%m-%d').date() > search_date and
+ f['decimal'] < search_decimal]
+ self.assertEqual(response.data, expected_data)
@unittest.skipUnless(django_filters, 'django-filters not installed')
def test_incorrectly_configured_filter(self):
@@ -165,4 +167,4 @@ class IntegrationTestFiltering(TestCase):
search_integer = 10
request = factory.get('/?integer=%s' % search_integer)
response = view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
diff --git a/rest_framework/tests/genericrelations.py b/rest_framework/tests/genericrelations.py
index bc7378e1..c38bfb9f 100644
--- a/rest_framework/tests/genericrelations.py
+++ b/rest_framework/tests/genericrelations.py
@@ -1,25 +1,62 @@
+from __future__ import unicode_literals
+from django.contrib.contenttypes.models import ContentType
+from django.contrib.contenttypes.generic import GenericRelation, GenericForeignKey
+from django.db import models
from django.test import TestCase
from rest_framework import serializers
-from rest_framework.tests.models import *
+
+
+class Tag(models.Model):
+ """
+ Tags have a descriptive slug, and are attached to an arbitrary object.
+ """
+ tag = models.SlugField()
+ content_type = models.ForeignKey(ContentType)
+ object_id = models.PositiveIntegerField()
+ tagged_item = GenericForeignKey('content_type', 'object_id')
+
+ def __unicode__(self):
+ return self.tag
+
+
+class Bookmark(models.Model):
+ """
+ A URL bookmark that may have multiple tags attached.
+ """
+ url = models.URLField()
+ tags = GenericRelation(Tag)
+
+ def __unicode__(self):
+ return 'Bookmark: %s' % self.url
+
+
+class Note(models.Model):
+ """
+ A textual note that may have multiple tags attached.
+ """
+ text = models.TextField()
+ tags = GenericRelation(Tag)
+
+ def __unicode__(self):
+ return 'Note: %s' % self.text
class TestGenericRelations(TestCase):
def setUp(self):
- bookmark = Bookmark(url='https://www.djangoproject.com/')
- bookmark.save()
- django = Tag(tag_name='django')
- django.save()
- python = Tag(tag_name='python')
- python.save()
- t1 = TaggedItem(content_object=bookmark, tag=django)
- t1.save()
- t2 = TaggedItem(content_object=bookmark, tag=python)
- t2.save()
- self.bookmark = bookmark
-
- def test_reverse_generic_relation(self):
+ self.bookmark = Bookmark.objects.create(url='https://www.djangoproject.com/')
+ Tag.objects.create(tagged_item=self.bookmark, tag='django')
+ Tag.objects.create(tagged_item=self.bookmark, tag='python')
+ self.note = Note.objects.create(text='Remember the milk')
+ Tag.objects.create(tagged_item=self.note, tag='reminder')
+
+ def test_generic_relation(self):
+ """
+ Test a relationship that spans a GenericRelation field.
+ IE. A reverse generic relationship.
+ """
+
class BookmarkSerializer(serializers.ModelSerializer):
- tags = serializers.ManyRelatedField(source='tags')
+ tags = serializers.RelatedField(many=True)
class Meta:
model = Bookmark
@@ -27,7 +64,37 @@ class TestGenericRelations(TestCase):
serializer = BookmarkSerializer(self.bookmark)
expected = {
- 'tags': [u'django', u'python'],
- 'url': u'https://www.djangoproject.com/'
+ 'tags': ['django', 'python'],
+ 'url': 'https://www.djangoproject.com/'
+ }
+ self.assertEqual(serializer.data, expected)
+
+ def test_generic_fk(self):
+ """
+ Test a relationship that spans a GenericForeignKey field.
+ IE. A forward generic relationship.
+ """
+
+ class TagSerializer(serializers.ModelSerializer):
+ tagged_item = serializers.RelatedField()
+
+ class Meta:
+ model = Tag
+ exclude = ('id', 'content_type', 'object_id')
+
+ serializer = TagSerializer(Tag.objects.all(), many=True)
+ expected = [
+ {
+ 'tag': 'django',
+ 'tagged_item': 'Bookmark: https://www.djangoproject.com/'
+ },
+ {
+ 'tag': 'python',
+ 'tagged_item': 'Bookmark: https://www.djangoproject.com/'
+ },
+ {
+ 'tag': 'reminder',
+ 'tagged_item': 'Note: Remember the milk'
}
- self.assertEquals(serializer.data, expected)
+ ]
+ self.assertEqual(serializer.data, expected)
diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py
index 4799a04b..f564890c 100644
--- a/rest_framework/tests/generics.py
+++ b/rest_framework/tests/generics.py
@@ -1,10 +1,11 @@
-import json
+from __future__ import unicode_literals
from django.db import models
from django.test import TestCase
from rest_framework import generics, serializers, status
from rest_framework.tests.utils import RequestFactory
from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel
-
+from rest_framework.compat import six
+import json
factory = RequestFactory()
@@ -42,7 +43,7 @@ class SlugBasedInstanceView(InstanceView):
class TestRootView(TestCase):
def setUp(self):
"""
- Create 3 BasicModel intances.
+ Create 3 BasicModel instances.
"""
items = ['foo', 'bar', 'baz']
for item in items:
@@ -59,9 +60,10 @@ class TestRootView(TestCase):
GET requests to ListCreateAPIView should return list of objects.
"""
request = factory.get('/')
- response = self.view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data, self.data)
+ with self.assertNumQueries(1):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
def test_post_root_view(self):
"""
@@ -70,11 +72,12 @@ class TestRootView(TestCase):
content = {'text': 'foobar'}
request = factory.post('/', json.dumps(content),
content_type='application/json')
- response = self.view(request).render()
- self.assertEquals(response.status_code, status.HTTP_201_CREATED)
- self.assertEquals(response.data, {'id': 4, 'text': u'foobar'})
+ with self.assertNumQueries(1):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ self.assertEqual(response.data, {'id': 4, 'text': 'foobar'})
created = self.objects.get(id=4)
- self.assertEquals(created.text, 'foobar')
+ self.assertEqual(created.text, 'foobar')
def test_put_root_view(self):
"""
@@ -83,25 +86,28 @@ class TestRootView(TestCase):
content = {'text': 'foobar'}
request = factory.put('/', json.dumps(content),
content_type='application/json')
- response = self.view(request).render()
- self.assertEquals(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
- self.assertEquals(response.data, {"detail": "Method 'PUT' not allowed."})
+ with self.assertNumQueries(0):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
+ self.assertEqual(response.data, {"detail": "Method 'PUT' not allowed."})
def test_delete_root_view(self):
"""
DELETE requests to ListCreateAPIView should not be allowed
"""
request = factory.delete('/')
- response = self.view(request).render()
- self.assertEquals(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
- self.assertEquals(response.data, {"detail": "Method 'DELETE' not allowed."})
+ with self.assertNumQueries(0):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
+ self.assertEqual(response.data, {"detail": "Method 'DELETE' not allowed."})
def test_options_root_view(self):
"""
OPTIONS requests to ListCreateAPIView should return metadata
"""
request = factory.options('/')
- response = self.view(request).render()
+ with self.assertNumQueries(0):
+ response = self.view(request).render()
expected = {
'parses': [
'application/json',
@@ -115,8 +121,8 @@ class TestRootView(TestCase):
'name': 'Root',
'description': 'Example description for OPTIONS.'
}
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data, expected)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, expected)
def test_post_cannot_set_id(self):
"""
@@ -125,11 +131,12 @@ class TestRootView(TestCase):
content = {'id': 999, 'text': 'foobar'}
request = factory.post('/', json.dumps(content),
content_type='application/json')
- response = self.view(request).render()
- self.assertEquals(response.status_code, status.HTTP_201_CREATED)
- self.assertEquals(response.data, {'id': 4, 'text': u'foobar'})
+ with self.assertNumQueries(1):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ self.assertEqual(response.data, {'id': 4, 'text': 'foobar'})
created = self.objects.get(id=4)
- self.assertEquals(created.text, 'foobar')
+ self.assertEqual(created.text, 'foobar')
class TestInstanceView(TestCase):
@@ -153,9 +160,10 @@ class TestInstanceView(TestCase):
GET requests to RetrieveUpdateDestroyAPIView should return a single object.
"""
request = factory.get('/1')
- response = self.view(request, pk=1).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data, self.data[0])
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data[0])
def test_post_instance_view(self):
"""
@@ -164,9 +172,10 @@ class TestInstanceView(TestCase):
content = {'text': 'foobar'}
request = factory.post('/', json.dumps(content),
content_type='application/json')
- response = self.view(request).render()
- self.assertEquals(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
- self.assertEquals(response.data, {"detail": "Method 'POST' not allowed."})
+ with self.assertNumQueries(0):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
+ self.assertEqual(response.data, {"detail": "Method 'POST' not allowed."})
def test_put_instance_view(self):
"""
@@ -175,11 +184,12 @@ class TestInstanceView(TestCase):
content = {'text': 'foobar'}
request = factory.put('/1', json.dumps(content),
content_type='application/json')
- response = self.view(request, pk='1').render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data, {'id': 1, 'text': 'foobar'})
+ with self.assertNumQueries(2):
+ response = self.view(request, pk='1').render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
updated = self.objects.get(id=1)
- self.assertEquals(updated.text, 'foobar')
+ self.assertEqual(updated.text, 'foobar')
def test_patch_instance_view(self):
"""
@@ -189,29 +199,32 @@ class TestInstanceView(TestCase):
request = factory.patch('/1', json.dumps(content),
content_type='application/json')
- response = self.view(request, pk=1).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data, {'id': 1, 'text': 'foobar'})
+ with self.assertNumQueries(2):
+ response = self.view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
updated = self.objects.get(id=1)
- self.assertEquals(updated.text, 'foobar')
+ self.assertEqual(updated.text, 'foobar')
def test_delete_instance_view(self):
"""
DELETE requests to RetrieveUpdateDestroyAPIView should delete an object.
"""
request = factory.delete('/1')
- response = self.view(request, pk=1).render()
- self.assertEquals(response.status_code, status.HTTP_204_NO_CONTENT)
- self.assertEquals(response.content, '')
+ with self.assertNumQueries(2):
+ response = self.view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
+ self.assertEqual(response.content, six.b(''))
ids = [obj.id for obj in self.objects.all()]
- self.assertEquals(ids, [2, 3])
+ self.assertEqual(ids, [2, 3])
def test_options_instance_view(self):
"""
OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata
"""
request = factory.options('/')
- response = self.view(request).render()
+ with self.assertNumQueries(0):
+ response = self.view(request).render()
expected = {
'parses': [
'application/json',
@@ -225,8 +238,8 @@ class TestInstanceView(TestCase):
'name': 'Instance',
'description': 'Example description for OPTIONS.'
}
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data, expected)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, expected)
def test_put_cannot_set_id(self):
"""
@@ -235,11 +248,12 @@ class TestInstanceView(TestCase):
content = {'id': 999, 'text': 'foobar'}
request = factory.put('/1', json.dumps(content),
content_type='application/json')
- response = self.view(request, pk=1).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data, {'id': 1, 'text': 'foobar'})
+ with self.assertNumQueries(2):
+ response = self.view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
updated = self.objects.get(id=1)
- self.assertEquals(updated.text, 'foobar')
+ self.assertEqual(updated.text, 'foobar')
def test_put_to_deleted_instance(self):
"""
@@ -250,11 +264,12 @@ class TestInstanceView(TestCase):
content = {'text': 'foobar'}
request = factory.put('/1', json.dumps(content),
content_type='application/json')
- response = self.view(request, pk=1).render()
- self.assertEquals(response.status_code, status.HTTP_201_CREATED)
- self.assertEquals(response.data, {'id': 1, 'text': 'foobar'})
+ with self.assertNumQueries(3):
+ response = self.view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
updated = self.objects.get(id=1)
- self.assertEquals(updated.text, 'foobar')
+ self.assertEqual(updated.text, 'foobar')
def test_put_as_create_on_id_based_url(self):
"""
@@ -262,13 +277,14 @@ class TestInstanceView(TestCase):
at the requested url if it doesn't exist.
"""
content = {'text': 'foobar'}
- # pk fields can not be created on demand, only the database can set th pk for a new object
+ # pk fields can not be created on demand, only the database can set the pk for a new object
request = factory.put('/5', json.dumps(content),
content_type='application/json')
- response = self.view(request, pk=5).render()
- self.assertEquals(response.status_code, status.HTTP_201_CREATED)
+ with self.assertNumQueries(3):
+ response = self.view(request, pk=5).render()
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
new_obj = self.objects.get(pk=5)
- self.assertEquals(new_obj.text, 'foobar')
+ self.assertEqual(new_obj.text, 'foobar')
def test_put_as_create_on_slug_based_url(self):
"""
@@ -278,11 +294,12 @@ class TestInstanceView(TestCase):
content = {'text': 'foobar'}
request = factory.put('/test_slug', json.dumps(content),
content_type='application/json')
- response = self.slug_based_view(request, slug='test_slug').render()
- self.assertEquals(response.status_code, status.HTTP_201_CREATED)
- self.assertEquals(response.data, {'slug': 'test_slug', 'text': 'foobar'})
+ with self.assertNumQueries(2):
+ response = self.slug_based_view(request, slug='test_slug').render()
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ self.assertEqual(response.data, {'slug': 'test_slug', 'text': 'foobar'})
new_obj = SlugBasedModel.objects.get(slug='test_slug')
- self.assertEquals(new_obj.text, 'foobar')
+ self.assertEqual(new_obj.text, 'foobar')
# Regression test for #285
@@ -313,12 +330,12 @@ class TestCreateModelWithAutoNowAddField(TestCase):
request = factory.post('/', json.dumps(content),
content_type='application/json')
response = self.view(request).render()
- self.assertEquals(response.status_code, status.HTTP_201_CREATED)
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
created = self.objects.get(id=1)
- self.assertEquals(created.content, 'foobar')
+ self.assertEqual(created.content, 'foobar')
-# Test for particularly ugly reression with m2m in browseable API
+# Test for particularly ugly regression with m2m in browseable API
class ClassB(models.Model):
name = models.CharField(max_length=255)
@@ -329,7 +346,7 @@ class ClassA(models.Model):
class ClassASerializer(serializers.ModelSerializer):
- childs = serializers.ManyPrimaryKeyRelatedField(source='childs')
+ childs = serializers.PrimaryKeyRelatedField(many=True, source='childs')
class Meta:
model = ClassA
@@ -343,9 +360,84 @@ class ExampleView(generics.ListCreateAPIView):
class TestM2MBrowseableAPI(TestCase):
def test_m2m_in_browseable_api(self):
"""
- Test for particularly ugly reression with m2m in browseable API
+ Test for particularly ugly regression with m2m in browseable API
"""
request = factory.get('/', HTTP_ACCEPT='text/html')
view = ExampleView().as_view()
response = view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+
+class InclusiveFilterBackend(object):
+ def filter_queryset(self, request, queryset, view):
+ return queryset.filter(text='foo')
+
+
+class ExclusiveFilterBackend(object):
+ def filter_queryset(self, request, queryset, view):
+ return queryset.filter(text='other')
+
+
+class TestFilterBackendAppliedToViews(TestCase):
+
+ def setUp(self):
+ """
+ Create 3 BasicModel instances to filter on.
+ """
+ items = ['foo', 'bar', 'baz']
+ for item in items:
+ BasicModel(text=item).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.root_view = RootView.as_view()
+ self.instance_view = InstanceView.as_view()
+ self.original_root_backend = getattr(RootView, 'filter_backend')
+ self.original_instance_backend = getattr(InstanceView, 'filter_backend')
+
+ def tearDown(self):
+ setattr(RootView, 'filter_backend', self.original_root_backend)
+ setattr(InstanceView, 'filter_backend', self.original_instance_backend)
+
+ def test_get_root_view_filters_by_name_with_filter_backend(self):
+ """
+ GET requests to ListCreateAPIView should return filtered list.
+ """
+ setattr(RootView, 'filter_backend', InclusiveFilterBackend)
+ request = factory.get('/')
+ response = self.root_view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(len(response.data), 1)
+ self.assertEqual(response.data, [{'id': 1, 'text': 'foo'}])
+
+ def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(self):
+ """
+ GET requests to ListCreateAPIView should return empty list when all models are filtered out.
+ """
+ setattr(RootView, 'filter_backend', ExclusiveFilterBackend)
+ request = factory.get('/')
+ response = self.root_view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, [])
+
+ def test_get_instance_view_filters_out_name_with_filter_backend(self):
+ """
+ GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out.
+ """
+ setattr(InstanceView, 'filter_backend', ExclusiveFilterBackend)
+ request = factory.get('/1')
+ response = self.instance_view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+ self.assertEqual(response.data, {'detail': 'Not found'})
+
+ def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self):
+ """
+ GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded
+ """
+ setattr(InstanceView, 'filter_backend', InclusiveFilterBackend)
+ request = factory.get('/1')
+ response = self.instance_view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, {'id': 1, 'text': 'foo'})
diff --git a/rest_framework/tests/htmlrenderer.py b/rest_framework/tests/htmlrenderer.py
index 54096206..8f2e2b5a 100644
--- a/rest_framework/tests/htmlrenderer.py
+++ b/rest_framework/tests/htmlrenderer.py
@@ -1,12 +1,15 @@
+from __future__ import unicode_literals
from django.core.exceptions import PermissionDenied
from django.http import Http404
from django.test import TestCase
from django.template import TemplateDoesNotExist, Template
import django.template.loader
+from rest_framework import status
from rest_framework.compat import patterns, url
from rest_framework.decorators import api_view, renderer_classes
from rest_framework.renderers import TemplateHTMLRenderer
from rest_framework.response import Response
+from rest_framework.compat import six
@api_view(('GET',))
@@ -63,19 +66,19 @@ class TemplateHTMLRendererTests(TestCase):
def test_simple_html_view(self):
response = self.client.get('/')
self.assertContains(response, "example: foobar")
- self.assertEquals(response['Content-Type'], 'text/html')
+ self.assertEqual(response['Content-Type'], 'text/html')
def test_not_found_html_view(self):
response = self.client.get('/not_found')
- self.assertEquals(response.status_code, 404)
- self.assertEquals(response.content, "404 Not Found")
- self.assertEquals(response['Content-Type'], 'text/html')
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+ self.assertEqual(response.content, six.b("404 Not Found"))
+ self.assertEqual(response['Content-Type'], 'text/html')
def test_permission_denied_html_view(self):
response = self.client.get('/permission_denied')
- self.assertEquals(response.status_code, 403)
- self.assertEquals(response.content, "403 Forbidden")
- self.assertEquals(response['Content-Type'], 'text/html')
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+ self.assertEqual(response.content, six.b("403 Forbidden"))
+ self.assertEqual(response['Content-Type'], 'text/html')
class TemplateHTMLRendererExceptionTests(TestCase):
@@ -104,12 +107,12 @@ class TemplateHTMLRendererExceptionTests(TestCase):
def test_not_found_html_view_with_template(self):
response = self.client.get('/not_found')
- self.assertEquals(response.status_code, 404)
- self.assertEquals(response.content, "404: Not found")
- self.assertEquals(response['Content-Type'], 'text/html')
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+ self.assertEqual(response.content, six.b("404: Not found"))
+ self.assertEqual(response['Content-Type'], 'text/html')
def test_permission_denied_html_view_with_template(self):
response = self.client.get('/permission_denied')
- self.assertEquals(response.status_code, 403)
- self.assertEquals(response.content, "403: Permission denied")
- self.assertEquals(response['Content-Type'], 'text/html')
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+ self.assertEqual(response.content, six.b("403: Permission denied"))
+ self.assertEqual(response['Content-Type'], 'text/html')
diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py
index c6a8224b..9a61f299 100644
--- a/rest_framework/tests/hyperlinkedserializers.py
+++ b/rest_framework/tests/hyperlinkedserializers.py
@@ -1,3 +1,4 @@
+from __future__ import unicode_literals
import json
from django.test import TestCase
from django.test.client import RequestFactory
@@ -99,7 +100,7 @@ class TestBasicHyperlinkedView(TestCase):
def setUp(self):
"""
- Create 3 BasicModel intances.
+ Create 3 BasicModel instances.
"""
items = ['foo', 'bar', 'baz']
for item in items:
@@ -118,8 +119,8 @@ class TestBasicHyperlinkedView(TestCase):
"""
request = factory.get('/basic/')
response = self.list_view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data, self.data)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
def test_get_detail_view(self):
"""
@@ -127,8 +128,8 @@ class TestBasicHyperlinkedView(TestCase):
"""
request = factory.get('/basic/1')
response = self.detail_view(request, pk=1).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data, self.data[0])
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data[0])
class TestManyToManyHyperlinkedView(TestCase):
@@ -136,7 +137,7 @@ class TestManyToManyHyperlinkedView(TestCase):
def setUp(self):
"""
- Create 3 BasicModel intances.
+ Create 3 BasicModel instances.
"""
items = ['foo', 'bar', 'baz']
anchors = []
@@ -166,8 +167,8 @@ class TestManyToManyHyperlinkedView(TestCase):
"""
request = factory.get('/manytomany/')
response = self.list_view(request)
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data, self.data)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
def test_get_detail_view(self):
"""
@@ -175,8 +176,8 @@ class TestManyToManyHyperlinkedView(TestCase):
"""
request = factory.get('/manytomany/1/')
response = self.detail_view(request, pk=1)
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data, self.data[0])
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data[0])
class TestCreateWithForeignKeys(TestCase):
@@ -234,7 +235,7 @@ class TestOptionalRelationHyperlinkedView(TestCase):
def setUp(self):
"""
- Create 1 OptionalRelationModel intances.
+ Create 1 OptionalRelationModel instances.
"""
OptionalRelationModel().save()
self.objects = OptionalRelationModel.objects
@@ -248,8 +249,8 @@ class TestOptionalRelationHyperlinkedView(TestCase):
"""
request = factory.get('/optionalrelationmodel-detail/1')
response = self.detail_view(request, pk=1)
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data, self.data)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
def test_put_detail_view(self):
"""
diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py
index 93f09761..f2117538 100644
--- a/rest_framework/tests/models.py
+++ b/rest_framework/tests/models.py
@@ -1,35 +1,6 @@
+from __future__ import unicode_literals
from django.db import models
-from django.contrib.contenttypes.models import ContentType
-from django.contrib.contenttypes.generic import GenericForeignKey, GenericRelation
-# from django.contrib.auth.models import Group
-
-
-# class CustomUser(models.Model):
-# """
-# A custom user model, which uses a 'through' table for the foreign key
-# """
-# username = models.CharField(max_length=255, unique=True)
-# groups = models.ManyToManyField(
-# to=Group, blank=True, null=True, through='UserGroupMap'
-# )
-
-# @models.permalink
-# def get_absolute_url(self):
-# return ('custom_user', (), {
-# 'pk': self.id
-# })
-
-
-# class UserGroupMap(models.Model):
-# user = models.ForeignKey(to=CustomUser)
-# group = models.ForeignKey(to=Group)
-
-# @models.permalink
-# def get_absolute_url(self):
-# return ('user_group_map', (), {
-# 'pk': self.id
-# })
def foobar():
return 'foobar'
@@ -86,27 +57,6 @@ class ReadOnlyManyToManyModel(RESTFrameworkModel):
text = models.CharField(max_length=100, default='anchor')
rel = models.ManyToManyField(Anchor)
-# Models to test generic relations
-
-
-class Tag(RESTFrameworkModel):
- tag_name = models.SlugField()
-
-
-class TaggedItem(RESTFrameworkModel):
- tag = models.ForeignKey(Tag, related_name='items')
- content_type = models.ForeignKey(ContentType)
- object_id = models.PositiveIntegerField()
- content_object = GenericForeignKey('content_type', 'object_id')
-
- def __unicode__(self):
- return self.tag.tag_name
-
-
-class Bookmark(RESTFrameworkModel):
- url = models.URLField()
- tags = GenericRelation(TaggedItem)
-
# Model to test filtering.
class FilterableItem(RESTFrameworkModel):
diff --git a/rest_framework/tests/modelviews.py b/rest_framework/tests/modelviews.py
deleted file mode 100644
index f12e3b97..00000000
--- a/rest_framework/tests/modelviews.py
+++ /dev/null
@@ -1,90 +0,0 @@
-# from rest_framework.compat import patterns, url
-# from django.forms import ModelForm
-# from django.contrib.auth.models import Group, User
-# from rest_framework.resources import ModelResource
-# from rest_framework.views import ListOrCreateModelView, InstanceModelView
-# from rest_framework.tests.models import CustomUser
-# from rest_framework.tests.testcases import TestModelsTestCase
-
-
-# class GroupResource(ModelResource):
-# model = Group
-
-
-# class UserForm(ModelForm):
-# class Meta:
-# model = User
-# exclude = ('last_login', 'date_joined')
-
-
-# class UserResource(ModelResource):
-# model = User
-# form = UserForm
-
-
-# class CustomUserResource(ModelResource):
-# model = CustomUser
-
-# urlpatterns = patterns('',
-# url(r'^users/$', ListOrCreateModelView.as_view(resource=UserResource), name='users'),
-# url(r'^users/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=UserResource)),
-# url(r'^customusers/$', ListOrCreateModelView.as_view(resource=CustomUserResource), name='customusers'),
-# url(r'^customusers/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=CustomUserResource)),
-# url(r'^groups/$', ListOrCreateModelView.as_view(resource=GroupResource), name='groups'),
-# url(r'^groups/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=GroupResource)),
-# )
-
-
-# class ModelViewTests(TestModelsTestCase):
-# """Test the model views rest_framework provides"""
-# urls = 'rest_framework.tests.modelviews'
-
-# def test_creation(self):
-# """Ensure that a model object can be created"""
-# self.assertEqual(0, Group.objects.count())
-
-# response = self.client.post('/groups/', {'name': 'foo'})
-
-# self.assertEqual(response.status_code, 201)
-# self.assertEqual(1, Group.objects.count())
-# self.assertEqual('foo', Group.objects.all()[0].name)
-
-# def test_creation_with_m2m_relation(self):
-# """Ensure that a model object with a m2m relation can be created"""
-# group = Group(name='foo')
-# group.save()
-# self.assertEqual(0, User.objects.count())
-
-# response = self.client.post('/users/', {'username': 'bar', 'password': 'baz', 'groups': [group.id]})
-
-# self.assertEqual(response.status_code, 201)
-# self.assertEqual(1, User.objects.count())
-
-# user = User.objects.all()[0]
-# self.assertEqual('bar', user.username)
-# self.assertEqual('baz', user.password)
-# self.assertEqual(1, user.groups.count())
-
-# group = user.groups.all()[0]
-# self.assertEqual('foo', group.name)
-
-# def test_creation_with_m2m_relation_through(self):
-# """
-# Ensure that a model object with a m2m relation can be created where that
-# relation uses a through table
-# """
-# group = Group(name='foo')
-# group.save()
-# self.assertEqual(0, User.objects.count())
-
-# response = self.client.post('/customusers/', {'username': 'bar', 'groups': [group.id]})
-
-# self.assertEqual(response.status_code, 201)
-# self.assertEqual(1, CustomUser.objects.count())
-
-# user = CustomUser.objects.all()[0]
-# self.assertEqual('bar', user.username)
-# self.assertEqual(1, user.groups.count())
-
-# group = user.groups.all()[0]
-# self.assertEqual('foo', group.name)
diff --git a/rest_framework/tests/multitable_inheritance.py b/rest_framework/tests/multitable_inheritance.py
new file mode 100644
index 00000000..00c15327
--- /dev/null
+++ b/rest_framework/tests/multitable_inheritance.py
@@ -0,0 +1,67 @@
+from __future__ import unicode_literals
+from django.db import models
+from django.test import TestCase
+from rest_framework import serializers
+from rest_framework.tests.models import RESTFrameworkModel
+
+
+# Models
+class ParentModel(RESTFrameworkModel):
+ name1 = models.CharField(max_length=100)
+
+
+class ChildModel(ParentModel):
+ name2 = models.CharField(max_length=100)
+
+
+class AssociatedModel(RESTFrameworkModel):
+ ref = models.OneToOneField(ParentModel, primary_key=True)
+ name = models.CharField(max_length=100)
+
+
+# Serializers
+class DerivedModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ChildModel
+
+
+class AssociatedModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = AssociatedModel
+
+
+# Tests
+class IneritedModelSerializationTests(TestCase):
+
+ def test_multitable_inherited_model_fields_as_expected(self):
+ """
+ Assert that the parent pointer field is not included in the fields
+ serialized fields
+ """
+ child = ChildModel(name1='parent name', name2='child name')
+ serializer = DerivedModelSerializer(child)
+ self.assertEqual(set(serializer.data.keys()),
+ set(['name1', 'name2', 'id']))
+
+ def test_onetoone_primary_key_model_fields_as_expected(self):
+ """
+ Assert that a model with a onetoone field that is the primary key is
+ not treated like a derived model
+ """
+ parent = ParentModel(name1='parent name')
+ associate = AssociatedModel(name='hello', ref=parent)
+ serializer = AssociatedModelSerializer(associate)
+ self.assertEqual(set(serializer.data.keys()),
+ set(['name', 'ref']))
+
+ def test_data_is_valid_without_parent_ptr(self):
+ """
+ Assert that the pointer to the parent table is not a required field
+ for input data
+ """
+ data = {
+ 'name1': 'parent name',
+ 'name2': 'child name',
+ }
+ serializer = DerivedModelSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), True)
diff --git a/rest_framework/tests/negotiation.py b/rest_framework/tests/negotiation.py
index e06354ea..43721b84 100644
--- a/rest_framework/tests/negotiation.py
+++ b/rest_framework/tests/negotiation.py
@@ -1,6 +1,9 @@
+from __future__ import unicode_literals
from django.test import TestCase
from django.test.client import RequestFactory
from rest_framework.negotiation import DefaultContentNegotiation
+from rest_framework.request import Request
+
factory = RequestFactory()
@@ -22,16 +25,16 @@ class TestAcceptedMediaType(TestCase):
return self.negotiator.select_renderer(request, self.renderers)
def test_client_without_accept_use_renderer(self):
- request = factory.get('/')
+ request = Request(factory.get('/'))
accepted_renderer, accepted_media_type = self.select_renderer(request)
- self.assertEquals(accepted_media_type, 'application/json')
+ self.assertEqual(accepted_media_type, 'application/json')
def test_client_underspecifies_accept_use_renderer(self):
- request = factory.get('/', HTTP_ACCEPT='*/*')
+ request = Request(factory.get('/', HTTP_ACCEPT='*/*'))
accepted_renderer, accepted_media_type = self.select_renderer(request)
- self.assertEquals(accepted_media_type, 'application/json')
+ self.assertEqual(accepted_media_type, 'application/json')
def test_client_overspecifies_accept_use_client(self):
- request = factory.get('/', HTTP_ACCEPT='application/json; indent=8')
+ request = Request(factory.get('/', HTTP_ACCEPT='application/json; indent=8'))
accepted_renderer, accepted_media_type = self.select_renderer(request)
- self.assertEquals(accepted_media_type, 'application/json; indent=8')
+ self.assertEqual(accepted_media_type, 'application/json; indent=8')
diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py
index 3b550877..1a2d68a6 100644
--- a/rest_framework/tests/pagination.py
+++ b/rest_framework/tests/pagination.py
@@ -1,5 +1,7 @@
+from __future__ import unicode_literals
import datetime
from decimal import Decimal
+import django
from django.core.paginator import Paginator
from django.test import TestCase
from django.test.client import RequestFactory
@@ -19,21 +21,6 @@ class RootView(generics.ListCreateAPIView):
paginate_by = 10
-if django_filters:
- class DecimalFilter(django_filters.FilterSet):
- decimal = django_filters.NumberFilter(lookup_type='lt')
-
- class Meta:
- model = FilterableItem
- fields = ['text', 'decimal', 'date']
-
- class FilterFieldsRootView(generics.ListCreateAPIView):
- model = FilterableItem
- paginate_by = 10
- filter_class = DecimalFilter
- filter_backend = filters.DjangoFilterBackend
-
-
class DefaultPageSizeKwargView(generics.ListAPIView):
"""
View for testing default paginate_by_param usage
@@ -72,28 +59,32 @@ class IntegrationTestPagination(TestCase):
GET requests to paginated ListCreateAPIView should return paginated results.
"""
request = factory.get('/')
- response = self.view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data['count'], 26)
- self.assertEquals(response.data['results'], self.data[:10])
- self.assertNotEquals(response.data['next'], None)
- self.assertEquals(response.data['previous'], None)
+ # Note: Database queries are a `SELECT COUNT`, and `SELECT <fields>`
+ with self.assertNumQueries(2):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 26)
+ self.assertEqual(response.data['results'], self.data[:10])
+ self.assertNotEqual(response.data['next'], None)
+ self.assertEqual(response.data['previous'], None)
request = factory.get(response.data['next'])
- response = self.view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data['count'], 26)
- self.assertEquals(response.data['results'], self.data[10:20])
- self.assertNotEquals(response.data['next'], None)
- self.assertNotEquals(response.data['previous'], None)
+ with self.assertNumQueries(2):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 26)
+ self.assertEqual(response.data['results'], self.data[10:20])
+ self.assertNotEqual(response.data['next'], None)
+ self.assertNotEqual(response.data['previous'], None)
request = factory.get(response.data['next'])
- response = self.view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data['count'], 26)
- self.assertEquals(response.data['results'], self.data[20:])
- self.assertEquals(response.data['next'], None)
- self.assertNotEquals(response.data['previous'], None)
+ with self.assertNumQueries(2):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 26)
+ self.assertEqual(response.data['results'], self.data[20:])
+ self.assertEqual(response.data['next'], None)
+ self.assertNotEqual(response.data['previous'], None)
class IntegrationTestPaginationAndFiltering(TestCase):
@@ -111,41 +102,115 @@ class IntegrationTestPaginationAndFiltering(TestCase):
self.objects = FilterableItem.objects
self.data = [
- {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
- for obj in self.objects.all()
+ {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date.isoformat()}
+ for obj in self.objects.all()
]
- self.view = FilterFieldsRootView.as_view()
@unittest.skipUnless(django_filters, 'django-filters not installed')
- def test_get_paginated_filtered_root_view(self):
+ def test_get_django_filter_paginated_filtered_root_view(self):
"""
GET requests to paginated filtered ListCreateAPIView should return
paginated results. The next and previous links should preserve the
filtered parameters.
"""
+ class DecimalFilter(django_filters.FilterSet):
+ decimal = django_filters.NumberFilter(lookup_type='lt')
+
+ class Meta:
+ model = FilterableItem
+ fields = ['text', 'decimal', 'date']
+
+ class FilterFieldsRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ paginate_by = 10
+ filter_class = DecimalFilter
+ filter_backend = filters.DjangoFilterBackend
+
+ view = FilterFieldsRootView.as_view()
+
+ EXPECTED_NUM_QUERIES = 2
+ if django.VERSION < (1, 4):
+ # On Django 1.3 we need to use django-filter 0.5.4
+ #
+ # The filter objects there don't expose a `.count()` method,
+ # which means we only make a single query *but* it's a single
+ # query across *all* of the queryset, instead of a COUNT and then
+ # a SELECT with a LIMIT.
+ #
+ # Although this is fewer queries, it's actually a regression.
+ EXPECTED_NUM_QUERIES = 1
+
request = factory.get('/?decimal=15.20')
- response = self.view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data['count'], 15)
- self.assertEquals(response.data['results'], self.data[:10])
- self.assertNotEquals(response.data['next'], None)
- self.assertEquals(response.data['previous'], None)
+ with self.assertNumQueries(EXPECTED_NUM_QUERIES):
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 15)
+ self.assertEqual(response.data['results'], self.data[:10])
+ self.assertNotEqual(response.data['next'], None)
+ self.assertEqual(response.data['previous'], None)
request = factory.get(response.data['next'])
- response = self.view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data['count'], 15)
- self.assertEquals(response.data['results'], self.data[10:15])
- self.assertEquals(response.data['next'], None)
- self.assertNotEquals(response.data['previous'], None)
+ with self.assertNumQueries(EXPECTED_NUM_QUERIES):
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 15)
+ self.assertEqual(response.data['results'], self.data[10:15])
+ self.assertEqual(response.data['next'], None)
+ self.assertNotEqual(response.data['previous'], None)
request = factory.get(response.data['previous'])
- response = self.view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
- self.assertEquals(response.data['count'], 15)
- self.assertEquals(response.data['results'], self.data[:10])
- self.assertNotEquals(response.data['next'], None)
- self.assertEquals(response.data['previous'], None)
+ with self.assertNumQueries(EXPECTED_NUM_QUERIES):
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 15)
+ self.assertEqual(response.data['results'], self.data[:10])
+ self.assertNotEqual(response.data['next'], None)
+ self.assertEqual(response.data['previous'], None)
+
+ def test_get_basic_paginated_filtered_root_view(self):
+ """
+ Same as `test_get_django_filter_paginated_filtered_root_view`,
+ except using a custom filter backend instead of the django-filter
+ backend,
+ """
+
+ class DecimalFilterBackend(filters.BaseFilterBackend):
+ def filter_queryset(self, request, queryset, view):
+ return queryset.filter(decimal__lt=Decimal(request.GET['decimal']))
+
+ class BasicFilterFieldsRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ paginate_by = 10
+ filter_backend = DecimalFilterBackend
+
+ view = BasicFilterFieldsRootView.as_view()
+
+ request = factory.get('/?decimal=15.20')
+ with self.assertNumQueries(2):
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 15)
+ self.assertEqual(response.data['results'], self.data[:10])
+ self.assertNotEqual(response.data['next'], None)
+ self.assertEqual(response.data['previous'], None)
+
+ request = factory.get(response.data['next'])
+ with self.assertNumQueries(2):
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 15)
+ self.assertEqual(response.data['results'], self.data[10:15])
+ self.assertEqual(response.data['next'], None)
+ self.assertNotEqual(response.data['previous'], None)
+
+ request = factory.get(response.data['previous'])
+ with self.assertNumQueries(2):
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 15)
+ self.assertEqual(response.data['results'], self.data[:10])
+ self.assertNotEqual(response.data['next'], None)
+ self.assertEqual(response.data['previous'], None)
class PassOnContextPaginationSerializer(pagination.PaginationSerializer):
@@ -166,16 +231,16 @@ class UnitTestPagination(TestCase):
def test_native_pagination(self):
serializer = pagination.PaginationSerializer(self.first_page)
- self.assertEquals(serializer.data['count'], 26)
- self.assertEquals(serializer.data['next'], '?page=2')
- self.assertEquals(serializer.data['previous'], None)
- self.assertEquals(serializer.data['results'], self.objects[:10])
+ self.assertEqual(serializer.data['count'], 26)
+ self.assertEqual(serializer.data['next'], '?page=2')
+ self.assertEqual(serializer.data['previous'], None)
+ self.assertEqual(serializer.data['results'], self.objects[:10])
serializer = pagination.PaginationSerializer(self.last_page)
- self.assertEquals(serializer.data['count'], 26)
- self.assertEquals(serializer.data['next'], None)
- self.assertEquals(serializer.data['previous'], '?page=2')
- self.assertEquals(serializer.data['results'], self.objects[20:])
+ self.assertEqual(serializer.data['count'], 26)
+ self.assertEqual(serializer.data['next'], None)
+ self.assertEqual(serializer.data['previous'], '?page=2')
+ self.assertEqual(serializer.data['results'], self.objects[20:])
def test_context_available_in_result(self):
"""
@@ -184,7 +249,7 @@ class UnitTestPagination(TestCase):
serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'})
serializer.data
results = serializer.fields[serializer.results_field]
- self.assertEquals(serializer.context, results.context)
+ self.assertEqual(serializer.context, results.context)
class TestUnpaginated(TestCase):
@@ -212,7 +277,7 @@ class TestUnpaginated(TestCase):
"""
request = factory.get('/')
response = self.view(request)
- self.assertEquals(response.data, self.data)
+ self.assertEqual(response.data, self.data)
class TestCustomPaginateByParam(TestCase):
@@ -240,7 +305,7 @@ class TestCustomPaginateByParam(TestCase):
"""
request = factory.get('/')
response = self.view(request).render()
- self.assertEquals(response.data, self.data)
+ self.assertEqual(response.data, self.data)
def test_paginate_by_param(self):
"""
@@ -248,9 +313,11 @@ class TestCustomPaginateByParam(TestCase):
"""
request = factory.get('/?page_size=5')
response = self.view(request).render()
- self.assertEquals(response.data['count'], 13)
- self.assertEquals(response.data['results'], self.data[:5])
+ self.assertEqual(response.data['count'], 13)
+ self.assertEqual(response.data['results'], self.data[:5])
+
+### Tests for context in pagination serializers
class CustomField(serializers.Field):
def to_native(self, value):
@@ -262,6 +329,11 @@ class CustomField(serializers.Field):
class BasicModelSerializer(serializers.Serializer):
text = CustomField()
+ def __init__(self, *args, **kwargs):
+ super(BasicModelSerializer, self).__init__(*args, **kwargs)
+ if not 'view' in self.context:
+ raise RuntimeError("context isn't getting passed into serializer init")
+
class TestContextPassedToCustomField(TestCase):
def setUp(self):
@@ -277,5 +349,41 @@ class TestContextPassedToCustomField(TestCase):
request = factory.get('/')
response = self.view(request).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+
+### Tests for custom pagination serializers
+
+class LinksSerializer(serializers.Serializer):
+ next = pagination.NextPageField(source='*')
+ prev = pagination.PreviousPageField(source='*')
+
+class CustomPaginationSerializer(pagination.BasePaginationSerializer):
+ links = LinksSerializer(source='*') # Takes the page object as the source
+ total_results = serializers.Field(source='paginator.count')
+
+ results_field = 'objects'
+
+
+class TestCustomPaginationSerializer(TestCase):
+ def setUp(self):
+ objects = ['john', 'paul', 'george', 'ringo']
+ paginator = Paginator(objects, 2)
+ self.page = paginator.page(1)
+
+ def test_custom_pagination_serializer(self):
+ request = RequestFactory().get('/foobar')
+ serializer = CustomPaginationSerializer(
+ instance=self.page,
+ context={'request': request}
+ )
+ expected = {
+ 'links': {
+ 'next': 'http://testserver/foobar?page=2',
+ 'prev': None
+ },
+ 'total_results': 4,
+ 'objects': ['john', 'paul']
+ }
+ self.assertEqual(serializer.data, expected)
diff --git a/rest_framework/tests/parsers.py b/rest_framework/tests/parsers.py
index 8ab8a52f..539c5b44 100644
--- a/rest_framework/tests/parsers.py
+++ b/rest_framework/tests/parsers.py
@@ -1,139 +1,9 @@
-# """
-# ..
-# >>> from rest_framework.parsers import FormParser
-# >>> from django.test.client import RequestFactory
-# >>> from rest_framework.views import View
-# >>> from StringIO import StringIO
-# >>> from urllib import urlencode
-# >>> req = RequestFactory().get('/')
-# >>> some_view = View()
-# >>> some_view.request = req # Make as if this request had been dispatched
-#
-# FormParser
-# ============
-#
-# Data flatening
-# ----------------
-#
-# Here is some example data, which would eventually be sent along with a post request :
-#
-# >>> inpt = urlencode([
-# ... ('key1', 'bla1'),
-# ... ('key2', 'blo1'), ('key2', 'blo2'),
-# ... ])
-#
-# Default behaviour for :class:`parsers.FormParser`, is to return a single value for each parameter :
-#
-# >>> (data, files) = FormParser(some_view).parse(StringIO(inpt))
-# >>> data == {'key1': 'bla1', 'key2': 'blo1'}
-# True
-#
-# However, you can customize this behaviour by subclassing :class:`parsers.FormParser`, and overriding :meth:`parsers.FormParser.is_a_list` :
-#
-# >>> class MyFormParser(FormParser):
-# ...
-# ... def is_a_list(self, key, val_list):
-# ... return len(val_list) > 1
-#
-# This new parser only flattens the lists of parameters that contain a single value.
-#
-# >>> (data, files) = MyFormParser(some_view).parse(StringIO(inpt))
-# >>> data == {'key1': 'bla1', 'key2': ['blo1', 'blo2']}
-# True
-#
-# .. note:: The same functionality is available for :class:`parsers.MultiPartParser`.
-#
-# Submitting an empty list
-# --------------------------
-#
-# When submitting an empty select multiple, like this one ::
-#
-# <select multiple="multiple" name="key2"></select>
-#
-# The browsers usually strip the parameter completely. A hack to avoid this, and therefore being able to submit an empty select multiple, is to submit a value that tells the server that the list is empty ::
-#
-# <select multiple="multiple" name="key2"><option value="_empty"></select>
-#
-# :class:`parsers.FormParser` provides the server-side implementation for this hack. Considering the following posted data :
-#
-# >>> inpt = urlencode([
-# ... ('key1', 'blo1'), ('key1', '_empty'),
-# ... ('key2', '_empty'),
-# ... ])
-#
-# :class:`parsers.FormParser` strips the values ``_empty`` from all the lists.
-#
-# >>> (data, files) = MyFormParser(some_view).parse(StringIO(inpt))
-# >>> data == {'key1': 'blo1'}
-# True
-#
-# Oh ... but wait a second, the parameter ``key2`` isn't even supposed to be a list, so the parser just stripped it.
-#
-# >>> class MyFormParser(FormParser):
-# ...
-# ... def is_a_list(self, key, val_list):
-# ... return key == 'key2'
-# ...
-# >>> (data, files) = MyFormParser(some_view).parse(StringIO(inpt))
-# >>> data == {'key1': 'blo1', 'key2': []}
-# True
-#
-# Better like that. Note that you can configure something else than ``_empty`` for the empty value by setting :attr:`parsers.FormParser.EMPTY_VALUE`.
-# """
-# import httplib, mimetypes
-# from tempfile import TemporaryFile
-# from django.test import TestCase
-# from django.test.client import RequestFactory
-# from rest_framework.parsers import MultiPartParser
-# from rest_framework.views import View
-# from StringIO import StringIO
-#
-# def encode_multipart_formdata(fields, files):
-# """For testing multipart parser.
-# fields is a sequence of (name, value) elements for regular form fields.
-# files is a sequence of (name, filename, value) elements for data to be uploaded as files
-# Return (content_type, body)."""
-# BOUNDARY = '----------ThIs_Is_tHe_bouNdaRY_$'
-# CRLF = '\r\n'
-# L = []
-# for (key, value) in fields:
-# L.append('--' + BOUNDARY)
-# L.append('Content-Disposition: form-data; name="%s"' % key)
-# L.append('')
-# L.append(value)
-# for (key, filename, value) in files:
-# L.append('--' + BOUNDARY)
-# L.append('Content-Disposition: form-data; name="%s"; filename="%s"' % (key, filename))
-# L.append('Content-Type: %s' % get_content_type(filename))
-# L.append('')
-# L.append(value)
-# L.append('--' + BOUNDARY + '--')
-# L.append('')
-# body = CRLF.join(L)
-# content_type = 'multipart/form-data; boundary=%s' % BOUNDARY
-# return content_type, body
-#
-# def get_content_type(filename):
-# return mimetypes.guess_type(filename)[0] or 'application/octet-stream'
-#
-#class TestMultiPartParser(TestCase):
-# def setUp(self):
-# self.req = RequestFactory()
-# self.content_type, self.body = encode_multipart_formdata([('key1', 'val1'), ('key1', 'val2')],
-# [('file1', 'pic.jpg', 'blablabla'), ('file1', 't.txt', 'blobloblo')])
-#
-# def test_multipartparser(self):
-# """Ensure that MultiPartParser can parse multipart/form-data that contains a mix of several files and parameters."""
-# post_req = RequestFactory().post('/', self.body, content_type=self.content_type)
-# view = View()
-# view.request = post_req
-# (data, files) = MultiPartParser(view).parse(StringIO(self.body))
-# self.assertEqual(data['key1'], 'val1')
-# self.assertEqual(files['file1'].read(), 'blablabla')
-
-from StringIO import StringIO
+from __future__ import unicode_literals
+from rest_framework.compat import StringIO
from django import forms
from django.test import TestCase
+from django.utils import unittest
+from rest_framework.compat import etree
from rest_framework.parsers import FormParser
from rest_framework.parsers import XMLParser
import datetime
@@ -201,11 +71,13 @@ class TestXMLParser(TestCase):
]
}
+ @unittest.skipUnless(etree, 'defusedxml not installed')
def test_parse(self):
parser = XMLParser()
data = parser.parse(self._input)
self.assertEqual(data, self._data)
+ @unittest.skipUnless(etree, 'defusedxml not installed')
def test_complex_data_parse(self):
parser = XMLParser()
data = parser.parse(self._complex_data_input)
diff --git a/rest_framework/tests/permissions.py b/rest_framework/tests/permissions.py
new file mode 100644
index 00000000..b3993be5
--- /dev/null
+++ b/rest_framework/tests/permissions.py
@@ -0,0 +1,153 @@
+from __future__ import unicode_literals
+from django.contrib.auth.models import User, Permission
+from django.db import models
+from django.test import TestCase
+from rest_framework import generics, status, permissions, authentication, HTTP_HEADER_ENCODING
+from rest_framework.tests.utils import RequestFactory
+import base64
+import json
+
+factory = RequestFactory()
+
+
+class BasicModel(models.Model):
+ text = models.CharField(max_length=100)
+
+
+class RootView(generics.ListCreateAPIView):
+ model = BasicModel
+ authentication_classes = [authentication.BasicAuthentication]
+ permission_classes = [permissions.DjangoModelPermissions]
+
+
+class InstanceView(generics.RetrieveUpdateDestroyAPIView):
+ model = BasicModel
+ authentication_classes = [authentication.BasicAuthentication]
+ permission_classes = [permissions.DjangoModelPermissions]
+
+root_view = RootView.as_view()
+instance_view = InstanceView.as_view()
+
+
+def basic_auth_header(username, password):
+ credentials = ('%s:%s' % (username, password))
+ base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING)
+ return 'Basic %s' % base64_credentials
+
+
+class ModelPermissionsIntegrationTests(TestCase):
+ def setUp(self):
+ User.objects.create_user('disallowed', 'disallowed@example.com', 'password')
+ user = User.objects.create_user('permitted', 'permitted@example.com', 'password')
+ user.user_permissions = [
+ Permission.objects.get(codename='add_basicmodel'),
+ Permission.objects.get(codename='change_basicmodel'),
+ Permission.objects.get(codename='delete_basicmodel')
+ ]
+ user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password')
+ user.user_permissions = [
+ Permission.objects.get(codename='change_basicmodel'),
+ ]
+
+ self.permitted_credentials = basic_auth_header('permitted', 'password')
+ self.disallowed_credentials = basic_auth_header('disallowed', 'password')
+ self.updateonly_credentials = basic_auth_header('updateonly', 'password')
+
+ BasicModel(text='foo').save()
+
+ def test_has_create_permissions(self):
+ request = factory.post('/', json.dumps({'text': 'foobar'}),
+ content_type='application/json',
+ HTTP_AUTHORIZATION=self.permitted_credentials)
+ response = root_view(request, pk=1)
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+
+ def test_has_put_permissions(self):
+ request = factory.put('/1', json.dumps({'text': 'foobar'}),
+ content_type='application/json',
+ HTTP_AUTHORIZATION=self.permitted_credentials)
+ response = instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ def test_has_delete_permissions(self):
+ request = factory.delete('/1', HTTP_AUTHORIZATION=self.permitted_credentials)
+ response = instance_view(request, pk=1)
+ self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
+
+ def test_does_not_have_create_permissions(self):
+ request = factory.post('/', json.dumps({'text': 'foobar'}),
+ content_type='application/json',
+ HTTP_AUTHORIZATION=self.disallowed_credentials)
+ response = root_view(request, pk=1)
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+ def test_does_not_have_put_permissions(self):
+ request = factory.put('/1', json.dumps({'text': 'foobar'}),
+ content_type='application/json',
+ HTTP_AUTHORIZATION=self.disallowed_credentials)
+ response = instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+ def test_does_not_have_delete_permissions(self):
+ request = factory.delete('/1', HTTP_AUTHORIZATION=self.disallowed_credentials)
+ response = instance_view(request, pk=1)
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+ def test_has_put_as_create_permissions(self):
+ # User only has update permissions - should be able to update an entity.
+ request = factory.put('/1', json.dumps({'text': 'foobar'}),
+ content_type='application/json',
+ HTTP_AUTHORIZATION=self.updateonly_credentials)
+ response = instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ # But if PUTing to a new entity, permission should be denied.
+ request = factory.put('/2', json.dumps({'text': 'foobar'}),
+ content_type='application/json',
+ HTTP_AUTHORIZATION=self.updateonly_credentials)
+ response = instance_view(request, pk='2')
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+
+class OwnerModel(models.Model):
+ text = models.CharField(max_length=100)
+ owner = models.ForeignKey(User)
+
+
+class IsOwnerPermission(permissions.BasePermission):
+ def has_object_permission(self, request, view, obj):
+ return request.user == obj.owner
+
+
+class OwnerInstanceView(generics.RetrieveUpdateDestroyAPIView):
+ model = OwnerModel
+ authentication_classes = [authentication.BasicAuthentication]
+ permission_classes = [IsOwnerPermission]
+
+
+owner_instance_view = OwnerInstanceView.as_view()
+
+
+class ObjectPermissionsIntegrationTests(TestCase):
+ """
+ Integration tests for the object level permissions API.
+ """
+
+ def setUp(self):
+ User.objects.create_user('not_owner', 'not_owner@example.com', 'password')
+ user = User.objects.create_user('owner', 'owner@example.com', 'password')
+
+ self.not_owner_credentials = basic_auth_header('not_owner', 'password')
+ self.owner_credentials = basic_auth_header('owner', 'password')
+
+ OwnerModel(text='foo', owner=user).save()
+
+ def test_owner_has_delete_permissions(self):
+ request = factory.delete('/1', HTTP_AUTHORIZATION=self.owner_credentials)
+ response = owner_instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
+
+ def test_non_owner_does_not_have_delete_permissions(self):
+ request = factory.delete('/1', HTTP_AUTHORIZATION=self.not_owner_credentials)
+ response = owner_instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
diff --git a/rest_framework/tests/relations.py b/rest_framework/tests/relations.py
index 91daea8a..cbf93c65 100644
--- a/rest_framework/tests/relations.py
+++ b/rest_framework/tests/relations.py
@@ -1,7 +1,7 @@
"""
General tests for relational fields.
"""
-
+from __future__ import unicode_literals
from django.db import models
from django.test import TestCase
from rest_framework import serializers
@@ -31,3 +31,17 @@ class FieldTests(TestCase):
field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk')
self.assertRaises(serializers.ValidationError, field.from_native, '')
self.assertRaises(serializers.ValidationError, field.from_native, [])
+
+
+class TestManyRelateMixin(TestCase):
+ def test_missing_many_to_many_related_field(self):
+ '''
+ Regression test for #632
+
+ https://github.com/tomchristie/django-rest-framework/pull/632
+ '''
+ field = serializers.RelatedField(many=True, read_only=False)
+
+ into = {}
+ field.field_from_native({}, None, 'field_name', into)
+ self.assertEqual(into['field_name'], [])
diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/relations_hyperlink.py
index 57913670..b5702a48 100644
--- a/rest_framework/tests/relations_hyperlink.py
+++ b/rest_framework/tests/relations_hyperlink.py
@@ -1,7 +1,16 @@
+from __future__ import unicode_literals
from django.test import TestCase
+from django.test.client import RequestFactory
from rest_framework import serializers
from rest_framework.compat import patterns, url
-from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
+from rest_framework.tests.models import (
+ ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource,
+ NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
+)
+
+factory = RequestFactory()
+request = factory.get('/') # Just to ensure we have a request in the serializer context
+
def dummy_view(request, pk):
pass
@@ -16,8 +25,9 @@ urlpatterns = patterns('',
url(r'^nullableonetoonesource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableonetoonesource-detail'),
)
+
class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer):
- sources = serializers.ManyHyperlinkedRelatedField(view_name='manytomanysource-detail')
+ sources = serializers.HyperlinkedRelatedField(many=True, view_name='manytomanysource-detail')
class Meta:
model = ManyToManyTarget
@@ -29,7 +39,7 @@ class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer):
class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer):
- sources = serializers.ManyHyperlinkedRelatedField(view_name='foreignkeysource-detail')
+ sources = serializers.HyperlinkedRelatedField(many=True, view_name='foreignkeysource-detail')
class Meta:
model = ForeignKeyTarget
@@ -70,98 +80,98 @@ class HyperlinkedManyToManyTests(TestCase):
def test_many_to_many_retrieve(self):
queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset)
+ serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/']},
- {'url': '/manytomanysource/2/', 'name': u'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']},
- {'url': '/manytomanysource/3/', 'name': u'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']}
+ {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']},
+ {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
+ {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_reverse_many_to_many_retrieve(self):
queryset = ManyToManyTarget.objects.all()
- serializer = ManyToManyTargetSerializer(queryset)
+ serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/', '/manytomanysource/2/', '/manytomanysource/3/']},
- {'url': '/manytomanytarget/2/', 'name': u'target-2', 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']},
- {'url': '/manytomanytarget/3/', 'name': u'target-3', 'sources': ['/manytomanysource/3/']}
+ {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
+ {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
+ {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_many_to_many_update(self):
- data = {'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']}
+ data = {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
instance = ManyToManySource.objects.get(pk=1)
- serializer = ManyToManySourceSerializer(instance, data=data)
+ serializer = ManyToManySourceSerializer(instance, data=data, context={'request': request})
self.assertTrue(serializer.is_valid())
serializer.save()
- self.assertEquals(serializer.data, data)
+ self.assertEqual(serializer.data, data)
# Ensure source 1 is updated, and everything else is as expected
queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset)
+ serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']},
- {'url': '/manytomanysource/2/', 'name': u'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']},
- {'url': '/manytomanysource/3/', 'name': u'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']}
+ {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']},
+ {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
+ {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_reverse_many_to_many_update(self):
- data = {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/']}
+ data = {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']}
instance = ManyToManyTarget.objects.get(pk=1)
- serializer = ManyToManyTargetSerializer(instance, data=data)
+ serializer = ManyToManyTargetSerializer(instance, data=data, context={'request': request})
self.assertTrue(serializer.is_valid())
serializer.save()
- self.assertEquals(serializer.data, data)
+ self.assertEqual(serializer.data, data)
# Ensure target 1 is updated, and everything else is as expected
queryset = ManyToManyTarget.objects.all()
- serializer = ManyToManyTargetSerializer(queryset)
+ serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/']},
- {'url': '/manytomanytarget/2/', 'name': u'target-2', 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']},
- {'url': '/manytomanytarget/3/', 'name': u'target-3', 'sources': ['/manytomanysource/3/']}
+ {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']},
+ {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
+ {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_many_to_many_create(self):
- data = {'url': '/manytomanysource/4/', 'name': u'source-4', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/3/']}
- serializer = ManyToManySourceSerializer(data=data)
+ data = {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']}
+ serializer = ManyToManySourceSerializer(data=data, context={'request': request})
self.assertTrue(serializer.is_valid())
obj = serializer.save()
- self.assertEquals(serializer.data, data)
- self.assertEqual(obj.name, u'source-4')
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is added, and everything else is as expected
queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset)
+ serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/']},
- {'url': '/manytomanysource/2/', 'name': u'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']},
- {'url': '/manytomanysource/3/', 'name': u'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']},
- {'url': '/manytomanysource/4/', 'name': u'source-4', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/3/']}
+ {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']},
+ {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
+ {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']},
+ {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_reverse_many_to_many_create(self):
- data = {'url': '/manytomanytarget/4/', 'name': u'target-4', 'sources': ['/manytomanysource/1/', '/manytomanysource/3/']}
- serializer = ManyToManyTargetSerializer(data=data)
+ data = {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']}
+ serializer = ManyToManyTargetSerializer(data=data, context={'request': request})
self.assertTrue(serializer.is_valid())
obj = serializer.save()
- self.assertEquals(serializer.data, data)
- self.assertEqual(obj.name, u'target-4')
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-4')
# Ensure target 4 is added, and everything else is as expected
queryset = ManyToManyTarget.objects.all()
- serializer = ManyToManyTargetSerializer(queryset)
+ serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/', '/manytomanysource/2/', '/manytomanysource/3/']},
- {'url': '/manytomanytarget/2/', 'name': u'target-2', 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']},
- {'url': '/manytomanytarget/3/', 'name': u'target-3', 'sources': ['/manytomanysource/3/']},
- {'url': '/manytomanytarget/4/', 'name': u'target-4', 'sources': ['/manytomanysource/1/', '/manytomanysource/3/']}
+ {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
+ {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
+ {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']},
+ {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
class HyperlinkedForeignKeyTests(TestCase):
@@ -178,111 +188,118 @@ class HyperlinkedForeignKeyTests(TestCase):
def test_foreign_key_retrieve(self):
queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset)
+ serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'},
- {'url': '/foreignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
- {'url': '/foreignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'}
+ {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset)
+ serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/2/', '/foreignkeysource/3/']},
- {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []},
+ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},
+ {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_foreign_key_update(self):
- data = {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/2/'}
+ data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'}
instance = ForeignKeySource.objects.get(pk=1)
- serializer = ForeignKeySourceSerializer(instance, data=data)
+ serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
self.assertTrue(serializer.is_valid())
- self.assertEquals(serializer.data, data)
+ self.assertEqual(serializer.data, data)
serializer.save()
# Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset)
+ serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/2/'},
- {'url': '/foreignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
- {'url': '/foreignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'}
+ {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'},
+ {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_incorrect_type(self):
+ data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 2}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected url string, received int.']})
def test_reverse_foreign_key_update(self):
- data = {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']}
+ data = {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}
instance = ForeignKeyTarget.objects.get(pk=2)
- serializer = ForeignKeyTargetSerializer(instance, data=data)
+ serializer = ForeignKeyTargetSerializer(instance, data=data, context={'request': request})
self.assertTrue(serializer.is_valid())
# We shouldn't have saved anything to the db yet since save
# hasn't been called.
queryset = ForeignKeyTarget.objects.all()
- new_serializer = ForeignKeyTargetSerializer(queryset)
+ new_serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/2/', '/foreignkeysource/3/']},
- {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []},
- ]
- self.assertEquals(new_serializer.data, expected)
+ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},
+ {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
+ ]
+ self.assertEqual(new_serializer.data, expected)
serializer.save()
- self.assertEquals(serializer.data, data)
+ self.assertEqual(serializer.data, data)
# Ensure target 2 is update, and everything else is as expected
queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset)
+ serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/2/']},
- {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']},
+ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']},
+ {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_foreign_key_create(self):
- data = {'url': '/foreignkeysource/4/', 'name': u'source-4', 'target': '/foreignkeytarget/2/'}
- serializer = ForeignKeySourceSerializer(data=data)
+ data = {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'}
+ serializer = ForeignKeySourceSerializer(data=data, context={'request': request})
self.assertTrue(serializer.is_valid())
obj = serializer.save()
- self.assertEquals(serializer.data, data)
- self.assertEqual(obj.name, u'source-4')
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
# Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset)
+ serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'},
- {'url': '/foreignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
- {'url': '/foreignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'},
- {'url': '/foreignkeysource/4/', 'name': u'source-4', 'target': '/foreignkeytarget/2/'},
+ {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_reverse_foreign_key_create(self):
- data = {'url': '/foreignkeytarget/3/', 'name': u'target-3', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']}
- serializer = ForeignKeyTargetSerializer(data=data)
+ data = {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}
+ serializer = ForeignKeyTargetSerializer(data=data, context={'request': request})
self.assertTrue(serializer.is_valid())
obj = serializer.save()
- self.assertEquals(serializer.data, data)
- self.assertEqual(obj.name, u'target-3')
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-3')
# Ensure target 4 is added, and everything else is as expected
queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset)
+ serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/2/']},
- {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []},
- {'url': '/foreignkeytarget/3/', 'name': u'target-3', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']},
+ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']},
+ {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
+ {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_foreign_key_update_with_invalid_null(self):
- data = {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': None}
+ data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': None}
instance = ForeignKeySource.objects.get(pk=1)
- serializer = ForeignKeySourceSerializer(instance, data=data)
+ serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
self.assertFalse(serializer.is_valid())
- self.assertEquals(serializer.errors, {'target': [u'Value may not be null']})
+ self.assertEqual(serializer.errors, {'target': ['This field is required.']})
class HyperlinkedNullableForeignKeyTests(TestCase):
@@ -299,118 +316,118 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset)
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'},
- {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
- {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None},
+ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_foreign_key_create_with_valid_null(self):
- data = {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None}
- serializer = NullableForeignKeySourceSerializer(data=data)
+ data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request})
self.assertTrue(serializer.is_valid())
obj = serializer.save()
- self.assertEquals(serializer.data, data)
- self.assertEqual(obj.name, u'source-4')
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset)
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'},
- {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
- {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None},
- {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None}
+ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
+ {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_foreign_key_create_with_valid_emptystring(self):
"""
The emptystring should be interpreted as null in the context
of relationships.
"""
- data = {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': ''}
- expected_data = {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None}
- serializer = NullableForeignKeySourceSerializer(data=data)
+ data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': ''}
+ expected_data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request})
self.assertTrue(serializer.is_valid())
obj = serializer.save()
- self.assertEquals(serializer.data, expected_data)
- self.assertEqual(obj.name, u'source-4')
+ self.assertEqual(serializer.data, expected_data)
+ self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset)
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'},
- {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
- {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None},
- {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None}
+ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
+ {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_foreign_key_update_with_valid_null(self):
- data = {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None}
+ data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1)
- serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request})
self.assertTrue(serializer.is_valid())
- self.assertEquals(serializer.data, data)
+ self.assertEqual(serializer.data, data)
serializer.save()
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset)
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None},
- {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
- {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None},
+ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None},
+ {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_foreign_key_update_with_valid_emptystring(self):
"""
The emptystring should be interpreted as null in the context
of relationships.
"""
- data = {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': ''}
- expected_data = {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None}
+ data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': ''}
+ expected_data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1)
- serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request})
self.assertTrue(serializer.is_valid())
- self.assertEquals(serializer.data, expected_data)
+ self.assertEqual(serializer.data, expected_data)
serializer.save()
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset)
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None},
- {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'},
- {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None},
+ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None},
+ {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
# reverse foreign keys MUST be read_only
# In the general case they do not provide .remove() or .clear()
# and cannot be arbitrarily set.
# def test_reverse_foreign_key_update(self):
- # data = {'id': 1, 'name': u'target-1', 'sources': [1]}
+ # data = {'id': 1, 'name': 'target-1', 'sources': [1]}
# instance = ForeignKeyTarget.objects.get(pk=1)
# serializer = ForeignKeyTargetSerializer(instance, data=data)
# self.assertTrue(serializer.is_valid())
- # self.assertEquals(serializer.data, data)
+ # self.assertEqual(serializer.data, data)
# serializer.save()
# # Ensure target 1 is updated, and everything else is as expected
# queryset = ForeignKeyTarget.objects.all()
- # serializer = ForeignKeyTargetSerializer(queryset)
+ # serializer = ForeignKeyTargetSerializer(queryset, many=True)
# expected = [
- # {'id': 1, 'name': u'target-1', 'sources': [1]},
- # {'id': 2, 'name': u'target-2', 'sources': []},
+ # {'id': 1, 'name': 'target-1', 'sources': [1]},
+ # {'id': 2, 'name': 'target-2', 'sources': []},
# ]
- # self.assertEquals(serializer.data, expected)
+ # self.assertEqual(serializer.data, expected)
class HyperlinkedNullableOneToOneTests(TestCase):
@@ -426,9 +443,9 @@ class HyperlinkedNullableOneToOneTests(TestCase):
def test_reverse_foreign_key_retrieve_with_null(self):
queryset = OneToOneTarget.objects.all()
- serializer = NullableOneToOneTargetSerializer(queryset)
+ serializer = NullableOneToOneTargetSerializer(queryset, many=True, context={'request': request})
expected = [
- {'url': '/onetoonetarget/1/', 'name': u'target-1', 'nullable_source': '/nullableonetoonesource/1/'},
- {'url': '/onetoonetarget/2/', 'name': u'target-2', 'nullable_source': None},
+ {'url': 'http://testserver/onetoonetarget/1/', 'name': 'target-1', 'nullable_source': 'http://testserver/nullableonetoonesource/1/'},
+ {'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py
index 0e129fae..a125ba65 100644
--- a/rest_framework/tests/relations_nested.py
+++ b/rest_framework/tests/relations_nested.py
@@ -1,3 +1,4 @@
+from __future__ import unicode_literals
from django.test import TestCase
from rest_framework import serializers
from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
@@ -15,7 +16,7 @@ class FlatForeignKeySourceSerializer(serializers.ModelSerializer):
class ForeignKeyTargetSerializer(serializers.ModelSerializer):
- sources = FlatForeignKeySourceSerializer()
+ sources = FlatForeignKeySourceSerializer(many=True)
class Meta:
model = ForeignKeyTarget
@@ -51,27 +52,27 @@ class ReverseForeignKeyTests(TestCase):
def test_foreign_key_retrieve(self):
queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset)
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'source-1', 'target': {'id': 1, 'name': u'target-1'}},
- {'id': 2, 'name': u'source-2', 'target': {'id': 1, 'name': u'target-1'}},
- {'id': 3, 'name': u'source-3', 'target': {'id': 1, 'name': u'target-1'}},
+ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
+ {'id': 2, 'name': 'source-2', 'target': {'id': 1, 'name': 'target-1'}},
+ {'id': 3, 'name': 'source-3', 'target': {'id': 1, 'name': 'target-1'}},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset)
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'target-1', 'sources': [
- {'id': 1, 'name': u'source-1', 'target': 1},
- {'id': 2, 'name': u'source-2', 'target': 1},
- {'id': 3, 'name': u'source-3', 'target': 1},
+ {'id': 1, 'name': 'target-1', 'sources': [
+ {'id': 1, 'name': 'source-1', 'target': 1},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': 1},
]},
- {'id': 2, 'name': u'target-2', 'sources': [
+ {'id': 2, 'name': 'target-2', 'sources': [
]}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
class NestedNullableForeignKeyTests(TestCase):
@@ -86,13 +87,13 @@ class NestedNullableForeignKeyTests(TestCase):
def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset)
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'source-1', 'target': {'id': 1, 'name': u'target-1'}},
- {'id': 2, 'name': u'source-2', 'target': {'id': 1, 'name': u'target-1'}},
- {'id': 3, 'name': u'source-3', 'target': None},
+ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
+ {'id': 2, 'name': 'source-2', 'target': {'id': 1, 'name': 'target-1'}},
+ {'id': 3, 'name': 'source-3', 'target': None},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
class NestedNullableOneToOneTests(TestCase):
@@ -106,9 +107,9 @@ class NestedNullableOneToOneTests(TestCase):
def test_reverse_foreign_key_retrieve_with_null(self):
queryset = OneToOneTarget.objects.all()
- serializer = NullableOneToOneTargetSerializer(queryset)
+ serializer = NullableOneToOneTargetSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'target-1', 'nullable_source': {'id': 1, 'name': u'source-1', 'target': 1}},
- {'id': 2, 'name': u'target-2', 'nullable_source': None},
+ {'id': 1, 'name': 'target-1', 'nullable_source': {'id': 1, 'name': 'source-1', 'target': 1}},
+ {'id': 2, 'name': 'target-2', 'nullable_source': None},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/relations_pk.py
index 54835860..f08e1808 100644
--- a/rest_framework/tests/relations_pk.py
+++ b/rest_framework/tests/relations_pk.py
@@ -1,10 +1,12 @@
+from __future__ import unicode_literals
from django.test import TestCase
from rest_framework import serializers
from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
+from rest_framework.compat import six
class ManyToManyTargetSerializer(serializers.ModelSerializer):
- sources = serializers.ManyPrimaryKeyRelatedField()
+ sources = serializers.PrimaryKeyRelatedField(many=True)
class Meta:
model = ManyToManyTarget
@@ -16,7 +18,7 @@ class ManyToManySourceSerializer(serializers.ModelSerializer):
class ForeignKeyTargetSerializer(serializers.ModelSerializer):
- sources = serializers.ManyPrimaryKeyRelatedField()
+ sources = serializers.PrimaryKeyRelatedField(many=True)
class Meta:
model = ForeignKeyTarget
@@ -54,97 +56,97 @@ class PKManyToManyTests(TestCase):
def test_many_to_many_retrieve(self):
queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset)
+ serializer = ManyToManySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'source-1', 'targets': [1]},
- {'id': 2, 'name': u'source-2', 'targets': [1, 2]},
- {'id': 3, 'name': u'source-3', 'targets': [1, 2, 3]}
+ {'id': 1, 'name': 'source-1', 'targets': [1]},
+ {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
+ {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_reverse_many_to_many_retrieve(self):
queryset = ManyToManyTarget.objects.all()
- serializer = ManyToManyTargetSerializer(queryset)
+ serializer = ManyToManyTargetSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]},
- {'id': 2, 'name': u'target-2', 'sources': [2, 3]},
- {'id': 3, 'name': u'target-3', 'sources': [3]}
+ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
+ {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
+ {'id': 3, 'name': 'target-3', 'sources': [3]}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_many_to_many_update(self):
- data = {'id': 1, 'name': u'source-1', 'targets': [1, 2, 3]}
+ data = {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]}
instance = ManyToManySource.objects.get(pk=1)
serializer = ManyToManySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid())
serializer.save()
- self.assertEquals(serializer.data, data)
+ self.assertEqual(serializer.data, data)
# Ensure source 1 is updated, and everything else is as expected
queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset)
+ serializer = ManyToManySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'source-1', 'targets': [1, 2, 3]},
- {'id': 2, 'name': u'source-2', 'targets': [1, 2]},
- {'id': 3, 'name': u'source-3', 'targets': [1, 2, 3]}
+ {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]},
+ {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
+ {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_reverse_many_to_many_update(self):
- data = {'id': 1, 'name': u'target-1', 'sources': [1]}
+ data = {'id': 1, 'name': 'target-1', 'sources': [1]}
instance = ManyToManyTarget.objects.get(pk=1)
serializer = ManyToManyTargetSerializer(instance, data=data)
self.assertTrue(serializer.is_valid())
serializer.save()
- self.assertEquals(serializer.data, data)
+ self.assertEqual(serializer.data, data)
# Ensure target 1 is updated, and everything else is as expected
queryset = ManyToManyTarget.objects.all()
- serializer = ManyToManyTargetSerializer(queryset)
+ serializer = ManyToManyTargetSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'target-1', 'sources': [1]},
- {'id': 2, 'name': u'target-2', 'sources': [2, 3]},
- {'id': 3, 'name': u'target-3', 'sources': [3]}
+ {'id': 1, 'name': 'target-1', 'sources': [1]},
+ {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
+ {'id': 3, 'name': 'target-3', 'sources': [3]}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_many_to_many_create(self):
- data = {'id': 4, 'name': u'source-4', 'targets': [1, 3]}
+ data = {'id': 4, 'name': 'source-4', 'targets': [1, 3]}
serializer = ManyToManySourceSerializer(data=data)
self.assertTrue(serializer.is_valid())
obj = serializer.save()
- self.assertEquals(serializer.data, data)
- self.assertEqual(obj.name, u'source-4')
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is added, and everything else is as expected
queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset)
+ serializer = ManyToManySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'source-1', 'targets': [1]},
- {'id': 2, 'name': u'source-2', 'targets': [1, 2]},
- {'id': 3, 'name': u'source-3', 'targets': [1, 2, 3]},
- {'id': 4, 'name': u'source-4', 'targets': [1, 3]},
+ {'id': 1, 'name': 'source-1', 'targets': [1]},
+ {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
+ {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]},
+ {'id': 4, 'name': 'source-4', 'targets': [1, 3]},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_reverse_many_to_many_create(self):
- data = {'id': 4, 'name': u'target-4', 'sources': [1, 3]}
+ data = {'id': 4, 'name': 'target-4', 'sources': [1, 3]}
serializer = ManyToManyTargetSerializer(data=data)
self.assertTrue(serializer.is_valid())
obj = serializer.save()
- self.assertEquals(serializer.data, data)
- self.assertEqual(obj.name, u'target-4')
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-4')
# Ensure target 4 is added, and everything else is as expected
queryset = ManyToManyTarget.objects.all()
- serializer = ManyToManyTargetSerializer(queryset)
+ serializer = ManyToManyTargetSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]},
- {'id': 2, 'name': u'target-2', 'sources': [2, 3]},
- {'id': 3, 'name': u'target-3', 'sources': [3]},
- {'id': 4, 'name': u'target-4', 'sources': [1, 3]}
+ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
+ {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
+ {'id': 3, 'name': 'target-3', 'sources': [3]},
+ {'id': 4, 'name': 'target-4', 'sources': [1, 3]}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
class PKForeignKeyTests(TestCase):
@@ -159,111 +161,118 @@ class PKForeignKeyTests(TestCase):
def test_foreign_key_retrieve(self):
queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset)
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'source-1', 'target': 1},
- {'id': 2, 'name': u'source-2', 'target': 1},
- {'id': 3, 'name': u'source-3', 'target': 1}
+ {'id': 1, 'name': 'source-1', 'target': 1},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': 1}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset)
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]},
- {'id': 2, 'name': u'target-2', 'sources': []},
+ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
+ {'id': 2, 'name': 'target-2', 'sources': []},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_foreign_key_update(self):
- data = {'id': 1, 'name': u'source-1', 'target': 2}
+ data = {'id': 1, 'name': 'source-1', 'target': 2}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid())
- self.assertEquals(serializer.data, data)
+ self.assertEqual(serializer.data, data)
serializer.save()
# Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset)
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'source-1', 'target': 2},
- {'id': 2, 'name': u'source-2', 'target': 1},
- {'id': 3, 'name': u'source-3', 'target': 1}
+ {'id': 1, 'name': 'source-1', 'target': 2},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': 1}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_incorrect_type(self):
+ data = {'id': 1, 'name': 'source-1', 'target': 'foo'}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected pk value, received %s.' % six.text_type.__name__]})
def test_reverse_foreign_key_update(self):
- data = {'id': 2, 'name': u'target-2', 'sources': [1, 3]}
+ data = {'id': 2, 'name': 'target-2', 'sources': [1, 3]}
instance = ForeignKeyTarget.objects.get(pk=2)
serializer = ForeignKeyTargetSerializer(instance, data=data)
self.assertTrue(serializer.is_valid())
# We shouldn't have saved anything to the db yet since save
# hasn't been called.
queryset = ForeignKeyTarget.objects.all()
- new_serializer = ForeignKeyTargetSerializer(queryset)
+ new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]},
- {'id': 2, 'name': u'target-2', 'sources': []},
- ]
- self.assertEquals(new_serializer.data, expected)
+ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
+ {'id': 2, 'name': 'target-2', 'sources': []},
+ ]
+ self.assertEqual(new_serializer.data, expected)
serializer.save()
- self.assertEquals(serializer.data, data)
+ self.assertEqual(serializer.data, data)
# Ensure target 2 is update, and everything else is as expected
queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset)
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'target-1', 'sources': [2]},
- {'id': 2, 'name': u'target-2', 'sources': [1, 3]},
+ {'id': 1, 'name': 'target-1', 'sources': [2]},
+ {'id': 2, 'name': 'target-2', 'sources': [1, 3]},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_foreign_key_create(self):
- data = {'id': 4, 'name': u'source-4', 'target': 2}
+ data = {'id': 4, 'name': 'source-4', 'target': 2}
serializer = ForeignKeySourceSerializer(data=data)
self.assertTrue(serializer.is_valid())
obj = serializer.save()
- self.assertEquals(serializer.data, data)
- self.assertEqual(obj.name, u'source-4')
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is added, and everything else is as expected
queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset)
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'source-1', 'target': 1},
- {'id': 2, 'name': u'source-2', 'target': 1},
- {'id': 3, 'name': u'source-3', 'target': 1},
- {'id': 4, 'name': u'source-4', 'target': 2},
+ {'id': 1, 'name': 'source-1', 'target': 1},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': 1},
+ {'id': 4, 'name': 'source-4', 'target': 2},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_reverse_foreign_key_create(self):
- data = {'id': 3, 'name': u'target-3', 'sources': [1, 3]}
+ data = {'id': 3, 'name': 'target-3', 'sources': [1, 3]}
serializer = ForeignKeyTargetSerializer(data=data)
self.assertTrue(serializer.is_valid())
obj = serializer.save()
- self.assertEquals(serializer.data, data)
- self.assertEqual(obj.name, u'target-3')
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-3')
# Ensure target 3 is added, and everything else is as expected
queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset)
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'target-1', 'sources': [2]},
- {'id': 2, 'name': u'target-2', 'sources': []},
- {'id': 3, 'name': u'target-3', 'sources': [1, 3]},
+ {'id': 1, 'name': 'target-1', 'sources': [2]},
+ {'id': 2, 'name': 'target-2', 'sources': []},
+ {'id': 3, 'name': 'target-3', 'sources': [1, 3]},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_foreign_key_update_with_invalid_null(self):
- data = {'id': 1, 'name': u'source-1', 'target': None}
+ data = {'id': 1, 'name': 'source-1', 'target': None}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertFalse(serializer.is_valid())
- self.assertEquals(serializer.errors, {'target': [u'Value may not be null']})
+ self.assertEqual(serializer.errors, {'target': ['This field is required.']})
class PKNullableForeignKeyTests(TestCase):
@@ -278,118 +287,118 @@ class PKNullableForeignKeyTests(TestCase):
def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset)
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'source-1', 'target': 1},
- {'id': 2, 'name': u'source-2', 'target': 1},
- {'id': 3, 'name': u'source-3', 'target': None},
+ {'id': 1, 'name': 'source-1', 'target': 1},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': None},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_foreign_key_create_with_valid_null(self):
- data = {'id': 4, 'name': u'source-4', 'target': None}
+ data = {'id': 4, 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data)
self.assertTrue(serializer.is_valid())
obj = serializer.save()
- self.assertEquals(serializer.data, data)
- self.assertEqual(obj.name, u'source-4')
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset)
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'source-1', 'target': 1},
- {'id': 2, 'name': u'source-2', 'target': 1},
- {'id': 3, 'name': u'source-3', 'target': None},
- {'id': 4, 'name': u'source-4', 'target': None}
+ {'id': 1, 'name': 'source-1', 'target': 1},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': None},
+ {'id': 4, 'name': 'source-4', 'target': None}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_foreign_key_create_with_valid_emptystring(self):
"""
The emptystring should be interpreted as null in the context
of relationships.
"""
- data = {'id': 4, 'name': u'source-4', 'target': ''}
- expected_data = {'id': 4, 'name': u'source-4', 'target': None}
+ data = {'id': 4, 'name': 'source-4', 'target': ''}
+ expected_data = {'id': 4, 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data)
self.assertTrue(serializer.is_valid())
obj = serializer.save()
- self.assertEquals(serializer.data, expected_data)
- self.assertEqual(obj.name, u'source-4')
+ self.assertEqual(serializer.data, expected_data)
+ self.assertEqual(obj.name, 'source-4')
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset)
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'source-1', 'target': 1},
- {'id': 2, 'name': u'source-2', 'target': 1},
- {'id': 3, 'name': u'source-3', 'target': None},
- {'id': 4, 'name': u'source-4', 'target': None}
+ {'id': 1, 'name': 'source-1', 'target': 1},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': None},
+ {'id': 4, 'name': 'source-4', 'target': None}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_foreign_key_update_with_valid_null(self):
- data = {'id': 1, 'name': u'source-1', 'target': None}
+ data = {'id': 1, 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid())
- self.assertEquals(serializer.data, data)
+ self.assertEqual(serializer.data, data)
serializer.save()
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset)
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'source-1', 'target': None},
- {'id': 2, 'name': u'source-2', 'target': 1},
- {'id': 3, 'name': u'source-3', 'target': None}
+ {'id': 1, 'name': 'source-1', 'target': None},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': None}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_foreign_key_update_with_valid_emptystring(self):
"""
The emptystring should be interpreted as null in the context
of relationships.
"""
- data = {'id': 1, 'name': u'source-1', 'target': ''}
- expected_data = {'id': 1, 'name': u'source-1', 'target': None}
+ data = {'id': 1, 'name': 'source-1', 'target': ''}
+ expected_data = {'id': 1, 'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid())
- self.assertEquals(serializer.data, expected_data)
+ self.assertEqual(serializer.data, expected_data)
serializer.save()
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset)
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'source-1', 'target': None},
- {'id': 2, 'name': u'source-2', 'target': 1},
- {'id': 3, 'name': u'source-3', 'target': None}
+ {'id': 1, 'name': 'source-1', 'target': None},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': None}
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
# reverse foreign keys MUST be read_only
# In the general case they do not provide .remove() or .clear()
# and cannot be arbitrarily set.
# def test_reverse_foreign_key_update(self):
- # data = {'id': 1, 'name': u'target-1', 'sources': [1]}
+ # data = {'id': 1, 'name': 'target-1', 'sources': [1]}
# instance = ForeignKeyTarget.objects.get(pk=1)
# serializer = ForeignKeyTargetSerializer(instance, data=data)
# self.assertTrue(serializer.is_valid())
- # self.assertEquals(serializer.data, data)
+ # self.assertEqual(serializer.data, data)
# serializer.save()
# # Ensure target 1 is updated, and everything else is as expected
# queryset = ForeignKeyTarget.objects.all()
- # serializer = ForeignKeyTargetSerializer(queryset)
+ # serializer = ForeignKeyTargetSerializer(queryset, many=True)
# expected = [
- # {'id': 1, 'name': u'target-1', 'sources': [1]},
- # {'id': 2, 'name': u'target-2', 'sources': []},
+ # {'id': 1, 'name': 'target-1', 'sources': [1]},
+ # {'id': 2, 'name': 'target-2', 'sources': []},
# ]
- # self.assertEquals(serializer.data, expected)
+ # self.assertEqual(serializer.data, expected)
class PKNullableOneToOneTests(TestCase):
@@ -398,14 +407,14 @@ class PKNullableOneToOneTests(TestCase):
target.save()
new_target = OneToOneTarget(name='target-2')
new_target.save()
- source = NullableOneToOneSource(name='source-1', target=target)
+ source = NullableOneToOneSource(name='source-1', target=new_target)
source.save()
def test_reverse_foreign_key_retrieve_with_null(self):
queryset = OneToOneTarget.objects.all()
- serializer = NullableOneToOneTargetSerializer(queryset)
+ serializer = NullableOneToOneTargetSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': u'target-1', 'nullable_source': 1},
- {'id': 2, 'name': u'target-2', 'nullable_source': None},
+ {'id': 1, 'name': 'target-1', 'nullable_source': None},
+ {'id': 2, 'name': 'target-2', 'nullable_source': 1},
]
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
diff --git a/rest_framework/tests/relations_slug.py b/rest_framework/tests/relations_slug.py
new file mode 100644
index 00000000..435c821c
--- /dev/null
+++ b/rest_framework/tests/relations_slug.py
@@ -0,0 +1,257 @@
+from django.test import TestCase
+from rest_framework import serializers
+from rest_framework.tests.models import NullableForeignKeySource, ForeignKeySource, ForeignKeyTarget
+
+
+class ForeignKeyTargetSerializer(serializers.ModelSerializer):
+ sources = serializers.SlugRelatedField(many=True, slug_field='name')
+
+ class Meta:
+ model = ForeignKeyTarget
+
+
+class ForeignKeySourceSerializer(serializers.ModelSerializer):
+ target = serializers.SlugRelatedField(slug_field='name')
+
+ class Meta:
+ model = ForeignKeySource
+
+
+class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
+ target = serializers.SlugRelatedField(slug_field='name', required=False)
+
+ class Meta:
+ model = NullableForeignKeySource
+
+
+# TODO: M2M Tests, FKTests (Non-nullable), One2One
+class SlugForeignKeyTests(TestCase):
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ new_target = ForeignKeyTarget(name='target-2')
+ new_target.save()
+ for idx in range(1, 4):
+ source = ForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve(self):
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 'target-1'},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': 'target-1'}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_foreign_key_retrieve(self):
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
+ {'id': 2, 'name': 'target-2', 'sources': []},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update(self):
+ data = {'id': 1, 'name': 'source-1', 'target': 'target-2'}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 'target-2'},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': 'target-1'}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_incorrect_type(self):
+ data = {'id': 1, 'name': 'source-1', 'target': 123}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'target': ['Object with name=123 does not exist.']})
+
+ def test_reverse_foreign_key_update(self):
+ data = {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']}
+ instance = ForeignKeyTarget.objects.get(pk=2)
+ serializer = ForeignKeyTargetSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ # We shouldn't have saved anything to the db yet since save
+ # hasn't been called.
+ queryset = ForeignKeyTarget.objects.all()
+ new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
+ {'id': 2, 'name': 'target-2', 'sources': []},
+ ]
+ self.assertEqual(new_serializer.data, expected)
+
+ serializer.save()
+ self.assertEqual(serializer.data, data)
+
+ # Ensure target 2 is update, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': ['source-2']},
+ {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_create(self):
+ data = {'id': 4, 'name': 'source-4', 'target': 'target-2'}
+ serializer = ForeignKeySourceSerializer(data=data)
+ serializer.is_valid()
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is added, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 'target-1'},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': 'target-1'},
+ {'id': 4, 'name': 'source-4', 'target': 'target-2'},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_foreign_key_create(self):
+ data = {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']}
+ serializer = ForeignKeyTargetSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-3')
+
+ # Ensure target 3 is added, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': ['source-2']},
+ {'id': 2, 'name': 'target-2', 'sources': []},
+ {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_with_invalid_null(self):
+ data = {'id': 1, 'name': 'source-1', 'target': None}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'target': ['This field is required.']})
+
+
+class SlugNullableForeignKeyTests(TestCase):
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ for idx in range(1, 4):
+ if idx == 3:
+ target = None
+ source = NullableForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve_with_null(self):
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 'target-1'},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': None},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_create_with_valid_null(self):
+ data = {'id': 4, 'name': 'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is created, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 'target-1'},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': None},
+ {'id': 4, 'name': 'source-4', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_create_with_valid_emptystring(self):
+ """
+ The emptystring should be interpreted as null in the context
+ of relationships.
+ """
+ data = {'id': 4, 'name': 'source-4', 'target': ''}
+ expected_data = {'id': 4, 'name': 'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, expected_data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is created, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 'target-1'},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': None},
+ {'id': 4, 'name': 'source-4', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_with_valid_null(self):
+ data = {'id': 1, 'name': 'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': None},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_with_valid_emptystring(self):
+ """
+ The emptystring should be interpreted as null in the context
+ of relationships.
+ """
+ data = {'id': 1, 'name': 'source-1', 'target': ''}
+ expected_data = {'id': 1, 'name': 'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, expected_data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': None},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
diff --git a/rest_framework/tests/renderers.py b/rest_framework/tests/renderers.py
index c1b4e624..40bac9cb 100644
--- a/rest_framework/tests/renderers.py
+++ b/rest_framework/tests/renderers.py
@@ -1,29 +1,28 @@
-import pickle
-import re
-
+from decimal import Decimal
from django.core.cache import cache
from django.test import TestCase
from django.test.client import RequestFactory
-
+from django.utils import unittest
from rest_framework import status, permissions
-from rest_framework.compat import yaml, patterns, url, include
+from rest_framework.compat import yaml, etree, patterns, url, include
from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \
XMLRenderer, JSONPRenderer, BrowsableAPIRenderer
from rest_framework.parsers import YAMLParser, XMLParser
from rest_framework.settings import api_settings
-
-from StringIO import StringIO
+from rest_framework.compat import StringIO
+from rest_framework.compat import six
import datetime
-from decimal import Decimal
+import pickle
+import re
DUMMYSTATUS = status.HTTP_200_OK
DUMMYCONTENT = 'dummycontent'
-RENDERER_A_SERIALIZER = lambda x: 'Renderer A: %s' % x
-RENDERER_B_SERIALIZER = lambda x: 'Renderer B: %s' % x
+RENDERER_A_SERIALIZER = lambda x: ('Renderer A: %s' % x).encode('ascii')
+RENDERER_B_SERIALIZER = lambda x: ('Renderer B: %s' % x).encode('ascii')
expected_results = [
@@ -35,7 +34,7 @@ class BasicRendererTests(TestCase):
def test_expected_results(self):
for value, renderer_cls, expected in expected_results:
output = renderer_cls().render(value)
- self.assertEquals(output, expected)
+ self.assertEqual(output, expected)
class RendererA(BaseRenderer):
@@ -94,7 +93,7 @@ urlpatterns = patterns('',
class POSTDeniedPermission(permissions.BasePermission):
- def has_permission(self, request, view, obj=None):
+ def has_permission(self, request, view):
return request.method != 'POST'
@@ -111,6 +110,9 @@ class POSTDeniedView(APIView):
def put(self, request):
return Response()
+ def patch(self, request):
+ return Response()
+
class DocumentingRendererTests(TestCase):
def test_only_permitted_forms_are_displayed(self):
@@ -119,6 +121,7 @@ class DocumentingRendererTests(TestCase):
response = view(request).render()
self.assertNotContains(response, '>POST<')
self.assertContains(response, '>PUT<')
+ self.assertContains(response, '>PATCH<')
class RendererEndToEndTests(TestCase):
@@ -131,39 +134,39 @@ class RendererEndToEndTests(TestCase):
def test_default_renderer_serializes_content(self):
"""If the Accept header is not set the default renderer should serialize the response."""
resp = self.client.get('/')
- self.assertEquals(resp['Content-Type'], RendererA.media_type)
- self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
- self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type)
+ self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_head_method_serializes_no_content(self):
"""No response must be included in HEAD requests."""
resp = self.client.head('/')
- self.assertEquals(resp.status_code, DUMMYSTATUS)
- self.assertEquals(resp['Content-Type'], RendererA.media_type)
- self.assertEquals(resp.content, '')
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type)
+ self.assertEqual(resp.content, six.b(''))
def test_default_renderer_serializes_content_on_accept_any(self):
"""If the Accept header is set to */* the default renderer should serialize the response."""
resp = self.client.get('/', HTTP_ACCEPT='*/*')
- self.assertEquals(resp['Content-Type'], RendererA.media_type)
- self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
- self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type)
+ self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for the default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
- self.assertEquals(resp['Content-Type'], RendererA.media_type)
- self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
- self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type)
+ self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_non_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for a non-default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
- self.assertEquals(resp['Content-Type'], RendererB.media_type)
- self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_on_accept_query(self):
"""The '_accept' query string should behave in the same way as the Accept header."""
@@ -172,14 +175,14 @@ class RendererEndToEndTests(TestCase):
RendererB.media_type
)
resp = self.client.get('/' + param)
- self.assertEquals(resp['Content-Type'], RendererB.media_type)
- self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_unsatisfiable_accept_header_on_request_returns_406_status(self):
"""If the Accept header is unsatisfiable we should return a 406 Not Acceptable response."""
resp = self.client.get('/', HTTP_ACCEPT='foo/bar')
- self.assertEquals(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE)
+ self.assertEqual(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE)
def test_specified_renderer_serializes_content_on_format_query(self):
"""If a 'format' query is specified, the renderer with the matching
@@ -189,17 +192,17 @@ class RendererEndToEndTests(TestCase):
RendererB.format
)
resp = self.client.get('/' + param)
- self.assertEquals(resp['Content-Type'], RendererB.media_type)
- self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_on_format_kwargs(self):
"""If a 'format' keyword arg is specified, the renderer with the matching
format attribute should serialize the response."""
resp = self.client.get('/something.formatb')
- self.assertEquals(resp['Content-Type'], RendererB.media_type)
- self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
"""If both a 'format' query and a matching Accept header specified,
@@ -210,9 +213,9 @@ class RendererEndToEndTests(TestCase):
)
resp = self.client.get('/' + param,
HTTP_ACCEPT=RendererB.media_type)
- self.assertEquals(resp['Content-Type'], RendererB.media_type)
- self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
_flat_repr = '{"foo": ["bar", "baz"]}'
@@ -240,7 +243,7 @@ class JSONRendererTests(TestCase):
renderer = JSONRenderer()
content = renderer.render(obj, 'application/json')
# Fix failing test case which depends on version of JSON library.
- self.assertEquals(content, _flat_repr)
+ self.assertEqual(content, _flat_repr)
def test_with_content_type_args(self):
"""
@@ -249,7 +252,7 @@ class JSONRendererTests(TestCase):
obj = {'foo': ['bar', 'baz']}
renderer = JSONRenderer()
content = renderer.render(obj, 'application/json; indent=2')
- self.assertEquals(strip_trailing_whitespace(content), _indented_repr)
+ self.assertEqual(strip_trailing_whitespace(content), _indented_repr)
class JSONPRendererTests(TestCase):
@@ -265,9 +268,10 @@ class JSONPRendererTests(TestCase):
"""
resp = self.client.get('/jsonp/jsonrenderer',
HTTP_ACCEPT='application/javascript')
- self.assertEquals(resp.status_code, 200)
- self.assertEquals(resp['Content-Type'], 'application/javascript')
- self.assertEquals(resp.content, 'callback(%s);' % _flat_repr)
+ self.assertEqual(resp.status_code, status.HTTP_200_OK)
+ self.assertEqual(resp['Content-Type'], 'application/javascript')
+ self.assertEqual(resp.content,
+ ('callback(%s);' % _flat_repr).encode('ascii'))
def test_without_callback_without_json_renderer(self):
"""
@@ -275,9 +279,10 @@ class JSONPRendererTests(TestCase):
"""
resp = self.client.get('/jsonp/nojsonrenderer',
HTTP_ACCEPT='application/javascript')
- self.assertEquals(resp.status_code, 200)
- self.assertEquals(resp['Content-Type'], 'application/javascript')
- self.assertEquals(resp.content, 'callback(%s);' % _flat_repr)
+ self.assertEqual(resp.status_code, status.HTTP_200_OK)
+ self.assertEqual(resp['Content-Type'], 'application/javascript')
+ self.assertEqual(resp.content,
+ ('callback(%s);' % _flat_repr).encode('ascii'))
def test_with_callback(self):
"""
@@ -286,9 +291,10 @@ class JSONPRendererTests(TestCase):
callback_func = 'myjsonpcallback'
resp = self.client.get('/jsonp/nojsonrenderer?callback=' + callback_func,
HTTP_ACCEPT='application/javascript')
- self.assertEquals(resp.status_code, 200)
- self.assertEquals(resp['Content-Type'], 'application/javascript')
- self.assertEquals(resp.content, '%s(%s);' % (callback_func, _flat_repr))
+ self.assertEqual(resp.status_code, status.HTTP_200_OK)
+ self.assertEqual(resp['Content-Type'], 'application/javascript')
+ self.assertEqual(resp.content,
+ ('%s(%s);' % (callback_func, _flat_repr)).encode('ascii'))
if yaml:
@@ -306,7 +312,7 @@ if yaml:
obj = {'foo': ['bar', 'baz']}
renderer = YAMLRenderer()
content = renderer.render(obj, 'application/yaml')
- self.assertEquals(content, _yaml_repr)
+ self.assertEqual(content, _yaml_repr)
def test_render_and_parse(self):
"""
@@ -320,7 +326,7 @@ if yaml:
content = renderer.render(obj, 'application/yaml')
data = parser.parse(StringIO(content))
- self.assertEquals(obj, data)
+ self.assertEqual(obj, data)
class XMLRendererTestCase(TestCase):
@@ -402,6 +408,7 @@ class XMLRendererTestCase(TestCase):
self.assertXMLContains(content, '<sub_name>first</sub_name>')
self.assertXMLContains(content, '<sub_name>second</sub_name>')
+ @unittest.skipUnless(etree, 'defusedxml not installed')
def test_render_and_parse_complex_data(self):
"""
Test XML rendering.
diff --git a/rest_framework/tests/request.py b/rest_framework/tests/request.py
index 4b032405..97e5af20 100644
--- a/rest_framework/tests/request.py
+++ b/rest_framework/tests/request.py
@@ -1,7 +1,7 @@
"""
Tests for content parsing, and form-overloaded content parsing.
"""
-import json
+from __future__ import unicode_literals
from django.contrib.auth.models import User
from django.contrib.auth import authenticate, login, logout
from django.contrib.sessions.middleware import SessionMiddleware
@@ -20,6 +20,8 @@ from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework.views import APIView
+from rest_framework.compat import six
+import json
factory = RequestFactory()
@@ -56,21 +58,29 @@ class TestMethodOverloading(TestCase):
request = Request(factory.post('/', {api_settings.FORM_METHOD_OVERRIDE: 'DELETE'}))
self.assertEqual(request.method, 'DELETE')
+ def test_x_http_method_override_header(self):
+ """
+ POST requests can also be overloaded to another method by setting
+ the X-HTTP-Method-Override header.
+ """
+ request = Request(factory.post('/', {'foo': 'bar'}, HTTP_X_HTTP_METHOD_OVERRIDE='DELETE'))
+ self.assertEqual(request.method, 'DELETE')
+
class TestContentParsing(TestCase):
def test_standard_behaviour_determines_no_content_GET(self):
"""
- Ensure request.DATA returns None for GET request with no content.
+ Ensure request.DATA returns empty QueryDict for GET request.
"""
request = Request(factory.get('/'))
- self.assertEqual(request.DATA, None)
+ self.assertEqual(request.DATA, {})
def test_standard_behaviour_determines_no_content_HEAD(self):
"""
- Ensure request.DATA returns None for HEAD request.
+ Ensure request.DATA returns empty QueryDict for HEAD request.
"""
request = Request(factory.head('/'))
- self.assertEqual(request.DATA, None)
+ self.assertEqual(request.DATA, {})
def test_request_DATA_with_form_content(self):
"""
@@ -79,14 +89,14 @@ class TestContentParsing(TestCase):
data = {'qwerty': 'uiop'}
request = Request(factory.post('/', data))
request.parsers = (FormParser(), MultiPartParser())
- self.assertEqual(request.DATA.items(), data.items())
+ self.assertEqual(list(request.DATA.items()), list(data.items()))
def test_request_DATA_with_text_content(self):
"""
Ensure request.DATA returns content for POST request with
non-form content.
"""
- content = 'qwerty'
+ content = six.b('qwerty')
content_type = 'text/plain'
request = Request(factory.post('/', content, content_type=content_type))
request.parsers = (PlainTextParser(),)
@@ -99,7 +109,7 @@ class TestContentParsing(TestCase):
data = {'qwerty': 'uiop'}
request = Request(factory.post('/', data))
request.parsers = (FormParser(), MultiPartParser())
- self.assertEqual(request.POST.items(), data.items())
+ self.assertEqual(list(request.POST.items()), list(data.items()))
def test_standard_behaviour_determines_form_content_PUT(self):
"""
@@ -117,14 +127,14 @@ class TestContentParsing(TestCase):
request = Request(factory.put('/', data))
request.parsers = (FormParser(), MultiPartParser())
- self.assertEqual(request.DATA.items(), data.items())
+ self.assertEqual(list(request.DATA.items()), list(data.items()))
def test_standard_behaviour_determines_non_form_content_PUT(self):
"""
Ensure request.DATA returns content for PUT request with
non-form content.
"""
- content = 'qwerty'
+ content = six.b('qwerty')
content_type = 'text/plain'
request = Request(factory.put('/', content, content_type=content_type))
request.parsers = (PlainTextParser(), )
diff --git a/rest_framework/tests/response.py b/rest_framework/tests/response.py
index 875f4d42..aecf83f4 100644
--- a/rest_framework/tests/response.py
+++ b/rest_framework/tests/response.py
@@ -1,3 +1,4 @@
+from __future__ import unicode_literals
from django.test import TestCase
from rest_framework.compat import patterns, url, include
from rest_framework.response import Response
@@ -9,6 +10,7 @@ from rest_framework.renderers import (
BrowsableAPIRenderer
)
from rest_framework.settings import api_settings
+from rest_framework.compat import six
class MockPickleRenderer(BaseRenderer):
@@ -22,8 +24,8 @@ class MockJsonRenderer(BaseRenderer):
DUMMYSTATUS = status.HTTP_200_OK
DUMMYCONTENT = 'dummycontent'
-RENDERER_A_SERIALIZER = lambda x: 'Renderer A: %s' % x
-RENDERER_B_SERIALIZER = lambda x: 'Renderer B: %s' % x
+RENDERER_A_SERIALIZER = lambda x: ('Renderer A: %s' % x).encode('ascii')
+RENDERER_B_SERIALIZER = lambda x: ('Renderer B: %s' % x).encode('ascii')
class RendererA(BaseRenderer):
@@ -83,39 +85,39 @@ class RendererIntegrationTests(TestCase):
def test_default_renderer_serializes_content(self):
"""If the Accept header is not set the default renderer should serialize the response."""
resp = self.client.get('/')
- self.assertEquals(resp['Content-Type'], RendererA.media_type)
- self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
- self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type)
+ self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_head_method_serializes_no_content(self):
"""No response must be included in HEAD requests."""
resp = self.client.head('/')
- self.assertEquals(resp.status_code, DUMMYSTATUS)
- self.assertEquals(resp['Content-Type'], RendererA.media_type)
- self.assertEquals(resp.content, '')
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type)
+ self.assertEqual(resp.content, six.b(''))
def test_default_renderer_serializes_content_on_accept_any(self):
"""If the Accept header is set to */* the default renderer should serialize the response."""
resp = self.client.get('/', HTTP_ACCEPT='*/*')
- self.assertEquals(resp['Content-Type'], RendererA.media_type)
- self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
- self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type)
+ self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for the default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
- self.assertEquals(resp['Content-Type'], RendererA.media_type)
- self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
- self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type)
+ self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_non_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for a non-default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
- self.assertEquals(resp['Content-Type'], RendererB.media_type)
- self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_on_accept_query(self):
"""The '_accept' query string should behave in the same way as the Accept header."""
@@ -124,34 +126,34 @@ class RendererIntegrationTests(TestCase):
RendererB.media_type
)
resp = self.client.get('/' + param)
- self.assertEquals(resp['Content-Type'], RendererB.media_type)
- self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_on_format_query(self):
"""If a 'format' query is specified, the renderer with the matching
format attribute should serialize the response."""
resp = self.client.get('/?format=%s' % RendererB.format)
- self.assertEquals(resp['Content-Type'], RendererB.media_type)
- self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_on_format_kwargs(self):
"""If a 'format' keyword arg is specified, the renderer with the matching
format attribute should serialize the response."""
resp = self.client.get('/something.formatb')
- self.assertEquals(resp['Content-Type'], RendererB.media_type)
- self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
"""If both a 'format' query and a matching Accept header specified,
the renderer with the matching format attribute should serialize the response."""
resp = self.client.get('/?format=%s' % RendererB.format,
HTTP_ACCEPT=RendererB.media_type)
- self.assertEquals(resp['Content-Type'], RendererB.media_type)
- self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
class Issue122Tests(TestCase):
diff --git a/rest_framework/tests/reverse.py b/rest_framework/tests/reverse.py
index 8c86e1fb..cb8d8132 100644
--- a/rest_framework/tests/reverse.py
+++ b/rest_framework/tests/reverse.py
@@ -1,3 +1,4 @@
+from __future__ import unicode_literals
from django.test import TestCase
from django.test.client import RequestFactory
from rest_framework.compat import patterns, url
@@ -16,7 +17,7 @@ urlpatterns = patterns('',
class ReverseTests(TestCase):
"""
- Tests for fully qualifed URLs when using `reverse`.
+ Tests for fully qualified URLs when using `reverse`.
"""
urls = 'rest_framework.tests.reverse'
diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py
index bd96ba23..beb372c2 100644
--- a/rest_framework/tests/serializer.py
+++ b/rest_framework/tests/serializer.py
@@ -1,10 +1,12 @@
-import datetime
-import pickle
+from __future__ import unicode_literals
+from django.utils.datastructures import MultiValueDict
from django.test import TestCase
from rest_framework import serializers
from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel,
BlankFieldModel, BlogPost, Book, CallableDefaultValueModel, DefaultValueModel,
ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo)
+import datetime
+import pickle
class SubComment(object):
@@ -54,6 +56,19 @@ class ActionItemSerializer(serializers.ModelSerializer):
model = ActionItem
+class ActionItemSerializerCustomRestore(serializers.ModelSerializer):
+
+ class Meta:
+ model = ActionItem
+
+ def restore_object(self, data, instance=None):
+ if instance is None:
+ return ActionItem(**data)
+ for key, val in data.items():
+ setattr(instance, key, val)
+ return instance
+
+
class PersonSerializer(serializers.ModelSerializer):
info = serializers.Field(source='info')
@@ -76,6 +91,11 @@ class PositiveIntegerAsChoiceSerializer(serializers.ModelSerializer):
fields = ['some_integer']
+class BrokenModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ fields = ['some_field']
+
+
class BasicTests(TestCase):
def setUp(self):
self.comment = Comment(
@@ -92,7 +112,7 @@ class BasicTests(TestCase):
self.expected = {
'email': 'tom@example.com',
'content': 'Happy new year!',
- 'created': datetime.datetime(2012, 1, 1),
+ 'created': '2012-01-01T00:00:00',
'sub_comment': 'And Merry Christmas!'
}
self.person_data = {'name': 'dwight', 'age': 35}
@@ -107,39 +127,39 @@ class BasicTests(TestCase):
'created': None,
'sub_comment': ''
}
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_retrieve(self):
serializer = CommentSerializer(self.comment)
- self.assertEquals(serializer.data, self.expected)
+ self.assertEqual(serializer.data, self.expected)
def test_create(self):
serializer = CommentSerializer(data=self.data)
expected = self.comment
- self.assertEquals(serializer.is_valid(), True)
- self.assertEquals(serializer.object, expected)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, expected)
self.assertFalse(serializer.object is expected)
- self.assertEquals(serializer.data['sub_comment'], 'And Merry Christmas!')
+ self.assertEqual(serializer.data['sub_comment'], 'And Merry Christmas!')
def test_update(self):
serializer = CommentSerializer(self.comment, data=self.data)
expected = self.comment
- self.assertEquals(serializer.is_valid(), True)
- self.assertEquals(serializer.object, expected)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, expected)
self.assertTrue(serializer.object is expected)
- self.assertEquals(serializer.data['sub_comment'], 'And Merry Christmas!')
+ self.assertEqual(serializer.data['sub_comment'], 'And Merry Christmas!')
def test_partial_update(self):
msg = 'Merry New Year!'
partial_data = {'content': msg}
serializer = CommentSerializer(self.comment, data=partial_data)
- self.assertEquals(serializer.is_valid(), False)
+ self.assertEqual(serializer.is_valid(), False)
serializer = CommentSerializer(self.comment, data=partial_data, partial=True)
expected = self.comment
self.assertEqual(serializer.is_valid(), True)
- self.assertEquals(serializer.object, expected)
+ self.assertEqual(serializer.object, expected)
self.assertTrue(serializer.object is expected)
- self.assertEquals(serializer.data['content'], msg)
+ self.assertEqual(serializer.data['content'], msg)
def test_model_fields_as_expected(self):
"""
@@ -147,7 +167,7 @@ class BasicTests(TestCase):
in the Meta data
"""
serializer = PersonSerializer(self.person)
- self.assertEquals(set(serializer.data.keys()),
+ self.assertEqual(set(serializer.data.keys()),
set(['name', 'age', 'info']))
def test_field_with_dictionary(self):
@@ -156,19 +176,45 @@ class BasicTests(TestCase):
"""
serializer = PersonSerializer(self.person)
expected = self.person_data
- self.assertEquals(serializer.data['info'], expected)
+ self.assertEqual(serializer.data['info'], expected)
def test_read_only_fields(self):
"""
Attempting to update fields set as read_only should have no effect.
"""
-
serializer = PersonSerializer(self.person, data={'name': 'dwight', 'age': 99})
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
instance = serializer.save()
- self.assertEquals(serializer.errors, {})
+ self.assertEqual(serializer.errors, {})
# Assert age is unchanged (35)
- self.assertEquals(instance.age, self.person_data['age'])
+ self.assertEqual(instance.age, self.person_data['age'])
+
+
+class DictStyleSerializer(serializers.Serializer):
+ """
+ Note that we don't have any `restore_object` method, so the default
+ case of simply returning a dict will apply.
+ """
+ email = serializers.EmailField()
+
+
+class DictStyleSerializerTests(TestCase):
+ def test_dict_style_deserialize(self):
+ """
+ Ensure serializers can deserialize into a dict.
+ """
+ data = {'email': 'foo@example.com'}
+ serializer = DictStyleSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, data)
+
+ def test_dict_style_serialize(self):
+ """
+ Ensure serializers can serialize dict objects.
+ """
+ data = {'email': 'foo@example.com'}
+ serializer = DictStyleSerializer(data)
+ self.assertEqual(serializer.data, data)
class ValidationTests(TestCase):
@@ -183,18 +229,17 @@ class ValidationTests(TestCase):
'content': 'x' * 1001,
'created': datetime.datetime(2012, 1, 1)
}
- self.actionitem = ActionItem(title='Some to do item',
- )
+ self.actionitem = ActionItem(title='Some to do item',)
def test_create(self):
serializer = CommentSerializer(data=self.data)
- self.assertEquals(serializer.is_valid(), False)
- self.assertEquals(serializer.errors, {'content': [u'Ensure this value has at most 1000 characters (it has 1001).']})
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, {'content': ['Ensure this value has at most 1000 characters (it has 1001).']})
def test_update(self):
serializer = CommentSerializer(self.comment, data=self.data)
- self.assertEquals(serializer.is_valid(), False)
- self.assertEquals(serializer.errors, {'content': [u'Ensure this value has at most 1000 characters (it has 1001).']})
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, {'content': ['Ensure this value has at most 1000 characters (it has 1001).']})
def test_update_missing_field(self):
data = {
@@ -202,8 +247,8 @@ class ValidationTests(TestCase):
'created': datetime.datetime(2012, 1, 1)
}
serializer = CommentSerializer(self.comment, data=data)
- self.assertEquals(serializer.is_valid(), False)
- self.assertEquals(serializer.errors, {'email': [u'This field is required.']})
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, {'email': ['This field is required.']})
def test_missing_bool_with_default(self):
"""Make sure that a boolean value with a 'False' value is not
@@ -213,52 +258,36 @@ class ValidationTests(TestCase):
#No 'done' value.
}
serializer = ActionItemSerializer(self.actionitem, data=data)
- self.assertEquals(serializer.is_valid(), True)
- self.assertEquals(serializer.errors, {})
-
- def test_field_validation(self):
-
- class CommentSerializerWithFieldValidator(CommentSerializer):
-
- def validate_content(self, attrs, source):
- value = attrs[source]
- if "test" not in value:
- raise serializers.ValidationError("Test not in value")
- return attrs
-
- data = {
- 'email': 'tom@example.com',
- 'content': 'A test comment',
- 'created': datetime.datetime(2012, 1, 1)
- }
-
- serializer = CommentSerializerWithFieldValidator(data=data)
- self.assertTrue(serializer.is_valid())
-
- data['content'] = 'This should not validate'
-
- serializer = CommentSerializerWithFieldValidator(data=data)
- self.assertFalse(serializer.is_valid())
- self.assertEquals(serializer.errors, {'content': [u'Test not in value']})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.errors, {})
def test_bad_type_data_is_false(self):
"""
Data of the wrong type is not valid.
"""
data = ['i am', 'a', 'list']
- serializer = CommentSerializer(self.comment, data=data)
- self.assertEquals(serializer.is_valid(), False)
- self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']})
+ serializer = CommentSerializer(self.comment, data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertTrue(isinstance(serializer.errors, list))
+
+ self.assertEqual(
+ serializer.errors,
+ [
+ {'non_field_errors': ['Invalid data']},
+ {'non_field_errors': ['Invalid data']},
+ {'non_field_errors': ['Invalid data']}
+ ]
+ )
data = 'and i am a string'
serializer = CommentSerializer(self.comment, data=data)
- self.assertEquals(serializer.is_valid(), False)
- self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']})
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, {'non_field_errors': ['Invalid data']})
data = 42
serializer = CommentSerializer(self.comment, data=data)
- self.assertEquals(serializer.is_valid(), False)
- self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']})
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, {'non_field_errors': ['Invalid data']})
def test_cross_field_validation(self):
@@ -282,23 +311,37 @@ class ValidationTests(TestCase):
serializer = CommentSerializerWithCrossFieldValidator(data=data)
self.assertFalse(serializer.is_valid())
- self.assertEquals(serializer.errors, {'non_field_errors': [u'Email address not in content']})
+ self.assertEqual(serializer.errors, {'non_field_errors': ['Email address not in content']})
def test_null_is_true_fields(self):
"""
Omitting a value for null-field should validate.
"""
serializer = PersonSerializer(data={'name': 'marko'})
- self.assertEquals(serializer.is_valid(), True)
- self.assertEquals(serializer.errors, {})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.errors, {})
def test_modelserializer_max_length_exceeded(self):
data = {
'title': 'x' * 201,
}
serializer = ActionItemSerializer(data=data)
- self.assertEquals(serializer.is_valid(), False)
- self.assertEquals(serializer.errors, {'title': [u'Ensure this value has at most 200 characters (it has 201).']})
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, {'title': ['Ensure this value has at most 200 characters (it has 201).']})
+
+ def test_modelserializer_max_length_exceeded_with_custom_restore(self):
+ """
+ When overriding ModelSerializer.restore_object, validation tests should still apply.
+ Regression test for #623.
+
+ https://github.com/tomchristie/django-rest-framework/pull/623
+ """
+ data = {
+ 'title': 'x' * 201,
+ }
+ serializer = ActionItemSerializerCustomRestore(data=data)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, {'title': ['Ensure this value has at most 200 characters (it has 201).']})
def test_default_modelfield_max_length_exceeded(self):
data = {
@@ -306,15 +349,99 @@ class ValidationTests(TestCase):
'info': 'x' * 13,
}
serializer = ActionItemSerializer(data=data)
- self.assertEquals(serializer.is_valid(), False)
- self.assertEquals(serializer.errors, {'info': [u'Ensure this value has at most 12 characters (it has 13).']})
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, {'info': ['Ensure this value has at most 12 characters (it has 13).']})
+
+ def test_datetime_validation_failure(self):
+ """
+ Test DateTimeField validation errors on non-str values.
+ Regression test for #669.
+
+ https://github.com/tomchristie/django-rest-framework/issues/669
+ """
+ data = self.data
+ data['created'] = 0
+
+ serializer = CommentSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), False)
+
+ self.assertIn('created', serializer.errors)
+
+ def test_missing_model_field_exception_msg(self):
+ """
+ Assert that a meaningful exception message is outputted when the model
+ field is missing (e.g. when mistyping ``model``).
+ """
+ try:
+ serializer = BrokenModelSerializer()
+ except AssertionError as e:
+ self.assertEqual(e.args[0], "Serializer class 'BrokenModelSerializer' is missing 'model' Meta option")
+ except:
+ self.fail('Wrong exception type thrown.')
+
+
+class CustomValidationTests(TestCase):
+ class CommentSerializerWithFieldValidator(CommentSerializer):
+
+ def validate_email(self, attrs, source):
+ value = attrs[source]
+
+ return attrs
+
+ def validate_content(self, attrs, source):
+ value = attrs[source]
+ if "test" not in value:
+ raise serializers.ValidationError("Test not in value")
+ return attrs
+
+ def test_field_validation(self):
+ data = {
+ 'email': 'tom@example.com',
+ 'content': 'A test comment',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+
+ serializer = self.CommentSerializerWithFieldValidator(data=data)
+ self.assertTrue(serializer.is_valid())
+
+ data['content'] = 'This should not validate'
+
+ serializer = self.CommentSerializerWithFieldValidator(data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'content': ['Test not in value']})
+
+ def test_missing_data(self):
+ """
+ Make sure that validate_content isn't called if the field is missing
+ """
+ incomplete_data = {
+ 'email': 'tom@example.com',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+ serializer = self.CommentSerializerWithFieldValidator(data=incomplete_data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'content': ['This field is required.']})
+
+ def test_wrong_data(self):
+ """
+ Make sure that validate_content isn't called if the field input is wrong
+ """
+ wrong_data = {
+ 'email': 'not an email',
+ 'content': 'A test comment',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+ serializer = self.CommentSerializerWithFieldValidator(data=wrong_data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'email': ['Enter a valid e-mail address.']})
class PositiveIntegerAsChoiceTests(TestCase):
def test_positive_integer_in_json_is_correctly_parsed(self):
- data = {'some_integer':1}
+ data = {'some_integer': 1}
serializer = PositiveIntegerAsChoiceSerializer(data=data)
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
+
class ModelValidationTests(TestCase):
def test_validate_unique(self):
@@ -326,7 +453,7 @@ class ModelValidationTests(TestCase):
serializer.save()
second_serializer = AlbumsSerializer(data={'title': 'a'})
self.assertFalse(second_serializer.is_valid())
- self.assertEqual(second_serializer.errors, {'title': [u'Album with this Title already exists.']})
+ self.assertEqual(second_serializer.errors, {'title': ['Album with this Title already exists.']})
def test_foreign_key_with_partial(self):
"""
@@ -364,15 +491,15 @@ class RegexValidationTest(TestCase):
def test_create_failed(self):
serializer = BookSerializer(data={'isbn': '1234567890'})
self.assertFalse(serializer.is_valid())
- self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']})
+ self.assertEqual(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']})
serializer = BookSerializer(data={'isbn': '12345678901234'})
self.assertFalse(serializer.is_valid())
- self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']})
+ self.assertEqual(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']})
serializer = BookSerializer(data={'isbn': 'abcdefghijklm'})
self.assertFalse(serializer.is_valid())
- self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']})
+ self.assertEqual(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']})
def test_create_success(self):
serializer = BookSerializer(data={'isbn': '1234567890123'})
@@ -417,7 +544,7 @@ class ManyToManyTests(TestCase):
"""
serializer = self.serializer_class(instance=self.instance)
expected = self.data
- self.assertEquals(serializer.data, expected)
+ self.assertEqual(serializer.data, expected)
def test_create(self):
"""
@@ -425,11 +552,11 @@ class ManyToManyTests(TestCase):
"""
data = {'rel': [self.anchor.id]}
serializer = self.serializer_class(data=data)
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
instance = serializer.save()
- self.assertEquals(len(ManyToManyModel.objects.all()), 2)
- self.assertEquals(instance.pk, 2)
- self.assertEquals(list(instance.rel.all()), [self.anchor])
+ self.assertEqual(len(ManyToManyModel.objects.all()), 2)
+ self.assertEqual(instance.pk, 2)
+ self.assertEqual(list(instance.rel.all()), [self.anchor])
def test_update(self):
"""
@@ -439,11 +566,11 @@ class ManyToManyTests(TestCase):
new_anchor.save()
data = {'rel': [self.anchor.id, new_anchor.id]}
serializer = self.serializer_class(self.instance, data=data)
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
instance = serializer.save()
- self.assertEquals(len(ManyToManyModel.objects.all()), 1)
- self.assertEquals(instance.pk, 1)
- self.assertEquals(list(instance.rel.all()), [self.anchor, new_anchor])
+ self.assertEqual(len(ManyToManyModel.objects.all()), 1)
+ self.assertEqual(instance.pk, 1)
+ self.assertEqual(list(instance.rel.all()), [self.anchor, new_anchor])
def test_create_empty_relationship(self):
"""
@@ -452,11 +579,11 @@ class ManyToManyTests(TestCase):
"""
data = {'rel': []}
serializer = self.serializer_class(data=data)
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
instance = serializer.save()
- self.assertEquals(len(ManyToManyModel.objects.all()), 2)
- self.assertEquals(instance.pk, 2)
- self.assertEquals(list(instance.rel.all()), [])
+ self.assertEqual(len(ManyToManyModel.objects.all()), 2)
+ self.assertEqual(instance.pk, 2)
+ self.assertEqual(list(instance.rel.all()), [])
def test_update_empty_relationship(self):
"""
@@ -467,11 +594,11 @@ class ManyToManyTests(TestCase):
new_anchor.save()
data = {'rel': []}
serializer = self.serializer_class(self.instance, data=data)
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
instance = serializer.save()
- self.assertEquals(len(ManyToManyModel.objects.all()), 1)
- self.assertEquals(instance.pk, 1)
- self.assertEquals(list(instance.rel.all()), [])
+ self.assertEqual(len(ManyToManyModel.objects.all()), 1)
+ self.assertEqual(instance.pk, 1)
+ self.assertEqual(list(instance.rel.all()), [])
def test_create_empty_relationship_flat_data(self):
"""
@@ -479,19 +606,20 @@ class ManyToManyTests(TestCase):
containing no items, using a representation that does not support
lists (eg form data).
"""
- data = {'rel': ''}
+ data = MultiValueDict()
+ data.setlist('rel', [''])
serializer = self.serializer_class(data=data)
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
instance = serializer.save()
- self.assertEquals(len(ManyToManyModel.objects.all()), 2)
- self.assertEquals(instance.pk, 2)
- self.assertEquals(list(instance.rel.all()), [])
+ self.assertEqual(len(ManyToManyModel.objects.all()), 2)
+ self.assertEqual(instance.pk, 2)
+ self.assertEqual(list(instance.rel.all()), [])
class ReadOnlyManyToManyTests(TestCase):
def setUp(self):
class ReadOnlyManyToManySerializer(serializers.ModelSerializer):
- rel = serializers.ManyRelatedField(read_only=True)
+ rel = serializers.RelatedField(many=True, read_only=True)
class Meta:
model = ReadOnlyManyToManyModel
@@ -519,12 +647,12 @@ class ReadOnlyManyToManyTests(TestCase):
new_anchor.save()
data = {'rel': [self.anchor.id, new_anchor.id]}
serializer = self.serializer_class(self.instance, data=data)
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
instance = serializer.save()
- self.assertEquals(len(ReadOnlyManyToManyModel.objects.all()), 1)
- self.assertEquals(instance.pk, 1)
+ self.assertEqual(len(ReadOnlyManyToManyModel.objects.all()), 1)
+ self.assertEqual(instance.pk, 1)
# rel is still as original (1 entry)
- self.assertEquals(list(instance.rel.all()), [self.anchor])
+ self.assertEqual(list(instance.rel.all()), [self.anchor])
def test_update_without_relationship(self):
"""
@@ -535,12 +663,12 @@ class ReadOnlyManyToManyTests(TestCase):
new_anchor.save()
data = {}
serializer = self.serializer_class(self.instance, data=data)
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
instance = serializer.save()
- self.assertEquals(len(ReadOnlyManyToManyModel.objects.all()), 1)
- self.assertEquals(instance.pk, 1)
+ self.assertEqual(len(ReadOnlyManyToManyModel.objects.all()), 1)
+ self.assertEqual(instance.pk, 1)
# rel is still as original (1 entry)
- self.assertEquals(list(instance.rel.all()), [self.anchor])
+ self.assertEqual(list(instance.rel.all()), [self.anchor])
class DefaultValueTests(TestCase):
@@ -555,35 +683,35 @@ class DefaultValueTests(TestCase):
def test_create_using_default(self):
data = {}
serializer = self.serializer_class(data=data)
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
instance = serializer.save()
- self.assertEquals(len(self.objects.all()), 1)
- self.assertEquals(instance.pk, 1)
- self.assertEquals(instance.text, 'foobar')
+ self.assertEqual(len(self.objects.all()), 1)
+ self.assertEqual(instance.pk, 1)
+ self.assertEqual(instance.text, 'foobar')
def test_create_overriding_default(self):
data = {'text': 'overridden'}
serializer = self.serializer_class(data=data)
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
instance = serializer.save()
- self.assertEquals(len(self.objects.all()), 1)
- self.assertEquals(instance.pk, 1)
- self.assertEquals(instance.text, 'overridden')
+ self.assertEqual(len(self.objects.all()), 1)
+ self.assertEqual(instance.pk, 1)
+ self.assertEqual(instance.text, 'overridden')
def test_partial_update_default(self):
""" Regression test for issue #532 """
data = {'text': 'overridden'}
serializer = self.serializer_class(data=data, partial=True)
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
instance = serializer.save()
data = {'extra': 'extra_value'}
serializer = self.serializer_class(instance=instance, data=data, partial=True)
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
instance = serializer.save()
- self.assertEquals(instance.extra, 'extra_value')
- self.assertEquals(instance.text, 'overridden')
+ self.assertEqual(instance.extra, 'extra_value')
+ self.assertEqual(instance.text, 'overridden')
class CallableDefaultValueTests(TestCase):
@@ -598,20 +726,20 @@ class CallableDefaultValueTests(TestCase):
def test_create_using_default(self):
data = {}
serializer = self.serializer_class(data=data)
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
instance = serializer.save()
- self.assertEquals(len(self.objects.all()), 1)
- self.assertEquals(instance.pk, 1)
- self.assertEquals(instance.text, 'foobar')
+ self.assertEqual(len(self.objects.all()), 1)
+ self.assertEqual(instance.pk, 1)
+ self.assertEqual(instance.text, 'foobar')
def test_create_overriding_default(self):
data = {'text': 'overridden'}
serializer = self.serializer_class(data=data)
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
instance = serializer.save()
- self.assertEquals(len(self.objects.all()), 1)
- self.assertEquals(instance.pk, 1)
- self.assertEquals(instance.text, 'overridden')
+ self.assertEqual(len(self.objects.all()), 1)
+ self.assertEqual(instance.pk, 1)
+ self.assertEqual(instance.text, 'overridden')
class ManyRelatedTests(TestCase):
@@ -660,6 +788,9 @@ class ManyRelatedTests(TestCase):
class RelatedTraversalTest(TestCase):
def test_nested_traversal(self):
+ """
+ Source argument should support dotted.source notation.
+ """
user = Person.objects.create(name="django")
post = BlogPost.objects.create(title="Test blog post", writer=user)
post.blogpostcomment_set.create(text="I love this blog post")
@@ -686,11 +817,11 @@ class RelatedTraversalTest(TestCase):
serializer = BlogPostSerializer(instance=post)
expected = {
- 'title': u'Test blog post',
+ 'title': 'Test blog post',
'comments': [{
- 'text': u'I love this blog post',
+ 'text': 'I love this blog post',
'post_owner': {
- "name": u"django",
+ "name": "django",
"age": None
}
}]
@@ -698,6 +829,41 @@ class RelatedTraversalTest(TestCase):
self.assertEqual(serializer.data, expected)
+ def test_nested_traversal_with_none(self):
+ """
+ If a component of the dotted.source is None, return None for the field.
+ """
+ from rest_framework.tests.models import NullableForeignKeySource
+ instance = NullableForeignKeySource.objects.create(name='Source with null FK')
+
+ class NullableSourceSerializer(serializers.Serializer):
+ target_name = serializers.Field(source='target.name')
+
+ serializer = NullableSourceSerializer(instance=instance)
+
+ expected = {
+ 'target_name': None,
+ }
+
+ self.assertEqual(serializer.data, expected)
+
+ def test_queryset_nested_traversal(self):
+ """
+ Relational fields should be able to use methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+
+ class QuerysetMethodSerializer(serializers.Serializer):
+ blogposts = serializers.RelatedField(many=True, source='get_all_blogposts')
+
+ class ClassWithQuerysetMethod(object):
+ def get_all_blogposts(self):
+ return BlogPost.objects
+
+ obj = ClassWithQuerysetMethod()
+ serializer = QuerysetMethodSerializer(obj)
+ self.assertEqual(serializer.data, {'blogposts': ['BlogPost object']})
+
class SerializerMethodFieldTests(TestCase):
def setUp(self):
@@ -725,8 +891,8 @@ class SerializerMethodFieldTests(TestCase):
serializer = self.serializer_class(source_data)
expected = {
- 'beep': u'hello!',
- 'boop': [u'a', u'b', u'c'],
+ 'beep': 'hello!',
+ 'boop': ['a', 'b', 'c'],
'boop_count': 3,
}
@@ -742,7 +908,7 @@ class BlankFieldTests(TestCase):
model = BlankFieldModel
class BlankFieldSerializer(serializers.Serializer):
- title = serializers.CharField(blank=True)
+ title = serializers.CharField(required=False)
class NotBlankFieldModelSerializer(serializers.ModelSerializer):
class Meta:
@@ -759,15 +925,15 @@ class BlankFieldTests(TestCase):
def test_create_blank_field(self):
serializer = self.serializer_class(data=self.data)
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
def test_create_model_blank_field(self):
serializer = self.model_serializer_class(data=self.data)
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
def test_create_model_null_field(self):
serializer = self.model_serializer_class(data={'title': None})
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
def test_create_not_blank_field(self):
"""
@@ -775,7 +941,7 @@ class BlankFieldTests(TestCase):
is considered invalid in a non-model serializer
"""
serializer = self.not_blank_serializer_class(data=self.data)
- self.assertEquals(serializer.is_valid(), False)
+ self.assertEqual(serializer.is_valid(), False)
def test_create_model_not_blank_field(self):
"""
@@ -783,11 +949,11 @@ class BlankFieldTests(TestCase):
is considered invalid in a model serializer
"""
serializer = self.not_blank_model_serializer_class(data=self.data)
- self.assertEquals(serializer.is_valid(), False)
+ self.assertEqual(serializer.is_valid(), False)
- def test_create_model_null_field(self):
+ def test_create_model_empty_field(self):
serializer = self.model_serializer_class(data={})
- self.assertEquals(serializer.is_valid(), True)
+ self.assertEqual(serializer.is_valid(), True)
#test for issue #460
@@ -811,7 +977,21 @@ class SerializerPickleTests(TestCase):
class Meta:
model = Person
fields = ('name', 'age')
- pickle.dumps(InnerPersonSerializer(Person(name="Noah", age=950)).data)
+ pickle.dumps(InnerPersonSerializer(Person(name="Noah", age=950)).data, 0)
+
+ def test_getstate_method_should_not_return_none(self):
+ """
+ Regression test for #645.
+ """
+ data = serializers.DictWithMetadata({1: 1})
+ self.assertEqual(data.__getstate__(), serializers.SortedDict({1: 1}))
+
+ def test_serializer_data_is_pickleable(self):
+ """
+ Another regression test for #645.
+ """
+ data = serializers.SortedDictWithMetadata({1: 1})
+ repr(pickle.loads(pickle.dumps(data, 0)))
class DepthTest(TestCase):
@@ -825,8 +1005,8 @@ class DepthTest(TestCase):
depth = 1
serializer = BlogPostSerializer(instance=post)
- expected = {'id': 1, 'title': u'Test blog post',
- 'writer': {'id': 1, 'name': u'django', 'age': 1}}
+ expected = {'id': 1, 'title': 'Test blog post',
+ 'writer': {'id': 1, 'name': 'django', 'age': 1}}
self.assertEqual(serializer.data, expected)
@@ -845,8 +1025,8 @@ class DepthTest(TestCase):
model = BlogPost
serializer = BlogPostSerializer(instance=post)
- expected = {'id': 1, 'title': u'Test blog post',
- 'writer': {'id': 1, 'name': u'django', 'age': 1}}
+ expected = {'id': 1, 'title': 'Test blog post',
+ 'writer': {'id': 1, 'name': 'django', 'age': 1}}
self.assertEqual(serializer.data, expected)
@@ -901,3 +1081,32 @@ class NestedSerializerContextTests(TestCase):
# This will raise RuntimeError if context doesn't get passed correctly to the nested Serializers
AlbumCollectionSerializer(album_collection, context={'context_item': 'album context'}).data
+
+
+class DeserializeListTestCase(TestCase):
+
+ def setUp(self):
+ self.data = {
+ 'email': 'nobody@nowhere.com',
+ 'content': 'This is some test content',
+ 'created': datetime.datetime(2013, 3, 7),
+ }
+
+ def test_no_errors(self):
+ data = [self.data.copy() for x in range(0, 3)]
+ serializer = CommentSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertTrue(isinstance(serializer.object, list))
+ self.assertTrue(
+ all((isinstance(item, Comment) for item in serializer.object))
+ )
+
+ def test_errors_return_as_list(self):
+ invalid_item = self.data.copy()
+ invalid_item['email'] = ''
+ data = [self.data.copy(), invalid_item, self.data.copy()]
+
+ serializer = CommentSerializer(data=data)
+ self.assertFalse(serializer.is_valid())
+ expected = [{}, {'email': ['This field is required.']}, {}]
+ self.assertEqual(serializer.errors, expected)
diff --git a/rest_framework/tests/settings.py b/rest_framework/tests/settings.py
index 0293fdc3..857375c2 100644
--- a/rest_framework/tests/settings.py
+++ b/rest_framework/tests/settings.py
@@ -1,4 +1,5 @@
"""Tests for the settings module"""
+from __future__ import unicode_literals
from django.test import TestCase
from rest_framework.settings import APISettings, DEFAULTS, IMPORT_STRINGS
diff --git a/rest_framework/tests/status.py b/rest_framework/tests/status.py
index 30df5cef..e1644a6b 100644
--- a/rest_framework/tests/status.py
+++ b/rest_framework/tests/status.py
@@ -1,4 +1,5 @@
"""Tests for the status module"""
+from __future__ import unicode_literals
from django.test import TestCase
from rest_framework import status
@@ -8,5 +9,5 @@ class TestStatus(TestCase):
def test_status(self):
"""Ensure the status module is present and correct."""
- self.assertEquals(200, status.HTTP_200_OK)
- self.assertEquals(404, status.HTTP_404_NOT_FOUND)
+ self.assertEqual(200, status.HTTP_200_OK)
+ self.assertEqual(404, status.HTTP_404_NOT_FOUND)
diff --git a/rest_framework/tests/testcases.py b/rest_framework/tests/testcases.py
index 97f492ff..f8c2579e 100644
--- a/rest_framework/tests/testcases.py
+++ b/rest_framework/tests/testcases.py
@@ -1,4 +1,5 @@
# http://djangosnippets.org/snippets/1011/
+from __future__ import unicode_literals
from django.conf import settings
from django.core.management import call_command
from django.db.models import loading
diff --git a/rest_framework/tests/tests.py b/rest_framework/tests/tests.py
index adeaf6da..08f88e11 100644
--- a/rest_framework/tests/tests.py
+++ b/rest_framework/tests/tests.py
@@ -2,6 +2,7 @@
Force import of all modules in this package in order to get the standard test
runner to pick up the tests. Yowzers.
"""
+from __future__ import unicode_literals
import os
modules = [filename.rsplit('.', 1)[0]
diff --git a/rest_framework/tests/throttling.py b/rest_framework/tests/throttling.py
index 4b98b941..11cbd8eb 100644
--- a/rest_framework/tests/throttling.py
+++ b/rest_framework/tests/throttling.py
@@ -1,11 +1,10 @@
"""
Tests for the throttling implementations in the permissions module.
"""
-
+from __future__ import unicode_literals
from django.test import TestCase
from django.contrib.auth.models import User
from django.core.cache import cache
-
from django.test.client import RequestFactory
from rest_framework.views import APIView
from rest_framework.throttling import UserRateThrottle
@@ -104,7 +103,7 @@ class ThrottlingTests(TestCase):
self.set_throttle_timer(view, timer)
response = view.as_view()(request)
if expect is not None:
- self.assertEquals(response['X-Throttle-Wait-Seconds'], expect)
+ self.assertEqual(response['X-Throttle-Wait-Seconds'], expect)
else:
self.assertFalse('X-Throttle-Wait-Seconds' in response)
diff --git a/rest_framework/tests/urlpatterns.py b/rest_framework/tests/urlpatterns.py
new file mode 100644
index 00000000..29ed4a96
--- /dev/null
+++ b/rest_framework/tests/urlpatterns.py
@@ -0,0 +1,76 @@
+from __future__ import unicode_literals
+from collections import namedtuple
+from django.core import urlresolvers
+from django.test import TestCase
+from django.test.client import RequestFactory
+from rest_framework.compat import patterns, url, include
+from rest_framework.urlpatterns import format_suffix_patterns
+
+
+# A container class for test paths for the test case
+URLTestPath = namedtuple('URLTestPath', ['path', 'args', 'kwargs'])
+
+
+def dummy_view(request, *args, **kwargs):
+ pass
+
+
+class FormatSuffixTests(TestCase):
+ """
+ Tests `format_suffix_patterns` against different URLPatterns to ensure the URLs still resolve properly, including any captured parameters.
+ """
+ def _resolve_urlpatterns(self, urlpatterns, test_paths):
+ factory = RequestFactory()
+ try:
+ urlpatterns = format_suffix_patterns(urlpatterns)
+ except Exception:
+ self.fail("Failed to apply `format_suffix_patterns` on the supplied urlpatterns")
+ resolver = urlresolvers.RegexURLResolver(r'^/', urlpatterns)
+ for test_path in test_paths:
+ request = factory.get(test_path.path)
+ try:
+ callback, callback_args, callback_kwargs = resolver.resolve(request.path_info)
+ except Exception:
+ self.fail("Failed to resolve URL: %s" % request.path_info)
+ self.assertEqual(callback_args, test_path.args)
+ self.assertEqual(callback_kwargs, test_path.kwargs)
+
+ def test_format_suffix(self):
+ urlpatterns = patterns(
+ '',
+ url(r'^test$', dummy_view),
+ )
+ test_paths = [
+ URLTestPath('/test', (), {}),
+ URLTestPath('/test.api', (), {'format': 'api'}),
+ URLTestPath('/test.asdf', (), {'format': 'asdf'}),
+ ]
+ self._resolve_urlpatterns(urlpatterns, test_paths)
+
+ def test_default_args(self):
+ urlpatterns = patterns(
+ '',
+ url(r'^test$', dummy_view, {'foo': 'bar'}),
+ )
+ test_paths = [
+ URLTestPath('/test', (), {'foo': 'bar', }),
+ URLTestPath('/test.api', (), {'foo': 'bar', 'format': 'api'}),
+ URLTestPath('/test.asdf', (), {'foo': 'bar', 'format': 'asdf'}),
+ ]
+ self._resolve_urlpatterns(urlpatterns, test_paths)
+
+ def test_included_urls(self):
+ nested_patterns = patterns(
+ '',
+ url(r'^path$', dummy_view)
+ )
+ urlpatterns = patterns(
+ '',
+ url(r'^test/', include(nested_patterns), {'foo': 'bar'}),
+ )
+ test_paths = [
+ URLTestPath('/test/path', (), {'foo': 'bar', }),
+ URLTestPath('/test/path.api', (), {'foo': 'bar', 'format': 'api'}),
+ URLTestPath('/test/path.asdf', (), {'foo': 'bar', 'format': 'asdf'}),
+ ]
+ self._resolve_urlpatterns(urlpatterns, test_paths)
diff --git a/rest_framework/tests/utils.py b/rest_framework/tests/utils.py
index 3906adb9..8c87917d 100644
--- a/rest_framework/tests/utils.py
+++ b/rest_framework/tests/utils.py
@@ -1,9 +1,10 @@
-from django.test.client import RequestFactory, FakePayload
+from __future__ import unicode_literals
+from django.test.client import FakePayload, Client as _Client, RequestFactory as _RequestFactory
from django.test.client import MULTIPART_CONTENT
-from urlparse import urlparse
+from rest_framework.compat import urlparse
-class RequestFactory(RequestFactory):
+class RequestFactory(_RequestFactory):
def __init__(self, **defaults):
super(RequestFactory, self).__init__(**defaults)
@@ -14,7 +15,7 @@ class RequestFactory(RequestFactory):
patch_data = self._encode_data(data, content_type)
- parsed = urlparse(path)
+ parsed = urlparse.urlparse(path)
r = {
'CONTENT_LENGTH': len(patch_data),
'CONTENT_TYPE': content_type,
@@ -25,3 +26,15 @@ class RequestFactory(RequestFactory):
}
r.update(extra)
return self.request(**r)
+
+
+class Client(_Client, RequestFactory):
+ def patch(self, path, data={}, content_type=MULTIPART_CONTENT,
+ follow=False, **extra):
+ """
+ Send a resource to the server using PATCH.
+ """
+ response = super(Client, self).patch(path, data=data, content_type=content_type, **extra)
+ if follow:
+ response = self._handle_redirects(response, **extra)
+ return response
diff --git a/rest_framework/tests/validation.py b/rest_framework/tests/validation.py
new file mode 100644
index 00000000..cbdd6515
--- /dev/null
+++ b/rest_framework/tests/validation.py
@@ -0,0 +1,65 @@
+from __future__ import unicode_literals
+from django.db import models
+from django.test import TestCase
+from rest_framework import generics, serializers, status
+from rest_framework.tests.utils import RequestFactory
+import json
+
+factory = RequestFactory()
+
+
+# Regression for #666
+
+class ValidationModel(models.Model):
+ blank_validated_field = models.CharField(max_length=255)
+
+
+class ValidationModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ValidationModel
+ fields = ('blank_validated_field',)
+ read_only_fields = ('blank_validated_field',)
+
+
+class UpdateValidationModel(generics.RetrieveUpdateDestroyAPIView):
+ model = ValidationModel
+ serializer_class = ValidationModelSerializer
+
+
+class TestPreSaveValidationExclusions(TestCase):
+ def test_pre_save_validation_exclusions(self):
+ """
+ Somewhat weird test case to ensure that we don't perform model
+ validation on read only fields.
+ """
+ obj = ValidationModel.objects.create(blank_validated_field='')
+ request = factory.put('/', json.dumps({}),
+ content_type='application/json')
+ view = UpdateValidationModel().as_view()
+ response = view(request, pk=obj.pk).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+
+# Regression for #653
+
+class ShouldValidateModel(models.Model):
+ should_validate_field = models.CharField(max_length=255)
+
+
+class ShouldValidateModelSerializer(serializers.ModelSerializer):
+ renamed = serializers.CharField(source='should_validate_field', required=False)
+
+ class Meta:
+ model = ShouldValidateModel
+ fields = ('renamed',)
+
+
+class TestPreSaveValidationExclusions(TestCase):
+ def test_renamed_fields_are_model_validated(self):
+ """
+ Ensure fields with 'source' applied do get still get model validation.
+ """
+ # We've set `required=False` on the serializer, but the model
+ # does not have `blank=True`, so this serializer should not validate.
+ serializer = ShouldValidateModelSerializer(data={'renamed': ''})
+ self.assertEqual(serializer.is_valid(), False)
diff --git a/rest_framework/tests/validators.py b/rest_framework/tests/validators.py
deleted file mode 100644
index c032985e..00000000
--- a/rest_framework/tests/validators.py
+++ /dev/null
@@ -1,329 +0,0 @@
-# from django import forms
-# from django.db import models
-# from django.test import TestCase
-# from rest_framework.response import ImmediateResponse
-# from rest_framework.views import View
-
-
-# class TestDisabledValidations(TestCase):
-# """Tests on FormValidator with validation disabled by setting form to None"""
-
-# def test_disabled_form_validator_returns_content_unchanged(self):
-# """If the view's form attribute is None then FormValidator(view).validate_request(content, None)
-# should just return the content unmodified."""
-# class DisabledFormResource(FormResource):
-# form = None
-
-# class MockView(View):
-# resource = DisabledFormResource
-
-# view = MockView()
-# content = {'qwerty': 'uiop'}
-# self.assertEqual(FormResource(view).validate_request(content, None), content)
-
-# def test_disabled_form_validator_get_bound_form_returns_none(self):
-# """If the view's form attribute is None on then
-# FormValidator(view).get_bound_form(content) should just return None."""
-# class DisabledFormResource(FormResource):
-# form = None
-
-# class MockView(View):
-# resource = DisabledFormResource
-
-# view = MockView()
-# content = {'qwerty': 'uiop'}
-# self.assertEqual(FormResource(view).get_bound_form(content), None)
-
-# def test_disabled_model_form_validator_returns_content_unchanged(self):
-# """If the view's form is None and does not have a Resource with a model set then
-# ModelFormValidator(view).validate_request(content, None) should just return the content unmodified."""
-
-# class DisabledModelFormView(View):
-# resource = ModelResource
-
-# view = DisabledModelFormView()
-# content = {'qwerty': 'uiop'}
-# self.assertEqual(ModelResource(view).get_bound_form(content), None)
-
-# def test_disabled_model_form_validator_get_bound_form_returns_none(self):
-# """If the form attribute is None on FormValidatorMixin then get_bound_form(content) should just return None."""
-# class DisabledModelFormView(View):
-# resource = ModelResource
-
-# view = DisabledModelFormView()
-# content = {'qwerty': 'uiop'}
-# self.assertEqual(ModelResource(view).get_bound_form(content), None)
-
-
-# class TestNonFieldErrors(TestCase):
-# """Tests against form validation errors caused by non-field errors. (eg as might be caused by some custom form validation)"""
-
-# def test_validate_failed_due_to_non_field_error_returns_appropriate_message(self):
-# """If validation fails with a non-field error, ensure the response a non-field error"""
-# class MockForm(forms.Form):
-# field1 = forms.CharField(required=False)
-# field2 = forms.CharField(required=False)
-# ERROR_TEXT = 'You may not supply both field1 and field2'
-
-# def clean(self):
-# if 'field1' in self.cleaned_data and 'field2' in self.cleaned_data:
-# raise forms.ValidationError(self.ERROR_TEXT)
-# return self.cleaned_data
-
-# class MockResource(FormResource):
-# form = MockForm
-
-# class MockView(View):
-# pass
-
-# view = MockView()
-# content = {'field1': 'example1', 'field2': 'example2'}
-# try:
-# MockResource(view).validate_request(content, None)
-# except ImmediateResponse, exc:
-# response = exc.response
-# self.assertEqual(response.raw_content, {'errors': [MockForm.ERROR_TEXT]})
-# else:
-# self.fail('ImmediateResponse was not raised')
-
-
-# class TestFormValidation(TestCase):
-# """Tests which check basic form validation.
-# Also includes the same set of tests with a ModelFormValidator for which the form has been explicitly set.
-# (ModelFormValidator should behave as FormValidator if a form is set rather than relying on the default ModelForm)"""
-# def setUp(self):
-# class MockForm(forms.Form):
-# qwerty = forms.CharField(required=True)
-
-# class MockFormResource(FormResource):
-# form = MockForm
-
-# class MockModelResource(ModelResource):
-# form = MockForm
-
-# class MockFormView(View):
-# resource = MockFormResource
-
-# class MockModelFormView(View):
-# resource = MockModelResource
-
-# self.MockFormResource = MockFormResource
-# self.MockModelResource = MockModelResource
-# self.MockFormView = MockFormView
-# self.MockModelFormView = MockModelFormView
-
-# def validation_returns_content_unchanged_if_already_valid_and_clean(self, validator):
-# """If the content is already valid and clean then validate(content) should just return the content unmodified."""
-# content = {'qwerty': 'uiop'}
-# self.assertEqual(validator.validate_request(content, None), content)
-
-# def validation_failure_raises_response_exception(self, validator):
-# """If form validation fails a ResourceException 400 (Bad Request) should be raised."""
-# content = {}
-# self.assertRaises(ImmediateResponse, validator.validate_request, content, None)
-
-# def validation_does_not_allow_extra_fields_by_default(self, validator):
-# """If some (otherwise valid) content includes fields that are not in the form then validation should fail.
-# It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up
-# broken clients more easily (eg submitting content with a misnamed field)"""
-# content = {'qwerty': 'uiop', 'extra': 'extra'}
-# self.assertRaises(ImmediateResponse, validator.validate_request, content, None)
-
-# def validation_allows_extra_fields_if_explicitly_set(self, validator):
-# """If we include an allowed_extra_fields paramater on _validate, then allow fields with those names."""
-# content = {'qwerty': 'uiop', 'extra': 'extra'}
-# validator._validate(content, None, allowed_extra_fields=('extra',))
-
-# def validation_allows_unknown_fields_if_explicitly_allowed(self, validator):
-# """If we set ``unknown_form_fields`` on the form resource, then don't
-# raise errors on unexpected request data"""
-# content = {'qwerty': 'uiop', 'extra': 'extra'}
-# validator.allow_unknown_form_fields = True
-# self.assertEqual({'qwerty': u'uiop'},
-# validator.validate_request(content, None),
-# "Resource didn't accept unknown fields.")
-# validator.allow_unknown_form_fields = False
-
-# def validation_does_not_require_extra_fields_if_explicitly_set(self, validator):
-# """If we include an allowed_extra_fields paramater on _validate, then do not fail if we do not have fields with those names."""
-# content = {'qwerty': 'uiop'}
-# self.assertEqual(validator._validate(content, None, allowed_extra_fields=('extra',)), content)
-
-# def validation_failed_due_to_no_content_returns_appropriate_message(self, validator):
-# """If validation fails due to no content, ensure the response contains a single non-field error"""
-# content = {}
-# try:
-# validator.validate_request(content, None)
-# except ImmediateResponse, exc:
-# response = exc.response
-# self.assertEqual(response.raw_content, {'field_errors': {'qwerty': ['This field is required.']}})
-# else:
-# self.fail('ResourceException was not raised')
-
-# def validation_failed_due_to_field_error_returns_appropriate_message(self, validator):
-# """If validation fails due to a field error, ensure the response contains a single field error"""
-# content = {'qwerty': ''}
-# try:
-# validator.validate_request(content, None)
-# except ImmediateResponse, exc:
-# response = exc.response
-# self.assertEqual(response.raw_content, {'field_errors': {'qwerty': ['This field is required.']}})
-# else:
-# self.fail('ResourceException was not raised')
-
-# def validation_failed_due_to_invalid_field_returns_appropriate_message(self, validator):
-# """If validation fails due to an invalid field, ensure the response contains a single field error"""
-# content = {'qwerty': 'uiop', 'extra': 'extra'}
-# try:
-# validator.validate_request(content, None)
-# except ImmediateResponse, exc:
-# response = exc.response
-# self.assertEqual(response.raw_content, {'field_errors': {'extra': ['This field does not exist.']}})
-# else:
-# self.fail('ResourceException was not raised')
-
-# def validation_failed_due_to_multiple_errors_returns_appropriate_message(self, validator):
-# """If validation for multiple reasons, ensure the response contains each error"""
-# content = {'qwerty': '', 'extra': 'extra'}
-# try:
-# validator.validate_request(content, None)
-# except ImmediateResponse, exc:
-# response = exc.response
-# self.assertEqual(response.raw_content, {'field_errors': {'qwerty': ['This field is required.'],
-# 'extra': ['This field does not exist.']}})
-# else:
-# self.fail('ResourceException was not raised')
-
-# # Tests on FormResource
-
-# def test_form_validation_returns_content_unchanged_if_already_valid_and_clean(self):
-# validator = self.MockFormResource(self.MockFormView())
-# self.validation_returns_content_unchanged_if_already_valid_and_clean(validator)
-
-# def test_form_validation_failure_raises_response_exception(self):
-# validator = self.MockFormResource(self.MockFormView())
-# self.validation_failure_raises_response_exception(validator)
-
-# def test_validation_does_not_allow_extra_fields_by_default(self):
-# validator = self.MockFormResource(self.MockFormView())
-# self.validation_does_not_allow_extra_fields_by_default(validator)
-
-# def test_validation_allows_extra_fields_if_explicitly_set(self):
-# validator = self.MockFormResource(self.MockFormView())
-# self.validation_allows_extra_fields_if_explicitly_set(validator)
-
-# def test_validation_allows_unknown_fields_if_explicitly_allowed(self):
-# validator = self.MockFormResource(self.MockFormView())
-# self.validation_allows_unknown_fields_if_explicitly_allowed(validator)
-
-# def test_validation_does_not_require_extra_fields_if_explicitly_set(self):
-# validator = self.MockFormResource(self.MockFormView())
-# self.validation_does_not_require_extra_fields_if_explicitly_set(validator)
-
-# def test_validation_failed_due_to_no_content_returns_appropriate_message(self):
-# validator = self.MockFormResource(self.MockFormView())
-# self.validation_failed_due_to_no_content_returns_appropriate_message(validator)
-
-# def test_validation_failed_due_to_field_error_returns_appropriate_message(self):
-# validator = self.MockFormResource(self.MockFormView())
-# self.validation_failed_due_to_field_error_returns_appropriate_message(validator)
-
-# def test_validation_failed_due_to_invalid_field_returns_appropriate_message(self):
-# validator = self.MockFormResource(self.MockFormView())
-# self.validation_failed_due_to_invalid_field_returns_appropriate_message(validator)
-
-# def test_validation_failed_due_to_multiple_errors_returns_appropriate_message(self):
-# validator = self.MockFormResource(self.MockFormView())
-# self.validation_failed_due_to_multiple_errors_returns_appropriate_message(validator)
-
-# # Same tests on ModelResource
-
-# def test_modelform_validation_returns_content_unchanged_if_already_valid_and_clean(self):
-# validator = self.MockModelResource(self.MockModelFormView())
-# self.validation_returns_content_unchanged_if_already_valid_and_clean(validator)
-
-# def test_modelform_validation_failure_raises_response_exception(self):
-# validator = self.MockModelResource(self.MockModelFormView())
-# self.validation_failure_raises_response_exception(validator)
-
-# def test_modelform_validation_does_not_allow_extra_fields_by_default(self):
-# validator = self.MockModelResource(self.MockModelFormView())
-# self.validation_does_not_allow_extra_fields_by_default(validator)
-
-# def test_modelform_validation_allows_extra_fields_if_explicitly_set(self):
-# validator = self.MockModelResource(self.MockModelFormView())
-# self.validation_allows_extra_fields_if_explicitly_set(validator)
-
-# def test_modelform_validation_does_not_require_extra_fields_if_explicitly_set(self):
-# validator = self.MockModelResource(self.MockModelFormView())
-# self.validation_does_not_require_extra_fields_if_explicitly_set(validator)
-
-# def test_modelform_validation_failed_due_to_no_content_returns_appropriate_message(self):
-# validator = self.MockModelResource(self.MockModelFormView())
-# self.validation_failed_due_to_no_content_returns_appropriate_message(validator)
-
-# def test_modelform_validation_failed_due_to_field_error_returns_appropriate_message(self):
-# validator = self.MockModelResource(self.MockModelFormView())
-# self.validation_failed_due_to_field_error_returns_appropriate_message(validator)
-
-# def test_modelform_validation_failed_due_to_invalid_field_returns_appropriate_message(self):
-# validator = self.MockModelResource(self.MockModelFormView())
-# self.validation_failed_due_to_invalid_field_returns_appropriate_message(validator)
-
-# def test_modelform_validation_failed_due_to_multiple_errors_returns_appropriate_message(self):
-# validator = self.MockModelResource(self.MockModelFormView())
-# self.validation_failed_due_to_multiple_errors_returns_appropriate_message(validator)
-
-
-# class TestModelFormValidator(TestCase):
-# """Tests specific to ModelFormValidatorMixin"""
-
-# def setUp(self):
-# """Create a validator for a model with two fields and a property."""
-# class MockModel(models.Model):
-# qwerty = models.CharField(max_length=256)
-# uiop = models.CharField(max_length=256, blank=True)
-
-# @property
-# def read_only(self):
-# return 'read only'
-
-# class MockResource(ModelResource):
-# model = MockModel
-
-# class MockView(View):
-# resource = MockResource
-
-# self.validator = MockResource(MockView)
-
-# def test_property_fields_are_allowed_on_model_forms(self):
-# """Validation on ModelForms may include property fields that exist on the Model to be included in the input."""
-# content = {'qwerty': 'example', 'uiop': 'example', 'read_only': 'read only'}
-# self.assertEqual(self.validator.validate_request(content, None), content)
-
-# def test_property_fields_are_not_required_on_model_forms(self):
-# """Validation on ModelForms does not require property fields that exist on the Model to be included in the input."""
-# content = {'qwerty': 'example', 'uiop': 'example'}
-# self.assertEqual(self.validator.validate_request(content, None), content)
-
-# def test_extra_fields_not_allowed_on_model_forms(self):
-# """If some (otherwise valid) content includes fields that are not in the form then validation should fail.
-# It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up
-# broken clients more easily (eg submitting content with a misnamed field)"""
-# content = {'qwerty': 'example', 'uiop': 'example', 'read_only': 'read only', 'extra': 'extra'}
-# self.assertRaises(ImmediateResponse, self.validator.validate_request, content, None)
-
-# def test_validate_requires_fields_on_model_forms(self):
-# """If some (otherwise valid) content includes fields that are not in the form then validation should fail.
-# It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up
-# broken clients more easily (eg submitting content with a misnamed field)"""
-# content = {'read_only': 'read only'}
-# self.assertRaises(ImmediateResponse, self.validator.validate_request, content, None)
-
-# def test_validate_does_not_require_blankable_fields_on_model_forms(self):
-# """Test standard ModelForm validation behaviour - fields with blank=True are not required."""
-# content = {'qwerty': 'example', 'read_only': 'read only'}
-# self.validator.validate_request(content, None)
-
-# def test_model_form_validator_uses_model_forms(self):
-# self.assertTrue(isinstance(self.validator.get_bound_form(), forms.ModelForm))
diff --git a/rest_framework/tests/views.py b/rest_framework/tests/views.py
index 7cd82656..994cf6dc 100644
--- a/rest_framework/tests/views.py
+++ b/rest_framework/tests/views.py
@@ -1,4 +1,4 @@
-import copy
+from __future__ import unicode_literals
from django.test import TestCase
from django.test.client import RequestFactory
from rest_framework import status
@@ -6,6 +6,7 @@ from rest_framework.decorators import api_view
from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework.views import APIView
+import copy
factory = RequestFactory()
@@ -49,10 +50,10 @@ class ClassBasedViewIntegrationTests(TestCase):
request = factory.post('/', 'f00bar', content_type='application/json')
response = self.view(request)
expected = {
- 'detail': u'JSON parse error - No JSON object could be decoded'
+ 'detail': 'JSON parse error - No JSON object could be decoded'
}
- self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST)
- self.assertEquals(sanitise_json_error(response.data), expected)
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEqual(sanitise_json_error(response.data), expected)
def test_400_parse_error_tunneled_content(self):
content = 'f00bar'
@@ -64,10 +65,10 @@ class ClassBasedViewIntegrationTests(TestCase):
request = factory.post('/', form_data)
response = self.view(request)
expected = {
- 'detail': u'JSON parse error - No JSON object could be decoded'
+ 'detail': 'JSON parse error - No JSON object could be decoded'
}
- self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST)
- self.assertEquals(sanitise_json_error(response.data), expected)
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEqual(sanitise_json_error(response.data), expected)
class FunctionBasedViewIntegrationTests(TestCase):
@@ -78,10 +79,10 @@ class FunctionBasedViewIntegrationTests(TestCase):
request = factory.post('/', 'f00bar', content_type='application/json')
response = self.view(request)
expected = {
- 'detail': u'JSON parse error - No JSON object could be decoded'
+ 'detail': 'JSON parse error - No JSON object could be decoded'
}
- self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST)
- self.assertEquals(sanitise_json_error(response.data), expected)
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEqual(sanitise_json_error(response.data), expected)
def test_400_parse_error_tunneled_content(self):
content = 'f00bar'
@@ -93,7 +94,7 @@ class FunctionBasedViewIntegrationTests(TestCase):
request = factory.post('/', form_data)
response = self.view(request)
expected = {
- 'detail': u'JSON parse error - No JSON object could be decoded'
+ 'detail': 'JSON parse error - No JSON object could be decoded'
}
- self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST)
- self.assertEquals(sanitise_json_error(response.data), expected)
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEqual(sanitise_json_error(response.data), expected)