diff options
Diffstat (limited to 'rest_framework/tests')
38 files changed, 2855 insertions, 1454 deletions
diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py index e86041bc..b663ca48 100644 --- a/rest_framework/tests/authentication.py +++ b/rest_framework/tests/authentication.py @@ -1,33 +1,65 @@ +from __future__ import unicode_literals from django.contrib.auth.models import User from django.http import HttpResponse from django.test import Client, TestCase - +from django.utils import unittest +from rest_framework import HTTP_HEADER_ENCODING +from rest_framework import exceptions from rest_framework import permissions +from rest_framework import status +from rest_framework.authentication import ( + BaseAuthentication, + TokenAuthentication, + BasicAuthentication, + SessionAuthentication, + OAuthAuthentication, + OAuth2Authentication +) from rest_framework.authtoken.models import Token -from rest_framework.authentication import TokenAuthentication -from rest_framework.compat import patterns +from rest_framework.compat import patterns, url, include +from rest_framework.compat import oauth2_provider, oauth2_provider_models, oauth2_provider_scope +from rest_framework.compat import oauth, oauth_provider +from rest_framework.tests.utils import RequestFactory from rest_framework.views import APIView - import json import base64 +import time +import datetime + +factory = RequestFactory() class MockView(APIView): permission_classes = (permissions.IsAuthenticated,) + def get(self, request): + return HttpResponse({'a': 1, 'b': 2, 'c': 3}) + def post(self, request): return HttpResponse({'a': 1, 'b': 2, 'c': 3}) def put(self, request): return HttpResponse({'a': 1, 'b': 2, 'c': 3}) -MockView.authentication_classes += (TokenAuthentication,) urlpatterns = patterns('', - (r'^$', MockView.as_view()), + (r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])), + (r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])), + (r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])), (r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'), + (r'^oauth/$', MockView.as_view(authentication_classes=[OAuthAuthentication])), + (r'^oauth-with-scope/$', MockView.as_view(authentication_classes=[OAuthAuthentication], + permission_classes=[permissions.TokenHasReadWriteScope])) ) +if oauth2_provider is not None: + urlpatterns += patterns('', + url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')), + url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])), + url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication], + permission_classes=[permissions.TokenHasReadWriteScope])), + ) + class BasicAuthTests(TestCase): """Basic authentication""" @@ -42,25 +74,30 @@ class BasicAuthTests(TestCase): def test_post_form_passing_basic_auth(self): """Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF""" - auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip() - response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 200) + credentials = ('%s:%s' % (self.username, self.password)) + base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING) + auth = 'Basic %s' % base64_credentials + response = self.csrf_client.post('/basic/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, status.HTTP_200_OK) def test_post_json_passing_basic_auth(self): """Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF""" - auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip() - response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 200) + credentials = ('%s:%s' % (self.username, self.password)) + base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING) + auth = 'Basic %s' % base64_credentials + response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, status.HTTP_200_OK) def test_post_form_failing_basic_auth(self): """Ensure POSTing form over basic auth without correct credentials fails""" - response = self.csrf_client.post('/', {'example': 'example'}) - self.assertEqual(response.status_code, 403) + response = self.csrf_client.post('/basic/', {'example': 'example'}) + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) def test_post_json_failing_basic_auth(self): """Ensure POSTing json over basic auth without correct credentials fails""" - response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json') - self.assertEqual(response.status_code, 403) + response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json') + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"') class SessionAuthTests(TestCase): @@ -83,31 +120,31 @@ class SessionAuthTests(TestCase): Ensure POSTing form over session authentication without CSRF token fails. """ self.csrf_client.login(username=self.username, password=self.password) - response = self.csrf_client.post('/', {'example': 'example'}) - self.assertEqual(response.status_code, 403) + response = self.csrf_client.post('/session/', {'example': 'example'}) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) def test_post_form_session_auth_passing(self): """ Ensure POSTing form over session authentication with logged in user and CSRF token passes. """ self.non_csrf_client.login(username=self.username, password=self.password) - response = self.non_csrf_client.post('/', {'example': 'example'}) - self.assertEqual(response.status_code, 200) + response = self.non_csrf_client.post('/session/', {'example': 'example'}) + self.assertEqual(response.status_code, status.HTTP_200_OK) def test_put_form_session_auth_passing(self): """ Ensure PUTting form over session authentication with logged in user and CSRF token passes. """ self.non_csrf_client.login(username=self.username, password=self.password) - response = self.non_csrf_client.put('/', {'example': 'example'}) - self.assertEqual(response.status_code, 200) + response = self.non_csrf_client.put('/session/', {'example': 'example'}) + self.assertEqual(response.status_code, status.HTTP_200_OK) def test_post_form_session_auth_failing(self): """ Ensure POSTing form over session authentication without logged in user fails. """ - response = self.csrf_client.post('/', {'example': 'example'}) - self.assertEqual(response.status_code, 403) + response = self.csrf_client.post('/session/', {'example': 'example'}) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) class TokenAuthTests(TestCase): @@ -126,25 +163,25 @@ class TokenAuthTests(TestCase): def test_post_form_passing_token_auth(self): """Ensure POSTing json over token auth with correct credentials passes and does not require CSRF""" - auth = "Token " + self.key - response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 200) + auth = 'Token ' + self.key + response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, status.HTTP_200_OK) def test_post_json_passing_token_auth(self): """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF""" auth = "Token " + self.key - response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 200) + response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, status.HTTP_200_OK) def test_post_form_failing_token_auth(self): """Ensure POSTing form over token auth without correct credentials fails""" - response = self.csrf_client.post('/', {'example': 'example'}) - self.assertEqual(response.status_code, 403) + response = self.csrf_client.post('/token/', {'example': 'example'}) + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) def test_post_json_failing_token_auth(self): """Ensure POSTing json over token auth without correct credentials fails""" - response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json') - self.assertEqual(response.status_code, 403) + response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json') + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) def test_token_has_auto_assigned_key_if_none_provided(self): """Ensure creating a token with no key will auto-assign a key""" @@ -157,8 +194,8 @@ class TokenAuthTests(TestCase): client = Client(enforce_csrf_checks=True) response = client.post('/auth-token/', json.dumps({'username': self.username, 'password': self.password}), 'application/json') - self.assertEqual(response.status_code, 200) - self.assertEqual(json.loads(response.content)['token'], self.key) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(json.loads(response.content.decode('ascii'))['token'], self.key) def test_token_login_json_bad_creds(self): """Ensure token login view using JSON POST fails if bad credentials are used.""" @@ -179,5 +216,362 @@ class TokenAuthTests(TestCase): client = Client(enforce_csrf_checks=True) response = client.post('/auth-token/', {'username': self.username, 'password': self.password}) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(json.loads(response.content.decode('ascii'))['token'], self.key) + + +class IncorrectCredentialsTests(TestCase): + def test_incorrect_credentials(self): + """ + If a request contains bad authentication credentials, then + authentication should run and error, even if no permissions + are set on the view. + """ + class IncorrectCredentialsAuth(BaseAuthentication): + def authenticate(self, request): + raise exceptions.AuthenticationFailed('Bad credentials') + + request = factory.get('/') + view = MockView.as_view( + authentication_classes=(IncorrectCredentialsAuth,), + permission_classes=() + ) + response = view(request) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(response.data, {'detail': 'Bad credentials'}) + + +class OAuthTests(TestCase): + """OAuth 1.0a authentication""" + urls = 'rest_framework.tests.authentication' + + def setUp(self): + # these imports are here because oauth is optional and hiding them in try..except block or compat + # could obscure problems if something breaks + from oauth_provider.models import Consumer, Resource + from oauth_provider.models import Token as OAuthToken + from oauth_provider import consts + + self.consts = consts + + self.csrf_client = Client(enforce_csrf_checks=True) + self.username = 'john' + self.email = 'lennon@thebeatles.com' + self.password = 'password' + self.user = User.objects.create_user(self.username, self.email, self.password) + + self.CONSUMER_KEY = 'consumer_key' + self.CONSUMER_SECRET = 'consumer_secret' + self.TOKEN_KEY = "token_key" + self.TOKEN_SECRET = "token_secret" + + self.consumer = Consumer.objects.create(key=self.CONSUMER_KEY, secret=self.CONSUMER_SECRET, + name='example', user=self.user, status=self.consts.ACCEPTED) + + self.resource = Resource.objects.create(name="resource name", url="api/") + self.token = OAuthToken.objects.create(user=self.user, consumer=self.consumer, resource=self.resource, + token_type=OAuthToken.ACCESS, key=self.TOKEN_KEY, secret=self.TOKEN_SECRET, is_approved=True + ) + + def _create_authorization_header(self): + params = { + 'oauth_version': "1.0", + 'oauth_nonce': oauth.generate_nonce(), + 'oauth_timestamp': int(time.time()), + 'oauth_token': self.token.key, + 'oauth_consumer_key': self.consumer.key + } + + req = oauth.Request(method="GET", url="http://example.com", parameters=params) + + signature_method = oauth.SignatureMethod_PLAINTEXT() + req.sign_request(signature_method, self.consumer, self.token) + + return req.to_header()["Authorization"] + + def _create_authorization_url_parameters(self): + params = { + 'oauth_version': "1.0", + 'oauth_nonce': oauth.generate_nonce(), + 'oauth_timestamp': int(time.time()), + 'oauth_token': self.token.key, + 'oauth_consumer_key': self.consumer.key + } + + req = oauth.Request(method="GET", url="http://example.com", parameters=params) + + signature_method = oauth.SignatureMethod_PLAINTEXT() + req.sign_request(signature_method, self.consumer, self.token) + return dict(req) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_post_form_passing_oauth(self): + """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF""" + auth = self._create_authorization_header() + response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 200) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_post_form_repeated_nonce_failing_oauth(self): + """Ensure POSTing form over OAuth with repeated auth (same nonces and timestamp) credentials fails""" + auth = self._create_authorization_header() + response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 200) + + # simulate reply attack auth header containes already used (nonce, timestamp) pair + response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_post_form_token_removed_failing_oauth(self): + """Ensure POSTing when there is no OAuth access token in db fails""" + self.token.delete() + auth = self._create_authorization_header() + response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_post_form_consumer_status_not_accepted_failing_oauth(self): + """Ensure POSTing when consumer status is anything other than ACCEPTED fails""" + for consumer_status in (self.consts.CANCELED, self.consts.PENDING, self.consts.REJECTED): + self.consumer.status = consumer_status + self.consumer.save() + + auth = self._create_authorization_header() + response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_post_form_with_request_token_failing_oauth(self): + """Ensure POSTing with unauthorized request token instead of access token fails""" + self.token.token_type = self.token.REQUEST + self.token.save() + + auth = self._create_authorization_header() + response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_post_form_with_urlencoded_parameters(self): + """Ensure POSTing with x-www-form-urlencoded auth parameters passes""" + params = self._create_authorization_url_parameters() + response = self.csrf_client.post('/oauth/', params) + self.assertEqual(response.status_code, 200) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_get_form_with_url_parameters(self): + """Ensure GETing with auth in url parameters passes""" + params = self._create_authorization_url_parameters() + response = self.csrf_client.get('/oauth/', params) + self.assertEqual(response.status_code, 200) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_post_hmac_sha1_signature_passes(self): + """Ensure POSTing using HMAC_SHA1 signature method passes""" + params = { + 'oauth_version': "1.0", + 'oauth_nonce': oauth.generate_nonce(), + 'oauth_timestamp': int(time.time()), + 'oauth_token': self.token.key, + 'oauth_consumer_key': self.consumer.key + } + + req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params) + + signature_method = oauth.SignatureMethod_HMAC_SHA1() + req.sign_request(signature_method, self.consumer, self.token) + auth = req.to_header()["Authorization"] + + response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 200) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_get_form_with_readonly_resource_passing_auth(self): + """Ensure POSTing with a readonly resource instead of a write scope fails""" + read_only_access_token = self.token + read_only_access_token.resource.is_readonly = True + read_only_access_token.resource.save() + params = self._create_authorization_url_parameters() + response = self.csrf_client.get('/oauth-with-scope/', params) + self.assertEqual(response.status_code, 200) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_post_form_with_readonly_resource_failing_auth(self): + """Ensure POSTing with a readonly resource instead of a write scope fails""" + read_only_access_token = self.token + read_only_access_token.resource.is_readonly = True + read_only_access_token.resource.save() + params = self._create_authorization_url_parameters() + response = self.csrf_client.post('/oauth-with-scope/', params) + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_post_form_with_write_resource_passing_auth(self): + """Ensure POSTing with a write resource succeed""" + read_write_access_token = self.token + read_write_access_token.resource.is_readonly = False + read_write_access_token.resource.save() + params = self._create_authorization_url_parameters() + response = self.csrf_client.post('/oauth-with-scope/', params) + self.assertEqual(response.status_code, 200) + + +class OAuth2Tests(TestCase): + """OAuth 2.0 authentication""" + urls = 'rest_framework.tests.authentication' + + def setUp(self): + self.csrf_client = Client(enforce_csrf_checks=True) + self.username = 'john' + self.email = 'lennon@thebeatles.com' + self.password = 'password' + self.user = User.objects.create_user(self.username, self.email, self.password) + + self.CLIENT_ID = 'client_key' + self.CLIENT_SECRET = 'client_secret' + self.ACCESS_TOKEN = "access_token" + self.REFRESH_TOKEN = "refresh_token" + + self.oauth2_client = oauth2_provider_models.Client.objects.create( + client_id=self.CLIENT_ID, + client_secret=self.CLIENT_SECRET, + redirect_uri='', + client_type=0, + name='example', + user=None, + ) + + self.access_token = oauth2_provider_models.AccessToken.objects.create( + token=self.ACCESS_TOKEN, + client=self.oauth2_client, + user=self.user, + ) + self.refresh_token = oauth2_provider_models.RefreshToken.objects.create( + user=self.user, + access_token=self.access_token, + client=self.oauth2_client + ) + + def _create_authorization_header(self, token=None): + return "Bearer {0}".format(token or self.access_token.token) + + def _client_credentials_params(self): + return {'client_id': self.CLIENT_ID, 'client_secret': self.CLIENT_SECRET} + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_get_form_with_wrong_authorization_header_token_type_failing(self): + """Ensure that a wrong token type lead to the correct HTTP error status code""" + auth = "Wrong token-type-obsviously" + response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + params = self._client_credentials_params() + response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_get_form_with_wrong_authorization_header_token_format_failing(self): + """Ensure that a wrong token format lead to the correct HTTP error status code""" + auth = "Bearer wrong token format" + response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + params = self._client_credentials_params() + response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_get_form_with_wrong_authorization_header_token_failing(self): + """Ensure that a wrong token lead to the correct HTTP error status code""" + auth = "Bearer wrong-token" + response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + params = self._client_credentials_params() + response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_get_form_with_wrong_client_data_failing_auth(self): + """Ensure GETing form over OAuth with incorrect client credentials fails""" + auth = self._create_authorization_header() + params = self._client_credentials_params() + params['client_id'] += 'a' + response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_get_form_passing_auth(self): + """Ensure GETing form over OAuth with correct client credentials succeed""" + auth = self._create_authorization_header() + params = self._client_credentials_params() + response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 200) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_post_form_passing_auth(self): + """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF""" + auth = self._create_authorization_header() + params = self._client_credentials_params() + response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 200) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_post_form_token_removed_failing_auth(self): + """Ensure POSTing when there is no OAuth access token in db fails""" + self.access_token.delete() + auth = self._create_authorization_header() + params = self._client_credentials_params() + response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_post_form_with_refresh_token_failing_auth(self): + """Ensure POSTing with refresh token instead of access token fails""" + auth = self._create_authorization_header(token=self.refresh_token.token) + params = self._client_credentials_params() + response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_post_form_with_expired_access_token_failing_auth(self): + """Ensure POSTing with expired access token fails with an 'Invalid token' error""" + self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10) # 10 seconds late + self.access_token.save() + auth = self._create_authorization_header() + params = self._client_credentials_params() + response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + self.assertIn('Invalid token', response.content) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_post_form_with_invalid_scope_failing_auth(self): + """Ensure POSTing with a readonly scope instead of a write scope fails""" + read_only_access_token = self.access_token + read_only_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['read'] + read_only_access_token.save() + auth = self._create_authorization_header(token=read_only_access_token.token) + params = self._client_credentials_params() + response = self.csrf_client.get('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 200) + response = self.csrf_client.post('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_post_form_with_valid_scope_passing_auth(self): + """Ensure POSTing with a write scope succeed""" + read_write_access_token = self.access_token + read_write_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['write'] + read_write_access_token.save() + auth = self._create_authorization_header(token=read_write_access_token.token) + params = self._client_credentials_params() + response = self.csrf_client.post('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) - self.assertEqual(json.loads(response.content)['token'], self.key) diff --git a/rest_framework/tests/breadcrumbs.py b/rest_framework/tests/breadcrumbs.py index df891683..d9ed647e 100644 --- a/rest_framework/tests/breadcrumbs.py +++ b/rest_framework/tests/breadcrumbs.py @@ -1,3 +1,4 @@ +from __future__ import unicode_literals from django.test import TestCase from rest_framework.compat import patterns, url from rest_framework.utils.breadcrumbs import get_breadcrumbs diff --git a/rest_framework/tests/decorators.py b/rest_framework/tests/decorators.py index 5e6bce4e..1016fed3 100644 --- a/rest_framework/tests/decorators.py +++ b/rest_framework/tests/decorators.py @@ -1,3 +1,4 @@ +from __future__ import unicode_literals from django.test import TestCase from rest_framework import status from rest_framework.response import Response @@ -28,13 +29,27 @@ class DecoratorTestCase(TestCase): response.request = request return APIView.finalize_response(self, request, response, *args, **kwargs) - def test_wrap_view(self): + def test_api_view_incorrect(self): + """ + If @api_view is not applied correct, we should raise an assertion. + """ - @api_view(['GET']) + @api_view def view(request): - return Response({}) + return Response() + + request = self.factory.get('/') + self.assertRaises(AssertionError, view, request) + + def test_api_view_incorrect_arguments(self): + """ + If @api_view is missing arguments, we should raise an assertion. + """ - self.assertTrue(isinstance(view.cls_instance, APIView)) + with self.assertRaises(AssertionError): + @api_view('GET') + def view(request): + return Response() def test_calling_method(self): @@ -44,11 +59,11 @@ class DecoratorTestCase(TestCase): request = self.factory.get('/') response = view(request) - self.assertEqual(response.status_code, 200) + self.assertEqual(response.status_code, status.HTTP_200_OK) request = self.factory.post('/') response = view(request) - self.assertEqual(response.status_code, 405) + self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) def test_calling_put_method(self): @@ -58,11 +73,11 @@ class DecoratorTestCase(TestCase): request = self.factory.put('/') response = view(request) - self.assertEqual(response.status_code, 200) + self.assertEqual(response.status_code, status.HTTP_200_OK) request = self.factory.post('/') response = view(request) - self.assertEqual(response.status_code, 405) + self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) def test_calling_patch_method(self): @@ -72,11 +87,11 @@ class DecoratorTestCase(TestCase): request = self.factory.patch('/') response = view(request) - self.assertEqual(response.status_code, 200) + self.assertEqual(response.status_code, status.HTTP_200_OK) request = self.factory.post('/') response = view(request) - self.assertEqual(response.status_code, 405) + self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) def test_renderer_classes(self): @@ -124,7 +139,7 @@ class DecoratorTestCase(TestCase): request = self.factory.get('/') response = view(request) - self.assertEquals(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) def test_throttle_classes(self): class OncePerDayUserThrottle(UserRateThrottle): @@ -137,7 +152,7 @@ class DecoratorTestCase(TestCase): request = self.factory.get('/') response = view(request) - self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.status_code, status.HTTP_200_OK) response = view(request) - self.assertEquals(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) + self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) diff --git a/rest_framework/tests/description.py b/rest_framework/tests/description.py index d958b840..5b3315bc 100644 --- a/rest_framework/tests/description.py +++ b/rest_framework/tests/description.py @@ -1,3 +1,6 @@ +# -- coding: utf-8 -- + +from __future__ import unicode_literals from django.test import TestCase from rest_framework.views import APIView from rest_framework.compat import apply_markdown @@ -50,7 +53,7 @@ class TestViewNamesAndDescriptions(TestCase): """Ensure Resource names are based on the classname by default.""" class MockView(APIView): pass - self.assertEquals(MockView().get_name(), 'Mock') + self.assertEqual(MockView().get_name(), 'Mock') def test_resource_name_can_be_set_explicitly(self): """Ensure Resource names can be set using the 'get_name' method.""" @@ -58,7 +61,7 @@ class TestViewNamesAndDescriptions(TestCase): class MockView(APIView): def get_name(self): return example - self.assertEquals(MockView().get_name(), example) + self.assertEqual(MockView().get_name(), example) def test_resource_description_uses_docstring_by_default(self): """Ensure Resource names are based on the docstring by default.""" @@ -78,7 +81,7 @@ class TestViewNamesAndDescriptions(TestCase): # hash style header #""" - self.assertEquals(MockView().get_description(), DESCRIPTION) + self.assertEqual(MockView().get_description(), DESCRIPTION) def test_resource_description_can_be_set_explicitly(self): """Ensure Resource descriptions can be set using the 'get_description' method.""" @@ -88,7 +91,16 @@ class TestViewNamesAndDescriptions(TestCase): """docstring""" def get_description(self): return example - self.assertEquals(MockView().get_description(), example) + self.assertEqual(MockView().get_description(), example) + + def test_resource_description_supports_unicode(self): + + class MockView(APIView): + """Проверка""" + pass + + self.assertEqual(MockView().get_description(), "Проверка") + def test_resource_description_does_not_require_docstring(self): """Ensure that empty docstrings do not affect the Resource's description if it has been set using the 'get_description' method.""" @@ -97,13 +109,13 @@ class TestViewNamesAndDescriptions(TestCase): class MockView(APIView): def get_description(self): return example - self.assertEquals(MockView().get_description(), example) + self.assertEqual(MockView().get_description(), example) def test_resource_description_can_be_empty(self): """Ensure that if a resource has no doctring or 'description' class attribute, then it's description is the empty string.""" class MockView(APIView): pass - self.assertEquals(MockView().get_description(), '') + self.assertEqual(MockView().get_description(), '') def test_markdown(self): """Ensure markdown to HTML works as expected""" diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py index 8068272d..fd6de779 100644 --- a/rest_framework/tests/fields.py +++ b/rest_framework/tests/fields.py @@ -1,9 +1,13 @@ """ General serializer field tests. """ +from __future__ import unicode_literals +import datetime from django.db import models from django.test import TestCase +from django.core import validators + from rest_framework import serializers @@ -26,24 +30,415 @@ class CharPrimaryKeyModelSerializer(serializers.ModelSerializer): model = CharPrimaryKeyModel -class ReadOnlyFieldTests(TestCase): +class TimeFieldModel(models.Model): + clock = models.TimeField() + + +class TimeFieldModelSerializer(serializers.ModelSerializer): + class Meta: + model = TimeFieldModel + + +class BasicFieldTests(TestCase): def test_auto_now_fields_read_only(self): """ auto_now and auto_now_add fields should be read_only by default. """ serializer = TimestampedModelSerializer() - self.assertEquals(serializer.fields['added'].read_only, True) + self.assertEqual(serializer.fields['added'].read_only, True) def test_auto_pk_fields_read_only(self): """ AutoField fields should be read_only by default. """ serializer = TimestampedModelSerializer() - self.assertEquals(serializer.fields['id'].read_only, True) + self.assertEqual(serializer.fields['id'].read_only, True) def test_non_auto_pk_fields_not_read_only(self): """ PK fields other than AutoField fields should not be read_only by default. """ serializer = CharPrimaryKeyModelSerializer() - self.assertEquals(serializer.fields['id'].read_only, False) + self.assertEqual(serializer.fields['id'].read_only, False) + + +class DateFieldTest(TestCase): + """ + Tests for the DateFieldTest from_native() and to_native() behavior + """ + + def test_from_native_string(self): + """ + Make sure from_native() accepts default iso input formats. + """ + f = serializers.DateField() + result_1 = f.from_native('1984-07-31') + + self.assertEqual(datetime.date(1984, 7, 31), result_1) + + def test_from_native_datetime_date(self): + """ + Make sure from_native() accepts a datetime.date instance. + """ + f = serializers.DateField() + result_1 = f.from_native(datetime.date(1984, 7, 31)) + + self.assertEqual(result_1, datetime.date(1984, 7, 31)) + + def test_from_native_custom_format(self): + """ + Make sure from_native() accepts custom input formats. + """ + f = serializers.DateField(input_formats=['%Y -- %d']) + result = f.from_native('1984 -- 31') + + self.assertEqual(datetime.date(1984, 1, 31), result) + + def test_from_native_invalid_default_on_custom_format(self): + """ + Make sure from_native() don't accept default formats if custom format is preset + """ + f = serializers.DateField(input_formats=['%Y -- %d']) + + try: + f.from_native('1984-07-31') + except validators.ValidationError as e: + self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY -- DD"]) + else: + self.fail("ValidationError was not properly raised") + + def test_from_native_empty(self): + """ + Make sure from_native() returns None on empty param. + """ + f = serializers.DateField() + result = f.from_native('') + + self.assertEqual(result, None) + + def test_from_native_none(self): + """ + Make sure from_native() returns None on None param. + """ + f = serializers.DateField() + result = f.from_native(None) + + self.assertEqual(result, None) + + def test_from_native_invalid_date(self): + """ + Make sure from_native() raises a ValidationError on passing an invalid date. + """ + f = serializers.DateField() + + try: + f.from_native('1984-13-31') + except validators.ValidationError as e: + self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]"]) + else: + self.fail("ValidationError was not properly raised") + + def test_from_native_invalid_format(self): + """ + Make sure from_native() raises a ValidationError on passing an invalid format. + """ + f = serializers.DateField() + + try: + f.from_native('1984 -- 31') + except validators.ValidationError as e: + self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]"]) + else: + self.fail("ValidationError was not properly raised") + + def test_to_native(self): + """ + Make sure to_native() returns isoformat as default. + """ + f = serializers.DateField() + + result_1 = f.to_native(datetime.date(1984, 7, 31)) + + self.assertEqual('1984-07-31', result_1) + + def test_to_native_custom_format(self): + """ + Make sure to_native() returns correct custom format. + """ + f = serializers.DateField(format="%Y - %m.%d") + + result_1 = f.to_native(datetime.date(1984, 7, 31)) + + self.assertEqual('1984 - 07.31', result_1) + + def test_to_native_none(self): + """ + Make sure from_native() returns None on None param. + """ + f = serializers.DateField(required=False) + self.assertEqual(None, f.to_native(None)) + + +class DateTimeFieldTest(TestCase): + """ + Tests for the DateTimeField from_native() and to_native() behavior + """ + + def test_from_native_string(self): + """ + Make sure from_native() accepts default iso input formats. + """ + f = serializers.DateTimeField() + result_1 = f.from_native('1984-07-31 04:31') + result_2 = f.from_native('1984-07-31 04:31:59') + result_3 = f.from_native('1984-07-31 04:31:59.000200') + + self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_1) + self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_2) + self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_3) + + def test_from_native_datetime_datetime(self): + """ + Make sure from_native() accepts a datetime.datetime instance. + """ + f = serializers.DateTimeField() + result_1 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31)) + result_2 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31, 59)) + result_3 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) + + self.assertEqual(result_1, datetime.datetime(1984, 7, 31, 4, 31)) + self.assertEqual(result_2, datetime.datetime(1984, 7, 31, 4, 31, 59)) + self.assertEqual(result_3, datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) + + def test_from_native_custom_format(self): + """ + Make sure from_native() accepts custom input formats. + """ + f = serializers.DateTimeField(input_formats=['%Y -- %H:%M']) + result = f.from_native('1984 -- 04:59') + + self.assertEqual(datetime.datetime(1984, 1, 1, 4, 59), result) + + def test_from_native_invalid_default_on_custom_format(self): + """ + Make sure from_native() don't accept default formats if custom format is preset + """ + f = serializers.DateTimeField(input_formats=['%Y -- %H:%M']) + + try: + f.from_native('1984-07-31 04:31:59') + except validators.ValidationError as e: + self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: YYYY -- hh:mm"]) + else: + self.fail("ValidationError was not properly raised") + + def test_from_native_empty(self): + """ + Make sure from_native() returns None on empty param. + """ + f = serializers.DateTimeField() + result = f.from_native('') + + self.assertEqual(result, None) + + def test_from_native_none(self): + """ + Make sure from_native() returns None on None param. + """ + f = serializers.DateTimeField() + result = f.from_native(None) + + self.assertEqual(result, None) + + def test_from_native_invalid_datetime(self): + """ + Make sure from_native() raises a ValidationError on passing an invalid datetime. + """ + f = serializers.DateTimeField() + + try: + f.from_native('04:61:59') + except validators.ValidationError as e: + self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: " + "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]"]) + else: + self.fail("ValidationError was not properly raised") + + def test_from_native_invalid_format(self): + """ + Make sure from_native() raises a ValidationError on passing an invalid format. + """ + f = serializers.DateTimeField() + + try: + f.from_native('04 -- 31') + except validators.ValidationError as e: + self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: " + "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]"]) + else: + self.fail("ValidationError was not properly raised") + + def test_to_native(self): + """ + Make sure to_native() returns isoformat as default. + """ + f = serializers.DateTimeField() + + result_1 = f.to_native(datetime.datetime(1984, 7, 31)) + result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31)) + result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59)) + result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) + + self.assertEqual('1984-07-31T00:00:00', result_1) + self.assertEqual('1984-07-31T04:31:00', result_2) + self.assertEqual('1984-07-31T04:31:59', result_3) + self.assertEqual('1984-07-31T04:31:59.000200', result_4) + + def test_to_native_custom_format(self): + """ + Make sure to_native() returns correct custom format. + """ + f = serializers.DateTimeField(format="%Y - %H:%M") + + result_1 = f.to_native(datetime.datetime(1984, 7, 31)) + result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31)) + result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59)) + result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) + + self.assertEqual('1984 - 00:00', result_1) + self.assertEqual('1984 - 04:31', result_2) + self.assertEqual('1984 - 04:31', result_3) + self.assertEqual('1984 - 04:31', result_4) + + def test_to_native_none(self): + """ + Make sure from_native() returns None on None param. + """ + f = serializers.DateTimeField(required=False) + self.assertEqual(None, f.to_native(None)) + + +class TimeFieldTest(TestCase): + """ + Tests for the TimeField from_native() and to_native() behavior + """ + + def test_from_native_string(self): + """ + Make sure from_native() accepts default iso input formats. + """ + f = serializers.TimeField() + result_1 = f.from_native('04:31') + result_2 = f.from_native('04:31:59') + result_3 = f.from_native('04:31:59.000200') + + self.assertEqual(datetime.time(4, 31), result_1) + self.assertEqual(datetime.time(4, 31, 59), result_2) + self.assertEqual(datetime.time(4, 31, 59, 200), result_3) + + def test_from_native_datetime_time(self): + """ + Make sure from_native() accepts a datetime.time instance. + """ + f = serializers.TimeField() + result_1 = f.from_native(datetime.time(4, 31)) + result_2 = f.from_native(datetime.time(4, 31, 59)) + result_3 = f.from_native(datetime.time(4, 31, 59, 200)) + + self.assertEqual(result_1, datetime.time(4, 31)) + self.assertEqual(result_2, datetime.time(4, 31, 59)) + self.assertEqual(result_3, datetime.time(4, 31, 59, 200)) + + def test_from_native_custom_format(self): + """ + Make sure from_native() accepts custom input formats. + """ + f = serializers.TimeField(input_formats=['%H -- %M']) + result = f.from_native('04 -- 31') + + self.assertEqual(datetime.time(4, 31), result) + + def test_from_native_invalid_default_on_custom_format(self): + """ + Make sure from_native() don't accept default formats if custom format is preset + """ + f = serializers.TimeField(input_formats=['%H -- %M']) + + try: + f.from_native('04:31:59') + except validators.ValidationError as e: + self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: hh -- mm"]) + else: + self.fail("ValidationError was not properly raised") + + def test_from_native_empty(self): + """ + Make sure from_native() returns None on empty param. + """ + f = serializers.TimeField() + result = f.from_native('') + + self.assertEqual(result, None) + + def test_from_native_none(self): + """ + Make sure from_native() returns None on None param. + """ + f = serializers.TimeField() + result = f.from_native(None) + + self.assertEqual(result, None) + + def test_from_native_invalid_time(self): + """ + Make sure from_native() raises a ValidationError on passing an invalid time. + """ + f = serializers.TimeField() + + try: + f.from_native('04:61:59') + except validators.ValidationError as e: + self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: " + "hh:mm[:ss[.uuuuuu]]"]) + else: + self.fail("ValidationError was not properly raised") + + def test_from_native_invalid_format(self): + """ + Make sure from_native() raises a ValidationError on passing an invalid format. + """ + f = serializers.TimeField() + + try: + f.from_native('04 -- 31') + except validators.ValidationError as e: + self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: " + "hh:mm[:ss[.uuuuuu]]"]) + else: + self.fail("ValidationError was not properly raised") + + def test_to_native(self): + """ + Make sure to_native() returns isoformat as default. + """ + f = serializers.TimeField() + result_1 = f.to_native(datetime.time(4, 31)) + result_2 = f.to_native(datetime.time(4, 31, 59)) + result_3 = f.to_native(datetime.time(4, 31, 59, 200)) + + self.assertEqual('04:31:00', result_1) + self.assertEqual('04:31:59', result_2) + self.assertEqual('04:31:59.000200', result_3) + + def test_to_native_custom_format(self): + """ + Make sure to_native() returns correct custom format. + """ + f = serializers.TimeField(format="%H - %S [%f]") + result_1 = f.to_native(datetime.time(4, 31)) + result_2 = f.to_native(datetime.time(4, 31, 59)) + result_3 = f.to_native(datetime.time(4, 31, 59, 200)) + + self.assertEqual('04 - 00 [000000]', result_1) + self.assertEqual('04 - 59 [000000]', result_2) + self.assertEqual('04 - 59 [000200]', result_3) diff --git a/rest_framework/tests/files.py b/rest_framework/tests/files.py index 446e23c0..487046ac 100644 --- a/rest_framework/tests/files.py +++ b/rest_framework/tests/files.py @@ -1,9 +1,9 @@ -import StringIO -import datetime - +from __future__ import unicode_literals from django.test import TestCase - from rest_framework import serializers +from rest_framework.compat import BytesIO +from rest_framework.compat import six +import datetime class UploadedFile(object): @@ -27,14 +27,14 @@ class UploadedFileSerializer(serializers.Serializer): class FileSerializerTests(TestCase): def test_create(self): now = datetime.datetime.now() - file = StringIO.StringIO('stuff') + file = BytesIO(six.b('stuff')) file.name = 'stuff.txt' - file.size = file.len + file.size = len(file.getvalue()) serializer = UploadedFileSerializer(data={'created': now}, files={'file': file}) uploaded_file = UploadedFile(file=file, created=now) self.assertTrue(serializer.is_valid()) - self.assertEquals(serializer.object.created, uploaded_file.created) - self.assertEquals(serializer.object.file, uploaded_file.file) + self.assertEqual(serializer.object.created, uploaded_file.created) + self.assertEqual(serializer.object.file, uploaded_file.file) self.assertFalse(serializer.object is uploaded_file) def test_creation_failure(self): diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py index af2e6c2e..fe92e0bc 100644 --- a/rest_framework/tests/filterset.py +++ b/rest_framework/tests/filterset.py @@ -1,3 +1,4 @@ +from __future__ import unicode_literals import datetime from decimal import Decimal from django.test import TestCase @@ -64,8 +65,8 @@ class IntegrationTestFiltering(TestCase): self.objects = FilterableItem.objects self.data = [ - {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} - for obj in self.objects.all() + {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date.isoformat()} + for obj in self.objects.all() ] @unittest.skipUnless(django_filters, 'django-filters not installed') @@ -78,24 +79,24 @@ class IntegrationTestFiltering(TestCase): # Basic test with no filter. request = factory.get('/') response = view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data, self.data) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data) # Tests that the decimal filter works. search_decimal = Decimal('2.25') request = factory.get('/?decimal=%s' % search_decimal) response = view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.status_code, status.HTTP_200_OK) expected_data = [f for f in self.data if f['decimal'] == search_decimal] - self.assertEquals(response.data, expected_data) + self.assertEqual(response.data, expected_data) # Tests that the date filter works. search_date = datetime.date(2012, 9, 22) request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22' response = view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - expected_data = [f for f in self.data if f['date'] == search_date] - self.assertEquals(response.data, expected_data) + self.assertEqual(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if datetime.datetime.strptime(f['date'], '%Y-%m-%d').date() == search_date] + self.assertEqual(response.data, expected_data) @unittest.skipUnless(django_filters, 'django-filters not installed') def test_get_filtered_class_root_view(self): @@ -108,42 +109,43 @@ class IntegrationTestFiltering(TestCase): # Basic test with no filter. request = factory.get('/') response = view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data, self.data) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data) # Tests that the decimal filter set with 'lt' in the filter class works. search_decimal = Decimal('4.25') request = factory.get('/?decimal=%s' % search_decimal) response = view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.status_code, status.HTTP_200_OK) expected_data = [f for f in self.data if f['decimal'] < search_decimal] - self.assertEquals(response.data, expected_data) + self.assertEqual(response.data, expected_data) # Tests that the date filter set with 'gt' in the filter class works. search_date = datetime.date(2012, 10, 2) request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02' response = view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - expected_data = [f for f in self.data if f['date'] > search_date] - self.assertEquals(response.data, expected_data) + self.assertEqual(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if datetime.datetime.strptime(f['date'], '%Y-%m-%d').date() > search_date] + self.assertEqual(response.data, expected_data) # Tests that the text filter set with 'icontains' in the filter class works. search_text = 'ff' request = factory.get('/?text=%s' % search_text) response = view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.status_code, status.HTTP_200_OK) expected_data = [f for f in self.data if search_text in f['text'].lower()] - self.assertEquals(response.data, expected_data) + self.assertEqual(response.data, expected_data) # Tests that multiple filters works. search_decimal = Decimal('5.25') search_date = datetime.date(2012, 10, 2) request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date)) response = view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - expected_data = [f for f in self.data if f['date'] > search_date and - f['decimal'] < search_decimal] - self.assertEquals(response.data, expected_data) + self.assertEqual(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if + datetime.datetime.strptime(f['date'], '%Y-%m-%d').date() > search_date and + f['decimal'] < search_decimal] + self.assertEqual(response.data, expected_data) @unittest.skipUnless(django_filters, 'django-filters not installed') def test_incorrectly_configured_filter(self): @@ -165,4 +167,4 @@ class IntegrationTestFiltering(TestCase): search_integer = 10 request = factory.get('/?integer=%s' % search_integer) response = view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.status_code, status.HTTP_200_OK) diff --git a/rest_framework/tests/genericrelations.py b/rest_framework/tests/genericrelations.py index bc7378e1..c38bfb9f 100644 --- a/rest_framework/tests/genericrelations.py +++ b/rest_framework/tests/genericrelations.py @@ -1,25 +1,62 @@ +from __future__ import unicode_literals +from django.contrib.contenttypes.models import ContentType +from django.contrib.contenttypes.generic import GenericRelation, GenericForeignKey +from django.db import models from django.test import TestCase from rest_framework import serializers -from rest_framework.tests.models import * + + +class Tag(models.Model): + """ + Tags have a descriptive slug, and are attached to an arbitrary object. + """ + tag = models.SlugField() + content_type = models.ForeignKey(ContentType) + object_id = models.PositiveIntegerField() + tagged_item = GenericForeignKey('content_type', 'object_id') + + def __unicode__(self): + return self.tag + + +class Bookmark(models.Model): + """ + A URL bookmark that may have multiple tags attached. + """ + url = models.URLField() + tags = GenericRelation(Tag) + + def __unicode__(self): + return 'Bookmark: %s' % self.url + + +class Note(models.Model): + """ + A textual note that may have multiple tags attached. + """ + text = models.TextField() + tags = GenericRelation(Tag) + + def __unicode__(self): + return 'Note: %s' % self.text class TestGenericRelations(TestCase): def setUp(self): - bookmark = Bookmark(url='https://www.djangoproject.com/') - bookmark.save() - django = Tag(tag_name='django') - django.save() - python = Tag(tag_name='python') - python.save() - t1 = TaggedItem(content_object=bookmark, tag=django) - t1.save() - t2 = TaggedItem(content_object=bookmark, tag=python) - t2.save() - self.bookmark = bookmark - - def test_reverse_generic_relation(self): + self.bookmark = Bookmark.objects.create(url='https://www.djangoproject.com/') + Tag.objects.create(tagged_item=self.bookmark, tag='django') + Tag.objects.create(tagged_item=self.bookmark, tag='python') + self.note = Note.objects.create(text='Remember the milk') + Tag.objects.create(tagged_item=self.note, tag='reminder') + + def test_generic_relation(self): + """ + Test a relationship that spans a GenericRelation field. + IE. A reverse generic relationship. + """ + class BookmarkSerializer(serializers.ModelSerializer): - tags = serializers.ManyRelatedField(source='tags') + tags = serializers.RelatedField(many=True) class Meta: model = Bookmark @@ -27,7 +64,37 @@ class TestGenericRelations(TestCase): serializer = BookmarkSerializer(self.bookmark) expected = { - 'tags': [u'django', u'python'], - 'url': u'https://www.djangoproject.com/' + 'tags': ['django', 'python'], + 'url': 'https://www.djangoproject.com/' + } + self.assertEqual(serializer.data, expected) + + def test_generic_fk(self): + """ + Test a relationship that spans a GenericForeignKey field. + IE. A forward generic relationship. + """ + + class TagSerializer(serializers.ModelSerializer): + tagged_item = serializers.RelatedField() + + class Meta: + model = Tag + exclude = ('id', 'content_type', 'object_id') + + serializer = TagSerializer(Tag.objects.all(), many=True) + expected = [ + { + 'tag': 'django', + 'tagged_item': 'Bookmark: https://www.djangoproject.com/' + }, + { + 'tag': 'python', + 'tagged_item': 'Bookmark: https://www.djangoproject.com/' + }, + { + 'tag': 'reminder', + 'tagged_item': 'Note: Remember the milk' } - self.assertEquals(serializer.data, expected) + ] + self.assertEqual(serializer.data, expected) diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py index 4799a04b..f564890c 100644 --- a/rest_framework/tests/generics.py +++ b/rest_framework/tests/generics.py @@ -1,10 +1,11 @@ -import json +from __future__ import unicode_literals from django.db import models from django.test import TestCase from rest_framework import generics, serializers, status from rest_framework.tests.utils import RequestFactory from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel - +from rest_framework.compat import six +import json factory = RequestFactory() @@ -42,7 +43,7 @@ class SlugBasedInstanceView(InstanceView): class TestRootView(TestCase): def setUp(self): """ - Create 3 BasicModel intances. + Create 3 BasicModel instances. """ items = ['foo', 'bar', 'baz'] for item in items: @@ -59,9 +60,10 @@ class TestRootView(TestCase): GET requests to ListCreateAPIView should return list of objects. """ request = factory.get('/') - response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data, self.data) + with self.assertNumQueries(1): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data) def test_post_root_view(self): """ @@ -70,11 +72,12 @@ class TestRootView(TestCase): content = {'text': 'foobar'} request = factory.post('/', json.dumps(content), content_type='application/json') - response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_201_CREATED) - self.assertEquals(response.data, {'id': 4, 'text': u'foobar'}) + with self.assertNumQueries(1): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.data, {'id': 4, 'text': 'foobar'}) created = self.objects.get(id=4) - self.assertEquals(created.text, 'foobar') + self.assertEqual(created.text, 'foobar') def test_put_root_view(self): """ @@ -83,25 +86,28 @@ class TestRootView(TestCase): content = {'text': 'foobar'} request = factory.put('/', json.dumps(content), content_type='application/json') - response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) - self.assertEquals(response.data, {"detail": "Method 'PUT' not allowed."}) + with self.assertNumQueries(0): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) + self.assertEqual(response.data, {"detail": "Method 'PUT' not allowed."}) def test_delete_root_view(self): """ DELETE requests to ListCreateAPIView should not be allowed """ request = factory.delete('/') - response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) - self.assertEquals(response.data, {"detail": "Method 'DELETE' not allowed."}) + with self.assertNumQueries(0): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) + self.assertEqual(response.data, {"detail": "Method 'DELETE' not allowed."}) def test_options_root_view(self): """ OPTIONS requests to ListCreateAPIView should return metadata """ request = factory.options('/') - response = self.view(request).render() + with self.assertNumQueries(0): + response = self.view(request).render() expected = { 'parses': [ 'application/json', @@ -115,8 +121,8 @@ class TestRootView(TestCase): 'name': 'Root', 'description': 'Example description for OPTIONS.' } - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data, expected) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, expected) def test_post_cannot_set_id(self): """ @@ -125,11 +131,12 @@ class TestRootView(TestCase): content = {'id': 999, 'text': 'foobar'} request = factory.post('/', json.dumps(content), content_type='application/json') - response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_201_CREATED) - self.assertEquals(response.data, {'id': 4, 'text': u'foobar'}) + with self.assertNumQueries(1): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.data, {'id': 4, 'text': 'foobar'}) created = self.objects.get(id=4) - self.assertEquals(created.text, 'foobar') + self.assertEqual(created.text, 'foobar') class TestInstanceView(TestCase): @@ -153,9 +160,10 @@ class TestInstanceView(TestCase): GET requests to RetrieveUpdateDestroyAPIView should return a single object. """ request = factory.get('/1') - response = self.view(request, pk=1).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data, self.data[0]) + with self.assertNumQueries(1): + response = self.view(request, pk=1).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data[0]) def test_post_instance_view(self): """ @@ -164,9 +172,10 @@ class TestInstanceView(TestCase): content = {'text': 'foobar'} request = factory.post('/', json.dumps(content), content_type='application/json') - response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) - self.assertEquals(response.data, {"detail": "Method 'POST' not allowed."}) + with self.assertNumQueries(0): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) + self.assertEqual(response.data, {"detail": "Method 'POST' not allowed."}) def test_put_instance_view(self): """ @@ -175,11 +184,12 @@ class TestInstanceView(TestCase): content = {'text': 'foobar'} request = factory.put('/1', json.dumps(content), content_type='application/json') - response = self.view(request, pk='1').render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data, {'id': 1, 'text': 'foobar'}) + with self.assertNumQueries(2): + response = self.view(request, pk='1').render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) updated = self.objects.get(id=1) - self.assertEquals(updated.text, 'foobar') + self.assertEqual(updated.text, 'foobar') def test_patch_instance_view(self): """ @@ -189,29 +199,32 @@ class TestInstanceView(TestCase): request = factory.patch('/1', json.dumps(content), content_type='application/json') - response = self.view(request, pk=1).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data, {'id': 1, 'text': 'foobar'}) + with self.assertNumQueries(2): + response = self.view(request, pk=1).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) updated = self.objects.get(id=1) - self.assertEquals(updated.text, 'foobar') + self.assertEqual(updated.text, 'foobar') def test_delete_instance_view(self): """ DELETE requests to RetrieveUpdateDestroyAPIView should delete an object. """ request = factory.delete('/1') - response = self.view(request, pk=1).render() - self.assertEquals(response.status_code, status.HTTP_204_NO_CONTENT) - self.assertEquals(response.content, '') + with self.assertNumQueries(2): + response = self.view(request, pk=1).render() + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + self.assertEqual(response.content, six.b('')) ids = [obj.id for obj in self.objects.all()] - self.assertEquals(ids, [2, 3]) + self.assertEqual(ids, [2, 3]) def test_options_instance_view(self): """ OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata """ request = factory.options('/') - response = self.view(request).render() + with self.assertNumQueries(0): + response = self.view(request).render() expected = { 'parses': [ 'application/json', @@ -225,8 +238,8 @@ class TestInstanceView(TestCase): 'name': 'Instance', 'description': 'Example description for OPTIONS.' } - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data, expected) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, expected) def test_put_cannot_set_id(self): """ @@ -235,11 +248,12 @@ class TestInstanceView(TestCase): content = {'id': 999, 'text': 'foobar'} request = factory.put('/1', json.dumps(content), content_type='application/json') - response = self.view(request, pk=1).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data, {'id': 1, 'text': 'foobar'}) + with self.assertNumQueries(2): + response = self.view(request, pk=1).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) updated = self.objects.get(id=1) - self.assertEquals(updated.text, 'foobar') + self.assertEqual(updated.text, 'foobar') def test_put_to_deleted_instance(self): """ @@ -250,11 +264,12 @@ class TestInstanceView(TestCase): content = {'text': 'foobar'} request = factory.put('/1', json.dumps(content), content_type='application/json') - response = self.view(request, pk=1).render() - self.assertEquals(response.status_code, status.HTTP_201_CREATED) - self.assertEquals(response.data, {'id': 1, 'text': 'foobar'}) + with self.assertNumQueries(3): + response = self.view(request, pk=1).render() + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) updated = self.objects.get(id=1) - self.assertEquals(updated.text, 'foobar') + self.assertEqual(updated.text, 'foobar') def test_put_as_create_on_id_based_url(self): """ @@ -262,13 +277,14 @@ class TestInstanceView(TestCase): at the requested url if it doesn't exist. """ content = {'text': 'foobar'} - # pk fields can not be created on demand, only the database can set th pk for a new object + # pk fields can not be created on demand, only the database can set the pk for a new object request = factory.put('/5', json.dumps(content), content_type='application/json') - response = self.view(request, pk=5).render() - self.assertEquals(response.status_code, status.HTTP_201_CREATED) + with self.assertNumQueries(3): + response = self.view(request, pk=5).render() + self.assertEqual(response.status_code, status.HTTP_201_CREATED) new_obj = self.objects.get(pk=5) - self.assertEquals(new_obj.text, 'foobar') + self.assertEqual(new_obj.text, 'foobar') def test_put_as_create_on_slug_based_url(self): """ @@ -278,11 +294,12 @@ class TestInstanceView(TestCase): content = {'text': 'foobar'} request = factory.put('/test_slug', json.dumps(content), content_type='application/json') - response = self.slug_based_view(request, slug='test_slug').render() - self.assertEquals(response.status_code, status.HTTP_201_CREATED) - self.assertEquals(response.data, {'slug': 'test_slug', 'text': 'foobar'}) + with self.assertNumQueries(2): + response = self.slug_based_view(request, slug='test_slug').render() + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.data, {'slug': 'test_slug', 'text': 'foobar'}) new_obj = SlugBasedModel.objects.get(slug='test_slug') - self.assertEquals(new_obj.text, 'foobar') + self.assertEqual(new_obj.text, 'foobar') # Regression test for #285 @@ -313,12 +330,12 @@ class TestCreateModelWithAutoNowAddField(TestCase): request = factory.post('/', json.dumps(content), content_type='application/json') response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) created = self.objects.get(id=1) - self.assertEquals(created.content, 'foobar') + self.assertEqual(created.content, 'foobar') -# Test for particularly ugly reression with m2m in browseable API +# Test for particularly ugly regression with m2m in browseable API class ClassB(models.Model): name = models.CharField(max_length=255) @@ -329,7 +346,7 @@ class ClassA(models.Model): class ClassASerializer(serializers.ModelSerializer): - childs = serializers.ManyPrimaryKeyRelatedField(source='childs') + childs = serializers.PrimaryKeyRelatedField(many=True, source='childs') class Meta: model = ClassA @@ -343,9 +360,84 @@ class ExampleView(generics.ListCreateAPIView): class TestM2MBrowseableAPI(TestCase): def test_m2m_in_browseable_api(self): """ - Test for particularly ugly reression with m2m in browseable API + Test for particularly ugly regression with m2m in browseable API """ request = factory.get('/', HTTP_ACCEPT='text/html') view = ExampleView().as_view() response = view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + +class InclusiveFilterBackend(object): + def filter_queryset(self, request, queryset, view): + return queryset.filter(text='foo') + + +class ExclusiveFilterBackend(object): + def filter_queryset(self, request, queryset, view): + return queryset.filter(text='other') + + +class TestFilterBackendAppliedToViews(TestCase): + + def setUp(self): + """ + Create 3 BasicModel instances to filter on. + """ + items = ['foo', 'bar', 'baz'] + for item in items: + BasicModel(text=item).save() + self.objects = BasicModel.objects + self.data = [ + {'id': obj.id, 'text': obj.text} + for obj in self.objects.all() + ] + self.root_view = RootView.as_view() + self.instance_view = InstanceView.as_view() + self.original_root_backend = getattr(RootView, 'filter_backend') + self.original_instance_backend = getattr(InstanceView, 'filter_backend') + + def tearDown(self): + setattr(RootView, 'filter_backend', self.original_root_backend) + setattr(InstanceView, 'filter_backend', self.original_instance_backend) + + def test_get_root_view_filters_by_name_with_filter_backend(self): + """ + GET requests to ListCreateAPIView should return filtered list. + """ + setattr(RootView, 'filter_backend', InclusiveFilterBackend) + request = factory.get('/') + response = self.root_view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data), 1) + self.assertEqual(response.data, [{'id': 1, 'text': 'foo'}]) + + def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(self): + """ + GET requests to ListCreateAPIView should return empty list when all models are filtered out. + """ + setattr(RootView, 'filter_backend', ExclusiveFilterBackend) + request = factory.get('/') + response = self.root_view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, []) + + def test_get_instance_view_filters_out_name_with_filter_backend(self): + """ + GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out. + """ + setattr(InstanceView, 'filter_backend', ExclusiveFilterBackend) + request = factory.get('/1') + response = self.instance_view(request, pk=1).render() + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + self.assertEqual(response.data, {'detail': 'Not found'}) + + def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self): + """ + GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded + """ + setattr(InstanceView, 'filter_backend', InclusiveFilterBackend) + request = factory.get('/1') + response = self.instance_view(request, pk=1).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, {'id': 1, 'text': 'foo'}) diff --git a/rest_framework/tests/htmlrenderer.py b/rest_framework/tests/htmlrenderer.py index 54096206..8f2e2b5a 100644 --- a/rest_framework/tests/htmlrenderer.py +++ b/rest_framework/tests/htmlrenderer.py @@ -1,12 +1,15 @@ +from __future__ import unicode_literals from django.core.exceptions import PermissionDenied from django.http import Http404 from django.test import TestCase from django.template import TemplateDoesNotExist, Template import django.template.loader +from rest_framework import status from rest_framework.compat import patterns, url from rest_framework.decorators import api_view, renderer_classes from rest_framework.renderers import TemplateHTMLRenderer from rest_framework.response import Response +from rest_framework.compat import six @api_view(('GET',)) @@ -63,19 +66,19 @@ class TemplateHTMLRendererTests(TestCase): def test_simple_html_view(self): response = self.client.get('/') self.assertContains(response, "example: foobar") - self.assertEquals(response['Content-Type'], 'text/html') + self.assertEqual(response['Content-Type'], 'text/html') def test_not_found_html_view(self): response = self.client.get('/not_found') - self.assertEquals(response.status_code, 404) - self.assertEquals(response.content, "404 Not Found") - self.assertEquals(response['Content-Type'], 'text/html') + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + self.assertEqual(response.content, six.b("404 Not Found")) + self.assertEqual(response['Content-Type'], 'text/html') def test_permission_denied_html_view(self): response = self.client.get('/permission_denied') - self.assertEquals(response.status_code, 403) - self.assertEquals(response.content, "403 Forbidden") - self.assertEquals(response['Content-Type'], 'text/html') + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(response.content, six.b("403 Forbidden")) + self.assertEqual(response['Content-Type'], 'text/html') class TemplateHTMLRendererExceptionTests(TestCase): @@ -104,12 +107,12 @@ class TemplateHTMLRendererExceptionTests(TestCase): def test_not_found_html_view_with_template(self): response = self.client.get('/not_found') - self.assertEquals(response.status_code, 404) - self.assertEquals(response.content, "404: Not found") - self.assertEquals(response['Content-Type'], 'text/html') + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + self.assertEqual(response.content, six.b("404: Not found")) + self.assertEqual(response['Content-Type'], 'text/html') def test_permission_denied_html_view_with_template(self): response = self.client.get('/permission_denied') - self.assertEquals(response.status_code, 403) - self.assertEquals(response.content, "403: Permission denied") - self.assertEquals(response['Content-Type'], 'text/html') + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual(response.content, six.b("403: Permission denied")) + self.assertEqual(response['Content-Type'], 'text/html') diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py index c6a8224b..9a61f299 100644 --- a/rest_framework/tests/hyperlinkedserializers.py +++ b/rest_framework/tests/hyperlinkedserializers.py @@ -1,3 +1,4 @@ +from __future__ import unicode_literals import json from django.test import TestCase from django.test.client import RequestFactory @@ -99,7 +100,7 @@ class TestBasicHyperlinkedView(TestCase): def setUp(self): """ - Create 3 BasicModel intances. + Create 3 BasicModel instances. """ items = ['foo', 'bar', 'baz'] for item in items: @@ -118,8 +119,8 @@ class TestBasicHyperlinkedView(TestCase): """ request = factory.get('/basic/') response = self.list_view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data, self.data) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data) def test_get_detail_view(self): """ @@ -127,8 +128,8 @@ class TestBasicHyperlinkedView(TestCase): """ request = factory.get('/basic/1') response = self.detail_view(request, pk=1).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data, self.data[0]) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data[0]) class TestManyToManyHyperlinkedView(TestCase): @@ -136,7 +137,7 @@ class TestManyToManyHyperlinkedView(TestCase): def setUp(self): """ - Create 3 BasicModel intances. + Create 3 BasicModel instances. """ items = ['foo', 'bar', 'baz'] anchors = [] @@ -166,8 +167,8 @@ class TestManyToManyHyperlinkedView(TestCase): """ request = factory.get('/manytomany/') response = self.list_view(request) - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data, self.data) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data) def test_get_detail_view(self): """ @@ -175,8 +176,8 @@ class TestManyToManyHyperlinkedView(TestCase): """ request = factory.get('/manytomany/1/') response = self.detail_view(request, pk=1) - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data, self.data[0]) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data[0]) class TestCreateWithForeignKeys(TestCase): @@ -234,7 +235,7 @@ class TestOptionalRelationHyperlinkedView(TestCase): def setUp(self): """ - Create 1 OptionalRelationModel intances. + Create 1 OptionalRelationModel instances. """ OptionalRelationModel().save() self.objects = OptionalRelationModel.objects @@ -248,8 +249,8 @@ class TestOptionalRelationHyperlinkedView(TestCase): """ request = factory.get('/optionalrelationmodel-detail/1') response = self.detail_view(request, pk=1) - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data, self.data) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data) def test_put_detail_view(self): """ diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py index 93f09761..f2117538 100644 --- a/rest_framework/tests/models.py +++ b/rest_framework/tests/models.py @@ -1,35 +1,6 @@ +from __future__ import unicode_literals from django.db import models -from django.contrib.contenttypes.models import ContentType -from django.contrib.contenttypes.generic import GenericForeignKey, GenericRelation -# from django.contrib.auth.models import Group - - -# class CustomUser(models.Model): -# """ -# A custom user model, which uses a 'through' table for the foreign key -# """ -# username = models.CharField(max_length=255, unique=True) -# groups = models.ManyToManyField( -# to=Group, blank=True, null=True, through='UserGroupMap' -# ) - -# @models.permalink -# def get_absolute_url(self): -# return ('custom_user', (), { -# 'pk': self.id -# }) - - -# class UserGroupMap(models.Model): -# user = models.ForeignKey(to=CustomUser) -# group = models.ForeignKey(to=Group) - -# @models.permalink -# def get_absolute_url(self): -# return ('user_group_map', (), { -# 'pk': self.id -# }) def foobar(): return 'foobar' @@ -86,27 +57,6 @@ class ReadOnlyManyToManyModel(RESTFrameworkModel): text = models.CharField(max_length=100, default='anchor') rel = models.ManyToManyField(Anchor) -# Models to test generic relations - - -class Tag(RESTFrameworkModel): - tag_name = models.SlugField() - - -class TaggedItem(RESTFrameworkModel): - tag = models.ForeignKey(Tag, related_name='items') - content_type = models.ForeignKey(ContentType) - object_id = models.PositiveIntegerField() - content_object = GenericForeignKey('content_type', 'object_id') - - def __unicode__(self): - return self.tag.tag_name - - -class Bookmark(RESTFrameworkModel): - url = models.URLField() - tags = GenericRelation(TaggedItem) - # Model to test filtering. class FilterableItem(RESTFrameworkModel): diff --git a/rest_framework/tests/modelviews.py b/rest_framework/tests/modelviews.py deleted file mode 100644 index f12e3b97..00000000 --- a/rest_framework/tests/modelviews.py +++ /dev/null @@ -1,90 +0,0 @@ -# from rest_framework.compat import patterns, url -# from django.forms import ModelForm -# from django.contrib.auth.models import Group, User -# from rest_framework.resources import ModelResource -# from rest_framework.views import ListOrCreateModelView, InstanceModelView -# from rest_framework.tests.models import CustomUser -# from rest_framework.tests.testcases import TestModelsTestCase - - -# class GroupResource(ModelResource): -# model = Group - - -# class UserForm(ModelForm): -# class Meta: -# model = User -# exclude = ('last_login', 'date_joined') - - -# class UserResource(ModelResource): -# model = User -# form = UserForm - - -# class CustomUserResource(ModelResource): -# model = CustomUser - -# urlpatterns = patterns('', -# url(r'^users/$', ListOrCreateModelView.as_view(resource=UserResource), name='users'), -# url(r'^users/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=UserResource)), -# url(r'^customusers/$', ListOrCreateModelView.as_view(resource=CustomUserResource), name='customusers'), -# url(r'^customusers/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=CustomUserResource)), -# url(r'^groups/$', ListOrCreateModelView.as_view(resource=GroupResource), name='groups'), -# url(r'^groups/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=GroupResource)), -# ) - - -# class ModelViewTests(TestModelsTestCase): -# """Test the model views rest_framework provides""" -# urls = 'rest_framework.tests.modelviews' - -# def test_creation(self): -# """Ensure that a model object can be created""" -# self.assertEqual(0, Group.objects.count()) - -# response = self.client.post('/groups/', {'name': 'foo'}) - -# self.assertEqual(response.status_code, 201) -# self.assertEqual(1, Group.objects.count()) -# self.assertEqual('foo', Group.objects.all()[0].name) - -# def test_creation_with_m2m_relation(self): -# """Ensure that a model object with a m2m relation can be created""" -# group = Group(name='foo') -# group.save() -# self.assertEqual(0, User.objects.count()) - -# response = self.client.post('/users/', {'username': 'bar', 'password': 'baz', 'groups': [group.id]}) - -# self.assertEqual(response.status_code, 201) -# self.assertEqual(1, User.objects.count()) - -# user = User.objects.all()[0] -# self.assertEqual('bar', user.username) -# self.assertEqual('baz', user.password) -# self.assertEqual(1, user.groups.count()) - -# group = user.groups.all()[0] -# self.assertEqual('foo', group.name) - -# def test_creation_with_m2m_relation_through(self): -# """ -# Ensure that a model object with a m2m relation can be created where that -# relation uses a through table -# """ -# group = Group(name='foo') -# group.save() -# self.assertEqual(0, User.objects.count()) - -# response = self.client.post('/customusers/', {'username': 'bar', 'groups': [group.id]}) - -# self.assertEqual(response.status_code, 201) -# self.assertEqual(1, CustomUser.objects.count()) - -# user = CustomUser.objects.all()[0] -# self.assertEqual('bar', user.username) -# self.assertEqual(1, user.groups.count()) - -# group = user.groups.all()[0] -# self.assertEqual('foo', group.name) diff --git a/rest_framework/tests/multitable_inheritance.py b/rest_framework/tests/multitable_inheritance.py new file mode 100644 index 00000000..00c15327 --- /dev/null +++ b/rest_framework/tests/multitable_inheritance.py @@ -0,0 +1,67 @@ +from __future__ import unicode_literals +from django.db import models +from django.test import TestCase +from rest_framework import serializers +from rest_framework.tests.models import RESTFrameworkModel + + +# Models +class ParentModel(RESTFrameworkModel): + name1 = models.CharField(max_length=100) + + +class ChildModel(ParentModel): + name2 = models.CharField(max_length=100) + + +class AssociatedModel(RESTFrameworkModel): + ref = models.OneToOneField(ParentModel, primary_key=True) + name = models.CharField(max_length=100) + + +# Serializers +class DerivedModelSerializer(serializers.ModelSerializer): + class Meta: + model = ChildModel + + +class AssociatedModelSerializer(serializers.ModelSerializer): + class Meta: + model = AssociatedModel + + +# Tests +class IneritedModelSerializationTests(TestCase): + + def test_multitable_inherited_model_fields_as_expected(self): + """ + Assert that the parent pointer field is not included in the fields + serialized fields + """ + child = ChildModel(name1='parent name', name2='child name') + serializer = DerivedModelSerializer(child) + self.assertEqual(set(serializer.data.keys()), + set(['name1', 'name2', 'id'])) + + def test_onetoone_primary_key_model_fields_as_expected(self): + """ + Assert that a model with a onetoone field that is the primary key is + not treated like a derived model + """ + parent = ParentModel(name1='parent name') + associate = AssociatedModel(name='hello', ref=parent) + serializer = AssociatedModelSerializer(associate) + self.assertEqual(set(serializer.data.keys()), + set(['name', 'ref'])) + + def test_data_is_valid_without_parent_ptr(self): + """ + Assert that the pointer to the parent table is not a required field + for input data + """ + data = { + 'name1': 'parent name', + 'name2': 'child name', + } + serializer = DerivedModelSerializer(data=data) + self.assertEqual(serializer.is_valid(), True) diff --git a/rest_framework/tests/negotiation.py b/rest_framework/tests/negotiation.py index e06354ea..43721b84 100644 --- a/rest_framework/tests/negotiation.py +++ b/rest_framework/tests/negotiation.py @@ -1,6 +1,9 @@ +from __future__ import unicode_literals from django.test import TestCase from django.test.client import RequestFactory from rest_framework.negotiation import DefaultContentNegotiation +from rest_framework.request import Request + factory = RequestFactory() @@ -22,16 +25,16 @@ class TestAcceptedMediaType(TestCase): return self.negotiator.select_renderer(request, self.renderers) def test_client_without_accept_use_renderer(self): - request = factory.get('/') + request = Request(factory.get('/')) accepted_renderer, accepted_media_type = self.select_renderer(request) - self.assertEquals(accepted_media_type, 'application/json') + self.assertEqual(accepted_media_type, 'application/json') def test_client_underspecifies_accept_use_renderer(self): - request = factory.get('/', HTTP_ACCEPT='*/*') + request = Request(factory.get('/', HTTP_ACCEPT='*/*')) accepted_renderer, accepted_media_type = self.select_renderer(request) - self.assertEquals(accepted_media_type, 'application/json') + self.assertEqual(accepted_media_type, 'application/json') def test_client_overspecifies_accept_use_client(self): - request = factory.get('/', HTTP_ACCEPT='application/json; indent=8') + request = Request(factory.get('/', HTTP_ACCEPT='application/json; indent=8')) accepted_renderer, accepted_media_type = self.select_renderer(request) - self.assertEquals(accepted_media_type, 'application/json; indent=8') + self.assertEqual(accepted_media_type, 'application/json; indent=8') diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py index 3b550877..1a2d68a6 100644 --- a/rest_framework/tests/pagination.py +++ b/rest_framework/tests/pagination.py @@ -1,5 +1,7 @@ +from __future__ import unicode_literals import datetime from decimal import Decimal +import django from django.core.paginator import Paginator from django.test import TestCase from django.test.client import RequestFactory @@ -19,21 +21,6 @@ class RootView(generics.ListCreateAPIView): paginate_by = 10 -if django_filters: - class DecimalFilter(django_filters.FilterSet): - decimal = django_filters.NumberFilter(lookup_type='lt') - - class Meta: - model = FilterableItem - fields = ['text', 'decimal', 'date'] - - class FilterFieldsRootView(generics.ListCreateAPIView): - model = FilterableItem - paginate_by = 10 - filter_class = DecimalFilter - filter_backend = filters.DjangoFilterBackend - - class DefaultPageSizeKwargView(generics.ListAPIView): """ View for testing default paginate_by_param usage @@ -72,28 +59,32 @@ class IntegrationTestPagination(TestCase): GET requests to paginated ListCreateAPIView should return paginated results. """ request = factory.get('/') - response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data['count'], 26) - self.assertEquals(response.data['results'], self.data[:10]) - self.assertNotEquals(response.data['next'], None) - self.assertEquals(response.data['previous'], None) + # Note: Database queries are a `SELECT COUNT`, and `SELECT <fields>` + with self.assertNumQueries(2): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['count'], 26) + self.assertEqual(response.data['results'], self.data[:10]) + self.assertNotEqual(response.data['next'], None) + self.assertEqual(response.data['previous'], None) request = factory.get(response.data['next']) - response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data['count'], 26) - self.assertEquals(response.data['results'], self.data[10:20]) - self.assertNotEquals(response.data['next'], None) - self.assertNotEquals(response.data['previous'], None) + with self.assertNumQueries(2): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['count'], 26) + self.assertEqual(response.data['results'], self.data[10:20]) + self.assertNotEqual(response.data['next'], None) + self.assertNotEqual(response.data['previous'], None) request = factory.get(response.data['next']) - response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data['count'], 26) - self.assertEquals(response.data['results'], self.data[20:]) - self.assertEquals(response.data['next'], None) - self.assertNotEquals(response.data['previous'], None) + with self.assertNumQueries(2): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['count'], 26) + self.assertEqual(response.data['results'], self.data[20:]) + self.assertEqual(response.data['next'], None) + self.assertNotEqual(response.data['previous'], None) class IntegrationTestPaginationAndFiltering(TestCase): @@ -111,41 +102,115 @@ class IntegrationTestPaginationAndFiltering(TestCase): self.objects = FilterableItem.objects self.data = [ - {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} - for obj in self.objects.all() + {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date.isoformat()} + for obj in self.objects.all() ] - self.view = FilterFieldsRootView.as_view() @unittest.skipUnless(django_filters, 'django-filters not installed') - def test_get_paginated_filtered_root_view(self): + def test_get_django_filter_paginated_filtered_root_view(self): """ GET requests to paginated filtered ListCreateAPIView should return paginated results. The next and previous links should preserve the filtered parameters. """ + class DecimalFilter(django_filters.FilterSet): + decimal = django_filters.NumberFilter(lookup_type='lt') + + class Meta: + model = FilterableItem + fields = ['text', 'decimal', 'date'] + + class FilterFieldsRootView(generics.ListCreateAPIView): + model = FilterableItem + paginate_by = 10 + filter_class = DecimalFilter + filter_backend = filters.DjangoFilterBackend + + view = FilterFieldsRootView.as_view() + + EXPECTED_NUM_QUERIES = 2 + if django.VERSION < (1, 4): + # On Django 1.3 we need to use django-filter 0.5.4 + # + # The filter objects there don't expose a `.count()` method, + # which means we only make a single query *but* it's a single + # query across *all* of the queryset, instead of a COUNT and then + # a SELECT with a LIMIT. + # + # Although this is fewer queries, it's actually a regression. + EXPECTED_NUM_QUERIES = 1 + request = factory.get('/?decimal=15.20') - response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data['count'], 15) - self.assertEquals(response.data['results'], self.data[:10]) - self.assertNotEquals(response.data['next'], None) - self.assertEquals(response.data['previous'], None) + with self.assertNumQueries(EXPECTED_NUM_QUERIES): + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['count'], 15) + self.assertEqual(response.data['results'], self.data[:10]) + self.assertNotEqual(response.data['next'], None) + self.assertEqual(response.data['previous'], None) request = factory.get(response.data['next']) - response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data['count'], 15) - self.assertEquals(response.data['results'], self.data[10:15]) - self.assertEquals(response.data['next'], None) - self.assertNotEquals(response.data['previous'], None) + with self.assertNumQueries(EXPECTED_NUM_QUERIES): + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['count'], 15) + self.assertEqual(response.data['results'], self.data[10:15]) + self.assertEqual(response.data['next'], None) + self.assertNotEqual(response.data['previous'], None) request = factory.get(response.data['previous']) - response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data['count'], 15) - self.assertEquals(response.data['results'], self.data[:10]) - self.assertNotEquals(response.data['next'], None) - self.assertEquals(response.data['previous'], None) + with self.assertNumQueries(EXPECTED_NUM_QUERIES): + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['count'], 15) + self.assertEqual(response.data['results'], self.data[:10]) + self.assertNotEqual(response.data['next'], None) + self.assertEqual(response.data['previous'], None) + + def test_get_basic_paginated_filtered_root_view(self): + """ + Same as `test_get_django_filter_paginated_filtered_root_view`, + except using a custom filter backend instead of the django-filter + backend, + """ + + class DecimalFilterBackend(filters.BaseFilterBackend): + def filter_queryset(self, request, queryset, view): + return queryset.filter(decimal__lt=Decimal(request.GET['decimal'])) + + class BasicFilterFieldsRootView(generics.ListCreateAPIView): + model = FilterableItem + paginate_by = 10 + filter_backend = DecimalFilterBackend + + view = BasicFilterFieldsRootView.as_view() + + request = factory.get('/?decimal=15.20') + with self.assertNumQueries(2): + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['count'], 15) + self.assertEqual(response.data['results'], self.data[:10]) + self.assertNotEqual(response.data['next'], None) + self.assertEqual(response.data['previous'], None) + + request = factory.get(response.data['next']) + with self.assertNumQueries(2): + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['count'], 15) + self.assertEqual(response.data['results'], self.data[10:15]) + self.assertEqual(response.data['next'], None) + self.assertNotEqual(response.data['previous'], None) + + request = factory.get(response.data['previous']) + with self.assertNumQueries(2): + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['count'], 15) + self.assertEqual(response.data['results'], self.data[:10]) + self.assertNotEqual(response.data['next'], None) + self.assertEqual(response.data['previous'], None) class PassOnContextPaginationSerializer(pagination.PaginationSerializer): @@ -166,16 +231,16 @@ class UnitTestPagination(TestCase): def test_native_pagination(self): serializer = pagination.PaginationSerializer(self.first_page) - self.assertEquals(serializer.data['count'], 26) - self.assertEquals(serializer.data['next'], '?page=2') - self.assertEquals(serializer.data['previous'], None) - self.assertEquals(serializer.data['results'], self.objects[:10]) + self.assertEqual(serializer.data['count'], 26) + self.assertEqual(serializer.data['next'], '?page=2') + self.assertEqual(serializer.data['previous'], None) + self.assertEqual(serializer.data['results'], self.objects[:10]) serializer = pagination.PaginationSerializer(self.last_page) - self.assertEquals(serializer.data['count'], 26) - self.assertEquals(serializer.data['next'], None) - self.assertEquals(serializer.data['previous'], '?page=2') - self.assertEquals(serializer.data['results'], self.objects[20:]) + self.assertEqual(serializer.data['count'], 26) + self.assertEqual(serializer.data['next'], None) + self.assertEqual(serializer.data['previous'], '?page=2') + self.assertEqual(serializer.data['results'], self.objects[20:]) def test_context_available_in_result(self): """ @@ -184,7 +249,7 @@ class UnitTestPagination(TestCase): serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'}) serializer.data results = serializer.fields[serializer.results_field] - self.assertEquals(serializer.context, results.context) + self.assertEqual(serializer.context, results.context) class TestUnpaginated(TestCase): @@ -212,7 +277,7 @@ class TestUnpaginated(TestCase): """ request = factory.get('/') response = self.view(request) - self.assertEquals(response.data, self.data) + self.assertEqual(response.data, self.data) class TestCustomPaginateByParam(TestCase): @@ -240,7 +305,7 @@ class TestCustomPaginateByParam(TestCase): """ request = factory.get('/') response = self.view(request).render() - self.assertEquals(response.data, self.data) + self.assertEqual(response.data, self.data) def test_paginate_by_param(self): """ @@ -248,9 +313,11 @@ class TestCustomPaginateByParam(TestCase): """ request = factory.get('/?page_size=5') response = self.view(request).render() - self.assertEquals(response.data['count'], 13) - self.assertEquals(response.data['results'], self.data[:5]) + self.assertEqual(response.data['count'], 13) + self.assertEqual(response.data['results'], self.data[:5]) + +### Tests for context in pagination serializers class CustomField(serializers.Field): def to_native(self, value): @@ -262,6 +329,11 @@ class CustomField(serializers.Field): class BasicModelSerializer(serializers.Serializer): text = CustomField() + def __init__(self, *args, **kwargs): + super(BasicModelSerializer, self).__init__(*args, **kwargs) + if not 'view' in self.context: + raise RuntimeError("context isn't getting passed into serializer init") + class TestContextPassedToCustomField(TestCase): def setUp(self): @@ -277,5 +349,41 @@ class TestContextPassedToCustomField(TestCase): request = factory.get('/') response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + +### Tests for custom pagination serializers + +class LinksSerializer(serializers.Serializer): + next = pagination.NextPageField(source='*') + prev = pagination.PreviousPageField(source='*') + +class CustomPaginationSerializer(pagination.BasePaginationSerializer): + links = LinksSerializer(source='*') # Takes the page object as the source + total_results = serializers.Field(source='paginator.count') + + results_field = 'objects' + + +class TestCustomPaginationSerializer(TestCase): + def setUp(self): + objects = ['john', 'paul', 'george', 'ringo'] + paginator = Paginator(objects, 2) + self.page = paginator.page(1) + + def test_custom_pagination_serializer(self): + request = RequestFactory().get('/foobar') + serializer = CustomPaginationSerializer( + instance=self.page, + context={'request': request} + ) + expected = { + 'links': { + 'next': 'http://testserver/foobar?page=2', + 'prev': None + }, + 'total_results': 4, + 'objects': ['john', 'paul'] + } + self.assertEqual(serializer.data, expected) diff --git a/rest_framework/tests/parsers.py b/rest_framework/tests/parsers.py index 8ab8a52f..539c5b44 100644 --- a/rest_framework/tests/parsers.py +++ b/rest_framework/tests/parsers.py @@ -1,139 +1,9 @@ -# """ -# .. -# >>> from rest_framework.parsers import FormParser -# >>> from django.test.client import RequestFactory -# >>> from rest_framework.views import View -# >>> from StringIO import StringIO -# >>> from urllib import urlencode -# >>> req = RequestFactory().get('/') -# >>> some_view = View() -# >>> some_view.request = req # Make as if this request had been dispatched -# -# FormParser -# ============ -# -# Data flatening -# ---------------- -# -# Here is some example data, which would eventually be sent along with a post request : -# -# >>> inpt = urlencode([ -# ... ('key1', 'bla1'), -# ... ('key2', 'blo1'), ('key2', 'blo2'), -# ... ]) -# -# Default behaviour for :class:`parsers.FormParser`, is to return a single value for each parameter : -# -# >>> (data, files) = FormParser(some_view).parse(StringIO(inpt)) -# >>> data == {'key1': 'bla1', 'key2': 'blo1'} -# True -# -# However, you can customize this behaviour by subclassing :class:`parsers.FormParser`, and overriding :meth:`parsers.FormParser.is_a_list` : -# -# >>> class MyFormParser(FormParser): -# ... -# ... def is_a_list(self, key, val_list): -# ... return len(val_list) > 1 -# -# This new parser only flattens the lists of parameters that contain a single value. -# -# >>> (data, files) = MyFormParser(some_view).parse(StringIO(inpt)) -# >>> data == {'key1': 'bla1', 'key2': ['blo1', 'blo2']} -# True -# -# .. note:: The same functionality is available for :class:`parsers.MultiPartParser`. -# -# Submitting an empty list -# -------------------------- -# -# When submitting an empty select multiple, like this one :: -# -# <select multiple="multiple" name="key2"></select> -# -# The browsers usually strip the parameter completely. A hack to avoid this, and therefore being able to submit an empty select multiple, is to submit a value that tells the server that the list is empty :: -# -# <select multiple="multiple" name="key2"><option value="_empty"></select> -# -# :class:`parsers.FormParser` provides the server-side implementation for this hack. Considering the following posted data : -# -# >>> inpt = urlencode([ -# ... ('key1', 'blo1'), ('key1', '_empty'), -# ... ('key2', '_empty'), -# ... ]) -# -# :class:`parsers.FormParser` strips the values ``_empty`` from all the lists. -# -# >>> (data, files) = MyFormParser(some_view).parse(StringIO(inpt)) -# >>> data == {'key1': 'blo1'} -# True -# -# Oh ... but wait a second, the parameter ``key2`` isn't even supposed to be a list, so the parser just stripped it. -# -# >>> class MyFormParser(FormParser): -# ... -# ... def is_a_list(self, key, val_list): -# ... return key == 'key2' -# ... -# >>> (data, files) = MyFormParser(some_view).parse(StringIO(inpt)) -# >>> data == {'key1': 'blo1', 'key2': []} -# True -# -# Better like that. Note that you can configure something else than ``_empty`` for the empty value by setting :attr:`parsers.FormParser.EMPTY_VALUE`. -# """ -# import httplib, mimetypes -# from tempfile import TemporaryFile -# from django.test import TestCase -# from django.test.client import RequestFactory -# from rest_framework.parsers import MultiPartParser -# from rest_framework.views import View -# from StringIO import StringIO -# -# def encode_multipart_formdata(fields, files): -# """For testing multipart parser. -# fields is a sequence of (name, value) elements for regular form fields. -# files is a sequence of (name, filename, value) elements for data to be uploaded as files -# Return (content_type, body).""" -# BOUNDARY = '----------ThIs_Is_tHe_bouNdaRY_$' -# CRLF = '\r\n' -# L = [] -# for (key, value) in fields: -# L.append('--' + BOUNDARY) -# L.append('Content-Disposition: form-data; name="%s"' % key) -# L.append('') -# L.append(value) -# for (key, filename, value) in files: -# L.append('--' + BOUNDARY) -# L.append('Content-Disposition: form-data; name="%s"; filename="%s"' % (key, filename)) -# L.append('Content-Type: %s' % get_content_type(filename)) -# L.append('') -# L.append(value) -# L.append('--' + BOUNDARY + '--') -# L.append('') -# body = CRLF.join(L) -# content_type = 'multipart/form-data; boundary=%s' % BOUNDARY -# return content_type, body -# -# def get_content_type(filename): -# return mimetypes.guess_type(filename)[0] or 'application/octet-stream' -# -#class TestMultiPartParser(TestCase): -# def setUp(self): -# self.req = RequestFactory() -# self.content_type, self.body = encode_multipart_formdata([('key1', 'val1'), ('key1', 'val2')], -# [('file1', 'pic.jpg', 'blablabla'), ('file1', 't.txt', 'blobloblo')]) -# -# def test_multipartparser(self): -# """Ensure that MultiPartParser can parse multipart/form-data that contains a mix of several files and parameters.""" -# post_req = RequestFactory().post('/', self.body, content_type=self.content_type) -# view = View() -# view.request = post_req -# (data, files) = MultiPartParser(view).parse(StringIO(self.body)) -# self.assertEqual(data['key1'], 'val1') -# self.assertEqual(files['file1'].read(), 'blablabla') - -from StringIO import StringIO +from __future__ import unicode_literals +from rest_framework.compat import StringIO from django import forms from django.test import TestCase +from django.utils import unittest +from rest_framework.compat import etree from rest_framework.parsers import FormParser from rest_framework.parsers import XMLParser import datetime @@ -201,11 +71,13 @@ class TestXMLParser(TestCase): ] } + @unittest.skipUnless(etree, 'defusedxml not installed') def test_parse(self): parser = XMLParser() data = parser.parse(self._input) self.assertEqual(data, self._data) + @unittest.skipUnless(etree, 'defusedxml not installed') def test_complex_data_parse(self): parser = XMLParser() data = parser.parse(self._complex_data_input) diff --git a/rest_framework/tests/permissions.py b/rest_framework/tests/permissions.py new file mode 100644 index 00000000..b3993be5 --- /dev/null +++ b/rest_framework/tests/permissions.py @@ -0,0 +1,153 @@ +from __future__ import unicode_literals +from django.contrib.auth.models import User, Permission +from django.db import models +from django.test import TestCase +from rest_framework import generics, status, permissions, authentication, HTTP_HEADER_ENCODING +from rest_framework.tests.utils import RequestFactory +import base64 +import json + +factory = RequestFactory() + + +class BasicModel(models.Model): + text = models.CharField(max_length=100) + + +class RootView(generics.ListCreateAPIView): + model = BasicModel + authentication_classes = [authentication.BasicAuthentication] + permission_classes = [permissions.DjangoModelPermissions] + + +class InstanceView(generics.RetrieveUpdateDestroyAPIView): + model = BasicModel + authentication_classes = [authentication.BasicAuthentication] + permission_classes = [permissions.DjangoModelPermissions] + +root_view = RootView.as_view() +instance_view = InstanceView.as_view() + + +def basic_auth_header(username, password): + credentials = ('%s:%s' % (username, password)) + base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING) + return 'Basic %s' % base64_credentials + + +class ModelPermissionsIntegrationTests(TestCase): + def setUp(self): + User.objects.create_user('disallowed', 'disallowed@example.com', 'password') + user = User.objects.create_user('permitted', 'permitted@example.com', 'password') + user.user_permissions = [ + Permission.objects.get(codename='add_basicmodel'), + Permission.objects.get(codename='change_basicmodel'), + Permission.objects.get(codename='delete_basicmodel') + ] + user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password') + user.user_permissions = [ + Permission.objects.get(codename='change_basicmodel'), + ] + + self.permitted_credentials = basic_auth_header('permitted', 'password') + self.disallowed_credentials = basic_auth_header('disallowed', 'password') + self.updateonly_credentials = basic_auth_header('updateonly', 'password') + + BasicModel(text='foo').save() + + def test_has_create_permissions(self): + request = factory.post('/', json.dumps({'text': 'foobar'}), + content_type='application/json', + HTTP_AUTHORIZATION=self.permitted_credentials) + response = root_view(request, pk=1) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + + def test_has_put_permissions(self): + request = factory.put('/1', json.dumps({'text': 'foobar'}), + content_type='application/json', + HTTP_AUTHORIZATION=self.permitted_credentials) + response = instance_view(request, pk='1') + self.assertEqual(response.status_code, status.HTTP_200_OK) + + def test_has_delete_permissions(self): + request = factory.delete('/1', HTTP_AUTHORIZATION=self.permitted_credentials) + response = instance_view(request, pk=1) + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + + def test_does_not_have_create_permissions(self): + request = factory.post('/', json.dumps({'text': 'foobar'}), + content_type='application/json', + HTTP_AUTHORIZATION=self.disallowed_credentials) + response = root_view(request, pk=1) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_does_not_have_put_permissions(self): + request = factory.put('/1', json.dumps({'text': 'foobar'}), + content_type='application/json', + HTTP_AUTHORIZATION=self.disallowed_credentials) + response = instance_view(request, pk='1') + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_does_not_have_delete_permissions(self): + request = factory.delete('/1', HTTP_AUTHORIZATION=self.disallowed_credentials) + response = instance_view(request, pk=1) + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_has_put_as_create_permissions(self): + # User only has update permissions - should be able to update an entity. + request = factory.put('/1', json.dumps({'text': 'foobar'}), + content_type='application/json', + HTTP_AUTHORIZATION=self.updateonly_credentials) + response = instance_view(request, pk='1') + self.assertEqual(response.status_code, status.HTTP_200_OK) + + # But if PUTing to a new entity, permission should be denied. + request = factory.put('/2', json.dumps({'text': 'foobar'}), + content_type='application/json', + HTTP_AUTHORIZATION=self.updateonly_credentials) + response = instance_view(request, pk='2') + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + +class OwnerModel(models.Model): + text = models.CharField(max_length=100) + owner = models.ForeignKey(User) + + +class IsOwnerPermission(permissions.BasePermission): + def has_object_permission(self, request, view, obj): + return request.user == obj.owner + + +class OwnerInstanceView(generics.RetrieveUpdateDestroyAPIView): + model = OwnerModel + authentication_classes = [authentication.BasicAuthentication] + permission_classes = [IsOwnerPermission] + + +owner_instance_view = OwnerInstanceView.as_view() + + +class ObjectPermissionsIntegrationTests(TestCase): + """ + Integration tests for the object level permissions API. + """ + + def setUp(self): + User.objects.create_user('not_owner', 'not_owner@example.com', 'password') + user = User.objects.create_user('owner', 'owner@example.com', 'password') + + self.not_owner_credentials = basic_auth_header('not_owner', 'password') + self.owner_credentials = basic_auth_header('owner', 'password') + + OwnerModel(text='foo', owner=user).save() + + def test_owner_has_delete_permissions(self): + request = factory.delete('/1', HTTP_AUTHORIZATION=self.owner_credentials) + response = owner_instance_view(request, pk='1') + self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + + def test_non_owner_does_not_have_delete_permissions(self): + request = factory.delete('/1', HTTP_AUTHORIZATION=self.not_owner_credentials) + response = owner_instance_view(request, pk='1') + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) diff --git a/rest_framework/tests/relations.py b/rest_framework/tests/relations.py index 91daea8a..cbf93c65 100644 --- a/rest_framework/tests/relations.py +++ b/rest_framework/tests/relations.py @@ -1,7 +1,7 @@ """ General tests for relational fields. """ - +from __future__ import unicode_literals from django.db import models from django.test import TestCase from rest_framework import serializers @@ -31,3 +31,17 @@ class FieldTests(TestCase): field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk') self.assertRaises(serializers.ValidationError, field.from_native, '') self.assertRaises(serializers.ValidationError, field.from_native, []) + + +class TestManyRelateMixin(TestCase): + def test_missing_many_to_many_related_field(self): + ''' + Regression test for #632 + + https://github.com/tomchristie/django-rest-framework/pull/632 + ''' + field = serializers.RelatedField(many=True, read_only=False) + + into = {} + field.field_from_native({}, None, 'field_name', into) + self.assertEqual(into['field_name'], []) diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/relations_hyperlink.py index 57913670..b5702a48 100644 --- a/rest_framework/tests/relations_hyperlink.py +++ b/rest_framework/tests/relations_hyperlink.py @@ -1,7 +1,16 @@ +from __future__ import unicode_literals from django.test import TestCase +from django.test.client import RequestFactory from rest_framework import serializers from rest_framework.compat import patterns, url -from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource +from rest_framework.tests.models import ( + ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, + NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource +) + +factory = RequestFactory() +request = factory.get('/') # Just to ensure we have a request in the serializer context + def dummy_view(request, pk): pass @@ -16,8 +25,9 @@ urlpatterns = patterns('', url(r'^nullableonetoonesource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableonetoonesource-detail'), ) + class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer): - sources = serializers.ManyHyperlinkedRelatedField(view_name='manytomanysource-detail') + sources = serializers.HyperlinkedRelatedField(many=True, view_name='manytomanysource-detail') class Meta: model = ManyToManyTarget @@ -29,7 +39,7 @@ class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer): class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer): - sources = serializers.ManyHyperlinkedRelatedField(view_name='foreignkeysource-detail') + sources = serializers.HyperlinkedRelatedField(many=True, view_name='foreignkeysource-detail') class Meta: model = ForeignKeyTarget @@ -70,98 +80,98 @@ class HyperlinkedManyToManyTests(TestCase): def test_many_to_many_retrieve(self): queryset = ManyToManySource.objects.all() - serializer = ManyToManySourceSerializer(queryset) + serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) expected = [ - {'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/']}, - {'url': '/manytomanysource/2/', 'name': u'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']}, - {'url': '/manytomanysource/3/', 'name': u'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']} + {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']}, + {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, + {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_reverse_many_to_many_retrieve(self): queryset = ManyToManyTarget.objects.all() - serializer = ManyToManyTargetSerializer(queryset) + serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) expected = [ - {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/', '/manytomanysource/2/', '/manytomanysource/3/']}, - {'url': '/manytomanytarget/2/', 'name': u'target-2', 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']}, - {'url': '/manytomanytarget/3/', 'name': u'target-3', 'sources': ['/manytomanysource/3/']} + {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, + {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, + {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_many_to_many_update(self): - data = {'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']} + data = {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} instance = ManyToManySource.objects.get(pk=1) - serializer = ManyToManySourceSerializer(instance, data=data) + serializer = ManyToManySourceSerializer(instance, data=data, context={'request': request}) self.assertTrue(serializer.is_valid()) serializer.save() - self.assertEquals(serializer.data, data) + self.assertEqual(serializer.data, data) # Ensure source 1 is updated, and everything else is as expected queryset = ManyToManySource.objects.all() - serializer = ManyToManySourceSerializer(queryset) + serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) expected = [ - {'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']}, - {'url': '/manytomanysource/2/', 'name': u'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']}, - {'url': '/manytomanysource/3/', 'name': u'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']} + {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}, + {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, + {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_reverse_many_to_many_update(self): - data = {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/']} + data = {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']} instance = ManyToManyTarget.objects.get(pk=1) - serializer = ManyToManyTargetSerializer(instance, data=data) + serializer = ManyToManyTargetSerializer(instance, data=data, context={'request': request}) self.assertTrue(serializer.is_valid()) serializer.save() - self.assertEquals(serializer.data, data) + self.assertEqual(serializer.data, data) # Ensure target 1 is updated, and everything else is as expected queryset = ManyToManyTarget.objects.all() - serializer = ManyToManyTargetSerializer(queryset) + serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) expected = [ - {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/']}, - {'url': '/manytomanytarget/2/', 'name': u'target-2', 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']}, - {'url': '/manytomanytarget/3/', 'name': u'target-3', 'sources': ['/manytomanysource/3/']} + {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']}, + {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, + {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_many_to_many_create(self): - data = {'url': '/manytomanysource/4/', 'name': u'source-4', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/3/']} - serializer = ManyToManySourceSerializer(data=data) + data = {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']} + serializer = ManyToManySourceSerializer(data=data, context={'request': request}) self.assertTrue(serializer.is_valid()) obj = serializer.save() - self.assertEquals(serializer.data, data) - self.assertEqual(obj.name, u'source-4') + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'source-4') # Ensure source 4 is added, and everything else is as expected queryset = ManyToManySource.objects.all() - serializer = ManyToManySourceSerializer(queryset) + serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) expected = [ - {'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/']}, - {'url': '/manytomanysource/2/', 'name': u'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']}, - {'url': '/manytomanysource/3/', 'name': u'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']}, - {'url': '/manytomanysource/4/', 'name': u'source-4', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/3/']} + {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']}, + {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, + {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}, + {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_reverse_many_to_many_create(self): - data = {'url': '/manytomanytarget/4/', 'name': u'target-4', 'sources': ['/manytomanysource/1/', '/manytomanysource/3/']} - serializer = ManyToManyTargetSerializer(data=data) + data = {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']} + serializer = ManyToManyTargetSerializer(data=data, context={'request': request}) self.assertTrue(serializer.is_valid()) obj = serializer.save() - self.assertEquals(serializer.data, data) - self.assertEqual(obj.name, u'target-4') + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'target-4') # Ensure target 4 is added, and everything else is as expected queryset = ManyToManyTarget.objects.all() - serializer = ManyToManyTargetSerializer(queryset) + serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) expected = [ - {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/', '/manytomanysource/2/', '/manytomanysource/3/']}, - {'url': '/manytomanytarget/2/', 'name': u'target-2', 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']}, - {'url': '/manytomanytarget/3/', 'name': u'target-3', 'sources': ['/manytomanysource/3/']}, - {'url': '/manytomanytarget/4/', 'name': u'target-4', 'sources': ['/manytomanysource/1/', '/manytomanysource/3/']} + {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, + {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, + {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}, + {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) class HyperlinkedForeignKeyTests(TestCase): @@ -178,111 +188,118 @@ class HyperlinkedForeignKeyTests(TestCase): def test_foreign_key_retrieve(self): queryset = ForeignKeySource.objects.all() - serializer = ForeignKeySourceSerializer(queryset) + serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request}) expected = [ - {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'}, - {'url': '/foreignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, - {'url': '/foreignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'} + {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, + {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, + {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_reverse_foreign_key_retrieve(self): queryset = ForeignKeyTarget.objects.all() - serializer = ForeignKeyTargetSerializer(queryset) + serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) expected = [ - {'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/2/', '/foreignkeysource/3/']}, - {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []}, + {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']}, + {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_foreign_key_update(self): - data = {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/2/'} + data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'} instance = ForeignKeySource.objects.get(pk=1) - serializer = ForeignKeySourceSerializer(instance, data=data) + serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) self.assertTrue(serializer.is_valid()) - self.assertEquals(serializer.data, data) + self.assertEqual(serializer.data, data) serializer.save() # Ensure source 1 is updated, and everything else is as expected queryset = ForeignKeySource.objects.all() - serializer = ForeignKeySourceSerializer(queryset) + serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request}) expected = [ - {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/2/'}, - {'url': '/foreignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, - {'url': '/foreignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'} + {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'}, + {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, + {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) + + def test_foreign_key_update_incorrect_type(self): + data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 2} + instance = ForeignKeySource.objects.get(pk=1) + serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected url string, received int.']}) def test_reverse_foreign_key_update(self): - data = {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']} + data = {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']} instance = ForeignKeyTarget.objects.get(pk=2) - serializer = ForeignKeyTargetSerializer(instance, data=data) + serializer = ForeignKeyTargetSerializer(instance, data=data, context={'request': request}) self.assertTrue(serializer.is_valid()) # We shouldn't have saved anything to the db yet since save # hasn't been called. queryset = ForeignKeyTarget.objects.all() - new_serializer = ForeignKeyTargetSerializer(queryset) + new_serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) expected = [ - {'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/2/', '/foreignkeysource/3/']}, - {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []}, - ] - self.assertEquals(new_serializer.data, expected) + {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']}, + {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, + ] + self.assertEqual(new_serializer.data, expected) serializer.save() - self.assertEquals(serializer.data, data) + self.assertEqual(serializer.data, data) # Ensure target 2 is update, and everything else is as expected queryset = ForeignKeyTarget.objects.all() - serializer = ForeignKeyTargetSerializer(queryset) + serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) expected = [ - {'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/2/']}, - {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']}, + {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']}, + {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}, ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_foreign_key_create(self): - data = {'url': '/foreignkeysource/4/', 'name': u'source-4', 'target': '/foreignkeytarget/2/'} - serializer = ForeignKeySourceSerializer(data=data) + data = {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'} + serializer = ForeignKeySourceSerializer(data=data, context={'request': request}) self.assertTrue(serializer.is_valid()) obj = serializer.save() - self.assertEquals(serializer.data, data) - self.assertEqual(obj.name, u'source-4') + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'source-4') # Ensure source 1 is updated, and everything else is as expected queryset = ForeignKeySource.objects.all() - serializer = ForeignKeySourceSerializer(queryset) + serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request}) expected = [ - {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'}, - {'url': '/foreignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, - {'url': '/foreignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'}, - {'url': '/foreignkeysource/4/', 'name': u'source-4', 'target': '/foreignkeytarget/2/'}, + {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, + {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, + {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}, + {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'}, ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_reverse_foreign_key_create(self): - data = {'url': '/foreignkeytarget/3/', 'name': u'target-3', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']} - serializer = ForeignKeyTargetSerializer(data=data) + data = {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']} + serializer = ForeignKeyTargetSerializer(data=data, context={'request': request}) self.assertTrue(serializer.is_valid()) obj = serializer.save() - self.assertEquals(serializer.data, data) - self.assertEqual(obj.name, u'target-3') + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'target-3') # Ensure target 4 is added, and everything else is as expected queryset = ForeignKeyTarget.objects.all() - serializer = ForeignKeyTargetSerializer(queryset) + serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) expected = [ - {'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/2/']}, - {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []}, - {'url': '/foreignkeytarget/3/', 'name': u'target-3', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']}, + {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']}, + {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, + {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}, ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_foreign_key_update_with_invalid_null(self): - data = {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': None} + data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': None} instance = ForeignKeySource.objects.get(pk=1) - serializer = ForeignKeySourceSerializer(instance, data=data) + serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) self.assertFalse(serializer.is_valid()) - self.assertEquals(serializer.errors, {'target': [u'Value may not be null']}) + self.assertEqual(serializer.errors, {'target': ['This field is required.']}) class HyperlinkedNullableForeignKeyTests(TestCase): @@ -299,118 +316,118 @@ class HyperlinkedNullableForeignKeyTests(TestCase): def test_foreign_key_retrieve_with_null(self): queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset) + serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) expected = [ - {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'}, - {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, - {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None}, + {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, + {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, + {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_foreign_key_create_with_valid_null(self): - data = {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None} - serializer = NullableForeignKeySourceSerializer(data=data) + data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} + serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request}) self.assertTrue(serializer.is_valid()) obj = serializer.save() - self.assertEquals(serializer.data, data) - self.assertEqual(obj.name, u'source-4') + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'source-4') # Ensure source 4 is created, and everything else is as expected queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset) + serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) expected = [ - {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'}, - {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, - {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None}, - {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None} + {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, + {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, + {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, + {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_foreign_key_create_with_valid_emptystring(self): """ The emptystring should be interpreted as null in the context of relationships. """ - data = {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': ''} - expected_data = {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None} - serializer = NullableForeignKeySourceSerializer(data=data) + data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': ''} + expected_data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} + serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request}) self.assertTrue(serializer.is_valid()) obj = serializer.save() - self.assertEquals(serializer.data, expected_data) - self.assertEqual(obj.name, u'source-4') + self.assertEqual(serializer.data, expected_data) + self.assertEqual(obj.name, 'source-4') # Ensure source 4 is created, and everything else is as expected queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset) + serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) expected = [ - {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'}, - {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, - {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None}, - {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None} + {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, + {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, + {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, + {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_foreign_key_update_with_valid_null(self): - data = {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None} + data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None} instance = NullableForeignKeySource.objects.get(pk=1) - serializer = NullableForeignKeySourceSerializer(instance, data=data) + serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request}) self.assertTrue(serializer.is_valid()) - self.assertEquals(serializer.data, data) + self.assertEqual(serializer.data, data) serializer.save() # Ensure source 1 is updated, and everything else is as expected queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset) + serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) expected = [ - {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None}, - {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, - {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None}, + {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}, + {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, + {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_foreign_key_update_with_valid_emptystring(self): """ The emptystring should be interpreted as null in the context of relationships. """ - data = {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': ''} - expected_data = {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None} + data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': ''} + expected_data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None} instance = NullableForeignKeySource.objects.get(pk=1) - serializer = NullableForeignKeySourceSerializer(instance, data=data) + serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request}) self.assertTrue(serializer.is_valid()) - self.assertEquals(serializer.data, expected_data) + self.assertEqual(serializer.data, expected_data) serializer.save() # Ensure source 1 is updated, and everything else is as expected queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset) + serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) expected = [ - {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None}, - {'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, - {'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None}, + {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}, + {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, + {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) # reverse foreign keys MUST be read_only # In the general case they do not provide .remove() or .clear() # and cannot be arbitrarily set. # def test_reverse_foreign_key_update(self): - # data = {'id': 1, 'name': u'target-1', 'sources': [1]} + # data = {'id': 1, 'name': 'target-1', 'sources': [1]} # instance = ForeignKeyTarget.objects.get(pk=1) # serializer = ForeignKeyTargetSerializer(instance, data=data) # self.assertTrue(serializer.is_valid()) - # self.assertEquals(serializer.data, data) + # self.assertEqual(serializer.data, data) # serializer.save() # # Ensure target 1 is updated, and everything else is as expected # queryset = ForeignKeyTarget.objects.all() - # serializer = ForeignKeyTargetSerializer(queryset) + # serializer = ForeignKeyTargetSerializer(queryset, many=True) # expected = [ - # {'id': 1, 'name': u'target-1', 'sources': [1]}, - # {'id': 2, 'name': u'target-2', 'sources': []}, + # {'id': 1, 'name': 'target-1', 'sources': [1]}, + # {'id': 2, 'name': 'target-2', 'sources': []}, # ] - # self.assertEquals(serializer.data, expected) + # self.assertEqual(serializer.data, expected) class HyperlinkedNullableOneToOneTests(TestCase): @@ -426,9 +443,9 @@ class HyperlinkedNullableOneToOneTests(TestCase): def test_reverse_foreign_key_retrieve_with_null(self): queryset = OneToOneTarget.objects.all() - serializer = NullableOneToOneTargetSerializer(queryset) + serializer = NullableOneToOneTargetSerializer(queryset, many=True, context={'request': request}) expected = [ - {'url': '/onetoonetarget/1/', 'name': u'target-1', 'nullable_source': '/nullableonetoonesource/1/'}, - {'url': '/onetoonetarget/2/', 'name': u'target-2', 'nullable_source': None}, + {'url': 'http://testserver/onetoonetarget/1/', 'name': 'target-1', 'nullable_source': 'http://testserver/nullableonetoonesource/1/'}, + {'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None}, ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py index 0e129fae..a125ba65 100644 --- a/rest_framework/tests/relations_nested.py +++ b/rest_framework/tests/relations_nested.py @@ -1,3 +1,4 @@ +from __future__ import unicode_literals from django.test import TestCase from rest_framework import serializers from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource @@ -15,7 +16,7 @@ class FlatForeignKeySourceSerializer(serializers.ModelSerializer): class ForeignKeyTargetSerializer(serializers.ModelSerializer): - sources = FlatForeignKeySourceSerializer() + sources = FlatForeignKeySourceSerializer(many=True) class Meta: model = ForeignKeyTarget @@ -51,27 +52,27 @@ class ReverseForeignKeyTests(TestCase): def test_foreign_key_retrieve(self): queryset = ForeignKeySource.objects.all() - serializer = ForeignKeySourceSerializer(queryset) + serializer = ForeignKeySourceSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'source-1', 'target': {'id': 1, 'name': u'target-1'}}, - {'id': 2, 'name': u'source-2', 'target': {'id': 1, 'name': u'target-1'}}, - {'id': 3, 'name': u'source-3', 'target': {'id': 1, 'name': u'target-1'}}, + {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, + {'id': 2, 'name': 'source-2', 'target': {'id': 1, 'name': 'target-1'}}, + {'id': 3, 'name': 'source-3', 'target': {'id': 1, 'name': 'target-1'}}, ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_reverse_foreign_key_retrieve(self): queryset = ForeignKeyTarget.objects.all() - serializer = ForeignKeyTargetSerializer(queryset) + serializer = ForeignKeyTargetSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'target-1', 'sources': [ - {'id': 1, 'name': u'source-1', 'target': 1}, - {'id': 2, 'name': u'source-2', 'target': 1}, - {'id': 3, 'name': u'source-3', 'target': 1}, + {'id': 1, 'name': 'target-1', 'sources': [ + {'id': 1, 'name': 'source-1', 'target': 1}, + {'id': 2, 'name': 'source-2', 'target': 1}, + {'id': 3, 'name': 'source-3', 'target': 1}, ]}, - {'id': 2, 'name': u'target-2', 'sources': [ + {'id': 2, 'name': 'target-2', 'sources': [ ]} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) class NestedNullableForeignKeyTests(TestCase): @@ -86,13 +87,13 @@ class NestedNullableForeignKeyTests(TestCase): def test_foreign_key_retrieve_with_null(self): queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset) + serializer = NullableForeignKeySourceSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'source-1', 'target': {'id': 1, 'name': u'target-1'}}, - {'id': 2, 'name': u'source-2', 'target': {'id': 1, 'name': u'target-1'}}, - {'id': 3, 'name': u'source-3', 'target': None}, + {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, + {'id': 2, 'name': 'source-2', 'target': {'id': 1, 'name': 'target-1'}}, + {'id': 3, 'name': 'source-3', 'target': None}, ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) class NestedNullableOneToOneTests(TestCase): @@ -106,9 +107,9 @@ class NestedNullableOneToOneTests(TestCase): def test_reverse_foreign_key_retrieve_with_null(self): queryset = OneToOneTarget.objects.all() - serializer = NullableOneToOneTargetSerializer(queryset) + serializer = NullableOneToOneTargetSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'target-1', 'nullable_source': {'id': 1, 'name': u'source-1', 'target': 1}}, - {'id': 2, 'name': u'target-2', 'nullable_source': None}, + {'id': 1, 'name': 'target-1', 'nullable_source': {'id': 1, 'name': 'source-1', 'target': 1}}, + {'id': 2, 'name': 'target-2', 'nullable_source': None}, ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/relations_pk.py index 54835860..f08e1808 100644 --- a/rest_framework/tests/relations_pk.py +++ b/rest_framework/tests/relations_pk.py @@ -1,10 +1,12 @@ +from __future__ import unicode_literals from django.test import TestCase from rest_framework import serializers from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource +from rest_framework.compat import six class ManyToManyTargetSerializer(serializers.ModelSerializer): - sources = serializers.ManyPrimaryKeyRelatedField() + sources = serializers.PrimaryKeyRelatedField(many=True) class Meta: model = ManyToManyTarget @@ -16,7 +18,7 @@ class ManyToManySourceSerializer(serializers.ModelSerializer): class ForeignKeyTargetSerializer(serializers.ModelSerializer): - sources = serializers.ManyPrimaryKeyRelatedField() + sources = serializers.PrimaryKeyRelatedField(many=True) class Meta: model = ForeignKeyTarget @@ -54,97 +56,97 @@ class PKManyToManyTests(TestCase): def test_many_to_many_retrieve(self): queryset = ManyToManySource.objects.all() - serializer = ManyToManySourceSerializer(queryset) + serializer = ManyToManySourceSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'source-1', 'targets': [1]}, - {'id': 2, 'name': u'source-2', 'targets': [1, 2]}, - {'id': 3, 'name': u'source-3', 'targets': [1, 2, 3]} + {'id': 1, 'name': 'source-1', 'targets': [1]}, + {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, + {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_reverse_many_to_many_retrieve(self): queryset = ManyToManyTarget.objects.all() - serializer = ManyToManyTargetSerializer(queryset) + serializer = ManyToManyTargetSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]}, - {'id': 2, 'name': u'target-2', 'sources': [2, 3]}, - {'id': 3, 'name': u'target-3', 'sources': [3]} + {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, + {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, + {'id': 3, 'name': 'target-3', 'sources': [3]} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_many_to_many_update(self): - data = {'id': 1, 'name': u'source-1', 'targets': [1, 2, 3]} + data = {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]} instance = ManyToManySource.objects.get(pk=1) serializer = ManyToManySourceSerializer(instance, data=data) self.assertTrue(serializer.is_valid()) serializer.save() - self.assertEquals(serializer.data, data) + self.assertEqual(serializer.data, data) # Ensure source 1 is updated, and everything else is as expected queryset = ManyToManySource.objects.all() - serializer = ManyToManySourceSerializer(queryset) + serializer = ManyToManySourceSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'source-1', 'targets': [1, 2, 3]}, - {'id': 2, 'name': u'source-2', 'targets': [1, 2]}, - {'id': 3, 'name': u'source-3', 'targets': [1, 2, 3]} + {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]}, + {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, + {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_reverse_many_to_many_update(self): - data = {'id': 1, 'name': u'target-1', 'sources': [1]} + data = {'id': 1, 'name': 'target-1', 'sources': [1]} instance = ManyToManyTarget.objects.get(pk=1) serializer = ManyToManyTargetSerializer(instance, data=data) self.assertTrue(serializer.is_valid()) serializer.save() - self.assertEquals(serializer.data, data) + self.assertEqual(serializer.data, data) # Ensure target 1 is updated, and everything else is as expected queryset = ManyToManyTarget.objects.all() - serializer = ManyToManyTargetSerializer(queryset) + serializer = ManyToManyTargetSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'target-1', 'sources': [1]}, - {'id': 2, 'name': u'target-2', 'sources': [2, 3]}, - {'id': 3, 'name': u'target-3', 'sources': [3]} + {'id': 1, 'name': 'target-1', 'sources': [1]}, + {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, + {'id': 3, 'name': 'target-3', 'sources': [3]} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_many_to_many_create(self): - data = {'id': 4, 'name': u'source-4', 'targets': [1, 3]} + data = {'id': 4, 'name': 'source-4', 'targets': [1, 3]} serializer = ManyToManySourceSerializer(data=data) self.assertTrue(serializer.is_valid()) obj = serializer.save() - self.assertEquals(serializer.data, data) - self.assertEqual(obj.name, u'source-4') + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'source-4') # Ensure source 4 is added, and everything else is as expected queryset = ManyToManySource.objects.all() - serializer = ManyToManySourceSerializer(queryset) + serializer = ManyToManySourceSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'source-1', 'targets': [1]}, - {'id': 2, 'name': u'source-2', 'targets': [1, 2]}, - {'id': 3, 'name': u'source-3', 'targets': [1, 2, 3]}, - {'id': 4, 'name': u'source-4', 'targets': [1, 3]}, + {'id': 1, 'name': 'source-1', 'targets': [1]}, + {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, + {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}, + {'id': 4, 'name': 'source-4', 'targets': [1, 3]}, ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_reverse_many_to_many_create(self): - data = {'id': 4, 'name': u'target-4', 'sources': [1, 3]} + data = {'id': 4, 'name': 'target-4', 'sources': [1, 3]} serializer = ManyToManyTargetSerializer(data=data) self.assertTrue(serializer.is_valid()) obj = serializer.save() - self.assertEquals(serializer.data, data) - self.assertEqual(obj.name, u'target-4') + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'target-4') # Ensure target 4 is added, and everything else is as expected queryset = ManyToManyTarget.objects.all() - serializer = ManyToManyTargetSerializer(queryset) + serializer = ManyToManyTargetSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]}, - {'id': 2, 'name': u'target-2', 'sources': [2, 3]}, - {'id': 3, 'name': u'target-3', 'sources': [3]}, - {'id': 4, 'name': u'target-4', 'sources': [1, 3]} + {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, + {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, + {'id': 3, 'name': 'target-3', 'sources': [3]}, + {'id': 4, 'name': 'target-4', 'sources': [1, 3]} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) class PKForeignKeyTests(TestCase): @@ -159,111 +161,118 @@ class PKForeignKeyTests(TestCase): def test_foreign_key_retrieve(self): queryset = ForeignKeySource.objects.all() - serializer = ForeignKeySourceSerializer(queryset) + serializer = ForeignKeySourceSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'source-1', 'target': 1}, - {'id': 2, 'name': u'source-2', 'target': 1}, - {'id': 3, 'name': u'source-3', 'target': 1} + {'id': 1, 'name': 'source-1', 'target': 1}, + {'id': 2, 'name': 'source-2', 'target': 1}, + {'id': 3, 'name': 'source-3', 'target': 1} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_reverse_foreign_key_retrieve(self): queryset = ForeignKeyTarget.objects.all() - serializer = ForeignKeyTargetSerializer(queryset) + serializer = ForeignKeyTargetSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]}, - {'id': 2, 'name': u'target-2', 'sources': []}, + {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, + {'id': 2, 'name': 'target-2', 'sources': []}, ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_foreign_key_update(self): - data = {'id': 1, 'name': u'source-1', 'target': 2} + data = {'id': 1, 'name': 'source-1', 'target': 2} instance = ForeignKeySource.objects.get(pk=1) serializer = ForeignKeySourceSerializer(instance, data=data) self.assertTrue(serializer.is_valid()) - self.assertEquals(serializer.data, data) + self.assertEqual(serializer.data, data) serializer.save() # Ensure source 1 is updated, and everything else is as expected queryset = ForeignKeySource.objects.all() - serializer = ForeignKeySourceSerializer(queryset) + serializer = ForeignKeySourceSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'source-1', 'target': 2}, - {'id': 2, 'name': u'source-2', 'target': 1}, - {'id': 3, 'name': u'source-3', 'target': 1} + {'id': 1, 'name': 'source-1', 'target': 2}, + {'id': 2, 'name': 'source-2', 'target': 1}, + {'id': 3, 'name': 'source-3', 'target': 1} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) + + def test_foreign_key_update_incorrect_type(self): + data = {'id': 1, 'name': 'source-1', 'target': 'foo'} + instance = ForeignKeySource.objects.get(pk=1) + serializer = ForeignKeySourceSerializer(instance, data=data) + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected pk value, received %s.' % six.text_type.__name__]}) def test_reverse_foreign_key_update(self): - data = {'id': 2, 'name': u'target-2', 'sources': [1, 3]} + data = {'id': 2, 'name': 'target-2', 'sources': [1, 3]} instance = ForeignKeyTarget.objects.get(pk=2) serializer = ForeignKeyTargetSerializer(instance, data=data) self.assertTrue(serializer.is_valid()) # We shouldn't have saved anything to the db yet since save # hasn't been called. queryset = ForeignKeyTarget.objects.all() - new_serializer = ForeignKeyTargetSerializer(queryset) + new_serializer = ForeignKeyTargetSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]}, - {'id': 2, 'name': u'target-2', 'sources': []}, - ] - self.assertEquals(new_serializer.data, expected) + {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, + {'id': 2, 'name': 'target-2', 'sources': []}, + ] + self.assertEqual(new_serializer.data, expected) serializer.save() - self.assertEquals(serializer.data, data) + self.assertEqual(serializer.data, data) # Ensure target 2 is update, and everything else is as expected queryset = ForeignKeyTarget.objects.all() - serializer = ForeignKeyTargetSerializer(queryset) + serializer = ForeignKeyTargetSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'target-1', 'sources': [2]}, - {'id': 2, 'name': u'target-2', 'sources': [1, 3]}, + {'id': 1, 'name': 'target-1', 'sources': [2]}, + {'id': 2, 'name': 'target-2', 'sources': [1, 3]}, ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_foreign_key_create(self): - data = {'id': 4, 'name': u'source-4', 'target': 2} + data = {'id': 4, 'name': 'source-4', 'target': 2} serializer = ForeignKeySourceSerializer(data=data) self.assertTrue(serializer.is_valid()) obj = serializer.save() - self.assertEquals(serializer.data, data) - self.assertEqual(obj.name, u'source-4') + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'source-4') # Ensure source 4 is added, and everything else is as expected queryset = ForeignKeySource.objects.all() - serializer = ForeignKeySourceSerializer(queryset) + serializer = ForeignKeySourceSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'source-1', 'target': 1}, - {'id': 2, 'name': u'source-2', 'target': 1}, - {'id': 3, 'name': u'source-3', 'target': 1}, - {'id': 4, 'name': u'source-4', 'target': 2}, + {'id': 1, 'name': 'source-1', 'target': 1}, + {'id': 2, 'name': 'source-2', 'target': 1}, + {'id': 3, 'name': 'source-3', 'target': 1}, + {'id': 4, 'name': 'source-4', 'target': 2}, ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_reverse_foreign_key_create(self): - data = {'id': 3, 'name': u'target-3', 'sources': [1, 3]} + data = {'id': 3, 'name': 'target-3', 'sources': [1, 3]} serializer = ForeignKeyTargetSerializer(data=data) self.assertTrue(serializer.is_valid()) obj = serializer.save() - self.assertEquals(serializer.data, data) - self.assertEqual(obj.name, u'target-3') + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'target-3') # Ensure target 3 is added, and everything else is as expected queryset = ForeignKeyTarget.objects.all() - serializer = ForeignKeyTargetSerializer(queryset) + serializer = ForeignKeyTargetSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'target-1', 'sources': [2]}, - {'id': 2, 'name': u'target-2', 'sources': []}, - {'id': 3, 'name': u'target-3', 'sources': [1, 3]}, + {'id': 1, 'name': 'target-1', 'sources': [2]}, + {'id': 2, 'name': 'target-2', 'sources': []}, + {'id': 3, 'name': 'target-3', 'sources': [1, 3]}, ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_foreign_key_update_with_invalid_null(self): - data = {'id': 1, 'name': u'source-1', 'target': None} + data = {'id': 1, 'name': 'source-1', 'target': None} instance = ForeignKeySource.objects.get(pk=1) serializer = ForeignKeySourceSerializer(instance, data=data) self.assertFalse(serializer.is_valid()) - self.assertEquals(serializer.errors, {'target': [u'Value may not be null']}) + self.assertEqual(serializer.errors, {'target': ['This field is required.']}) class PKNullableForeignKeyTests(TestCase): @@ -278,118 +287,118 @@ class PKNullableForeignKeyTests(TestCase): def test_foreign_key_retrieve_with_null(self): queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset) + serializer = NullableForeignKeySourceSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'source-1', 'target': 1}, - {'id': 2, 'name': u'source-2', 'target': 1}, - {'id': 3, 'name': u'source-3', 'target': None}, + {'id': 1, 'name': 'source-1', 'target': 1}, + {'id': 2, 'name': 'source-2', 'target': 1}, + {'id': 3, 'name': 'source-3', 'target': None}, ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_foreign_key_create_with_valid_null(self): - data = {'id': 4, 'name': u'source-4', 'target': None} + data = {'id': 4, 'name': 'source-4', 'target': None} serializer = NullableForeignKeySourceSerializer(data=data) self.assertTrue(serializer.is_valid()) obj = serializer.save() - self.assertEquals(serializer.data, data) - self.assertEqual(obj.name, u'source-4') + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'source-4') # Ensure source 4 is created, and everything else is as expected queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset) + serializer = NullableForeignKeySourceSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'source-1', 'target': 1}, - {'id': 2, 'name': u'source-2', 'target': 1}, - {'id': 3, 'name': u'source-3', 'target': None}, - {'id': 4, 'name': u'source-4', 'target': None} + {'id': 1, 'name': 'source-1', 'target': 1}, + {'id': 2, 'name': 'source-2', 'target': 1}, + {'id': 3, 'name': 'source-3', 'target': None}, + {'id': 4, 'name': 'source-4', 'target': None} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_foreign_key_create_with_valid_emptystring(self): """ The emptystring should be interpreted as null in the context of relationships. """ - data = {'id': 4, 'name': u'source-4', 'target': ''} - expected_data = {'id': 4, 'name': u'source-4', 'target': None} + data = {'id': 4, 'name': 'source-4', 'target': ''} + expected_data = {'id': 4, 'name': 'source-4', 'target': None} serializer = NullableForeignKeySourceSerializer(data=data) self.assertTrue(serializer.is_valid()) obj = serializer.save() - self.assertEquals(serializer.data, expected_data) - self.assertEqual(obj.name, u'source-4') + self.assertEqual(serializer.data, expected_data) + self.assertEqual(obj.name, 'source-4') # Ensure source 4 is created, and everything else is as expected queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset) + serializer = NullableForeignKeySourceSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'source-1', 'target': 1}, - {'id': 2, 'name': u'source-2', 'target': 1}, - {'id': 3, 'name': u'source-3', 'target': None}, - {'id': 4, 'name': u'source-4', 'target': None} + {'id': 1, 'name': 'source-1', 'target': 1}, + {'id': 2, 'name': 'source-2', 'target': 1}, + {'id': 3, 'name': 'source-3', 'target': None}, + {'id': 4, 'name': 'source-4', 'target': None} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_foreign_key_update_with_valid_null(self): - data = {'id': 1, 'name': u'source-1', 'target': None} + data = {'id': 1, 'name': 'source-1', 'target': None} instance = NullableForeignKeySource.objects.get(pk=1) serializer = NullableForeignKeySourceSerializer(instance, data=data) self.assertTrue(serializer.is_valid()) - self.assertEquals(serializer.data, data) + self.assertEqual(serializer.data, data) serializer.save() # Ensure source 1 is updated, and everything else is as expected queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset) + serializer = NullableForeignKeySourceSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'source-1', 'target': None}, - {'id': 2, 'name': u'source-2', 'target': 1}, - {'id': 3, 'name': u'source-3', 'target': None} + {'id': 1, 'name': 'source-1', 'target': None}, + {'id': 2, 'name': 'source-2', 'target': 1}, + {'id': 3, 'name': 'source-3', 'target': None} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_foreign_key_update_with_valid_emptystring(self): """ The emptystring should be interpreted as null in the context of relationships. """ - data = {'id': 1, 'name': u'source-1', 'target': ''} - expected_data = {'id': 1, 'name': u'source-1', 'target': None} + data = {'id': 1, 'name': 'source-1', 'target': ''} + expected_data = {'id': 1, 'name': 'source-1', 'target': None} instance = NullableForeignKeySource.objects.get(pk=1) serializer = NullableForeignKeySourceSerializer(instance, data=data) self.assertTrue(serializer.is_valid()) - self.assertEquals(serializer.data, expected_data) + self.assertEqual(serializer.data, expected_data) serializer.save() # Ensure source 1 is updated, and everything else is as expected queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset) + serializer = NullableForeignKeySourceSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'source-1', 'target': None}, - {'id': 2, 'name': u'source-2', 'target': 1}, - {'id': 3, 'name': u'source-3', 'target': None} + {'id': 1, 'name': 'source-1', 'target': None}, + {'id': 2, 'name': 'source-2', 'target': 1}, + {'id': 3, 'name': 'source-3', 'target': None} ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) # reverse foreign keys MUST be read_only # In the general case they do not provide .remove() or .clear() # and cannot be arbitrarily set. # def test_reverse_foreign_key_update(self): - # data = {'id': 1, 'name': u'target-1', 'sources': [1]} + # data = {'id': 1, 'name': 'target-1', 'sources': [1]} # instance = ForeignKeyTarget.objects.get(pk=1) # serializer = ForeignKeyTargetSerializer(instance, data=data) # self.assertTrue(serializer.is_valid()) - # self.assertEquals(serializer.data, data) + # self.assertEqual(serializer.data, data) # serializer.save() # # Ensure target 1 is updated, and everything else is as expected # queryset = ForeignKeyTarget.objects.all() - # serializer = ForeignKeyTargetSerializer(queryset) + # serializer = ForeignKeyTargetSerializer(queryset, many=True) # expected = [ - # {'id': 1, 'name': u'target-1', 'sources': [1]}, - # {'id': 2, 'name': u'target-2', 'sources': []}, + # {'id': 1, 'name': 'target-1', 'sources': [1]}, + # {'id': 2, 'name': 'target-2', 'sources': []}, # ] - # self.assertEquals(serializer.data, expected) + # self.assertEqual(serializer.data, expected) class PKNullableOneToOneTests(TestCase): @@ -398,14 +407,14 @@ class PKNullableOneToOneTests(TestCase): target.save() new_target = OneToOneTarget(name='target-2') new_target.save() - source = NullableOneToOneSource(name='source-1', target=target) + source = NullableOneToOneSource(name='source-1', target=new_target) source.save() def test_reverse_foreign_key_retrieve_with_null(self): queryset = OneToOneTarget.objects.all() - serializer = NullableOneToOneTargetSerializer(queryset) + serializer = NullableOneToOneTargetSerializer(queryset, many=True) expected = [ - {'id': 1, 'name': u'target-1', 'nullable_source': 1}, - {'id': 2, 'name': u'target-2', 'nullable_source': None}, + {'id': 1, 'name': 'target-1', 'nullable_source': None}, + {'id': 2, 'name': 'target-2', 'nullable_source': 1}, ] - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) diff --git a/rest_framework/tests/relations_slug.py b/rest_framework/tests/relations_slug.py new file mode 100644 index 00000000..435c821c --- /dev/null +++ b/rest_framework/tests/relations_slug.py @@ -0,0 +1,257 @@ +from django.test import TestCase +from rest_framework import serializers +from rest_framework.tests.models import NullableForeignKeySource, ForeignKeySource, ForeignKeyTarget + + +class ForeignKeyTargetSerializer(serializers.ModelSerializer): + sources = serializers.SlugRelatedField(many=True, slug_field='name') + + class Meta: + model = ForeignKeyTarget + + +class ForeignKeySourceSerializer(serializers.ModelSerializer): + target = serializers.SlugRelatedField(slug_field='name') + + class Meta: + model = ForeignKeySource + + +class NullableForeignKeySourceSerializer(serializers.ModelSerializer): + target = serializers.SlugRelatedField(slug_field='name', required=False) + + class Meta: + model = NullableForeignKeySource + + +# TODO: M2M Tests, FKTests (Non-nullable), One2One +class SlugForeignKeyTests(TestCase): + def setUp(self): + target = ForeignKeyTarget(name='target-1') + target.save() + new_target = ForeignKeyTarget(name='target-2') + new_target.save() + for idx in range(1, 4): + source = ForeignKeySource(name='source-%d' % idx, target=target) + source.save() + + def test_foreign_key_retrieve(self): + queryset = ForeignKeySource.objects.all() + serializer = ForeignKeySourceSerializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'source-1', 'target': 'target-1'}, + {'id': 2, 'name': 'source-2', 'target': 'target-1'}, + {'id': 3, 'name': 'source-3', 'target': 'target-1'} + ] + self.assertEqual(serializer.data, expected) + + def test_reverse_foreign_key_retrieve(self): + queryset = ForeignKeyTarget.objects.all() + serializer = ForeignKeyTargetSerializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, + {'id': 2, 'name': 'target-2', 'sources': []}, + ] + self.assertEqual(serializer.data, expected) + + def test_foreign_key_update(self): + data = {'id': 1, 'name': 'source-1', 'target': 'target-2'} + instance = ForeignKeySource.objects.get(pk=1) + serializer = ForeignKeySourceSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + self.assertEqual(serializer.data, data) + serializer.save() + + # Ensure source 1 is updated, and everything else is as expected + queryset = ForeignKeySource.objects.all() + serializer = ForeignKeySourceSerializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'source-1', 'target': 'target-2'}, + {'id': 2, 'name': 'source-2', 'target': 'target-1'}, + {'id': 3, 'name': 'source-3', 'target': 'target-1'} + ] + self.assertEqual(serializer.data, expected) + + def test_foreign_key_update_incorrect_type(self): + data = {'id': 1, 'name': 'source-1', 'target': 123} + instance = ForeignKeySource.objects.get(pk=1) + serializer = ForeignKeySourceSerializer(instance, data=data) + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'target': ['Object with name=123 does not exist.']}) + + def test_reverse_foreign_key_update(self): + data = {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']} + instance = ForeignKeyTarget.objects.get(pk=2) + serializer = ForeignKeyTargetSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + # We shouldn't have saved anything to the db yet since save + # hasn't been called. + queryset = ForeignKeyTarget.objects.all() + new_serializer = ForeignKeyTargetSerializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, + {'id': 2, 'name': 'target-2', 'sources': []}, + ] + self.assertEqual(new_serializer.data, expected) + + serializer.save() + self.assertEqual(serializer.data, data) + + # Ensure target 2 is update, and everything else is as expected + queryset = ForeignKeyTarget.objects.all() + serializer = ForeignKeyTargetSerializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'target-1', 'sources': ['source-2']}, + {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']}, + ] + self.assertEqual(serializer.data, expected) + + def test_foreign_key_create(self): + data = {'id': 4, 'name': 'source-4', 'target': 'target-2'} + serializer = ForeignKeySourceSerializer(data=data) + serializer.is_valid() + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'source-4') + + # Ensure source 4 is added, and everything else is as expected + queryset = ForeignKeySource.objects.all() + serializer = ForeignKeySourceSerializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'source-1', 'target': 'target-1'}, + {'id': 2, 'name': 'source-2', 'target': 'target-1'}, + {'id': 3, 'name': 'source-3', 'target': 'target-1'}, + {'id': 4, 'name': 'source-4', 'target': 'target-2'}, + ] + self.assertEqual(serializer.data, expected) + + def test_reverse_foreign_key_create(self): + data = {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']} + serializer = ForeignKeyTargetSerializer(data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'target-3') + + # Ensure target 3 is added, and everything else is as expected + queryset = ForeignKeyTarget.objects.all() + serializer = ForeignKeyTargetSerializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'target-1', 'sources': ['source-2']}, + {'id': 2, 'name': 'target-2', 'sources': []}, + {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']}, + ] + self.assertEqual(serializer.data, expected) + + def test_foreign_key_update_with_invalid_null(self): + data = {'id': 1, 'name': 'source-1', 'target': None} + instance = ForeignKeySource.objects.get(pk=1) + serializer = ForeignKeySourceSerializer(instance, data=data) + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'target': ['This field is required.']}) + + +class SlugNullableForeignKeyTests(TestCase): + def setUp(self): + target = ForeignKeyTarget(name='target-1') + target.save() + for idx in range(1, 4): + if idx == 3: + target = None + source = NullableForeignKeySource(name='source-%d' % idx, target=target) + source.save() + + def test_foreign_key_retrieve_with_null(self): + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'source-1', 'target': 'target-1'}, + {'id': 2, 'name': 'source-2', 'target': 'target-1'}, + {'id': 3, 'name': 'source-3', 'target': None}, + ] + self.assertEqual(serializer.data, expected) + + def test_foreign_key_create_with_valid_null(self): + data = {'id': 4, 'name': 'source-4', 'target': None} + serializer = NullableForeignKeySourceSerializer(data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'source-4') + + # Ensure source 4 is created, and everything else is as expected + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'source-1', 'target': 'target-1'}, + {'id': 2, 'name': 'source-2', 'target': 'target-1'}, + {'id': 3, 'name': 'source-3', 'target': None}, + {'id': 4, 'name': 'source-4', 'target': None} + ] + self.assertEqual(serializer.data, expected) + + def test_foreign_key_create_with_valid_emptystring(self): + """ + The emptystring should be interpreted as null in the context + of relationships. + """ + data = {'id': 4, 'name': 'source-4', 'target': ''} + expected_data = {'id': 4, 'name': 'source-4', 'target': None} + serializer = NullableForeignKeySourceSerializer(data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, expected_data) + self.assertEqual(obj.name, 'source-4') + + # Ensure source 4 is created, and everything else is as expected + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'source-1', 'target': 'target-1'}, + {'id': 2, 'name': 'source-2', 'target': 'target-1'}, + {'id': 3, 'name': 'source-3', 'target': None}, + {'id': 4, 'name': 'source-4', 'target': None} + ] + self.assertEqual(serializer.data, expected) + + def test_foreign_key_update_with_valid_null(self): + data = {'id': 1, 'name': 'source-1', 'target': None} + instance = NullableForeignKeySource.objects.get(pk=1) + serializer = NullableForeignKeySourceSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + self.assertEqual(serializer.data, data) + serializer.save() + + # Ensure source 1 is updated, and everything else is as expected + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'source-1', 'target': None}, + {'id': 2, 'name': 'source-2', 'target': 'target-1'}, + {'id': 3, 'name': 'source-3', 'target': None} + ] + self.assertEqual(serializer.data, expected) + + def test_foreign_key_update_with_valid_emptystring(self): + """ + The emptystring should be interpreted as null in the context + of relationships. + """ + data = {'id': 1, 'name': 'source-1', 'target': ''} + expected_data = {'id': 1, 'name': 'source-1', 'target': None} + instance = NullableForeignKeySource.objects.get(pk=1) + serializer = NullableForeignKeySourceSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + self.assertEqual(serializer.data, expected_data) + serializer.save() + + # Ensure source 1 is updated, and everything else is as expected + queryset = NullableForeignKeySource.objects.all() + serializer = NullableForeignKeySourceSerializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'source-1', 'target': None}, + {'id': 2, 'name': 'source-2', 'target': 'target-1'}, + {'id': 3, 'name': 'source-3', 'target': None} + ] + self.assertEqual(serializer.data, expected) diff --git a/rest_framework/tests/renderers.py b/rest_framework/tests/renderers.py index c1b4e624..40bac9cb 100644 --- a/rest_framework/tests/renderers.py +++ b/rest_framework/tests/renderers.py @@ -1,29 +1,28 @@ -import pickle -import re - +from decimal import Decimal from django.core.cache import cache from django.test import TestCase from django.test.client import RequestFactory - +from django.utils import unittest from rest_framework import status, permissions -from rest_framework.compat import yaml, patterns, url, include +from rest_framework.compat import yaml, etree, patterns, url, include from rest_framework.response import Response from rest_framework.views import APIView from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \ XMLRenderer, JSONPRenderer, BrowsableAPIRenderer from rest_framework.parsers import YAMLParser, XMLParser from rest_framework.settings import api_settings - -from StringIO import StringIO +from rest_framework.compat import StringIO +from rest_framework.compat import six import datetime -from decimal import Decimal +import pickle +import re DUMMYSTATUS = status.HTTP_200_OK DUMMYCONTENT = 'dummycontent' -RENDERER_A_SERIALIZER = lambda x: 'Renderer A: %s' % x -RENDERER_B_SERIALIZER = lambda x: 'Renderer B: %s' % x +RENDERER_A_SERIALIZER = lambda x: ('Renderer A: %s' % x).encode('ascii') +RENDERER_B_SERIALIZER = lambda x: ('Renderer B: %s' % x).encode('ascii') expected_results = [ @@ -35,7 +34,7 @@ class BasicRendererTests(TestCase): def test_expected_results(self): for value, renderer_cls, expected in expected_results: output = renderer_cls().render(value) - self.assertEquals(output, expected) + self.assertEqual(output, expected) class RendererA(BaseRenderer): @@ -94,7 +93,7 @@ urlpatterns = patterns('', class POSTDeniedPermission(permissions.BasePermission): - def has_permission(self, request, view, obj=None): + def has_permission(self, request, view): return request.method != 'POST' @@ -111,6 +110,9 @@ class POSTDeniedView(APIView): def put(self, request): return Response() + def patch(self, request): + return Response() + class DocumentingRendererTests(TestCase): def test_only_permitted_forms_are_displayed(self): @@ -119,6 +121,7 @@ class DocumentingRendererTests(TestCase): response = view(request).render() self.assertNotContains(response, '>POST<') self.assertContains(response, '>PUT<') + self.assertContains(response, '>PATCH<') class RendererEndToEndTests(TestCase): @@ -131,39 +134,39 @@ class RendererEndToEndTests(TestCase): def test_default_renderer_serializes_content(self): """If the Accept header is not set the default renderer should serialize the response.""" resp = self.client.get('/') - self.assertEquals(resp['Content-Type'], RendererA.media_type) - self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) - self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererA.media_type) + self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) + self.assertEqual(resp.status_code, DUMMYSTATUS) def test_head_method_serializes_no_content(self): """No response must be included in HEAD requests.""" resp = self.client.head('/') - self.assertEquals(resp.status_code, DUMMYSTATUS) - self.assertEquals(resp['Content-Type'], RendererA.media_type) - self.assertEquals(resp.content, '') + self.assertEqual(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererA.media_type) + self.assertEqual(resp.content, six.b('')) def test_default_renderer_serializes_content_on_accept_any(self): """If the Accept header is set to */* the default renderer should serialize the response.""" resp = self.client.get('/', HTTP_ACCEPT='*/*') - self.assertEquals(resp['Content-Type'], RendererA.media_type) - self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) - self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererA.media_type) + self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) + self.assertEqual(resp.status_code, DUMMYSTATUS) def test_specified_renderer_serializes_content_default_case(self): """If the Accept header is set the specified renderer should serialize the response. (In this case we check that works for the default renderer)""" resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type) - self.assertEquals(resp['Content-Type'], RendererA.media_type) - self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) - self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererA.media_type) + self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) + self.assertEqual(resp.status_code, DUMMYSTATUS) def test_specified_renderer_serializes_content_non_default_case(self): """If the Accept header is set the specified renderer should serialize the response. (In this case we check that works for a non-default renderer)""" resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type) - self.assertEquals(resp['Content-Type'], RendererB.media_type) - self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) - self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEqual(resp.status_code, DUMMYSTATUS) def test_specified_renderer_serializes_content_on_accept_query(self): """The '_accept' query string should behave in the same way as the Accept header.""" @@ -172,14 +175,14 @@ class RendererEndToEndTests(TestCase): RendererB.media_type ) resp = self.client.get('/' + param) - self.assertEquals(resp['Content-Type'], RendererB.media_type) - self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) - self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEqual(resp.status_code, DUMMYSTATUS) def test_unsatisfiable_accept_header_on_request_returns_406_status(self): """If the Accept header is unsatisfiable we should return a 406 Not Acceptable response.""" resp = self.client.get('/', HTTP_ACCEPT='foo/bar') - self.assertEquals(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE) + self.assertEqual(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE) def test_specified_renderer_serializes_content_on_format_query(self): """If a 'format' query is specified, the renderer with the matching @@ -189,17 +192,17 @@ class RendererEndToEndTests(TestCase): RendererB.format ) resp = self.client.get('/' + param) - self.assertEquals(resp['Content-Type'], RendererB.media_type) - self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) - self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEqual(resp.status_code, DUMMYSTATUS) def test_specified_renderer_serializes_content_on_format_kwargs(self): """If a 'format' keyword arg is specified, the renderer with the matching format attribute should serialize the response.""" resp = self.client.get('/something.formatb') - self.assertEquals(resp['Content-Type'], RendererB.media_type) - self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) - self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEqual(resp.status_code, DUMMYSTATUS) def test_specified_renderer_is_used_on_format_query_with_matching_accept(self): """If both a 'format' query and a matching Accept header specified, @@ -210,9 +213,9 @@ class RendererEndToEndTests(TestCase): ) resp = self.client.get('/' + param, HTTP_ACCEPT=RendererB.media_type) - self.assertEquals(resp['Content-Type'], RendererB.media_type) - self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) - self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEqual(resp.status_code, DUMMYSTATUS) _flat_repr = '{"foo": ["bar", "baz"]}' @@ -240,7 +243,7 @@ class JSONRendererTests(TestCase): renderer = JSONRenderer() content = renderer.render(obj, 'application/json') # Fix failing test case which depends on version of JSON library. - self.assertEquals(content, _flat_repr) + self.assertEqual(content, _flat_repr) def test_with_content_type_args(self): """ @@ -249,7 +252,7 @@ class JSONRendererTests(TestCase): obj = {'foo': ['bar', 'baz']} renderer = JSONRenderer() content = renderer.render(obj, 'application/json; indent=2') - self.assertEquals(strip_trailing_whitespace(content), _indented_repr) + self.assertEqual(strip_trailing_whitespace(content), _indented_repr) class JSONPRendererTests(TestCase): @@ -265,9 +268,10 @@ class JSONPRendererTests(TestCase): """ resp = self.client.get('/jsonp/jsonrenderer', HTTP_ACCEPT='application/javascript') - self.assertEquals(resp.status_code, 200) - self.assertEquals(resp['Content-Type'], 'application/javascript') - self.assertEquals(resp.content, 'callback(%s);' % _flat_repr) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + self.assertEqual(resp['Content-Type'], 'application/javascript') + self.assertEqual(resp.content, + ('callback(%s);' % _flat_repr).encode('ascii')) def test_without_callback_without_json_renderer(self): """ @@ -275,9 +279,10 @@ class JSONPRendererTests(TestCase): """ resp = self.client.get('/jsonp/nojsonrenderer', HTTP_ACCEPT='application/javascript') - self.assertEquals(resp.status_code, 200) - self.assertEquals(resp['Content-Type'], 'application/javascript') - self.assertEquals(resp.content, 'callback(%s);' % _flat_repr) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + self.assertEqual(resp['Content-Type'], 'application/javascript') + self.assertEqual(resp.content, + ('callback(%s);' % _flat_repr).encode('ascii')) def test_with_callback(self): """ @@ -286,9 +291,10 @@ class JSONPRendererTests(TestCase): callback_func = 'myjsonpcallback' resp = self.client.get('/jsonp/nojsonrenderer?callback=' + callback_func, HTTP_ACCEPT='application/javascript') - self.assertEquals(resp.status_code, 200) - self.assertEquals(resp['Content-Type'], 'application/javascript') - self.assertEquals(resp.content, '%s(%s);' % (callback_func, _flat_repr)) + self.assertEqual(resp.status_code, status.HTTP_200_OK) + self.assertEqual(resp['Content-Type'], 'application/javascript') + self.assertEqual(resp.content, + ('%s(%s);' % (callback_func, _flat_repr)).encode('ascii')) if yaml: @@ -306,7 +312,7 @@ if yaml: obj = {'foo': ['bar', 'baz']} renderer = YAMLRenderer() content = renderer.render(obj, 'application/yaml') - self.assertEquals(content, _yaml_repr) + self.assertEqual(content, _yaml_repr) def test_render_and_parse(self): """ @@ -320,7 +326,7 @@ if yaml: content = renderer.render(obj, 'application/yaml') data = parser.parse(StringIO(content)) - self.assertEquals(obj, data) + self.assertEqual(obj, data) class XMLRendererTestCase(TestCase): @@ -402,6 +408,7 @@ class XMLRendererTestCase(TestCase): self.assertXMLContains(content, '<sub_name>first</sub_name>') self.assertXMLContains(content, '<sub_name>second</sub_name>') + @unittest.skipUnless(etree, 'defusedxml not installed') def test_render_and_parse_complex_data(self): """ Test XML rendering. diff --git a/rest_framework/tests/request.py b/rest_framework/tests/request.py index 4b032405..97e5af20 100644 --- a/rest_framework/tests/request.py +++ b/rest_framework/tests/request.py @@ -1,7 +1,7 @@ """ Tests for content parsing, and form-overloaded content parsing. """ -import json +from __future__ import unicode_literals from django.contrib.auth.models import User from django.contrib.auth import authenticate, login, logout from django.contrib.sessions.middleware import SessionMiddleware @@ -20,6 +20,8 @@ from rest_framework.request import Request from rest_framework.response import Response from rest_framework.settings import api_settings from rest_framework.views import APIView +from rest_framework.compat import six +import json factory = RequestFactory() @@ -56,21 +58,29 @@ class TestMethodOverloading(TestCase): request = Request(factory.post('/', {api_settings.FORM_METHOD_OVERRIDE: 'DELETE'})) self.assertEqual(request.method, 'DELETE') + def test_x_http_method_override_header(self): + """ + POST requests can also be overloaded to another method by setting + the X-HTTP-Method-Override header. + """ + request = Request(factory.post('/', {'foo': 'bar'}, HTTP_X_HTTP_METHOD_OVERRIDE='DELETE')) + self.assertEqual(request.method, 'DELETE') + class TestContentParsing(TestCase): def test_standard_behaviour_determines_no_content_GET(self): """ - Ensure request.DATA returns None for GET request with no content. + Ensure request.DATA returns empty QueryDict for GET request. """ request = Request(factory.get('/')) - self.assertEqual(request.DATA, None) + self.assertEqual(request.DATA, {}) def test_standard_behaviour_determines_no_content_HEAD(self): """ - Ensure request.DATA returns None for HEAD request. + Ensure request.DATA returns empty QueryDict for HEAD request. """ request = Request(factory.head('/')) - self.assertEqual(request.DATA, None) + self.assertEqual(request.DATA, {}) def test_request_DATA_with_form_content(self): """ @@ -79,14 +89,14 @@ class TestContentParsing(TestCase): data = {'qwerty': 'uiop'} request = Request(factory.post('/', data)) request.parsers = (FormParser(), MultiPartParser()) - self.assertEqual(request.DATA.items(), data.items()) + self.assertEqual(list(request.DATA.items()), list(data.items())) def test_request_DATA_with_text_content(self): """ Ensure request.DATA returns content for POST request with non-form content. """ - content = 'qwerty' + content = six.b('qwerty') content_type = 'text/plain' request = Request(factory.post('/', content, content_type=content_type)) request.parsers = (PlainTextParser(),) @@ -99,7 +109,7 @@ class TestContentParsing(TestCase): data = {'qwerty': 'uiop'} request = Request(factory.post('/', data)) request.parsers = (FormParser(), MultiPartParser()) - self.assertEqual(request.POST.items(), data.items()) + self.assertEqual(list(request.POST.items()), list(data.items())) def test_standard_behaviour_determines_form_content_PUT(self): """ @@ -117,14 +127,14 @@ class TestContentParsing(TestCase): request = Request(factory.put('/', data)) request.parsers = (FormParser(), MultiPartParser()) - self.assertEqual(request.DATA.items(), data.items()) + self.assertEqual(list(request.DATA.items()), list(data.items())) def test_standard_behaviour_determines_non_form_content_PUT(self): """ Ensure request.DATA returns content for PUT request with non-form content. """ - content = 'qwerty' + content = six.b('qwerty') content_type = 'text/plain' request = Request(factory.put('/', content, content_type=content_type)) request.parsers = (PlainTextParser(), ) diff --git a/rest_framework/tests/response.py b/rest_framework/tests/response.py index 875f4d42..aecf83f4 100644 --- a/rest_framework/tests/response.py +++ b/rest_framework/tests/response.py @@ -1,3 +1,4 @@ +from __future__ import unicode_literals from django.test import TestCase from rest_framework.compat import patterns, url, include from rest_framework.response import Response @@ -9,6 +10,7 @@ from rest_framework.renderers import ( BrowsableAPIRenderer ) from rest_framework.settings import api_settings +from rest_framework.compat import six class MockPickleRenderer(BaseRenderer): @@ -22,8 +24,8 @@ class MockJsonRenderer(BaseRenderer): DUMMYSTATUS = status.HTTP_200_OK DUMMYCONTENT = 'dummycontent' -RENDERER_A_SERIALIZER = lambda x: 'Renderer A: %s' % x -RENDERER_B_SERIALIZER = lambda x: 'Renderer B: %s' % x +RENDERER_A_SERIALIZER = lambda x: ('Renderer A: %s' % x).encode('ascii') +RENDERER_B_SERIALIZER = lambda x: ('Renderer B: %s' % x).encode('ascii') class RendererA(BaseRenderer): @@ -83,39 +85,39 @@ class RendererIntegrationTests(TestCase): def test_default_renderer_serializes_content(self): """If the Accept header is not set the default renderer should serialize the response.""" resp = self.client.get('/') - self.assertEquals(resp['Content-Type'], RendererA.media_type) - self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) - self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererA.media_type) + self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) + self.assertEqual(resp.status_code, DUMMYSTATUS) def test_head_method_serializes_no_content(self): """No response must be included in HEAD requests.""" resp = self.client.head('/') - self.assertEquals(resp.status_code, DUMMYSTATUS) - self.assertEquals(resp['Content-Type'], RendererA.media_type) - self.assertEquals(resp.content, '') + self.assertEqual(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererA.media_type) + self.assertEqual(resp.content, six.b('')) def test_default_renderer_serializes_content_on_accept_any(self): """If the Accept header is set to */* the default renderer should serialize the response.""" resp = self.client.get('/', HTTP_ACCEPT='*/*') - self.assertEquals(resp['Content-Type'], RendererA.media_type) - self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) - self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererA.media_type) + self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) + self.assertEqual(resp.status_code, DUMMYSTATUS) def test_specified_renderer_serializes_content_default_case(self): """If the Accept header is set the specified renderer should serialize the response. (In this case we check that works for the default renderer)""" resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type) - self.assertEquals(resp['Content-Type'], RendererA.media_type) - self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) - self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererA.media_type) + self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) + self.assertEqual(resp.status_code, DUMMYSTATUS) def test_specified_renderer_serializes_content_non_default_case(self): """If the Accept header is set the specified renderer should serialize the response. (In this case we check that works for a non-default renderer)""" resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type) - self.assertEquals(resp['Content-Type'], RendererB.media_type) - self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) - self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEqual(resp.status_code, DUMMYSTATUS) def test_specified_renderer_serializes_content_on_accept_query(self): """The '_accept' query string should behave in the same way as the Accept header.""" @@ -124,34 +126,34 @@ class RendererIntegrationTests(TestCase): RendererB.media_type ) resp = self.client.get('/' + param) - self.assertEquals(resp['Content-Type'], RendererB.media_type) - self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) - self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEqual(resp.status_code, DUMMYSTATUS) def test_specified_renderer_serializes_content_on_format_query(self): """If a 'format' query is specified, the renderer with the matching format attribute should serialize the response.""" resp = self.client.get('/?format=%s' % RendererB.format) - self.assertEquals(resp['Content-Type'], RendererB.media_type) - self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) - self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEqual(resp.status_code, DUMMYSTATUS) def test_specified_renderer_serializes_content_on_format_kwargs(self): """If a 'format' keyword arg is specified, the renderer with the matching format attribute should serialize the response.""" resp = self.client.get('/something.formatb') - self.assertEquals(resp['Content-Type'], RendererB.media_type) - self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) - self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEqual(resp.status_code, DUMMYSTATUS) def test_specified_renderer_is_used_on_format_query_with_matching_accept(self): """If both a 'format' query and a matching Accept header specified, the renderer with the matching format attribute should serialize the response.""" resp = self.client.get('/?format=%s' % RendererB.format, HTTP_ACCEPT=RendererB.media_type) - self.assertEquals(resp['Content-Type'], RendererB.media_type) - self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) - self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEqual(resp.status_code, DUMMYSTATUS) class Issue122Tests(TestCase): diff --git a/rest_framework/tests/reverse.py b/rest_framework/tests/reverse.py index 8c86e1fb..cb8d8132 100644 --- a/rest_framework/tests/reverse.py +++ b/rest_framework/tests/reverse.py @@ -1,3 +1,4 @@ +from __future__ import unicode_literals from django.test import TestCase from django.test.client import RequestFactory from rest_framework.compat import patterns, url @@ -16,7 +17,7 @@ urlpatterns = patterns('', class ReverseTests(TestCase): """ - Tests for fully qualifed URLs when using `reverse`. + Tests for fully qualified URLs when using `reverse`. """ urls = 'rest_framework.tests.reverse' diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py index bd96ba23..beb372c2 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/serializer.py @@ -1,10 +1,12 @@ -import datetime -import pickle +from __future__ import unicode_literals +from django.utils.datastructures import MultiValueDict from django.test import TestCase from rest_framework import serializers from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel, BlankFieldModel, BlogPost, Book, CallableDefaultValueModel, DefaultValueModel, ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo) +import datetime +import pickle class SubComment(object): @@ -54,6 +56,19 @@ class ActionItemSerializer(serializers.ModelSerializer): model = ActionItem +class ActionItemSerializerCustomRestore(serializers.ModelSerializer): + + class Meta: + model = ActionItem + + def restore_object(self, data, instance=None): + if instance is None: + return ActionItem(**data) + for key, val in data.items(): + setattr(instance, key, val) + return instance + + class PersonSerializer(serializers.ModelSerializer): info = serializers.Field(source='info') @@ -76,6 +91,11 @@ class PositiveIntegerAsChoiceSerializer(serializers.ModelSerializer): fields = ['some_integer'] +class BrokenModelSerializer(serializers.ModelSerializer): + class Meta: + fields = ['some_field'] + + class BasicTests(TestCase): def setUp(self): self.comment = Comment( @@ -92,7 +112,7 @@ class BasicTests(TestCase): self.expected = { 'email': 'tom@example.com', 'content': 'Happy new year!', - 'created': datetime.datetime(2012, 1, 1), + 'created': '2012-01-01T00:00:00', 'sub_comment': 'And Merry Christmas!' } self.person_data = {'name': 'dwight', 'age': 35} @@ -107,39 +127,39 @@ class BasicTests(TestCase): 'created': None, 'sub_comment': '' } - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_retrieve(self): serializer = CommentSerializer(self.comment) - self.assertEquals(serializer.data, self.expected) + self.assertEqual(serializer.data, self.expected) def test_create(self): serializer = CommentSerializer(data=self.data) expected = self.comment - self.assertEquals(serializer.is_valid(), True) - self.assertEquals(serializer.object, expected) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.object, expected) self.assertFalse(serializer.object is expected) - self.assertEquals(serializer.data['sub_comment'], 'And Merry Christmas!') + self.assertEqual(serializer.data['sub_comment'], 'And Merry Christmas!') def test_update(self): serializer = CommentSerializer(self.comment, data=self.data) expected = self.comment - self.assertEquals(serializer.is_valid(), True) - self.assertEquals(serializer.object, expected) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.object, expected) self.assertTrue(serializer.object is expected) - self.assertEquals(serializer.data['sub_comment'], 'And Merry Christmas!') + self.assertEqual(serializer.data['sub_comment'], 'And Merry Christmas!') def test_partial_update(self): msg = 'Merry New Year!' partial_data = {'content': msg} serializer = CommentSerializer(self.comment, data=partial_data) - self.assertEquals(serializer.is_valid(), False) + self.assertEqual(serializer.is_valid(), False) serializer = CommentSerializer(self.comment, data=partial_data, partial=True) expected = self.comment self.assertEqual(serializer.is_valid(), True) - self.assertEquals(serializer.object, expected) + self.assertEqual(serializer.object, expected) self.assertTrue(serializer.object is expected) - self.assertEquals(serializer.data['content'], msg) + self.assertEqual(serializer.data['content'], msg) def test_model_fields_as_expected(self): """ @@ -147,7 +167,7 @@ class BasicTests(TestCase): in the Meta data """ serializer = PersonSerializer(self.person) - self.assertEquals(set(serializer.data.keys()), + self.assertEqual(set(serializer.data.keys()), set(['name', 'age', 'info'])) def test_field_with_dictionary(self): @@ -156,19 +176,45 @@ class BasicTests(TestCase): """ serializer = PersonSerializer(self.person) expected = self.person_data - self.assertEquals(serializer.data['info'], expected) + self.assertEqual(serializer.data['info'], expected) def test_read_only_fields(self): """ Attempting to update fields set as read_only should have no effect. """ - serializer = PersonSerializer(self.person, data={'name': 'dwight', 'age': 99}) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) instance = serializer.save() - self.assertEquals(serializer.errors, {}) + self.assertEqual(serializer.errors, {}) # Assert age is unchanged (35) - self.assertEquals(instance.age, self.person_data['age']) + self.assertEqual(instance.age, self.person_data['age']) + + +class DictStyleSerializer(serializers.Serializer): + """ + Note that we don't have any `restore_object` method, so the default + case of simply returning a dict will apply. + """ + email = serializers.EmailField() + + +class DictStyleSerializerTests(TestCase): + def test_dict_style_deserialize(self): + """ + Ensure serializers can deserialize into a dict. + """ + data = {'email': 'foo@example.com'} + serializer = DictStyleSerializer(data=data) + self.assertTrue(serializer.is_valid()) + self.assertEqual(serializer.data, data) + + def test_dict_style_serialize(self): + """ + Ensure serializers can serialize dict objects. + """ + data = {'email': 'foo@example.com'} + serializer = DictStyleSerializer(data) + self.assertEqual(serializer.data, data) class ValidationTests(TestCase): @@ -183,18 +229,17 @@ class ValidationTests(TestCase): 'content': 'x' * 1001, 'created': datetime.datetime(2012, 1, 1) } - self.actionitem = ActionItem(title='Some to do item', - ) + self.actionitem = ActionItem(title='Some to do item',) def test_create(self): serializer = CommentSerializer(data=self.data) - self.assertEquals(serializer.is_valid(), False) - self.assertEquals(serializer.errors, {'content': [u'Ensure this value has at most 1000 characters (it has 1001).']}) + self.assertEqual(serializer.is_valid(), False) + self.assertEqual(serializer.errors, {'content': ['Ensure this value has at most 1000 characters (it has 1001).']}) def test_update(self): serializer = CommentSerializer(self.comment, data=self.data) - self.assertEquals(serializer.is_valid(), False) - self.assertEquals(serializer.errors, {'content': [u'Ensure this value has at most 1000 characters (it has 1001).']}) + self.assertEqual(serializer.is_valid(), False) + self.assertEqual(serializer.errors, {'content': ['Ensure this value has at most 1000 characters (it has 1001).']}) def test_update_missing_field(self): data = { @@ -202,8 +247,8 @@ class ValidationTests(TestCase): 'created': datetime.datetime(2012, 1, 1) } serializer = CommentSerializer(self.comment, data=data) - self.assertEquals(serializer.is_valid(), False) - self.assertEquals(serializer.errors, {'email': [u'This field is required.']}) + self.assertEqual(serializer.is_valid(), False) + self.assertEqual(serializer.errors, {'email': ['This field is required.']}) def test_missing_bool_with_default(self): """Make sure that a boolean value with a 'False' value is not @@ -213,52 +258,36 @@ class ValidationTests(TestCase): #No 'done' value. } serializer = ActionItemSerializer(self.actionitem, data=data) - self.assertEquals(serializer.is_valid(), True) - self.assertEquals(serializer.errors, {}) - - def test_field_validation(self): - - class CommentSerializerWithFieldValidator(CommentSerializer): - - def validate_content(self, attrs, source): - value = attrs[source] - if "test" not in value: - raise serializers.ValidationError("Test not in value") - return attrs - - data = { - 'email': 'tom@example.com', - 'content': 'A test comment', - 'created': datetime.datetime(2012, 1, 1) - } - - serializer = CommentSerializerWithFieldValidator(data=data) - self.assertTrue(serializer.is_valid()) - - data['content'] = 'This should not validate' - - serializer = CommentSerializerWithFieldValidator(data=data) - self.assertFalse(serializer.is_valid()) - self.assertEquals(serializer.errors, {'content': [u'Test not in value']}) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.errors, {}) def test_bad_type_data_is_false(self): """ Data of the wrong type is not valid. """ data = ['i am', 'a', 'list'] - serializer = CommentSerializer(self.comment, data=data) - self.assertEquals(serializer.is_valid(), False) - self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']}) + serializer = CommentSerializer(self.comment, data=data, many=True) + self.assertEqual(serializer.is_valid(), False) + self.assertTrue(isinstance(serializer.errors, list)) + + self.assertEqual( + serializer.errors, + [ + {'non_field_errors': ['Invalid data']}, + {'non_field_errors': ['Invalid data']}, + {'non_field_errors': ['Invalid data']} + ] + ) data = 'and i am a string' serializer = CommentSerializer(self.comment, data=data) - self.assertEquals(serializer.is_valid(), False) - self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']}) + self.assertEqual(serializer.is_valid(), False) + self.assertEqual(serializer.errors, {'non_field_errors': ['Invalid data']}) data = 42 serializer = CommentSerializer(self.comment, data=data) - self.assertEquals(serializer.is_valid(), False) - self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']}) + self.assertEqual(serializer.is_valid(), False) + self.assertEqual(serializer.errors, {'non_field_errors': ['Invalid data']}) def test_cross_field_validation(self): @@ -282,23 +311,37 @@ class ValidationTests(TestCase): serializer = CommentSerializerWithCrossFieldValidator(data=data) self.assertFalse(serializer.is_valid()) - self.assertEquals(serializer.errors, {'non_field_errors': [u'Email address not in content']}) + self.assertEqual(serializer.errors, {'non_field_errors': ['Email address not in content']}) def test_null_is_true_fields(self): """ Omitting a value for null-field should validate. """ serializer = PersonSerializer(data={'name': 'marko'}) - self.assertEquals(serializer.is_valid(), True) - self.assertEquals(serializer.errors, {}) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.errors, {}) def test_modelserializer_max_length_exceeded(self): data = { 'title': 'x' * 201, } serializer = ActionItemSerializer(data=data) - self.assertEquals(serializer.is_valid(), False) - self.assertEquals(serializer.errors, {'title': [u'Ensure this value has at most 200 characters (it has 201).']}) + self.assertEqual(serializer.is_valid(), False) + self.assertEqual(serializer.errors, {'title': ['Ensure this value has at most 200 characters (it has 201).']}) + + def test_modelserializer_max_length_exceeded_with_custom_restore(self): + """ + When overriding ModelSerializer.restore_object, validation tests should still apply. + Regression test for #623. + + https://github.com/tomchristie/django-rest-framework/pull/623 + """ + data = { + 'title': 'x' * 201, + } + serializer = ActionItemSerializerCustomRestore(data=data) + self.assertEqual(serializer.is_valid(), False) + self.assertEqual(serializer.errors, {'title': ['Ensure this value has at most 200 characters (it has 201).']}) def test_default_modelfield_max_length_exceeded(self): data = { @@ -306,15 +349,99 @@ class ValidationTests(TestCase): 'info': 'x' * 13, } serializer = ActionItemSerializer(data=data) - self.assertEquals(serializer.is_valid(), False) - self.assertEquals(serializer.errors, {'info': [u'Ensure this value has at most 12 characters (it has 13).']}) + self.assertEqual(serializer.is_valid(), False) + self.assertEqual(serializer.errors, {'info': ['Ensure this value has at most 12 characters (it has 13).']}) + + def test_datetime_validation_failure(self): + """ + Test DateTimeField validation errors on non-str values. + Regression test for #669. + + https://github.com/tomchristie/django-rest-framework/issues/669 + """ + data = self.data + data['created'] = 0 + + serializer = CommentSerializer(data=data) + self.assertEqual(serializer.is_valid(), False) + + self.assertIn('created', serializer.errors) + + def test_missing_model_field_exception_msg(self): + """ + Assert that a meaningful exception message is outputted when the model + field is missing (e.g. when mistyping ``model``). + """ + try: + serializer = BrokenModelSerializer() + except AssertionError as e: + self.assertEqual(e.args[0], "Serializer class 'BrokenModelSerializer' is missing 'model' Meta option") + except: + self.fail('Wrong exception type thrown.') + + +class CustomValidationTests(TestCase): + class CommentSerializerWithFieldValidator(CommentSerializer): + + def validate_email(self, attrs, source): + value = attrs[source] + + return attrs + + def validate_content(self, attrs, source): + value = attrs[source] + if "test" not in value: + raise serializers.ValidationError("Test not in value") + return attrs + + def test_field_validation(self): + data = { + 'email': 'tom@example.com', + 'content': 'A test comment', + 'created': datetime.datetime(2012, 1, 1) + } + + serializer = self.CommentSerializerWithFieldValidator(data=data) + self.assertTrue(serializer.is_valid()) + + data['content'] = 'This should not validate' + + serializer = self.CommentSerializerWithFieldValidator(data=data) + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'content': ['Test not in value']}) + + def test_missing_data(self): + """ + Make sure that validate_content isn't called if the field is missing + """ + incomplete_data = { + 'email': 'tom@example.com', + 'created': datetime.datetime(2012, 1, 1) + } + serializer = self.CommentSerializerWithFieldValidator(data=incomplete_data) + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'content': ['This field is required.']}) + + def test_wrong_data(self): + """ + Make sure that validate_content isn't called if the field input is wrong + """ + wrong_data = { + 'email': 'not an email', + 'content': 'A test comment', + 'created': datetime.datetime(2012, 1, 1) + } + serializer = self.CommentSerializerWithFieldValidator(data=wrong_data) + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'email': ['Enter a valid e-mail address.']}) class PositiveIntegerAsChoiceTests(TestCase): def test_positive_integer_in_json_is_correctly_parsed(self): - data = {'some_integer':1} + data = {'some_integer': 1} serializer = PositiveIntegerAsChoiceSerializer(data=data) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) + class ModelValidationTests(TestCase): def test_validate_unique(self): @@ -326,7 +453,7 @@ class ModelValidationTests(TestCase): serializer.save() second_serializer = AlbumsSerializer(data={'title': 'a'}) self.assertFalse(second_serializer.is_valid()) - self.assertEqual(second_serializer.errors, {'title': [u'Album with this Title already exists.']}) + self.assertEqual(second_serializer.errors, {'title': ['Album with this Title already exists.']}) def test_foreign_key_with_partial(self): """ @@ -364,15 +491,15 @@ class RegexValidationTest(TestCase): def test_create_failed(self): serializer = BookSerializer(data={'isbn': '1234567890'}) self.assertFalse(serializer.is_valid()) - self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']}) + self.assertEqual(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']}) serializer = BookSerializer(data={'isbn': '12345678901234'}) self.assertFalse(serializer.is_valid()) - self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']}) + self.assertEqual(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']}) serializer = BookSerializer(data={'isbn': 'abcdefghijklm'}) self.assertFalse(serializer.is_valid()) - self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']}) + self.assertEqual(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']}) def test_create_success(self): serializer = BookSerializer(data={'isbn': '1234567890123'}) @@ -417,7 +544,7 @@ class ManyToManyTests(TestCase): """ serializer = self.serializer_class(instance=self.instance) expected = self.data - self.assertEquals(serializer.data, expected) + self.assertEqual(serializer.data, expected) def test_create(self): """ @@ -425,11 +552,11 @@ class ManyToManyTests(TestCase): """ data = {'rel': [self.anchor.id]} serializer = self.serializer_class(data=data) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) instance = serializer.save() - self.assertEquals(len(ManyToManyModel.objects.all()), 2) - self.assertEquals(instance.pk, 2) - self.assertEquals(list(instance.rel.all()), [self.anchor]) + self.assertEqual(len(ManyToManyModel.objects.all()), 2) + self.assertEqual(instance.pk, 2) + self.assertEqual(list(instance.rel.all()), [self.anchor]) def test_update(self): """ @@ -439,11 +566,11 @@ class ManyToManyTests(TestCase): new_anchor.save() data = {'rel': [self.anchor.id, new_anchor.id]} serializer = self.serializer_class(self.instance, data=data) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) instance = serializer.save() - self.assertEquals(len(ManyToManyModel.objects.all()), 1) - self.assertEquals(instance.pk, 1) - self.assertEquals(list(instance.rel.all()), [self.anchor, new_anchor]) + self.assertEqual(len(ManyToManyModel.objects.all()), 1) + self.assertEqual(instance.pk, 1) + self.assertEqual(list(instance.rel.all()), [self.anchor, new_anchor]) def test_create_empty_relationship(self): """ @@ -452,11 +579,11 @@ class ManyToManyTests(TestCase): """ data = {'rel': []} serializer = self.serializer_class(data=data) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) instance = serializer.save() - self.assertEquals(len(ManyToManyModel.objects.all()), 2) - self.assertEquals(instance.pk, 2) - self.assertEquals(list(instance.rel.all()), []) + self.assertEqual(len(ManyToManyModel.objects.all()), 2) + self.assertEqual(instance.pk, 2) + self.assertEqual(list(instance.rel.all()), []) def test_update_empty_relationship(self): """ @@ -467,11 +594,11 @@ class ManyToManyTests(TestCase): new_anchor.save() data = {'rel': []} serializer = self.serializer_class(self.instance, data=data) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) instance = serializer.save() - self.assertEquals(len(ManyToManyModel.objects.all()), 1) - self.assertEquals(instance.pk, 1) - self.assertEquals(list(instance.rel.all()), []) + self.assertEqual(len(ManyToManyModel.objects.all()), 1) + self.assertEqual(instance.pk, 1) + self.assertEqual(list(instance.rel.all()), []) def test_create_empty_relationship_flat_data(self): """ @@ -479,19 +606,20 @@ class ManyToManyTests(TestCase): containing no items, using a representation that does not support lists (eg form data). """ - data = {'rel': ''} + data = MultiValueDict() + data.setlist('rel', ['']) serializer = self.serializer_class(data=data) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) instance = serializer.save() - self.assertEquals(len(ManyToManyModel.objects.all()), 2) - self.assertEquals(instance.pk, 2) - self.assertEquals(list(instance.rel.all()), []) + self.assertEqual(len(ManyToManyModel.objects.all()), 2) + self.assertEqual(instance.pk, 2) + self.assertEqual(list(instance.rel.all()), []) class ReadOnlyManyToManyTests(TestCase): def setUp(self): class ReadOnlyManyToManySerializer(serializers.ModelSerializer): - rel = serializers.ManyRelatedField(read_only=True) + rel = serializers.RelatedField(many=True, read_only=True) class Meta: model = ReadOnlyManyToManyModel @@ -519,12 +647,12 @@ class ReadOnlyManyToManyTests(TestCase): new_anchor.save() data = {'rel': [self.anchor.id, new_anchor.id]} serializer = self.serializer_class(self.instance, data=data) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) instance = serializer.save() - self.assertEquals(len(ReadOnlyManyToManyModel.objects.all()), 1) - self.assertEquals(instance.pk, 1) + self.assertEqual(len(ReadOnlyManyToManyModel.objects.all()), 1) + self.assertEqual(instance.pk, 1) # rel is still as original (1 entry) - self.assertEquals(list(instance.rel.all()), [self.anchor]) + self.assertEqual(list(instance.rel.all()), [self.anchor]) def test_update_without_relationship(self): """ @@ -535,12 +663,12 @@ class ReadOnlyManyToManyTests(TestCase): new_anchor.save() data = {} serializer = self.serializer_class(self.instance, data=data) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) instance = serializer.save() - self.assertEquals(len(ReadOnlyManyToManyModel.objects.all()), 1) - self.assertEquals(instance.pk, 1) + self.assertEqual(len(ReadOnlyManyToManyModel.objects.all()), 1) + self.assertEqual(instance.pk, 1) # rel is still as original (1 entry) - self.assertEquals(list(instance.rel.all()), [self.anchor]) + self.assertEqual(list(instance.rel.all()), [self.anchor]) class DefaultValueTests(TestCase): @@ -555,35 +683,35 @@ class DefaultValueTests(TestCase): def test_create_using_default(self): data = {} serializer = self.serializer_class(data=data) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) instance = serializer.save() - self.assertEquals(len(self.objects.all()), 1) - self.assertEquals(instance.pk, 1) - self.assertEquals(instance.text, 'foobar') + self.assertEqual(len(self.objects.all()), 1) + self.assertEqual(instance.pk, 1) + self.assertEqual(instance.text, 'foobar') def test_create_overriding_default(self): data = {'text': 'overridden'} serializer = self.serializer_class(data=data) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) instance = serializer.save() - self.assertEquals(len(self.objects.all()), 1) - self.assertEquals(instance.pk, 1) - self.assertEquals(instance.text, 'overridden') + self.assertEqual(len(self.objects.all()), 1) + self.assertEqual(instance.pk, 1) + self.assertEqual(instance.text, 'overridden') def test_partial_update_default(self): """ Regression test for issue #532 """ data = {'text': 'overridden'} serializer = self.serializer_class(data=data, partial=True) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) instance = serializer.save() data = {'extra': 'extra_value'} serializer = self.serializer_class(instance=instance, data=data, partial=True) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) instance = serializer.save() - self.assertEquals(instance.extra, 'extra_value') - self.assertEquals(instance.text, 'overridden') + self.assertEqual(instance.extra, 'extra_value') + self.assertEqual(instance.text, 'overridden') class CallableDefaultValueTests(TestCase): @@ -598,20 +726,20 @@ class CallableDefaultValueTests(TestCase): def test_create_using_default(self): data = {} serializer = self.serializer_class(data=data) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) instance = serializer.save() - self.assertEquals(len(self.objects.all()), 1) - self.assertEquals(instance.pk, 1) - self.assertEquals(instance.text, 'foobar') + self.assertEqual(len(self.objects.all()), 1) + self.assertEqual(instance.pk, 1) + self.assertEqual(instance.text, 'foobar') def test_create_overriding_default(self): data = {'text': 'overridden'} serializer = self.serializer_class(data=data) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) instance = serializer.save() - self.assertEquals(len(self.objects.all()), 1) - self.assertEquals(instance.pk, 1) - self.assertEquals(instance.text, 'overridden') + self.assertEqual(len(self.objects.all()), 1) + self.assertEqual(instance.pk, 1) + self.assertEqual(instance.text, 'overridden') class ManyRelatedTests(TestCase): @@ -660,6 +788,9 @@ class ManyRelatedTests(TestCase): class RelatedTraversalTest(TestCase): def test_nested_traversal(self): + """ + Source argument should support dotted.source notation. + """ user = Person.objects.create(name="django") post = BlogPost.objects.create(title="Test blog post", writer=user) post.blogpostcomment_set.create(text="I love this blog post") @@ -686,11 +817,11 @@ class RelatedTraversalTest(TestCase): serializer = BlogPostSerializer(instance=post) expected = { - 'title': u'Test blog post', + 'title': 'Test blog post', 'comments': [{ - 'text': u'I love this blog post', + 'text': 'I love this blog post', 'post_owner': { - "name": u"django", + "name": "django", "age": None } }] @@ -698,6 +829,41 @@ class RelatedTraversalTest(TestCase): self.assertEqual(serializer.data, expected) + def test_nested_traversal_with_none(self): + """ + If a component of the dotted.source is None, return None for the field. + """ + from rest_framework.tests.models import NullableForeignKeySource + instance = NullableForeignKeySource.objects.create(name='Source with null FK') + + class NullableSourceSerializer(serializers.Serializer): + target_name = serializers.Field(source='target.name') + + serializer = NullableSourceSerializer(instance=instance) + + expected = { + 'target_name': None, + } + + self.assertEqual(serializer.data, expected) + + def test_queryset_nested_traversal(self): + """ + Relational fields should be able to use methods as their source. + """ + BlogPost.objects.create(title='blah') + + class QuerysetMethodSerializer(serializers.Serializer): + blogposts = serializers.RelatedField(many=True, source='get_all_blogposts') + + class ClassWithQuerysetMethod(object): + def get_all_blogposts(self): + return BlogPost.objects + + obj = ClassWithQuerysetMethod() + serializer = QuerysetMethodSerializer(obj) + self.assertEqual(serializer.data, {'blogposts': ['BlogPost object']}) + class SerializerMethodFieldTests(TestCase): def setUp(self): @@ -725,8 +891,8 @@ class SerializerMethodFieldTests(TestCase): serializer = self.serializer_class(source_data) expected = { - 'beep': u'hello!', - 'boop': [u'a', u'b', u'c'], + 'beep': 'hello!', + 'boop': ['a', 'b', 'c'], 'boop_count': 3, } @@ -742,7 +908,7 @@ class BlankFieldTests(TestCase): model = BlankFieldModel class BlankFieldSerializer(serializers.Serializer): - title = serializers.CharField(blank=True) + title = serializers.CharField(required=False) class NotBlankFieldModelSerializer(serializers.ModelSerializer): class Meta: @@ -759,15 +925,15 @@ class BlankFieldTests(TestCase): def test_create_blank_field(self): serializer = self.serializer_class(data=self.data) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) def test_create_model_blank_field(self): serializer = self.model_serializer_class(data=self.data) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) def test_create_model_null_field(self): serializer = self.model_serializer_class(data={'title': None}) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) def test_create_not_blank_field(self): """ @@ -775,7 +941,7 @@ class BlankFieldTests(TestCase): is considered invalid in a non-model serializer """ serializer = self.not_blank_serializer_class(data=self.data) - self.assertEquals(serializer.is_valid(), False) + self.assertEqual(serializer.is_valid(), False) def test_create_model_not_blank_field(self): """ @@ -783,11 +949,11 @@ class BlankFieldTests(TestCase): is considered invalid in a model serializer """ serializer = self.not_blank_model_serializer_class(data=self.data) - self.assertEquals(serializer.is_valid(), False) + self.assertEqual(serializer.is_valid(), False) - def test_create_model_null_field(self): + def test_create_model_empty_field(self): serializer = self.model_serializer_class(data={}) - self.assertEquals(serializer.is_valid(), True) + self.assertEqual(serializer.is_valid(), True) #test for issue #460 @@ -811,7 +977,21 @@ class SerializerPickleTests(TestCase): class Meta: model = Person fields = ('name', 'age') - pickle.dumps(InnerPersonSerializer(Person(name="Noah", age=950)).data) + pickle.dumps(InnerPersonSerializer(Person(name="Noah", age=950)).data, 0) + + def test_getstate_method_should_not_return_none(self): + """ + Regression test for #645. + """ + data = serializers.DictWithMetadata({1: 1}) + self.assertEqual(data.__getstate__(), serializers.SortedDict({1: 1})) + + def test_serializer_data_is_pickleable(self): + """ + Another regression test for #645. + """ + data = serializers.SortedDictWithMetadata({1: 1}) + repr(pickle.loads(pickle.dumps(data, 0))) class DepthTest(TestCase): @@ -825,8 +1005,8 @@ class DepthTest(TestCase): depth = 1 serializer = BlogPostSerializer(instance=post) - expected = {'id': 1, 'title': u'Test blog post', - 'writer': {'id': 1, 'name': u'django', 'age': 1}} + expected = {'id': 1, 'title': 'Test blog post', + 'writer': {'id': 1, 'name': 'django', 'age': 1}} self.assertEqual(serializer.data, expected) @@ -845,8 +1025,8 @@ class DepthTest(TestCase): model = BlogPost serializer = BlogPostSerializer(instance=post) - expected = {'id': 1, 'title': u'Test blog post', - 'writer': {'id': 1, 'name': u'django', 'age': 1}} + expected = {'id': 1, 'title': 'Test blog post', + 'writer': {'id': 1, 'name': 'django', 'age': 1}} self.assertEqual(serializer.data, expected) @@ -901,3 +1081,32 @@ class NestedSerializerContextTests(TestCase): # This will raise RuntimeError if context doesn't get passed correctly to the nested Serializers AlbumCollectionSerializer(album_collection, context={'context_item': 'album context'}).data + + +class DeserializeListTestCase(TestCase): + + def setUp(self): + self.data = { + 'email': 'nobody@nowhere.com', + 'content': 'This is some test content', + 'created': datetime.datetime(2013, 3, 7), + } + + def test_no_errors(self): + data = [self.data.copy() for x in range(0, 3)] + serializer = CommentSerializer(data=data) + self.assertTrue(serializer.is_valid()) + self.assertTrue(isinstance(serializer.object, list)) + self.assertTrue( + all((isinstance(item, Comment) for item in serializer.object)) + ) + + def test_errors_return_as_list(self): + invalid_item = self.data.copy() + invalid_item['email'] = '' + data = [self.data.copy(), invalid_item, self.data.copy()] + + serializer = CommentSerializer(data=data) + self.assertFalse(serializer.is_valid()) + expected = [{}, {'email': ['This field is required.']}, {}] + self.assertEqual(serializer.errors, expected) diff --git a/rest_framework/tests/settings.py b/rest_framework/tests/settings.py index 0293fdc3..857375c2 100644 --- a/rest_framework/tests/settings.py +++ b/rest_framework/tests/settings.py @@ -1,4 +1,5 @@ """Tests for the settings module""" +from __future__ import unicode_literals from django.test import TestCase from rest_framework.settings import APISettings, DEFAULTS, IMPORT_STRINGS diff --git a/rest_framework/tests/status.py b/rest_framework/tests/status.py index 30df5cef..e1644a6b 100644 --- a/rest_framework/tests/status.py +++ b/rest_framework/tests/status.py @@ -1,4 +1,5 @@ """Tests for the status module""" +from __future__ import unicode_literals from django.test import TestCase from rest_framework import status @@ -8,5 +9,5 @@ class TestStatus(TestCase): def test_status(self): """Ensure the status module is present and correct.""" - self.assertEquals(200, status.HTTP_200_OK) - self.assertEquals(404, status.HTTP_404_NOT_FOUND) + self.assertEqual(200, status.HTTP_200_OK) + self.assertEqual(404, status.HTTP_404_NOT_FOUND) diff --git a/rest_framework/tests/testcases.py b/rest_framework/tests/testcases.py index 97f492ff..f8c2579e 100644 --- a/rest_framework/tests/testcases.py +++ b/rest_framework/tests/testcases.py @@ -1,4 +1,5 @@ # http://djangosnippets.org/snippets/1011/ +from __future__ import unicode_literals from django.conf import settings from django.core.management import call_command from django.db.models import loading diff --git a/rest_framework/tests/tests.py b/rest_framework/tests/tests.py index adeaf6da..08f88e11 100644 --- a/rest_framework/tests/tests.py +++ b/rest_framework/tests/tests.py @@ -2,6 +2,7 @@ Force import of all modules in this package in order to get the standard test runner to pick up the tests. Yowzers. """ +from __future__ import unicode_literals import os modules = [filename.rsplit('.', 1)[0] diff --git a/rest_framework/tests/throttling.py b/rest_framework/tests/throttling.py index 4b98b941..11cbd8eb 100644 --- a/rest_framework/tests/throttling.py +++ b/rest_framework/tests/throttling.py @@ -1,11 +1,10 @@ """ Tests for the throttling implementations in the permissions module. """ - +from __future__ import unicode_literals from django.test import TestCase from django.contrib.auth.models import User from django.core.cache import cache - from django.test.client import RequestFactory from rest_framework.views import APIView from rest_framework.throttling import UserRateThrottle @@ -104,7 +103,7 @@ class ThrottlingTests(TestCase): self.set_throttle_timer(view, timer) response = view.as_view()(request) if expect is not None: - self.assertEquals(response['X-Throttle-Wait-Seconds'], expect) + self.assertEqual(response['X-Throttle-Wait-Seconds'], expect) else: self.assertFalse('X-Throttle-Wait-Seconds' in response) diff --git a/rest_framework/tests/urlpatterns.py b/rest_framework/tests/urlpatterns.py new file mode 100644 index 00000000..29ed4a96 --- /dev/null +++ b/rest_framework/tests/urlpatterns.py @@ -0,0 +1,76 @@ +from __future__ import unicode_literals +from collections import namedtuple +from django.core import urlresolvers +from django.test import TestCase +from django.test.client import RequestFactory +from rest_framework.compat import patterns, url, include +from rest_framework.urlpatterns import format_suffix_patterns + + +# A container class for test paths for the test case +URLTestPath = namedtuple('URLTestPath', ['path', 'args', 'kwargs']) + + +def dummy_view(request, *args, **kwargs): + pass + + +class FormatSuffixTests(TestCase): + """ + Tests `format_suffix_patterns` against different URLPatterns to ensure the URLs still resolve properly, including any captured parameters. + """ + def _resolve_urlpatterns(self, urlpatterns, test_paths): + factory = RequestFactory() + try: + urlpatterns = format_suffix_patterns(urlpatterns) + except Exception: + self.fail("Failed to apply `format_suffix_patterns` on the supplied urlpatterns") + resolver = urlresolvers.RegexURLResolver(r'^/', urlpatterns) + for test_path in test_paths: + request = factory.get(test_path.path) + try: + callback, callback_args, callback_kwargs = resolver.resolve(request.path_info) + except Exception: + self.fail("Failed to resolve URL: %s" % request.path_info) + self.assertEqual(callback_args, test_path.args) + self.assertEqual(callback_kwargs, test_path.kwargs) + + def test_format_suffix(self): + urlpatterns = patterns( + '', + url(r'^test$', dummy_view), + ) + test_paths = [ + URLTestPath('/test', (), {}), + URLTestPath('/test.api', (), {'format': 'api'}), + URLTestPath('/test.asdf', (), {'format': 'asdf'}), + ] + self._resolve_urlpatterns(urlpatterns, test_paths) + + def test_default_args(self): + urlpatterns = patterns( + '', + url(r'^test$', dummy_view, {'foo': 'bar'}), + ) + test_paths = [ + URLTestPath('/test', (), {'foo': 'bar', }), + URLTestPath('/test.api', (), {'foo': 'bar', 'format': 'api'}), + URLTestPath('/test.asdf', (), {'foo': 'bar', 'format': 'asdf'}), + ] + self._resolve_urlpatterns(urlpatterns, test_paths) + + def test_included_urls(self): + nested_patterns = patterns( + '', + url(r'^path$', dummy_view) + ) + urlpatterns = patterns( + '', + url(r'^test/', include(nested_patterns), {'foo': 'bar'}), + ) + test_paths = [ + URLTestPath('/test/path', (), {'foo': 'bar', }), + URLTestPath('/test/path.api', (), {'foo': 'bar', 'format': 'api'}), + URLTestPath('/test/path.asdf', (), {'foo': 'bar', 'format': 'asdf'}), + ] + self._resolve_urlpatterns(urlpatterns, test_paths) diff --git a/rest_framework/tests/utils.py b/rest_framework/tests/utils.py index 3906adb9..8c87917d 100644 --- a/rest_framework/tests/utils.py +++ b/rest_framework/tests/utils.py @@ -1,9 +1,10 @@ -from django.test.client import RequestFactory, FakePayload +from __future__ import unicode_literals +from django.test.client import FakePayload, Client as _Client, RequestFactory as _RequestFactory from django.test.client import MULTIPART_CONTENT -from urlparse import urlparse +from rest_framework.compat import urlparse -class RequestFactory(RequestFactory): +class RequestFactory(_RequestFactory): def __init__(self, **defaults): super(RequestFactory, self).__init__(**defaults) @@ -14,7 +15,7 @@ class RequestFactory(RequestFactory): patch_data = self._encode_data(data, content_type) - parsed = urlparse(path) + parsed = urlparse.urlparse(path) r = { 'CONTENT_LENGTH': len(patch_data), 'CONTENT_TYPE': content_type, @@ -25,3 +26,15 @@ class RequestFactory(RequestFactory): } r.update(extra) return self.request(**r) + + +class Client(_Client, RequestFactory): + def patch(self, path, data={}, content_type=MULTIPART_CONTENT, + follow=False, **extra): + """ + Send a resource to the server using PATCH. + """ + response = super(Client, self).patch(path, data=data, content_type=content_type, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response diff --git a/rest_framework/tests/validation.py b/rest_framework/tests/validation.py new file mode 100644 index 00000000..cbdd6515 --- /dev/null +++ b/rest_framework/tests/validation.py @@ -0,0 +1,65 @@ +from __future__ import unicode_literals +from django.db import models +from django.test import TestCase +from rest_framework import generics, serializers, status +from rest_framework.tests.utils import RequestFactory +import json + +factory = RequestFactory() + + +# Regression for #666 + +class ValidationModel(models.Model): + blank_validated_field = models.CharField(max_length=255) + + +class ValidationModelSerializer(serializers.ModelSerializer): + class Meta: + model = ValidationModel + fields = ('blank_validated_field',) + read_only_fields = ('blank_validated_field',) + + +class UpdateValidationModel(generics.RetrieveUpdateDestroyAPIView): + model = ValidationModel + serializer_class = ValidationModelSerializer + + +class TestPreSaveValidationExclusions(TestCase): + def test_pre_save_validation_exclusions(self): + """ + Somewhat weird test case to ensure that we don't perform model + validation on read only fields. + """ + obj = ValidationModel.objects.create(blank_validated_field='') + request = factory.put('/', json.dumps({}), + content_type='application/json') + view = UpdateValidationModel().as_view() + response = view(request, pk=obj.pk).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + + +# Regression for #653 + +class ShouldValidateModel(models.Model): + should_validate_field = models.CharField(max_length=255) + + +class ShouldValidateModelSerializer(serializers.ModelSerializer): + renamed = serializers.CharField(source='should_validate_field', required=False) + + class Meta: + model = ShouldValidateModel + fields = ('renamed',) + + +class TestPreSaveValidationExclusions(TestCase): + def test_renamed_fields_are_model_validated(self): + """ + Ensure fields with 'source' applied do get still get model validation. + """ + # We've set `required=False` on the serializer, but the model + # does not have `blank=True`, so this serializer should not validate. + serializer = ShouldValidateModelSerializer(data={'renamed': ''}) + self.assertEqual(serializer.is_valid(), False) diff --git a/rest_framework/tests/validators.py b/rest_framework/tests/validators.py deleted file mode 100644 index c032985e..00000000 --- a/rest_framework/tests/validators.py +++ /dev/null @@ -1,329 +0,0 @@ -# from django import forms -# from django.db import models -# from django.test import TestCase -# from rest_framework.response import ImmediateResponse -# from rest_framework.views import View - - -# class TestDisabledValidations(TestCase): -# """Tests on FormValidator with validation disabled by setting form to None""" - -# def test_disabled_form_validator_returns_content_unchanged(self): -# """If the view's form attribute is None then FormValidator(view).validate_request(content, None) -# should just return the content unmodified.""" -# class DisabledFormResource(FormResource): -# form = None - -# class MockView(View): -# resource = DisabledFormResource - -# view = MockView() -# content = {'qwerty': 'uiop'} -# self.assertEqual(FormResource(view).validate_request(content, None), content) - -# def test_disabled_form_validator_get_bound_form_returns_none(self): -# """If the view's form attribute is None on then -# FormValidator(view).get_bound_form(content) should just return None.""" -# class DisabledFormResource(FormResource): -# form = None - -# class MockView(View): -# resource = DisabledFormResource - -# view = MockView() -# content = {'qwerty': 'uiop'} -# self.assertEqual(FormResource(view).get_bound_form(content), None) - -# def test_disabled_model_form_validator_returns_content_unchanged(self): -# """If the view's form is None and does not have a Resource with a model set then -# ModelFormValidator(view).validate_request(content, None) should just return the content unmodified.""" - -# class DisabledModelFormView(View): -# resource = ModelResource - -# view = DisabledModelFormView() -# content = {'qwerty': 'uiop'} -# self.assertEqual(ModelResource(view).get_bound_form(content), None) - -# def test_disabled_model_form_validator_get_bound_form_returns_none(self): -# """If the form attribute is None on FormValidatorMixin then get_bound_form(content) should just return None.""" -# class DisabledModelFormView(View): -# resource = ModelResource - -# view = DisabledModelFormView() -# content = {'qwerty': 'uiop'} -# self.assertEqual(ModelResource(view).get_bound_form(content), None) - - -# class TestNonFieldErrors(TestCase): -# """Tests against form validation errors caused by non-field errors. (eg as might be caused by some custom form validation)""" - -# def test_validate_failed_due_to_non_field_error_returns_appropriate_message(self): -# """If validation fails with a non-field error, ensure the response a non-field error""" -# class MockForm(forms.Form): -# field1 = forms.CharField(required=False) -# field2 = forms.CharField(required=False) -# ERROR_TEXT = 'You may not supply both field1 and field2' - -# def clean(self): -# if 'field1' in self.cleaned_data and 'field2' in self.cleaned_data: -# raise forms.ValidationError(self.ERROR_TEXT) -# return self.cleaned_data - -# class MockResource(FormResource): -# form = MockForm - -# class MockView(View): -# pass - -# view = MockView() -# content = {'field1': 'example1', 'field2': 'example2'} -# try: -# MockResource(view).validate_request(content, None) -# except ImmediateResponse, exc: -# response = exc.response -# self.assertEqual(response.raw_content, {'errors': [MockForm.ERROR_TEXT]}) -# else: -# self.fail('ImmediateResponse was not raised') - - -# class TestFormValidation(TestCase): -# """Tests which check basic form validation. -# Also includes the same set of tests with a ModelFormValidator for which the form has been explicitly set. -# (ModelFormValidator should behave as FormValidator if a form is set rather than relying on the default ModelForm)""" -# def setUp(self): -# class MockForm(forms.Form): -# qwerty = forms.CharField(required=True) - -# class MockFormResource(FormResource): -# form = MockForm - -# class MockModelResource(ModelResource): -# form = MockForm - -# class MockFormView(View): -# resource = MockFormResource - -# class MockModelFormView(View): -# resource = MockModelResource - -# self.MockFormResource = MockFormResource -# self.MockModelResource = MockModelResource -# self.MockFormView = MockFormView -# self.MockModelFormView = MockModelFormView - -# def validation_returns_content_unchanged_if_already_valid_and_clean(self, validator): -# """If the content is already valid and clean then validate(content) should just return the content unmodified.""" -# content = {'qwerty': 'uiop'} -# self.assertEqual(validator.validate_request(content, None), content) - -# def validation_failure_raises_response_exception(self, validator): -# """If form validation fails a ResourceException 400 (Bad Request) should be raised.""" -# content = {} -# self.assertRaises(ImmediateResponse, validator.validate_request, content, None) - -# def validation_does_not_allow_extra_fields_by_default(self, validator): -# """If some (otherwise valid) content includes fields that are not in the form then validation should fail. -# It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up -# broken clients more easily (eg submitting content with a misnamed field)""" -# content = {'qwerty': 'uiop', 'extra': 'extra'} -# self.assertRaises(ImmediateResponse, validator.validate_request, content, None) - -# def validation_allows_extra_fields_if_explicitly_set(self, validator): -# """If we include an allowed_extra_fields paramater on _validate, then allow fields with those names.""" -# content = {'qwerty': 'uiop', 'extra': 'extra'} -# validator._validate(content, None, allowed_extra_fields=('extra',)) - -# def validation_allows_unknown_fields_if_explicitly_allowed(self, validator): -# """If we set ``unknown_form_fields`` on the form resource, then don't -# raise errors on unexpected request data""" -# content = {'qwerty': 'uiop', 'extra': 'extra'} -# validator.allow_unknown_form_fields = True -# self.assertEqual({'qwerty': u'uiop'}, -# validator.validate_request(content, None), -# "Resource didn't accept unknown fields.") -# validator.allow_unknown_form_fields = False - -# def validation_does_not_require_extra_fields_if_explicitly_set(self, validator): -# """If we include an allowed_extra_fields paramater on _validate, then do not fail if we do not have fields with those names.""" -# content = {'qwerty': 'uiop'} -# self.assertEqual(validator._validate(content, None, allowed_extra_fields=('extra',)), content) - -# def validation_failed_due_to_no_content_returns_appropriate_message(self, validator): -# """If validation fails due to no content, ensure the response contains a single non-field error""" -# content = {} -# try: -# validator.validate_request(content, None) -# except ImmediateResponse, exc: -# response = exc.response -# self.assertEqual(response.raw_content, {'field_errors': {'qwerty': ['This field is required.']}}) -# else: -# self.fail('ResourceException was not raised') - -# def validation_failed_due_to_field_error_returns_appropriate_message(self, validator): -# """If validation fails due to a field error, ensure the response contains a single field error""" -# content = {'qwerty': ''} -# try: -# validator.validate_request(content, None) -# except ImmediateResponse, exc: -# response = exc.response -# self.assertEqual(response.raw_content, {'field_errors': {'qwerty': ['This field is required.']}}) -# else: -# self.fail('ResourceException was not raised') - -# def validation_failed_due_to_invalid_field_returns_appropriate_message(self, validator): -# """If validation fails due to an invalid field, ensure the response contains a single field error""" -# content = {'qwerty': 'uiop', 'extra': 'extra'} -# try: -# validator.validate_request(content, None) -# except ImmediateResponse, exc: -# response = exc.response -# self.assertEqual(response.raw_content, {'field_errors': {'extra': ['This field does not exist.']}}) -# else: -# self.fail('ResourceException was not raised') - -# def validation_failed_due_to_multiple_errors_returns_appropriate_message(self, validator): -# """If validation for multiple reasons, ensure the response contains each error""" -# content = {'qwerty': '', 'extra': 'extra'} -# try: -# validator.validate_request(content, None) -# except ImmediateResponse, exc: -# response = exc.response -# self.assertEqual(response.raw_content, {'field_errors': {'qwerty': ['This field is required.'], -# 'extra': ['This field does not exist.']}}) -# else: -# self.fail('ResourceException was not raised') - -# # Tests on FormResource - -# def test_form_validation_returns_content_unchanged_if_already_valid_and_clean(self): -# validator = self.MockFormResource(self.MockFormView()) -# self.validation_returns_content_unchanged_if_already_valid_and_clean(validator) - -# def test_form_validation_failure_raises_response_exception(self): -# validator = self.MockFormResource(self.MockFormView()) -# self.validation_failure_raises_response_exception(validator) - -# def test_validation_does_not_allow_extra_fields_by_default(self): -# validator = self.MockFormResource(self.MockFormView()) -# self.validation_does_not_allow_extra_fields_by_default(validator) - -# def test_validation_allows_extra_fields_if_explicitly_set(self): -# validator = self.MockFormResource(self.MockFormView()) -# self.validation_allows_extra_fields_if_explicitly_set(validator) - -# def test_validation_allows_unknown_fields_if_explicitly_allowed(self): -# validator = self.MockFormResource(self.MockFormView()) -# self.validation_allows_unknown_fields_if_explicitly_allowed(validator) - -# def test_validation_does_not_require_extra_fields_if_explicitly_set(self): -# validator = self.MockFormResource(self.MockFormView()) -# self.validation_does_not_require_extra_fields_if_explicitly_set(validator) - -# def test_validation_failed_due_to_no_content_returns_appropriate_message(self): -# validator = self.MockFormResource(self.MockFormView()) -# self.validation_failed_due_to_no_content_returns_appropriate_message(validator) - -# def test_validation_failed_due_to_field_error_returns_appropriate_message(self): -# validator = self.MockFormResource(self.MockFormView()) -# self.validation_failed_due_to_field_error_returns_appropriate_message(validator) - -# def test_validation_failed_due_to_invalid_field_returns_appropriate_message(self): -# validator = self.MockFormResource(self.MockFormView()) -# self.validation_failed_due_to_invalid_field_returns_appropriate_message(validator) - -# def test_validation_failed_due_to_multiple_errors_returns_appropriate_message(self): -# validator = self.MockFormResource(self.MockFormView()) -# self.validation_failed_due_to_multiple_errors_returns_appropriate_message(validator) - -# # Same tests on ModelResource - -# def test_modelform_validation_returns_content_unchanged_if_already_valid_and_clean(self): -# validator = self.MockModelResource(self.MockModelFormView()) -# self.validation_returns_content_unchanged_if_already_valid_and_clean(validator) - -# def test_modelform_validation_failure_raises_response_exception(self): -# validator = self.MockModelResource(self.MockModelFormView()) -# self.validation_failure_raises_response_exception(validator) - -# def test_modelform_validation_does_not_allow_extra_fields_by_default(self): -# validator = self.MockModelResource(self.MockModelFormView()) -# self.validation_does_not_allow_extra_fields_by_default(validator) - -# def test_modelform_validation_allows_extra_fields_if_explicitly_set(self): -# validator = self.MockModelResource(self.MockModelFormView()) -# self.validation_allows_extra_fields_if_explicitly_set(validator) - -# def test_modelform_validation_does_not_require_extra_fields_if_explicitly_set(self): -# validator = self.MockModelResource(self.MockModelFormView()) -# self.validation_does_not_require_extra_fields_if_explicitly_set(validator) - -# def test_modelform_validation_failed_due_to_no_content_returns_appropriate_message(self): -# validator = self.MockModelResource(self.MockModelFormView()) -# self.validation_failed_due_to_no_content_returns_appropriate_message(validator) - -# def test_modelform_validation_failed_due_to_field_error_returns_appropriate_message(self): -# validator = self.MockModelResource(self.MockModelFormView()) -# self.validation_failed_due_to_field_error_returns_appropriate_message(validator) - -# def test_modelform_validation_failed_due_to_invalid_field_returns_appropriate_message(self): -# validator = self.MockModelResource(self.MockModelFormView()) -# self.validation_failed_due_to_invalid_field_returns_appropriate_message(validator) - -# def test_modelform_validation_failed_due_to_multiple_errors_returns_appropriate_message(self): -# validator = self.MockModelResource(self.MockModelFormView()) -# self.validation_failed_due_to_multiple_errors_returns_appropriate_message(validator) - - -# class TestModelFormValidator(TestCase): -# """Tests specific to ModelFormValidatorMixin""" - -# def setUp(self): -# """Create a validator for a model with two fields and a property.""" -# class MockModel(models.Model): -# qwerty = models.CharField(max_length=256) -# uiop = models.CharField(max_length=256, blank=True) - -# @property -# def read_only(self): -# return 'read only' - -# class MockResource(ModelResource): -# model = MockModel - -# class MockView(View): -# resource = MockResource - -# self.validator = MockResource(MockView) - -# def test_property_fields_are_allowed_on_model_forms(self): -# """Validation on ModelForms may include property fields that exist on the Model to be included in the input.""" -# content = {'qwerty': 'example', 'uiop': 'example', 'read_only': 'read only'} -# self.assertEqual(self.validator.validate_request(content, None), content) - -# def test_property_fields_are_not_required_on_model_forms(self): -# """Validation on ModelForms does not require property fields that exist on the Model to be included in the input.""" -# content = {'qwerty': 'example', 'uiop': 'example'} -# self.assertEqual(self.validator.validate_request(content, None), content) - -# def test_extra_fields_not_allowed_on_model_forms(self): -# """If some (otherwise valid) content includes fields that are not in the form then validation should fail. -# It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up -# broken clients more easily (eg submitting content with a misnamed field)""" -# content = {'qwerty': 'example', 'uiop': 'example', 'read_only': 'read only', 'extra': 'extra'} -# self.assertRaises(ImmediateResponse, self.validator.validate_request, content, None) - -# def test_validate_requires_fields_on_model_forms(self): -# """If some (otherwise valid) content includes fields that are not in the form then validation should fail. -# It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up -# broken clients more easily (eg submitting content with a misnamed field)""" -# content = {'read_only': 'read only'} -# self.assertRaises(ImmediateResponse, self.validator.validate_request, content, None) - -# def test_validate_does_not_require_blankable_fields_on_model_forms(self): -# """Test standard ModelForm validation behaviour - fields with blank=True are not required.""" -# content = {'qwerty': 'example', 'read_only': 'read only'} -# self.validator.validate_request(content, None) - -# def test_model_form_validator_uses_model_forms(self): -# self.assertTrue(isinstance(self.validator.get_bound_form(), forms.ModelForm)) diff --git a/rest_framework/tests/views.py b/rest_framework/tests/views.py index 7cd82656..994cf6dc 100644 --- a/rest_framework/tests/views.py +++ b/rest_framework/tests/views.py @@ -1,4 +1,4 @@ -import copy +from __future__ import unicode_literals from django.test import TestCase from django.test.client import RequestFactory from rest_framework import status @@ -6,6 +6,7 @@ from rest_framework.decorators import api_view from rest_framework.response import Response from rest_framework.settings import api_settings from rest_framework.views import APIView +import copy factory = RequestFactory() @@ -49,10 +50,10 @@ class ClassBasedViewIntegrationTests(TestCase): request = factory.post('/', 'f00bar', content_type='application/json') response = self.view(request) expected = { - 'detail': u'JSON parse error - No JSON object could be decoded' + 'detail': 'JSON parse error - No JSON object could be decoded' } - self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertEquals(sanitise_json_error(response.data), expected) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(sanitise_json_error(response.data), expected) def test_400_parse_error_tunneled_content(self): content = 'f00bar' @@ -64,10 +65,10 @@ class ClassBasedViewIntegrationTests(TestCase): request = factory.post('/', form_data) response = self.view(request) expected = { - 'detail': u'JSON parse error - No JSON object could be decoded' + 'detail': 'JSON parse error - No JSON object could be decoded' } - self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertEquals(sanitise_json_error(response.data), expected) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(sanitise_json_error(response.data), expected) class FunctionBasedViewIntegrationTests(TestCase): @@ -78,10 +79,10 @@ class FunctionBasedViewIntegrationTests(TestCase): request = factory.post('/', 'f00bar', content_type='application/json') response = self.view(request) expected = { - 'detail': u'JSON parse error - No JSON object could be decoded' + 'detail': 'JSON parse error - No JSON object could be decoded' } - self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertEquals(sanitise_json_error(response.data), expected) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(sanitise_json_error(response.data), expected) def test_400_parse_error_tunneled_content(self): content = 'f00bar' @@ -93,7 +94,7 @@ class FunctionBasedViewIntegrationTests(TestCase): request = factory.post('/', form_data) response = self.view(request) expected = { - 'detail': u'JSON parse error - No JSON object could be decoded' + 'detail': 'JSON parse error - No JSON object could be decoded' } - self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertEquals(sanitise_json_error(response.data), expected) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(sanitise_json_error(response.data), expected) |
