diff options
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/conftest.py | 20 | ||||
| -rw-r--r-- | tests/test_authentication.py | 430 | ||||
| -rw-r--r-- | tests/test_fields.py | 38 | ||||
| -rw-r--r-- | tests/test_generics.py | 6 | ||||
| -rw-r--r-- | tests/test_metadata.py | 30 | ||||
| -rw-r--r-- | tests/test_model_serializer.py | 8 | ||||
| -rw-r--r-- | tests/test_pagination.py | 926 | ||||
| -rw-r--r-- | tests/test_parsers.py | 60 | ||||
| -rw-r--r-- | tests/test_relations.py | 2 | ||||
| -rw-r--r-- | tests/test_renderers.py | 238 | ||||
| -rw-r--r-- | tests/test_serializer_bulk_update.py | 4 | ||||
| -rw-r--r-- | tests/test_templatetags.py | 13 | ||||
| -rw-r--r-- | tests/test_versioning.py | 223 |
13 files changed, 720 insertions, 1278 deletions
diff --git a/tests/conftest.py b/tests/conftest.py index 31142eaf..44ed070b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -44,26 +44,6 @@ def pytest_configure(): ), ) - try: - import oauth_provider # NOQA - import oauth2 # NOQA - except ImportError: - pass - else: - settings.INSTALLED_APPS += ( - 'oauth_provider', - ) - - try: - import provider # NOQA - except ImportError: - pass - else: - settings.INSTALLED_APPS += ( - 'provider', - 'provider.oauth2', - ) - # guardian is optional try: import guardian # NOQA diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 44837c4e..04c5782e 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -3,8 +3,7 @@ from django.conf.urls import patterns, url, include from django.contrib.auth.models import User from django.http import HttpResponse from django.test import TestCase -from django.utils import six, unittest -from django.utils.http import urlencode +from django.utils import six from rest_framework import HTTP_HEADER_ENCODING from rest_framework import exceptions from rest_framework import permissions @@ -16,17 +15,11 @@ from rest_framework.authentication import ( TokenAuthentication, BasicAuthentication, SessionAuthentication, - OAuthAuthentication, - OAuth2Authentication ) from rest_framework.authtoken.models import Token -from rest_framework.compat import oauth2_provider, oauth2_provider_scope -from rest_framework.compat import oauth, oauth_provider from rest_framework.test import APIRequestFactory, APIClient from rest_framework.views import APIView import base64 -import time -import datetime factory = APIRequestFactory() @@ -50,37 +43,10 @@ urlpatterns = patterns( (r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])), (r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])), (r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'), - (r'^oauth/$', MockView.as_view(authentication_classes=[OAuthAuthentication])), - ( - r'^oauth-with-scope/$', - MockView.as_view( - authentication_classes=[OAuthAuthentication], - permission_classes=[permissions.TokenHasReadWriteScope] - ) - ), url(r'^auth/', include('rest_framework.urls', namespace='rest_framework')) ) -class OAuth2AuthenticationDebug(OAuth2Authentication): - allow_query_params_token = True - -if oauth2_provider is not None: - urlpatterns += patterns( - '', - url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')), - url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])), - url(r'^oauth2-test-debug/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationDebug])), - url( - r'^oauth2-with-scope-test/$', - MockView.as_view( - authentication_classes=[OAuth2Authentication], - permission_classes=[permissions.TokenHasReadWriteScope] - ) - ) - ) - - class BasicAuthTests(TestCase): """Basic authentication""" urls = 'tests.test_authentication' @@ -276,400 +242,6 @@ class IncorrectCredentialsTests(TestCase): self.assertEqual(response.data, {'detail': 'Bad credentials'}) -class OAuthTests(TestCase): - """OAuth 1.0a authentication""" - urls = 'tests.test_authentication' - - def setUp(self): - # these imports are here because oauth is optional and hiding them in try..except block or compat - # could obscure problems if something breaks - from oauth_provider.models import Consumer, Scope - from oauth_provider.models import Token as OAuthToken - from oauth_provider import consts - - self.consts = consts - - self.csrf_client = APIClient(enforce_csrf_checks=True) - self.username = 'john' - self.email = 'lennon@thebeatles.com' - self.password = 'password' - self.user = User.objects.create_user(self.username, self.email, self.password) - - self.CONSUMER_KEY = 'consumer_key' - self.CONSUMER_SECRET = 'consumer_secret' - self.TOKEN_KEY = "token_key" - self.TOKEN_SECRET = "token_secret" - - self.consumer = Consumer.objects.create( - key=self.CONSUMER_KEY, secret=self.CONSUMER_SECRET, - name='example', user=self.user, status=self.consts.ACCEPTED - ) - - self.scope = Scope.objects.create(name="resource name", url="api/") - self.token = OAuthToken.objects.create( - user=self.user, consumer=self.consumer, scope=self.scope, - token_type=OAuthToken.ACCESS, key=self.TOKEN_KEY, secret=self.TOKEN_SECRET, - is_approved=True - ) - - def _create_authorization_header(self): - params = { - 'oauth_version': "1.0", - 'oauth_nonce': oauth.generate_nonce(), - 'oauth_timestamp': int(time.time()), - 'oauth_token': self.token.key, - 'oauth_consumer_key': self.consumer.key - } - - req = oauth.Request(method="GET", url="http://example.com", parameters=params) - - signature_method = oauth.SignatureMethod_PLAINTEXT() - req.sign_request(signature_method, self.consumer, self.token) - - return req.to_header()["Authorization"] - - def _create_authorization_url_parameters(self): - params = { - 'oauth_version': "1.0", - 'oauth_nonce': oauth.generate_nonce(), - 'oauth_timestamp': int(time.time()), - 'oauth_token': self.token.key, - 'oauth_consumer_key': self.consumer.key - } - - req = oauth.Request(method="GET", url="http://example.com", parameters=params) - - signature_method = oauth.SignatureMethod_PLAINTEXT() - req.sign_request(signature_method, self.consumer, self.token) - return dict(req) - - @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') - @unittest.skipUnless(oauth, 'oauth2 not installed') - def test_post_form_passing_oauth(self): - """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF""" - auth = self._create_authorization_header() - response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 200) - - @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') - @unittest.skipUnless(oauth, 'oauth2 not installed') - def test_post_form_repeated_nonce_failing_oauth(self): - """Ensure POSTing form over OAuth with repeated auth (same nonces and timestamp) credentials fails""" - auth = self._create_authorization_header() - response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 200) - - # simulate reply attack auth header containes already used (nonce, timestamp) pair - response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) - self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) - - @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') - @unittest.skipUnless(oauth, 'oauth2 not installed') - def test_post_form_token_removed_failing_oauth(self): - """Ensure POSTing when there is no OAuth access token in db fails""" - self.token.delete() - auth = self._create_authorization_header() - response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) - self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) - - @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') - @unittest.skipUnless(oauth, 'oauth2 not installed') - def test_post_form_consumer_status_not_accepted_failing_oauth(self): - """Ensure POSTing when consumer status is anything other than ACCEPTED fails""" - for consumer_status in (self.consts.CANCELED, self.consts.PENDING, self.consts.REJECTED): - self.consumer.status = consumer_status - self.consumer.save() - - auth = self._create_authorization_header() - response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) - self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) - - @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') - @unittest.skipUnless(oauth, 'oauth2 not installed') - def test_post_form_with_request_token_failing_oauth(self): - """Ensure POSTing with unauthorized request token instead of access token fails""" - self.token.token_type = self.token.REQUEST - self.token.save() - - auth = self._create_authorization_header() - response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) - self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) - - @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') - @unittest.skipUnless(oauth, 'oauth2 not installed') - def test_post_form_with_urlencoded_parameters(self): - """Ensure POSTing with x-www-form-urlencoded auth parameters passes""" - params = self._create_authorization_url_parameters() - auth = self._create_authorization_header() - response = self.csrf_client.post('/oauth/', params, HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 200) - - @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') - @unittest.skipUnless(oauth, 'oauth2 not installed') - def test_get_form_with_url_parameters(self): - """Ensure GETing with auth in url parameters passes""" - params = self._create_authorization_url_parameters() - response = self.csrf_client.get('/oauth/', params) - self.assertEqual(response.status_code, 200) - - @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') - @unittest.skipUnless(oauth, 'oauth2 not installed') - def test_post_hmac_sha1_signature_passes(self): - """Ensure POSTing using HMAC_SHA1 signature method passes""" - params = { - 'oauth_version': "1.0", - 'oauth_nonce': oauth.generate_nonce(), - 'oauth_timestamp': int(time.time()), - 'oauth_token': self.token.key, - 'oauth_consumer_key': self.consumer.key - } - - req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params) - - signature_method = oauth.SignatureMethod_HMAC_SHA1() - req.sign_request(signature_method, self.consumer, self.token) - auth = req.to_header()["Authorization"] - - response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 200) - - @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') - @unittest.skipUnless(oauth, 'oauth2 not installed') - def test_get_form_with_readonly_resource_passing_auth(self): - """Ensure POSTing with a readonly scope instead of a write scope fails""" - read_only_access_token = self.token - read_only_access_token.scope.is_readonly = True - read_only_access_token.scope.save() - params = self._create_authorization_url_parameters() - response = self.csrf_client.get('/oauth-with-scope/', params) - self.assertEqual(response.status_code, 200) - - @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') - @unittest.skipUnless(oauth, 'oauth2 not installed') - def test_post_form_with_readonly_resource_failing_auth(self): - """Ensure POSTing with a readonly resource instead of a write scope fails""" - read_only_access_token = self.token - read_only_access_token.scope.is_readonly = True - read_only_access_token.scope.save() - params = self._create_authorization_url_parameters() - response = self.csrf_client.post('/oauth-with-scope/', params) - self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) - - @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') - @unittest.skipUnless(oauth, 'oauth2 not installed') - def test_post_form_with_write_resource_passing_auth(self): - """Ensure POSTing with a write resource succeed""" - read_write_access_token = self.token - read_write_access_token.scope.is_readonly = False - read_write_access_token.scope.save() - params = self._create_authorization_url_parameters() - auth = self._create_authorization_header() - response = self.csrf_client.post('/oauth-with-scope/', params, HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 200) - - @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') - @unittest.skipUnless(oauth, 'oauth2 not installed') - def test_bad_consumer_key(self): - """Ensure POSTing using HMAC_SHA1 signature method passes""" - params = { - 'oauth_version': "1.0", - 'oauth_nonce': oauth.generate_nonce(), - 'oauth_timestamp': int(time.time()), - 'oauth_token': self.token.key, - 'oauth_consumer_key': 'badconsumerkey' - } - - req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params) - - signature_method = oauth.SignatureMethod_HMAC_SHA1() - req.sign_request(signature_method, self.consumer, self.token) - auth = req.to_header()["Authorization"] - - response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 401) - - @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') - @unittest.skipUnless(oauth, 'oauth2 not installed') - def test_bad_token_key(self): - """Ensure POSTing using HMAC_SHA1 signature method passes""" - params = { - 'oauth_version': "1.0", - 'oauth_nonce': oauth.generate_nonce(), - 'oauth_timestamp': int(time.time()), - 'oauth_token': 'badtokenkey', - 'oauth_consumer_key': self.consumer.key - } - - req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params) - - signature_method = oauth.SignatureMethod_HMAC_SHA1() - req.sign_request(signature_method, self.consumer, self.token) - auth = req.to_header()["Authorization"] - - response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 401) - - -class OAuth2Tests(TestCase): - """OAuth 2.0 authentication""" - urls = 'tests.test_authentication' - - def setUp(self): - self.csrf_client = APIClient(enforce_csrf_checks=True) - self.username = 'john' - self.email = 'lennon@thebeatles.com' - self.password = 'password' - self.user = User.objects.create_user(self.username, self.email, self.password) - - self.CLIENT_ID = 'client_key' - self.CLIENT_SECRET = 'client_secret' - self.ACCESS_TOKEN = "access_token" - self.REFRESH_TOKEN = "refresh_token" - - self.oauth2_client = oauth2_provider.oauth2.models.Client.objects.create( - client_id=self.CLIENT_ID, - client_secret=self.CLIENT_SECRET, - redirect_uri='', - client_type=0, - name='example', - user=None, - ) - - self.access_token = oauth2_provider.oauth2.models.AccessToken.objects.create( - token=self.ACCESS_TOKEN, - client=self.oauth2_client, - user=self.user, - ) - self.refresh_token = oauth2_provider.oauth2.models.RefreshToken.objects.create( - user=self.user, - access_token=self.access_token, - client=self.oauth2_client - ) - - def _create_authorization_header(self, token=None): - return "Bearer {0}".format(token or self.access_token.token) - - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') - def test_get_form_with_wrong_authorization_header_token_type_failing(self): - """Ensure that a wrong token type lead to the correct HTTP error status code""" - auth = "Wrong token-type-obsviously" - response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 401) - response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 401) - - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') - def test_get_form_with_wrong_authorization_header_token_format_failing(self): - """Ensure that a wrong token format lead to the correct HTTP error status code""" - auth = "Bearer wrong token format" - response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 401) - response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 401) - - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') - def test_get_form_with_wrong_authorization_header_token_failing(self): - """Ensure that a wrong token lead to the correct HTTP error status code""" - auth = "Bearer wrong-token" - response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 401) - response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 401) - - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') - def test_get_form_with_wrong_authorization_header_token_missing(self): - """Ensure that a missing token lead to the correct HTTP error status code""" - auth = "Bearer" - response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 401) - response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 401) - - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') - def test_get_form_passing_auth(self): - """Ensure GETing form over OAuth with correct client credentials succeed""" - auth = self._create_authorization_header() - response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 200) - - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') - def test_post_form_passing_auth_url_transport(self): - """Ensure GETing form over OAuth with correct client credentials in form data succeed""" - response = self.csrf_client.post( - '/oauth2-test/', - data={'access_token': self.access_token.token} - ) - self.assertEqual(response.status_code, 200) - - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') - def test_get_form_passing_auth_url_transport(self): - """Ensure GETing form over OAuth with correct client credentials in query succeed when DEBUG is True""" - query = urlencode({'access_token': self.access_token.token}) - response = self.csrf_client.get('/oauth2-test-debug/?%s' % query) - self.assertEqual(response.status_code, 200) - - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') - def test_get_form_failing_auth_url_transport(self): - """Ensure GETing form over OAuth with correct client credentials in query fails when DEBUG is False""" - query = urlencode({'access_token': self.access_token.token}) - response = self.csrf_client.get('/oauth2-test/?%s' % query) - self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) - - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') - def test_post_form_passing_auth(self): - """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF""" - auth = self._create_authorization_header() - response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 200) - - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') - def test_post_form_token_removed_failing_auth(self): - """Ensure POSTing when there is no OAuth access token in db fails""" - self.access_token.delete() - auth = self._create_authorization_header() - response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) - self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) - - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') - def test_post_form_with_refresh_token_failing_auth(self): - """Ensure POSTing with refresh token instead of access token fails""" - auth = self._create_authorization_header(token=self.refresh_token.token) - response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) - self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) - - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') - def test_post_form_with_expired_access_token_failing_auth(self): - """Ensure POSTing with expired access token fails with an 'Invalid token' error""" - self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10) # 10 seconds late - self.access_token.save() - auth = self._create_authorization_header() - response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) - self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) - self.assertIn('Invalid token', response.content) - - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') - def test_post_form_with_invalid_scope_failing_auth(self): - """Ensure POSTing with a readonly scope instead of a write scope fails""" - read_only_access_token = self.access_token - read_only_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['read'] - read_only_access_token.save() - auth = self._create_authorization_header(token=read_only_access_token.token) - response = self.csrf_client.get('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 200) - response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') - def test_post_form_with_valid_scope_passing_auth(self): - """Ensure POSTing with a write scope succeed""" - read_write_access_token = self.access_token - read_write_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['write'] - read_write_access_token.save() - auth = self._create_authorization_header(token=read_write_access_token.token) - response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 200) - - class FailingAuthAccessedInRenderer(TestCase): def setUp(self): class AuthAccessingRenderer(renderers.BaseRenderer): diff --git a/tests/test_fields.py b/tests/test_fields.py index 775d4618..dc21c234 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -346,7 +346,7 @@ class TestBooleanField(FieldValues): False: False, } invalid_inputs = { - 'foo': ['`foo` is not a valid boolean.'], + 'foo': ['"foo" is not a valid boolean.'], None: ['This field may not be null.'] } outputs = { @@ -376,7 +376,7 @@ class TestNullBooleanField(FieldValues): None: None } invalid_inputs = { - 'foo': ['`foo` is not a valid boolean.'], + 'foo': ['"foo" is not a valid boolean.'], } outputs = { 'true': True, @@ -447,7 +447,7 @@ class TestSlugField(FieldValues): 'slug-99': 'slug-99', } invalid_inputs = { - 'slug 99': ["Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens."] + 'slug 99': ['Enter a valid "slug" consisting of letters, numbers, underscores or hyphens.'] } outputs = {} field = serializers.SlugField() @@ -648,8 +648,8 @@ class TestDateField(FieldValues): datetime.date(2001, 1, 1): datetime.date(2001, 1, 1), } invalid_inputs = { - 'abc': ['Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]'], - '2001-99-99': ['Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]'], + 'abc': ['Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]].'], + '2001-99-99': ['Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]].'], datetime.datetime(2001, 1, 1, 12, 00): ['Expected a date but got a datetime.'], } outputs = { @@ -666,7 +666,7 @@ class TestCustomInputFormatDateField(FieldValues): '1 Jan 2001': datetime.date(2001, 1, 1), } invalid_inputs = { - '2001-01-01': ['Date has wrong format. Use one of these formats instead: DD [Jan-Dec] YYYY'] + '2001-01-01': ['Date has wrong format. Use one of these formats instead: DD [Jan-Dec] YYYY.'] } outputs = {} field = serializers.DateField(input_formats=['%d %b %Y']) @@ -710,8 +710,8 @@ class TestDateTimeField(FieldValues): '2001-01-01T14:00+01:00' if (django.VERSION > (1, 4)) else '2001-01-01T13:00Z': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()) } invalid_inputs = { - 'abc': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]'], - '2001-99-99T99:00': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]'], + 'abc': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z].'], + '2001-99-99T99:00': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z].'], datetime.date(2001, 1, 1): ['Expected a datetime but got a date.'], } outputs = { @@ -729,7 +729,7 @@ class TestCustomInputFormatDateTimeField(FieldValues): '1:35pm, 1 Jan 2001': datetime.datetime(2001, 1, 1, 13, 35, tzinfo=timezone.UTC()), } invalid_inputs = { - '2001-01-01T20:50': ['Datetime has wrong format. Use one of these formats instead: hh:mm[AM|PM], DD [Jan-Dec] YYYY'] + '2001-01-01T20:50': ['Datetime has wrong format. Use one of these formats instead: hh:mm[AM|PM], DD [Jan-Dec] YYYY.'] } outputs = {} field = serializers.DateTimeField(default_timezone=timezone.UTC(), input_formats=['%I:%M%p, %d %b %Y']) @@ -781,8 +781,8 @@ class TestTimeField(FieldValues): datetime.time(13, 00): datetime.time(13, 00), } invalid_inputs = { - 'abc': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]]'], - '99:99': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]]'], + 'abc': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]].'], + '99:99': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]].'], } outputs = { datetime.time(13, 00): '13:00:00' @@ -798,7 +798,7 @@ class TestCustomInputFormatTimeField(FieldValues): '1:00pm': datetime.time(13, 00), } invalid_inputs = { - '13:00': ['Time has wrong format. Use one of these formats instead: hh:mm[AM|PM]'], + '13:00': ['Time has wrong format. Use one of these formats instead: hh:mm[AM|PM].'], } outputs = {} field = serializers.TimeField(input_formats=['%I:%M%p']) @@ -840,7 +840,7 @@ class TestChoiceField(FieldValues): 'good': 'good', } invalid_inputs = { - 'amazing': ['`amazing` is not a valid choice.'] + 'amazing': ['"amazing" is not a valid choice.'] } outputs = { 'good': 'good', @@ -880,8 +880,8 @@ class TestChoiceFieldWithType(FieldValues): 3: 3, } invalid_inputs = { - 5: ['`5` is not a valid choice.'], - 'abc': ['`abc` is not a valid choice.'] + 5: ['"5" is not a valid choice.'], + 'abc': ['"abc" is not a valid choice.'] } outputs = { '1': 1, @@ -907,7 +907,7 @@ class TestChoiceFieldWithListChoices(FieldValues): 'good': 'good', } invalid_inputs = { - 'awful': ['`awful` is not a valid choice.'] + 'awful': ['"awful" is not a valid choice.'] } outputs = { 'good': 'good' @@ -925,8 +925,8 @@ class TestMultipleChoiceField(FieldValues): ('aircon', 'manual'): set(['aircon', 'manual']), } invalid_inputs = { - 'abc': ['Expected a list of items but got type `str`.'], - ('aircon', 'incorrect'): ['`incorrect` is not a valid choice.'] + 'abc': ['Expected a list of items but got type "str".'], + ('aircon', 'incorrect'): ['"incorrect" is not a valid choice.'] } outputs = [ (['aircon', 'manual'], set(['aircon', 'manual'])) @@ -1036,7 +1036,7 @@ class TestListField(FieldValues): (['1', '2', '3'], [1, 2, 3]) ] invalid_inputs = [ - ('not a list', ['Expected a list of items but got type `str`']), + ('not a list', ['Expected a list of items but got type "str".']), ([1, 2, 'error'], ['A valid integer is required.']) ] outputs = [ diff --git a/tests/test_generics.py b/tests/test_generics.py index 94023c30..fba8718f 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -117,7 +117,7 @@ class TestRootView(TestCase): with self.assertNumQueries(0): response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) - self.assertEqual(response.data, {"detail": "Method 'PUT' not allowed."}) + self.assertEqual(response.data, {"detail": 'Method "PUT" not allowed.'}) def test_delete_root_view(self): """ @@ -127,7 +127,7 @@ class TestRootView(TestCase): with self.assertNumQueries(0): response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) - self.assertEqual(response.data, {"detail": "Method 'DELETE' not allowed."}) + self.assertEqual(response.data, {"detail": 'Method "DELETE" not allowed.'}) def test_post_cannot_set_id(self): """ @@ -181,7 +181,7 @@ class TestInstanceView(TestCase): with self.assertNumQueries(0): response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) - self.assertEqual(response.data, {"detail": "Method 'POST' not allowed."}) + self.assertEqual(response.data, {"detail": 'Method "POST" not allowed.'}) def test_put_instance_view(self): """ diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 5ff59c72..972a896a 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -1,9 +1,7 @@ from __future__ import unicode_literals - -from rest_framework import exceptions, serializers, views +from rest_framework import exceptions, serializers, status, views from rest_framework.request import Request from rest_framework.test import APIRequestFactory -import pytest request = Request(APIRequestFactory().options('/')) @@ -17,7 +15,8 @@ class TestMetadata: """Example view.""" pass - response = ExampleView().options(request=request) + view = ExampleView.as_view() + response = view(request=request) expected = { 'name': 'Example', 'description': 'Example view.', @@ -31,7 +30,7 @@ class TestMetadata: 'multipart/form-data' ] } - assert response.status_code == 200 + assert response.status_code == status.HTTP_200_OK assert response.data == expected def test_none_metadata(self): @@ -42,8 +41,10 @@ class TestMetadata: class ExampleView(views.APIView): metadata_class = None - with pytest.raises(exceptions.MethodNotAllowed): - ExampleView().options(request=request) + view = ExampleView.as_view() + response = view(request=request) + assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED + assert response.data == {'detail': 'Method "OPTIONS" not allowed.'} def test_actions(self): """ @@ -63,7 +64,8 @@ class TestMetadata: def get_serializer(self): return ExampleSerializer() - response = ExampleView().options(request=request) + view = ExampleView.as_view() + response = view(request=request) expected = { 'name': 'Example', 'description': 'Example view.', @@ -104,7 +106,7 @@ class TestMetadata: } } } - assert response.status_code == 200 + assert response.status_code == status.HTTP_200_OK assert response.data == expected def test_global_permissions(self): @@ -132,8 +134,9 @@ class TestMetadata: if request.method == 'POST': raise exceptions.PermissionDenied() - response = ExampleView().options(request=request) - assert response.status_code == 200 + view = ExampleView.as_view() + response = view(request=request) + assert response.status_code == status.HTTP_200_OK assert list(response.data['actions'].keys()) == ['PUT'] def test_object_permissions(self): @@ -161,6 +164,7 @@ class TestMetadata: if self.request.method == 'PUT': raise exceptions.PermissionDenied() - response = ExampleView().options(request=request) - assert response.status_code == 200 + view = ExampleView.as_view() + response = view(request=request) + assert response.status_code == status.HTTP_200_OK assert list(response.data['actions'].keys()) == ['POST'] diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py index 247b309a..bce2008a 100644 --- a/tests/test_model_serializer.py +++ b/tests/test_model_serializer.py @@ -216,7 +216,7 @@ class TestRegularFieldMappings(TestCase): with self.assertRaises(ImproperlyConfigured) as excinfo: TestSerializer().fields - expected = 'Field name `invalid` is not valid for model `ModelBase`.' + expected = 'Field name `invalid` is not valid for model `RegularFieldsModel`.' assert str(excinfo.exception) == expected def test_missing_field(self): @@ -234,8 +234,8 @@ class TestRegularFieldMappings(TestCase): with self.assertRaises(AssertionError) as excinfo: TestSerializer().fields expected = ( - 'Field `missing` has been declared on serializer ' - '`TestSerializer`, but is missing from `Meta.fields`.' + "The field 'missing' was declared on serializer TestSerializer, " + "but has not been included in the 'fields' option." ) assert str(excinfo.exception) == expected @@ -637,5 +637,5 @@ class TestSerializerMetaClass(TestCase): exception = result.exception self.assertEqual( str(exception), - "Cannot set both 'fields' and 'exclude'." + "Cannot set both 'fields' and 'exclude' options on serializer ExampleSerializer." ) diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 1fd9cf9c..7cc92347 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -1,553 +1,475 @@ from __future__ import unicode_literals -import datetime -from decimal import Decimal -from django.core.paginator import Paginator -from django.test import TestCase -from django.utils import unittest -from rest_framework import generics, serializers, status, pagination, filters -from rest_framework.compat import django_filters +from rest_framework import exceptions, generics, pagination, serializers, status, filters +from rest_framework.request import Request +from rest_framework.pagination import PageLink, PAGE_BREAK from rest_framework.test import APIRequestFactory -from .models import BasicModel, FilterableItem +import pytest factory = APIRequestFactory() -# Helper function to split arguments out of an url -def split_arguments_from_url(url): - if '?' not in url: - return url - - path, args = url.split('?') - args = dict(r.split('=') for r in args.split('&')) - return path, args - - -class BasicSerializer(serializers.ModelSerializer): - class Meta: - model = BasicModel - - -class FilterableItemSerializer(serializers.ModelSerializer): - class Meta: - model = FilterableItem - - -class RootView(generics.ListCreateAPIView): +class TestPaginationIntegration: """ - Example description for OPTIONS. + Integration tests. """ - queryset = BasicModel.objects.all() - serializer_class = BasicSerializer - paginate_by = 10 + def setup(self): + class PassThroughSerializer(serializers.BaseSerializer): + def to_representation(self, item): + return item -class DefaultPageSizeKwargView(generics.ListAPIView): - """ - View for testing default paginate_by_param usage - """ - queryset = BasicModel.objects.all() - serializer_class = BasicSerializer - - -class PaginateByParamView(generics.ListAPIView): - """ - View for testing custom paginate_by_param usage - """ - queryset = BasicModel.objects.all() - serializer_class = BasicSerializer - paginate_by_param = 'page_size' - - -class MaxPaginateByView(generics.ListAPIView): - """ - View for testing custom max_paginate_by usage - """ - queryset = BasicModel.objects.all() - serializer_class = BasicSerializer - paginate_by = 3 - max_paginate_by = 5 - paginate_by_param = 'page_size' - + class EvenItemsOnly(filters.BaseFilterBackend): + def filter_queryset(self, request, queryset, view): + return [item for item in queryset if item % 2 == 0] + + class BasicPagination(pagination.PageNumberPagination): + paginate_by = 5 + paginate_by_param = 'page_size' + max_paginate_by = 20 + + self.view = generics.ListAPIView.as_view( + serializer_class=PassThroughSerializer, + queryset=range(1, 101), + filter_backends=[EvenItemsOnly], + pagination_class=BasicPagination + ) -class IntegrationTestPagination(TestCase): - """ - Integration tests for paginated list views. - """ + def test_filtered_items_are_paginated(self): + request = factory.get('/', {'page': 2}) + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == { + 'results': [12, 14, 16, 18, 20], + 'previous': 'http://testserver/', + 'next': 'http://testserver/?page=3', + 'count': 50 + } - def setUp(self): - """ - Create 26 BasicModel instances. - """ - for char in 'abcdefghijklmnopqrstuvwxyz': - BasicModel(text=char * 3).save() - self.objects = BasicModel.objects - self.data = [ - {'id': obj.id, 'text': obj.text} - for obj in self.objects.all() - ] - self.view = RootView.as_view() - - def test_get_paginated_root_view(self): - """ - GET requests to paginated ListCreateAPIView should return paginated results. + def test_setting_page_size(self): """ - request = factory.get('/') - # Note: Database queries are a `SELECT COUNT`, and `SELECT <fields>` - with self.assertNumQueries(2): - response = self.view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 26) - self.assertEqual(response.data['results'], self.data[:10]) - self.assertNotEqual(response.data['next'], None) - self.assertEqual(response.data['previous'], None) - - request = factory.get(*split_arguments_from_url(response.data['next'])) - with self.assertNumQueries(2): - response = self.view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 26) - self.assertEqual(response.data['results'], self.data[10:20]) - self.assertNotEqual(response.data['next'], None) - self.assertNotEqual(response.data['previous'], None) - - request = factory.get(*split_arguments_from_url(response.data['next'])) - with self.assertNumQueries(2): - response = self.view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 26) - self.assertEqual(response.data['results'], self.data[20:]) - self.assertEqual(response.data['next'], None) - self.assertNotEqual(response.data['previous'], None) - - -class IntegrationTestPaginationAndFiltering(TestCase): - - def setUp(self): + When 'paginate_by_param' is set, the client may choose a page size. """ - Create 50 FilterableItem instances. - """ - base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8)) - for i in range(26): - text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc. - decimal = base_data[1] + i - date = base_data[2] - datetime.timedelta(days=i * 2) - FilterableItem(text=text, decimal=decimal, date=date).save() - - self.objects = FilterableItem.objects - self.data = [ - {'id': obj.id, 'text': obj.text, 'decimal': str(obj.decimal), 'date': obj.date.isoformat()} - for obj in self.objects.all() - ] - - @unittest.skipUnless(django_filters, 'django-filter not installed') - def test_get_django_filter_paginated_filtered_root_view(self): - """ - GET requests to paginated filtered ListCreateAPIView should return - paginated results. The next and previous links should preserve the - filtered parameters. - """ - class DecimalFilter(django_filters.FilterSet): - decimal = django_filters.NumberFilter(lookup_type='lt') - - class Meta: - model = FilterableItem - fields = ['text', 'decimal', 'date'] - - class FilterFieldsRootView(generics.ListCreateAPIView): - queryset = FilterableItem.objects.all() - serializer_class = FilterableItemSerializer - paginate_by = 10 - filter_class = DecimalFilter - filter_backends = (filters.DjangoFilterBackend,) - - view = FilterFieldsRootView.as_view() - - EXPECTED_NUM_QUERIES = 2 - - request = factory.get('/', {'decimal': '15.20'}) - with self.assertNumQueries(EXPECTED_NUM_QUERIES): - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 15) - self.assertEqual(response.data['results'], self.data[:10]) - self.assertNotEqual(response.data['next'], None) - self.assertEqual(response.data['previous'], None) - - request = factory.get(*split_arguments_from_url(response.data['next'])) - with self.assertNumQueries(EXPECTED_NUM_QUERIES): - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 15) - self.assertEqual(response.data['results'], self.data[10:15]) - self.assertEqual(response.data['next'], None) - self.assertNotEqual(response.data['previous'], None) - - request = factory.get(*split_arguments_from_url(response.data['previous'])) - with self.assertNumQueries(EXPECTED_NUM_QUERIES): - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 15) - self.assertEqual(response.data['results'], self.data[:10]) - self.assertNotEqual(response.data['next'], None) - self.assertEqual(response.data['previous'], None) - - def test_get_basic_paginated_filtered_root_view(self): - """ - Same as `test_get_django_filter_paginated_filtered_root_view`, - except using a custom filter backend instead of the django-filter - backend, - """ - - class DecimalFilterBackend(filters.BaseFilterBackend): - def filter_queryset(self, request, queryset, view): - return queryset.filter(decimal__lt=Decimal(request.GET['decimal'])) - - class BasicFilterFieldsRootView(generics.ListCreateAPIView): - queryset = FilterableItem.objects.all() - serializer_class = FilterableItemSerializer - paginate_by = 10 - filter_backends = (DecimalFilterBackend,) - - view = BasicFilterFieldsRootView.as_view() - - request = factory.get('/', {'decimal': '15.20'}) - with self.assertNumQueries(2): - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 15) - self.assertEqual(response.data['results'], self.data[:10]) - self.assertNotEqual(response.data['next'], None) - self.assertEqual(response.data['previous'], None) - - request = factory.get(*split_arguments_from_url(response.data['next'])) - with self.assertNumQueries(2): - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 15) - self.assertEqual(response.data['results'], self.data[10:15]) - self.assertEqual(response.data['next'], None) - self.assertNotEqual(response.data['previous'], None) - - request = factory.get(*split_arguments_from_url(response.data['previous'])) - with self.assertNumQueries(2): - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 15) - self.assertEqual(response.data['results'], self.data[:10]) - self.assertNotEqual(response.data['next'], None) - self.assertEqual(response.data['previous'], None) - - -class PassOnContextPaginationSerializer(pagination.PaginationSerializer): - class Meta: - object_serializer_class = serializers.Serializer - - -class UnitTestPagination(TestCase): - """ - Unit tests for pagination of primitive objects. - """ + request = factory.get('/', {'page_size': 10}) + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == { + 'results': [2, 4, 6, 8, 10, 12, 14, 16, 18, 20], + 'previous': None, + 'next': 'http://testserver/?page=2&page_size=10', + 'count': 50 + } - def setUp(self): - self.objects = [char * 3 for char in 'abcdefghijklmnopqrstuvwxyz'] - paginator = Paginator(self.objects, 10) - self.first_page = paginator.page(1) - self.last_page = paginator.page(3) - - def test_native_pagination(self): - serializer = pagination.PaginationSerializer(self.first_page) - self.assertEqual(serializer.data['count'], 26) - self.assertEqual(serializer.data['next'], '?page=2') - self.assertEqual(serializer.data['previous'], None) - self.assertEqual(serializer.data['results'], self.objects[:10]) - - serializer = pagination.PaginationSerializer(self.last_page) - self.assertEqual(serializer.data['count'], 26) - self.assertEqual(serializer.data['next'], None) - self.assertEqual(serializer.data['previous'], '?page=2') - self.assertEqual(serializer.data['results'], self.objects[20:]) - - def test_context_available_in_result(self): + def test_setting_page_size_over_maximum(self): """ - Ensure context gets passed through to the object serializer. + When page_size parameter exceeds maxiumum allowable, + then it should be capped to the maxiumum. """ - serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'}) - serializer.data - results = serializer.fields[serializer.results_field] - self.assertEqual(serializer.context, results.context) - + request = factory.get('/', {'page_size': 1000}) + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == { + 'results': [ + 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, + 22, 24, 26, 28, 30, 32, 34, 36, 38, 40 + ], + 'previous': None, + 'next': 'http://testserver/?page=2&page_size=1000', + 'count': 50 + } -class TestUnpaginated(TestCase): - """ - Tests for list views without pagination. - """ + def test_additional_query_params_are_preserved(self): + request = factory.get('/', {'page': 2, 'filter': 'even'}) + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == { + 'results': [12, 14, 16, 18, 20], + 'previous': 'http://testserver/?filter=even', + 'next': 'http://testserver/?filter=even&page=3', + 'count': 50 + } - def setUp(self): - """ - Create 13 BasicModel instances. - """ - for i in range(13): - BasicModel(text=i).save() - self.objects = BasicModel.objects - self.data = [ - {'id': obj.id, 'text': obj.text} - for obj in self.objects.all() - ] - self.view = DefaultPageSizeKwargView.as_view() - - def test_unpaginated(self): - """ - Tests the default page size for this view. - no page size --> no limit --> no meta data - """ - request = factory.get('/') + def test_404_not_found_for_invalid_page(self): + request = factory.get('/', {'page': 'invalid'}) response = self.view(request) - self.assertEqual(response.data, self.data) + assert response.status_code == status.HTTP_404_NOT_FOUND + assert response.data == { + 'detail': 'Invalid page "invalid": That page number is not an integer.' + } -class TestCustomPaginateByParam(TestCase): +class TestPaginationDisabledIntegration: """ - Tests for list views with default page size kwarg + Integration tests for disabled pagination. """ - def setUp(self): - """ - Create 13 BasicModel instances. - """ - for i in range(13): - BasicModel(text=i).save() - self.objects = BasicModel.objects - self.data = [ - {'id': obj.id, 'text': obj.text} - for obj in self.objects.all() - ] - self.view = PaginateByParamView.as_view() - - def test_default_page_size(self): - """ - Tests the default page size for this view. - no page size --> no limit --> no meta data - """ - request = factory.get('/') - response = self.view(request).render() - self.assertEqual(response.data, self.data) + def setup(self): + class PassThroughSerializer(serializers.BaseSerializer): + def to_representation(self, item): + return item - def test_paginate_by_param(self): - """ - If paginate_by_param is set, the new kwarg should limit per view requests. - """ - request = factory.get('/', {'page_size': 5}) - response = self.view(request).render() - self.assertEqual(response.data['count'], 13) - self.assertEqual(response.data['results'], self.data[:5]) + self.view = generics.ListAPIView.as_view( + serializer_class=PassThroughSerializer, + queryset=range(1, 101), + pagination_class=None + ) + def test_unpaginated_list(self): + request = factory.get('/', {'page': 2}) + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == list(range(1, 101)) -class TestMaxPaginateByParam(TestCase): + +class TestDeprecatedStylePagination: """ - Tests for list views with max_paginate_by kwarg + Integration tests for deprecated style of setting pagination + attributes on the view. """ - def setUp(self): - """ - Create 13 BasicModel instances. - """ - for i in range(13): - BasicModel(text=i).save() - self.objects = BasicModel.objects - self.data = [ - {'id': obj.id, 'text': obj.text} - for obj in self.objects.all() - ] - self.view = MaxPaginateByView.as_view() - - def test_max_paginate_by(self): - """ - If max_paginate_by is set, it should limit page size for the view. - """ - request = factory.get('/', data={'page_size': 10}) - response = self.view(request).render() - self.assertEqual(response.data['count'], 13) - self.assertEqual(response.data['results'], self.data[:5]) - - def test_max_paginate_by_without_page_size_param(self): - """ - If max_paginate_by is set, but client does not specifiy page_size, - standard `paginate_by` behavior should be used. - """ - request = factory.get('/') - response = self.view(request).render() - self.assertEqual(response.data['results'], self.data[:3]) - - -# Tests for context in pagination serializers + def setup(self): + class PassThroughSerializer(serializers.BaseSerializer): + def to_representation(self, item): + return item -class CustomField(serializers.ReadOnlyField): - def to_native(self, value): - if 'view' not in self.context: - raise RuntimeError("context isn't getting passed into custom field") - return "value" + class ExampleView(generics.ListAPIView): + serializer_class = PassThroughSerializer + queryset = range(1, 101) + pagination_class = pagination.PageNumberPagination + paginate_by = 20 + page_query_param = 'page_number' + self.view = ExampleView.as_view() -class BasicModelSerializer(serializers.Serializer): - text = CustomField() - - def to_native(self, value): - if 'view' not in self.context: - raise RuntimeError("context isn't getting passed into serializer") - return super(BasicSerializer, self).to_native(value) - - -class TestContextPassedToCustomField(TestCase): - def setUp(self): - BasicModel.objects.create(text='ala ma kota') + def test_paginate_by_attribute_on_view(self): + request = factory.get('/?page_number=2') + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == { + 'results': [ + 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40 + ], + 'previous': 'http://testserver/', + 'next': 'http://testserver/?page_number=3', + 'count': 100 + } - def test_with_pagination(self): - class ListView(generics.ListCreateAPIView): - queryset = BasicModel.objects.all() - serializer_class = BasicModelSerializer - paginate_by = 1 - self.view = ListView.as_view() - request = factory.get('/') - response = self.view(request).render() +class TestPageNumberPagination: + """ + Unit tests for `pagination.PageNumberPagination`. + """ - self.assertEqual(response.status_code, status.HTTP_200_OK) + def setup(self): + class ExamplePagination(pagination.PageNumberPagination): + paginate_by = 5 + self.pagination = ExamplePagination() + self.queryset = range(1, 101) + + def paginate_queryset(self, request): + return list(self.pagination.paginate_queryset(self.queryset, request)) + + def get_paginated_content(self, queryset): + response = self.pagination.get_paginated_response(queryset) + return response.data + + def get_html_context(self): + return self.pagination.get_html_context() + + def test_no_page_number(self): + request = Request(factory.get('/')) + queryset = self.paginate_queryset(request) + content = self.get_paginated_content(queryset) + context = self.get_html_context() + assert queryset == [1, 2, 3, 4, 5] + assert content == { + 'results': [1, 2, 3, 4, 5], + 'previous': None, + 'next': 'http://testserver/?page=2', + 'count': 100 + } + assert context == { + 'previous_url': None, + 'next_url': 'http://testserver/?page=2', + 'page_links': [ + PageLink('http://testserver/', 1, True, False), + PageLink('http://testserver/?page=2', 2, False, False), + PageLink('http://testserver/?page=3', 3, False, False), + PAGE_BREAK, + PageLink('http://testserver/?page=20', 20, False, False), + ] + } + assert self.pagination.display_page_controls + assert isinstance(self.pagination.to_html(), type('')) + + def test_second_page(self): + request = Request(factory.get('/', {'page': 2})) + queryset = self.paginate_queryset(request) + content = self.get_paginated_content(queryset) + context = self.get_html_context() + assert queryset == [6, 7, 8, 9, 10] + assert content == { + 'results': [6, 7, 8, 9, 10], + 'previous': 'http://testserver/', + 'next': 'http://testserver/?page=3', + 'count': 100 + } + assert context == { + 'previous_url': 'http://testserver/', + 'next_url': 'http://testserver/?page=3', + 'page_links': [ + PageLink('http://testserver/', 1, False, False), + PageLink('http://testserver/?page=2', 2, True, False), + PageLink('http://testserver/?page=3', 3, False, False), + PAGE_BREAK, + PageLink('http://testserver/?page=20', 20, False, False), + ] + } + def test_last_page(self): + request = Request(factory.get('/', {'page': 'last'})) + queryset = self.paginate_queryset(request) + content = self.get_paginated_content(queryset) + context = self.get_html_context() + assert queryset == [96, 97, 98, 99, 100] + assert content == { + 'results': [96, 97, 98, 99, 100], + 'previous': 'http://testserver/?page=19', + 'next': None, + 'count': 100 + } + assert context == { + 'previous_url': 'http://testserver/?page=19', + 'next_url': None, + 'page_links': [ + PageLink('http://testserver/', 1, False, False), + PAGE_BREAK, + PageLink('http://testserver/?page=18', 18, False, False), + PageLink('http://testserver/?page=19', 19, False, False), + PageLink('http://testserver/?page=20', 20, True, False), + ] + } -# Tests for custom pagination serializers + def test_invalid_page(self): + request = Request(factory.get('/', {'page': 'invalid'})) + with pytest.raises(exceptions.NotFound): + self.paginate_queryset(request) -class LinksSerializer(serializers.Serializer): - next = pagination.NextPageField(source='*') - prev = pagination.PreviousPageField(source='*') +class TestLimitOffset: + """ + Unit tests for `pagination.LimitOffsetPagination`. + """ -class CustomPaginationSerializer(pagination.BasePaginationSerializer): - links = LinksSerializer(source='*') # Takes the page object as the source - total_results = serializers.ReadOnlyField(source='paginator.count') + def setup(self): + class ExamplePagination(pagination.LimitOffsetPagination): + default_limit = 10 + self.pagination = ExamplePagination() + self.queryset = range(1, 101) + + def paginate_queryset(self, request): + return list(self.pagination.paginate_queryset(self.queryset, request)) + + def get_paginated_content(self, queryset): + response = self.pagination.get_paginated_response(queryset) + return response.data + + def get_html_context(self): + return self.pagination.get_html_context() + + def test_no_offset(self): + request = Request(factory.get('/', {'limit': 5})) + queryset = self.paginate_queryset(request) + content = self.get_paginated_content(queryset) + context = self.get_html_context() + assert queryset == [1, 2, 3, 4, 5] + assert content == { + 'results': [1, 2, 3, 4, 5], + 'previous': None, + 'next': 'http://testserver/?limit=5&offset=5', + 'count': 100 + } + assert context == { + 'previous_url': None, + 'next_url': 'http://testserver/?limit=5&offset=5', + 'page_links': [ + PageLink('http://testserver/?limit=5', 1, True, False), + PageLink('http://testserver/?limit=5&offset=5', 2, False, False), + PageLink('http://testserver/?limit=5&offset=10', 3, False, False), + PAGE_BREAK, + PageLink('http://testserver/?limit=5&offset=95', 20, False, False), + ] + } + assert self.pagination.display_page_controls + assert isinstance(self.pagination.to_html(), type('')) - results_field = 'objects' + def test_single_offset(self): + """ + When the offset is not a multiple of the limit we get some edge cases: + * The first page should still be offset zero. + * We may end up displaying an extra page in the pagination control. + """ + request = Request(factory.get('/', {'limit': 5, 'offset': 1})) + queryset = self.paginate_queryset(request) + content = self.get_paginated_content(queryset) + context = self.get_html_context() + assert queryset == [2, 3, 4, 5, 6] + assert content == { + 'results': [2, 3, 4, 5, 6], + 'previous': 'http://testserver/?limit=5', + 'next': 'http://testserver/?limit=5&offset=6', + 'count': 100 + } + assert context == { + 'previous_url': 'http://testserver/?limit=5', + 'next_url': 'http://testserver/?limit=5&offset=6', + 'page_links': [ + PageLink('http://testserver/?limit=5', 1, False, False), + PageLink('http://testserver/?limit=5&offset=1', 2, True, False), + PageLink('http://testserver/?limit=5&offset=6', 3, False, False), + PAGE_BREAK, + PageLink('http://testserver/?limit=5&offset=96', 21, False, False), + ] + } + def test_first_offset(self): + request = Request(factory.get('/', {'limit': 5, 'offset': 5})) + queryset = self.paginate_queryset(request) + content = self.get_paginated_content(queryset) + context = self.get_html_context() + assert queryset == [6, 7, 8, 9, 10] + assert content == { + 'results': [6, 7, 8, 9, 10], + 'previous': 'http://testserver/?limit=5', + 'next': 'http://testserver/?limit=5&offset=10', + 'count': 100 + } + assert context == { + 'previous_url': 'http://testserver/?limit=5', + 'next_url': 'http://testserver/?limit=5&offset=10', + 'page_links': [ + PageLink('http://testserver/?limit=5', 1, False, False), + PageLink('http://testserver/?limit=5&offset=5', 2, True, False), + PageLink('http://testserver/?limit=5&offset=10', 3, False, False), + PAGE_BREAK, + PageLink('http://testserver/?limit=5&offset=95', 20, False, False), + ] + } -class CustomFooSerializer(serializers.Serializer): - foo = serializers.CharField() + def test_middle_offset(self): + request = Request(factory.get('/', {'limit': 5, 'offset': 10})) + queryset = self.paginate_queryset(request) + content = self.get_paginated_content(queryset) + context = self.get_html_context() + assert queryset == [11, 12, 13, 14, 15] + assert content == { + 'results': [11, 12, 13, 14, 15], + 'previous': 'http://testserver/?limit=5&offset=5', + 'next': 'http://testserver/?limit=5&offset=15', + 'count': 100 + } + assert context == { + 'previous_url': 'http://testserver/?limit=5&offset=5', + 'next_url': 'http://testserver/?limit=5&offset=15', + 'page_links': [ + PageLink('http://testserver/?limit=5', 1, False, False), + PageLink('http://testserver/?limit=5&offset=5', 2, False, False), + PageLink('http://testserver/?limit=5&offset=10', 3, True, False), + PageLink('http://testserver/?limit=5&offset=15', 4, False, False), + PAGE_BREAK, + PageLink('http://testserver/?limit=5&offset=95', 20, False, False), + ] + } + def test_ending_offset(self): + request = Request(factory.get('/', {'limit': 5, 'offset': 95})) + queryset = self.paginate_queryset(request) + content = self.get_paginated_content(queryset) + context = self.get_html_context() + assert queryset == [96, 97, 98, 99, 100] + assert content == { + 'results': [96, 97, 98, 99, 100], + 'previous': 'http://testserver/?limit=5&offset=90', + 'next': None, + 'count': 100 + } + assert context == { + 'previous_url': 'http://testserver/?limit=5&offset=90', + 'next_url': None, + 'page_links': [ + PageLink('http://testserver/?limit=5', 1, False, False), + PAGE_BREAK, + PageLink('http://testserver/?limit=5&offset=85', 18, False, False), + PageLink('http://testserver/?limit=5&offset=90', 19, False, False), + PageLink('http://testserver/?limit=5&offset=95', 20, True, False), + ] + } -class CustomFooPaginationSerializer(pagination.PaginationSerializer): - class Meta: - object_serializer_class = CustomFooSerializer + def test_invalid_offset(self): + """ + An invalid offset query param should be treated as 0. + """ + request = Request(factory.get('/', {'limit': 5, 'offset': 'invalid'})) + queryset = self.paginate_queryset(request) + assert queryset == [1, 2, 3, 4, 5] + def test_invalid_limit(self): + """ + An invalid limit query param should be ignored in favor of the default. + """ + request = Request(factory.get('/', {'limit': 'invalid', 'offset': 0})) + queryset = self.paginate_queryset(request) + assert queryset == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] -class TestCustomPaginationSerializer(TestCase): - def setUp(self): - objects = ['john', 'paul', 'george', 'ringo'] - paginator = Paginator(objects, 2) - self.page = paginator.page(1) - def test_custom_pagination_serializer(self): - request = APIRequestFactory().get('/foobar') - serializer = CustomPaginationSerializer( - instance=self.page, - context={'request': request} - ) - expected = { - 'links': { - 'next': 'http://testserver/foobar?page=2', - 'prev': None - }, - 'total_results': 4, - 'objects': ['john', 'paul'] - } - self.assertEqual(serializer.data, expected) - - def test_custom_pagination_serializer_with_custom_object_serializer(self): - objects = [ - {'foo': 'bar'}, - {'foo': 'spam'} - ] - paginator = Paginator(objects, 1) - page = paginator.page(1) - serializer = CustomFooPaginationSerializer(page) - serializer.data - - -class NonIntegerPage(object): - - def __init__(self, paginator, object_list, prev_token, token, next_token): - self.paginator = paginator - self.object_list = object_list - self.prev_token = prev_token - self.token = token - self.next_token = next_token - - def has_next(self): - return not not self.next_token - - def next_page_number(self): - return self.next_token - - def has_previous(self): - return not not self.prev_token - - def previous_page_number(self): - return self.prev_token - - -class NonIntegerPaginator(object): - - def __init__(self, object_list, per_page): - self.object_list = object_list - self.per_page = per_page - - def count(self): - # pretend like we don't know how many pages we have - return None - - def page(self, token=None): - if token: - try: - first = self.object_list.index(token) - except ValueError: - first = 0 - else: - first = 0 - n = len(self.object_list) - last = min(first + self.per_page, n) - prev_token = self.object_list[last - (2 * self.per_page)] if first else None - next_token = self.object_list[last] if last < n else None - return NonIntegerPage(self, self.object_list[first:last], prev_token, token, next_token) - - -class TestNonIntegerPagination(TestCase): - def test_custom_pagination_serializer(self): - objects = ['john', 'paul', 'george', 'ringo'] - paginator = NonIntegerPaginator(objects, 2) - - request = APIRequestFactory().get('/foobar') - serializer = CustomPaginationSerializer( - instance=paginator.page(), - context={'request': request} - ) - expected = { - 'links': { - 'next': 'http://testserver/foobar?page={0}'.format(objects[2]), - 'prev': None - }, - 'total_results': None, - 'objects': objects[:2] - } - self.assertEqual(serializer.data, expected) +def test_get_displayed_page_numbers(): + """ + Test our contextual page display function. - request = APIRequestFactory().get('/foobar') - serializer = CustomPaginationSerializer( - instance=paginator.page('george'), - context={'request': request} - ) - expected = { - 'links': { - 'next': None, - 'prev': 'http://testserver/foobar?page={0}'.format(objects[0]), - }, - 'total_results': None, - 'objects': objects[2:] - } - self.assertEqual(serializer.data, expected) + This determines which pages to display in a pagination control, + given the current page and the last page. + """ + displayed_page_numbers = pagination._get_displayed_page_numbers + + # At five pages or less, all pages are displayed, always. + assert displayed_page_numbers(1, 5) == [1, 2, 3, 4, 5] + assert displayed_page_numbers(2, 5) == [1, 2, 3, 4, 5] + assert displayed_page_numbers(3, 5) == [1, 2, 3, 4, 5] + assert displayed_page_numbers(4, 5) == [1, 2, 3, 4, 5] + assert displayed_page_numbers(5, 5) == [1, 2, 3, 4, 5] + + # Between six and either pages we may have a single page break. + assert displayed_page_numbers(1, 6) == [1, 2, 3, None, 6] + assert displayed_page_numbers(2, 6) == [1, 2, 3, None, 6] + assert displayed_page_numbers(3, 6) == [1, 2, 3, 4, 5, 6] + assert displayed_page_numbers(4, 6) == [1, 2, 3, 4, 5, 6] + assert displayed_page_numbers(5, 6) == [1, None, 4, 5, 6] + assert displayed_page_numbers(6, 6) == [1, None, 4, 5, 6] + + assert displayed_page_numbers(1, 7) == [1, 2, 3, None, 7] + assert displayed_page_numbers(2, 7) == [1, 2, 3, None, 7] + assert displayed_page_numbers(3, 7) == [1, 2, 3, 4, None, 7] + assert displayed_page_numbers(4, 7) == [1, 2, 3, 4, 5, 6, 7] + assert displayed_page_numbers(5, 7) == [1, None, 4, 5, 6, 7] + assert displayed_page_numbers(6, 7) == [1, None, 5, 6, 7] + assert displayed_page_numbers(7, 7) == [1, None, 5, 6, 7] + + assert displayed_page_numbers(1, 8) == [1, 2, 3, None, 8] + assert displayed_page_numbers(2, 8) == [1, 2, 3, None, 8] + assert displayed_page_numbers(3, 8) == [1, 2, 3, 4, None, 8] + assert displayed_page_numbers(4, 8) == [1, 2, 3, 4, 5, None, 8] + assert displayed_page_numbers(5, 8) == [1, None, 4, 5, 6, 7, 8] + assert displayed_page_numbers(6, 8) == [1, None, 5, 6, 7, 8] + assert displayed_page_numbers(7, 8) == [1, None, 6, 7, 8] + assert displayed_page_numbers(8, 8) == [1, None, 6, 7, 8] + + # At nine or more pages we may have two page breaks, one on each side. + assert displayed_page_numbers(1, 9) == [1, 2, 3, None, 9] + assert displayed_page_numbers(2, 9) == [1, 2, 3, None, 9] + assert displayed_page_numbers(3, 9) == [1, 2, 3, 4, None, 9] + assert displayed_page_numbers(4, 9) == [1, 2, 3, 4, 5, None, 9] + assert displayed_page_numbers(5, 9) == [1, None, 4, 5, 6, None, 9] + assert displayed_page_numbers(6, 9) == [1, None, 5, 6, 7, 8, 9] + assert displayed_page_numbers(7, 9) == [1, None, 6, 7, 8, 9] + assert displayed_page_numbers(8, 9) == [1, None, 7, 8, 9] + assert displayed_page_numbers(9, 9) == [1, None, 7, 8, 9] diff --git a/tests/test_parsers.py b/tests/test_parsers.py index d28d8bd4..54455cf6 100644 --- a/tests/test_parsers.py +++ b/tests/test_parsers.py @@ -4,13 +4,9 @@ from __future__ import unicode_literals from django import forms from django.core.files.uploadhandler import MemoryFileUploadHandler from django.test import TestCase -from django.utils import unittest from django.utils.six.moves import StringIO -from rest_framework.compat import etree from rest_framework.exceptions import ParseError from rest_framework.parsers import FormParser, FileUploadParser -from rest_framework.parsers import XMLParser -import datetime class Form(forms.Form): @@ -32,62 +28,6 @@ class TestFormParser(TestCase): self.assertEqual(Form(data).is_valid(), True) -class TestXMLParser(TestCase): - def setUp(self): - self._input = StringIO( - '<?xml version="1.0" encoding="utf-8"?>' - '<root>' - '<field_a>121.0</field_a>' - '<field_b>dasd</field_b>' - '<field_c></field_c>' - '<field_d>2011-12-25 12:45:00</field_d>' - '</root>' - ) - self._data = { - 'field_a': 121, - 'field_b': 'dasd', - 'field_c': None, - 'field_d': datetime.datetime(2011, 12, 25, 12, 45, 00) - } - self._complex_data_input = StringIO( - '<?xml version="1.0" encoding="utf-8"?>' - '<root>' - '<creation_date>2011-12-25 12:45:00</creation_date>' - '<sub_data_list>' - '<list-item><sub_id>1</sub_id><sub_name>first</sub_name></list-item>' - '<list-item><sub_id>2</sub_id><sub_name>second</sub_name></list-item>' - '</sub_data_list>' - '<name>name</name>' - '</root>' - ) - self._complex_data = { - "creation_date": datetime.datetime(2011, 12, 25, 12, 45, 00), - "name": "name", - "sub_data_list": [ - { - "sub_id": 1, - "sub_name": "first" - }, - { - "sub_id": 2, - "sub_name": "second" - } - ] - } - - @unittest.skipUnless(etree, 'defusedxml not installed') - def test_parse(self): - parser = XMLParser() - data = parser.parse(self._input) - self.assertEqual(data, self._data) - - @unittest.skipUnless(etree, 'defusedxml not installed') - def test_complex_data_parse(self): - parser = XMLParser() - data = parser.parse(self._complex_data_input) - self.assertEqual(data, self._complex_data) - - class TestFileUploadParser(TestCase): def setUp(self): class MockRequest(object): diff --git a/tests/test_relations.py b/tests/test_relations.py index 62353dc2..08c92242 100644 --- a/tests/test_relations.py +++ b/tests/test_relations.py @@ -33,7 +33,7 @@ class TestPrimaryKeyRelatedField(APISimpleTestCase): with pytest.raises(serializers.ValidationError) as excinfo: self.field.to_internal_value(4) msg = excinfo.value.detail[0] - assert msg == "Invalid pk '4' - object does not exist." + assert msg == 'Invalid pk "4" - object does not exist.' def test_pk_related_lookup_invalid_type(self): with pytest.raises(serializers.ValidationError) as excinfo: diff --git a/tests/test_renderers.py b/tests/test_renderers.py index 00a24fb1..3e64d8fe 100644 --- a/tests/test_renderers.py +++ b/tests/test_renderers.py @@ -1,26 +1,19 @@ # -*- coding: utf-8 -*- from __future__ import unicode_literals - -from decimal import Decimal from django.conf.urls import patterns, url, include from django.core.cache import cache from django.db import models from django.test import TestCase -from django.utils import six, unittest -from django.utils.six import BytesIO -from django.utils.six.moves import StringIO +from django.utils import six from django.utils.translation import ugettext_lazy as _ from rest_framework import status, permissions -from rest_framework.compat import yaml, etree +from rest_framework.compat import OrderedDict from rest_framework.response import Response from rest_framework.views import APIView -from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \ - XMLRenderer, JSONPRenderer, BrowsableAPIRenderer -from rest_framework.parsers import YAMLParser, XMLParser +from rest_framework.renderers import BaseRenderer, JSONRenderer, BrowsableAPIRenderer from rest_framework.settings import api_settings from rest_framework.test import APIRequestFactory from collections import MutableMapping -import datetime import json import pickle import re @@ -108,8 +101,6 @@ urlpatterns = patterns( url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])), url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])), url(r'^cache$', MockGETView.as_view()), - url(r'^jsonp/jsonrenderer$', MockGETView.as_view(renderer_classes=[JSONRenderer, JSONPRenderer])), - url(r'^jsonp/nojsonrenderer$', MockGETView.as_view(renderer_classes=[JSONPRenderer])), url(r'^parseerror$', MockPOSTView.as_view(renderer_classes=[JSONRenderer, BrowsableAPIRenderer])), url(r'^html$', HTMLView.as_view()), url(r'^html1$', HTMLView1.as_view()), @@ -409,207 +400,6 @@ class AsciiJSONRendererTests(TestCase): self.assertEqual(content, '{"countries":["United Kingdom","France","Espa\\u00f1a"]}'.encode('utf-8')) -class JSONPRendererTests(TestCase): - """ - Tests specific to the JSONP Renderer - """ - - urls = 'tests.test_renderers' - - def test_without_callback_with_json_renderer(self): - """ - Test JSONP rendering with View JSON Renderer. - """ - resp = self.client.get( - '/jsonp/jsonrenderer', - HTTP_ACCEPT='application/javascript' - ) - self.assertEqual(resp.status_code, status.HTTP_200_OK) - self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8') - self.assertEqual( - resp.content, - ('callback(%s);' % _flat_repr).encode('ascii') - ) - - def test_without_callback_without_json_renderer(self): - """ - Test JSONP rendering without View JSON Renderer. - """ - resp = self.client.get( - '/jsonp/nojsonrenderer', - HTTP_ACCEPT='application/javascript' - ) - self.assertEqual(resp.status_code, status.HTTP_200_OK) - self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8') - self.assertEqual( - resp.content, - ('callback(%s);' % _flat_repr).encode('ascii') - ) - - def test_with_callback(self): - """ - Test JSONP rendering with callback function name. - """ - callback_func = 'myjsonpcallback' - resp = self.client.get( - '/jsonp/nojsonrenderer?callback=' + callback_func, - HTTP_ACCEPT='application/javascript' - ) - self.assertEqual(resp.status_code, status.HTTP_200_OK) - self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8') - self.assertEqual( - resp.content, - ('%s(%s);' % (callback_func, _flat_repr)).encode('ascii') - ) - - -if yaml: - _yaml_repr = 'foo: [bar, baz]\n' - - class YAMLRendererTests(TestCase): - """ - Tests specific to the YAML Renderer - """ - - def test_render(self): - """ - Test basic YAML rendering. - """ - obj = {'foo': ['bar', 'baz']} - renderer = YAMLRenderer() - content = renderer.render(obj, 'application/yaml') - self.assertEqual(content.decode('utf-8'), _yaml_repr) - - def test_render_and_parse(self): - """ - Test rendering and then parsing returns the original object. - IE obj -> render -> parse -> obj. - """ - obj = {'foo': ['bar', 'baz']} - - renderer = YAMLRenderer() - parser = YAMLParser() - - content = renderer.render(obj, 'application/yaml') - data = parser.parse(BytesIO(content)) - self.assertEqual(obj, data) - - def test_render_decimal(self): - """ - Test YAML decimal rendering. - """ - renderer = YAMLRenderer() - content = renderer.render({'field': Decimal('111.2')}, 'application/yaml') - self.assertYAMLContains(content.decode('utf-8'), "field: '111.2'") - - def assertYAMLContains(self, content, string): - self.assertTrue(string in content, '%r not in %r' % (string, content)) - - def test_proper_encoding(self): - obj = {'countries': ['United Kingdom', 'France', 'España']} - renderer = YAMLRenderer() - content = renderer.render(obj, 'application/yaml') - self.assertEqual(content.strip(), 'countries: [United Kingdom, France, España]'.encode('utf-8')) - - -class XMLRendererTestCase(TestCase): - """ - Tests specific to the XML Renderer - """ - - _complex_data = { - "creation_date": datetime.datetime(2011, 12, 25, 12, 45, 00), - "name": "name", - "sub_data_list": [ - { - "sub_id": 1, - "sub_name": "first" - }, - { - "sub_id": 2, - "sub_name": "second" - } - ] - } - - def test_render_string(self): - """ - Test XML rendering. - """ - renderer = XMLRenderer() - content = renderer.render({'field': 'astring'}, 'application/xml') - self.assertXMLContains(content, '<field>astring</field>') - - def test_render_integer(self): - """ - Test XML rendering. - """ - renderer = XMLRenderer() - content = renderer.render({'field': 111}, 'application/xml') - self.assertXMLContains(content, '<field>111</field>') - - def test_render_datetime(self): - """ - Test XML rendering. - """ - renderer = XMLRenderer() - content = renderer.render({ - 'field': datetime.datetime(2011, 12, 25, 12, 45, 00) - }, 'application/xml') - self.assertXMLContains(content, '<field>2011-12-25 12:45:00</field>') - - def test_render_float(self): - """ - Test XML rendering. - """ - renderer = XMLRenderer() - content = renderer.render({'field': 123.4}, 'application/xml') - self.assertXMLContains(content, '<field>123.4</field>') - - def test_render_decimal(self): - """ - Test XML rendering. - """ - renderer = XMLRenderer() - content = renderer.render({'field': Decimal('111.2')}, 'application/xml') - self.assertXMLContains(content, '<field>111.2</field>') - - def test_render_none(self): - """ - Test XML rendering. - """ - renderer = XMLRenderer() - content = renderer.render({'field': None}, 'application/xml') - self.assertXMLContains(content, '<field></field>') - - def test_render_complex_data(self): - """ - Test XML rendering. - """ - renderer = XMLRenderer() - content = renderer.render(self._complex_data, 'application/xml') - self.assertXMLContains(content, '<sub_name>first</sub_name>') - self.assertXMLContains(content, '<sub_name>second</sub_name>') - - @unittest.skipUnless(etree, 'defusedxml not installed') - def test_render_and_parse_complex_data(self): - """ - Test XML rendering. - """ - renderer = XMLRenderer() - content = StringIO(renderer.render(self._complex_data, 'application/xml')) - - parser = XMLParser() - complex_data_out = parser.parse(content) - error_msg = "complex data differs!IN:\n %s \n\n OUT:\n %s" % (repr(self._complex_data), repr(complex_data_out)) - self.assertEqual(self._complex_data, complex_data_out, error_msg) - - def assertXMLContains(self, xml, string): - self.assertTrue(xml.startswith('<?xml version="1.0" encoding="utf-8"?>\n<root>')) - self.assertTrue(xml.endswith('</root>')) - self.assertTrue(string in xml, '%r not in %r' % (string, xml)) - - # Tests for caching issue, #346 class CacheRenderTest(TestCase): """ @@ -699,3 +489,25 @@ class CacheRenderTest(TestCase): cached_resp = cache.get(self.cache_key) self.assertIsInstance(cached_resp, Response) self.assertEqual(cached_resp.content, resp.content) + + +class TestJSONIndentationStyles: + def test_indented(self): + renderer = JSONRenderer() + data = OrderedDict([('a', 1), ('b', 2)]) + assert renderer.render(data) == b'{"a":1,"b":2}' + + def test_compact(self): + renderer = JSONRenderer() + data = OrderedDict([('a', 1), ('b', 2)]) + context = {'indent': 4} + assert ( + renderer.render(data, renderer_context=context) == + b'{\n "a": 1,\n "b": 2\n}' + ) + + def test_long_form(self): + renderer = JSONRenderer() + renderer.compact = False + data = OrderedDict([('a', 1), ('b', 2)]) + assert renderer.render(data) == b'{"a": 1, "b": 2}' diff --git a/tests/test_serializer_bulk_update.py b/tests/test_serializer_bulk_update.py index fb881a75..bc955b2e 100644 --- a/tests/test_serializer_bulk_update.py +++ b/tests/test_serializer_bulk_update.py @@ -101,7 +101,7 @@ class BulkCreateSerializerTests(TestCase): serializer = self.BookSerializer(data=data, many=True) self.assertEqual(serializer.is_valid(), False) - expected_errors = {'non_field_errors': ['Expected a list of items but got type `int`.']} + expected_errors = {'non_field_errors': ['Expected a list of items but got type "int".']} self.assertEqual(serializer.errors, expected_errors) @@ -118,6 +118,6 @@ class BulkCreateSerializerTests(TestCase): serializer = self.BookSerializer(data=data, many=True) self.assertEqual(serializer.is_valid(), False) - expected_errors = {'non_field_errors': ['Expected a list of items but got type `dict`.']} + expected_errors = {'non_field_errors': ['Expected a list of items but got type "dict".']} self.assertEqual(serializer.errors, expected_errors) diff --git a/tests/test_templatetags.py b/tests/test_templatetags.py index b04a937e..0cee91f1 100644 --- a/tests/test_templatetags.py +++ b/tests/test_templatetags.py @@ -54,7 +54,7 @@ class Issue1386Tests(TestCase): class URLizerTests(TestCase): """ - Test if both JSON and YAML URLs are transformed into links well + Test if JSON URLs are transformed into links well """ def _urlize_dict_check(self, data): """ @@ -73,14 +73,3 @@ class URLizerTests(TestCase): data['"foo_set": [\n "http://api/foos/1/"\n], '] = \ '"foo_set": [\n "<a href="http://api/foos/1/">http://api/foos/1/</a>"\n], ' self._urlize_dict_check(data) - - def test_yaml_with_url(self): - """ - Test if YAML URLs are transformed into links well - """ - data = {} - data['''{users: 'http://api/users/'}'''] = \ - '''{users: '<a href="http://api/users/">http://api/users/</a>'}''' - data['''foo_set: ['http://api/foos/1/']'''] = \ - '''foo_set: ['<a href="http://api/foos/1/">http://api/foos/1/</a>']''' - self._urlize_dict_check(data) diff --git a/tests/test_versioning.py b/tests/test_versioning.py new file mode 100644 index 00000000..c44f727d --- /dev/null +++ b/tests/test_versioning.py @@ -0,0 +1,223 @@ +from django.conf.urls import include, url +from rest_framework import status, versioning +from rest_framework.decorators import APIView +from rest_framework.response import Response +from rest_framework.reverse import reverse +from rest_framework.test import APIRequestFactory, APITestCase + + +class RequestVersionView(APIView): + def get(self, request, *args, **kwargs): + return Response({'version': request.version}) + + +class ReverseView(APIView): + def get(self, request, *args, **kwargs): + return Response({'url': reverse('another', request=request)}) + + +class RequestInvalidVersionView(APIView): + def determine_version(self, request, *args, **kwargs): + scheme = self.versioning_class() + scheme.allowed_versions = ('v1', 'v2') + return (scheme.determine_version(request, *args, **kwargs), scheme) + + def get(self, request, *args, **kwargs): + return Response({'version': request.version}) + + +factory = APIRequestFactory() + +mock_view = lambda request: None + +included_patterns = [ + url(r'^namespaced/$', mock_view, name='another'), +] + +urlpatterns = [ + url(r'^v1/', include(included_patterns, namespace='v1')), + url(r'^another/$', mock_view, name='another'), + url(r'^(?P<version>[^/]+)/another/$', mock_view, name='another') +] + + +class TestRequestVersion: + def test_unversioned(self): + view = RequestVersionView.as_view() + + request = factory.get('/endpoint/') + response = view(request) + assert response.data == {'version': None} + + def test_query_param_versioning(self): + scheme = versioning.QueryParameterVersioning + view = RequestVersionView.as_view(versioning_class=scheme) + + request = factory.get('/endpoint/?version=1.2.3') + response = view(request) + assert response.data == {'version': '1.2.3'} + + request = factory.get('/endpoint/') + response = view(request) + assert response.data == {'version': None} + + def test_host_name_versioning(self): + scheme = versioning.HostNameVersioning + view = RequestVersionView.as_view(versioning_class=scheme) + + request = factory.get('/endpoint/', HTTP_HOST='v1.example.org') + response = view(request) + assert response.data == {'version': 'v1'} + + request = factory.get('/endpoint/') + response = view(request) + assert response.data == {'version': None} + + def test_accept_header_versioning(self): + scheme = versioning.AcceptHeaderVersioning + view = RequestVersionView.as_view(versioning_class=scheme) + + request = factory.get('/endpoint/', HTTP_ACCEPT='application/json; version=1.2.3') + response = view(request) + assert response.data == {'version': '1.2.3'} + + request = factory.get('/endpoint/', HTTP_ACCEPT='application/json') + response = view(request) + assert response.data == {'version': None} + + def test_url_path_versioning(self): + scheme = versioning.URLPathVersioning + view = RequestVersionView.as_view(versioning_class=scheme) + + request = factory.get('/1.2.3/endpoint/') + response = view(request, version='1.2.3') + assert response.data == {'version': '1.2.3'} + + request = factory.get('/endpoint/') + response = view(request) + assert response.data == {'version': None} + + def test_namespace_versioning(self): + class FakeResolverMatch: + namespace = 'v1' + + scheme = versioning.NamespaceVersioning + view = RequestVersionView.as_view(versioning_class=scheme) + + request = factory.get('/v1/endpoint/') + request.resolver_match = FakeResolverMatch + response = view(request, version='v1') + assert response.data == {'version': 'v1'} + + request = factory.get('/endpoint/') + response = view(request) + assert response.data == {'version': None} + + +class TestURLReversing(APITestCase): + urls = 'tests.test_versioning' + + def test_reverse_unversioned(self): + view = ReverseView.as_view() + + request = factory.get('/endpoint/') + response = view(request) + assert response.data == {'url': 'http://testserver/another/'} + + def test_reverse_query_param_versioning(self): + scheme = versioning.QueryParameterVersioning + view = ReverseView.as_view(versioning_class=scheme) + + request = factory.get('/endpoint/?version=v1') + response = view(request) + assert response.data == {'url': 'http://testserver/another/?version=v1'} + + request = factory.get('/endpoint/') + response = view(request) + assert response.data == {'url': 'http://testserver/another/'} + + def test_reverse_host_name_versioning(self): + scheme = versioning.HostNameVersioning + view = ReverseView.as_view(versioning_class=scheme) + + request = factory.get('/endpoint/', HTTP_HOST='v1.example.org') + response = view(request) + assert response.data == {'url': 'http://v1.example.org/another/'} + + request = factory.get('/endpoint/') + response = view(request) + assert response.data == {'url': 'http://testserver/another/'} + + def test_reverse_url_path_versioning(self): + scheme = versioning.URLPathVersioning + view = ReverseView.as_view(versioning_class=scheme) + + request = factory.get('/v1/endpoint/') + response = view(request, version='v1') + assert response.data == {'url': 'http://testserver/v1/another/'} + + request = factory.get('/endpoint/') + response = view(request) + assert response.data == {'url': 'http://testserver/another/'} + + def test_reverse_namespace_versioning(self): + class FakeResolverMatch: + namespace = 'v1' + + scheme = versioning.NamespaceVersioning + view = ReverseView.as_view(versioning_class=scheme) + + request = factory.get('/v1/endpoint/') + request.resolver_match = FakeResolverMatch + response = view(request, version='v1') + assert response.data == {'url': 'http://testserver/v1/namespaced/'} + + request = factory.get('/endpoint/') + response = view(request) + assert response.data == {'url': 'http://testserver/another/'} + + +class TestInvalidVersion: + def test_invalid_query_param_versioning(self): + scheme = versioning.QueryParameterVersioning + view = RequestInvalidVersionView.as_view(versioning_class=scheme) + + request = factory.get('/endpoint/?version=v3') + response = view(request) + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_invalid_host_name_versioning(self): + scheme = versioning.HostNameVersioning + view = RequestInvalidVersionView.as_view(versioning_class=scheme) + + request = factory.get('/endpoint/', HTTP_HOST='v3.example.org') + response = view(request) + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_invalid_accept_header_versioning(self): + scheme = versioning.AcceptHeaderVersioning + view = RequestInvalidVersionView.as_view(versioning_class=scheme) + + request = factory.get('/endpoint/', HTTP_ACCEPT='application/json; version=v3') + response = view(request) + assert response.status_code == status.HTTP_406_NOT_ACCEPTABLE + + def test_invalid_url_path_versioning(self): + scheme = versioning.URLPathVersioning + view = RequestInvalidVersionView.as_view(versioning_class=scheme) + + request = factory.get('/v3/endpoint/') + response = view(request, version='v3') + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_invalid_namespace_versioning(self): + class FakeResolverMatch: + namespace = 'v3' + + scheme = versioning.NamespaceVersioning + view = RequestInvalidVersionView.as_view(versioning_class=scheme) + + request = factory.get('/v3/endpoint/') + request.resolver_match = FakeResolverMatch + response = view(request, version='v3') + assert response.status_code == status.HTTP_404_NOT_FOUND |
