aboutsummaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/browsable_api/auth_urls.py1
-rw-r--r--tests/conftest.py20
-rw-r--r--tests/test_authentication.py430
-rw-r--r--tests/test_fields.py44
-rw-r--r--tests/test_generics.py6
-rw-r--r--tests/test_metadata.py60
-rw-r--r--tests/test_model_serializer.py8
-rw-r--r--tests/test_pagination.py1048
-rw-r--r--tests/test_parsers.py60
-rw-r--r--tests/test_relations.py2
-rw-r--r--tests/test_renderers.py238
-rw-r--r--tests/test_serializer_bulk_update.py4
-rw-r--r--tests/test_templatetags.py13
-rw-r--r--tests/test_versioning.py223
14 files changed, 913 insertions, 1244 deletions
diff --git a/tests/browsable_api/auth_urls.py b/tests/browsable_api/auth_urls.py
index bce7dcf9..97bc1036 100644
--- a/tests/browsable_api/auth_urls.py
+++ b/tests/browsable_api/auth_urls.py
@@ -3,6 +3,7 @@ from django.conf.urls import patterns, url, include
from .views import MockView
+
urlpatterns = patterns(
'',
(r'^$', MockView.as_view()),
diff --git a/tests/conftest.py b/tests/conftest.py
index 31142eaf..44ed070b 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -44,26 +44,6 @@ def pytest_configure():
),
)
- try:
- import oauth_provider # NOQA
- import oauth2 # NOQA
- except ImportError:
- pass
- else:
- settings.INSTALLED_APPS += (
- 'oauth_provider',
- )
-
- try:
- import provider # NOQA
- except ImportError:
- pass
- else:
- settings.INSTALLED_APPS += (
- 'provider',
- 'provider.oauth2',
- )
-
# guardian is optional
try:
import guardian # NOQA
diff --git a/tests/test_authentication.py b/tests/test_authentication.py
index 44837c4e..04c5782e 100644
--- a/tests/test_authentication.py
+++ b/tests/test_authentication.py
@@ -3,8 +3,7 @@ from django.conf.urls import patterns, url, include
from django.contrib.auth.models import User
from django.http import HttpResponse
from django.test import TestCase
-from django.utils import six, unittest
-from django.utils.http import urlencode
+from django.utils import six
from rest_framework import HTTP_HEADER_ENCODING
from rest_framework import exceptions
from rest_framework import permissions
@@ -16,17 +15,11 @@ from rest_framework.authentication import (
TokenAuthentication,
BasicAuthentication,
SessionAuthentication,
- OAuthAuthentication,
- OAuth2Authentication
)
from rest_framework.authtoken.models import Token
-from rest_framework.compat import oauth2_provider, oauth2_provider_scope
-from rest_framework.compat import oauth, oauth_provider
from rest_framework.test import APIRequestFactory, APIClient
from rest_framework.views import APIView
import base64
-import time
-import datetime
factory = APIRequestFactory()
@@ -50,37 +43,10 @@ urlpatterns = patterns(
(r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])),
(r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),
(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),
- (r'^oauth/$', MockView.as_view(authentication_classes=[OAuthAuthentication])),
- (
- r'^oauth-with-scope/$',
- MockView.as_view(
- authentication_classes=[OAuthAuthentication],
- permission_classes=[permissions.TokenHasReadWriteScope]
- )
- ),
url(r'^auth/', include('rest_framework.urls', namespace='rest_framework'))
)
-class OAuth2AuthenticationDebug(OAuth2Authentication):
- allow_query_params_token = True
-
-if oauth2_provider is not None:
- urlpatterns += patterns(
- '',
- url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')),
- url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])),
- url(r'^oauth2-test-debug/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationDebug])),
- url(
- r'^oauth2-with-scope-test/$',
- MockView.as_view(
- authentication_classes=[OAuth2Authentication],
- permission_classes=[permissions.TokenHasReadWriteScope]
- )
- )
- )
-
-
class BasicAuthTests(TestCase):
"""Basic authentication"""
urls = 'tests.test_authentication'
@@ -276,400 +242,6 @@ class IncorrectCredentialsTests(TestCase):
self.assertEqual(response.data, {'detail': 'Bad credentials'})
-class OAuthTests(TestCase):
- """OAuth 1.0a authentication"""
- urls = 'tests.test_authentication'
-
- def setUp(self):
- # these imports are here because oauth is optional and hiding them in try..except block or compat
- # could obscure problems if something breaks
- from oauth_provider.models import Consumer, Scope
- from oauth_provider.models import Token as OAuthToken
- from oauth_provider import consts
-
- self.consts = consts
-
- self.csrf_client = APIClient(enforce_csrf_checks=True)
- self.username = 'john'
- self.email = 'lennon@thebeatles.com'
- self.password = 'password'
- self.user = User.objects.create_user(self.username, self.email, self.password)
-
- self.CONSUMER_KEY = 'consumer_key'
- self.CONSUMER_SECRET = 'consumer_secret'
- self.TOKEN_KEY = "token_key"
- self.TOKEN_SECRET = "token_secret"
-
- self.consumer = Consumer.objects.create(
- key=self.CONSUMER_KEY, secret=self.CONSUMER_SECRET,
- name='example', user=self.user, status=self.consts.ACCEPTED
- )
-
- self.scope = Scope.objects.create(name="resource name", url="api/")
- self.token = OAuthToken.objects.create(
- user=self.user, consumer=self.consumer, scope=self.scope,
- token_type=OAuthToken.ACCESS, key=self.TOKEN_KEY, secret=self.TOKEN_SECRET,
- is_approved=True
- )
-
- def _create_authorization_header(self):
- params = {
- 'oauth_version': "1.0",
- 'oauth_nonce': oauth.generate_nonce(),
- 'oauth_timestamp': int(time.time()),
- 'oauth_token': self.token.key,
- 'oauth_consumer_key': self.consumer.key
- }
-
- req = oauth.Request(method="GET", url="http://example.com", parameters=params)
-
- signature_method = oauth.SignatureMethod_PLAINTEXT()
- req.sign_request(signature_method, self.consumer, self.token)
-
- return req.to_header()["Authorization"]
-
- def _create_authorization_url_parameters(self):
- params = {
- 'oauth_version': "1.0",
- 'oauth_nonce': oauth.generate_nonce(),
- 'oauth_timestamp': int(time.time()),
- 'oauth_token': self.token.key,
- 'oauth_consumer_key': self.consumer.key
- }
-
- req = oauth.Request(method="GET", url="http://example.com", parameters=params)
-
- signature_method = oauth.SignatureMethod_PLAINTEXT()
- req.sign_request(signature_method, self.consumer, self.token)
- return dict(req)
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_post_form_passing_oauth(self):
- """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
- auth = self._create_authorization_header()
- response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_post_form_repeated_nonce_failing_oauth(self):
- """Ensure POSTing form over OAuth with repeated auth (same nonces and timestamp) credentials fails"""
- auth = self._create_authorization_header()
- response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
-
- # simulate reply attack auth header containes already used (nonce, timestamp) pair
- response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
- self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_post_form_token_removed_failing_oauth(self):
- """Ensure POSTing when there is no OAuth access token in db fails"""
- self.token.delete()
- auth = self._create_authorization_header()
- response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
- self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_post_form_consumer_status_not_accepted_failing_oauth(self):
- """Ensure POSTing when consumer status is anything other than ACCEPTED fails"""
- for consumer_status in (self.consts.CANCELED, self.consts.PENDING, self.consts.REJECTED):
- self.consumer.status = consumer_status
- self.consumer.save()
-
- auth = self._create_authorization_header()
- response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
- self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_post_form_with_request_token_failing_oauth(self):
- """Ensure POSTing with unauthorized request token instead of access token fails"""
- self.token.token_type = self.token.REQUEST
- self.token.save()
-
- auth = self._create_authorization_header()
- response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
- self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_post_form_with_urlencoded_parameters(self):
- """Ensure POSTing with x-www-form-urlencoded auth parameters passes"""
- params = self._create_authorization_url_parameters()
- auth = self._create_authorization_header()
- response = self.csrf_client.post('/oauth/', params, HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_get_form_with_url_parameters(self):
- """Ensure GETing with auth in url parameters passes"""
- params = self._create_authorization_url_parameters()
- response = self.csrf_client.get('/oauth/', params)
- self.assertEqual(response.status_code, 200)
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_post_hmac_sha1_signature_passes(self):
- """Ensure POSTing using HMAC_SHA1 signature method passes"""
- params = {
- 'oauth_version': "1.0",
- 'oauth_nonce': oauth.generate_nonce(),
- 'oauth_timestamp': int(time.time()),
- 'oauth_token': self.token.key,
- 'oauth_consumer_key': self.consumer.key
- }
-
- req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params)
-
- signature_method = oauth.SignatureMethod_HMAC_SHA1()
- req.sign_request(signature_method, self.consumer, self.token)
- auth = req.to_header()["Authorization"]
-
- response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_get_form_with_readonly_resource_passing_auth(self):
- """Ensure POSTing with a readonly scope instead of a write scope fails"""
- read_only_access_token = self.token
- read_only_access_token.scope.is_readonly = True
- read_only_access_token.scope.save()
- params = self._create_authorization_url_parameters()
- response = self.csrf_client.get('/oauth-with-scope/', params)
- self.assertEqual(response.status_code, 200)
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_post_form_with_readonly_resource_failing_auth(self):
- """Ensure POSTing with a readonly resource instead of a write scope fails"""
- read_only_access_token = self.token
- read_only_access_token.scope.is_readonly = True
- read_only_access_token.scope.save()
- params = self._create_authorization_url_parameters()
- response = self.csrf_client.post('/oauth-with-scope/', params)
- self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_post_form_with_write_resource_passing_auth(self):
- """Ensure POSTing with a write resource succeed"""
- read_write_access_token = self.token
- read_write_access_token.scope.is_readonly = False
- read_write_access_token.scope.save()
- params = self._create_authorization_url_parameters()
- auth = self._create_authorization_header()
- response = self.csrf_client.post('/oauth-with-scope/', params, HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_bad_consumer_key(self):
- """Ensure POSTing using HMAC_SHA1 signature method passes"""
- params = {
- 'oauth_version': "1.0",
- 'oauth_nonce': oauth.generate_nonce(),
- 'oauth_timestamp': int(time.time()),
- 'oauth_token': self.token.key,
- 'oauth_consumer_key': 'badconsumerkey'
- }
-
- req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params)
-
- signature_method = oauth.SignatureMethod_HMAC_SHA1()
- req.sign_request(signature_method, self.consumer, self.token)
- auth = req.to_header()["Authorization"]
-
- response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 401)
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_bad_token_key(self):
- """Ensure POSTing using HMAC_SHA1 signature method passes"""
- params = {
- 'oauth_version': "1.0",
- 'oauth_nonce': oauth.generate_nonce(),
- 'oauth_timestamp': int(time.time()),
- 'oauth_token': 'badtokenkey',
- 'oauth_consumer_key': self.consumer.key
- }
-
- req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params)
-
- signature_method = oauth.SignatureMethod_HMAC_SHA1()
- req.sign_request(signature_method, self.consumer, self.token)
- auth = req.to_header()["Authorization"]
-
- response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 401)
-
-
-class OAuth2Tests(TestCase):
- """OAuth 2.0 authentication"""
- urls = 'tests.test_authentication'
-
- def setUp(self):
- self.csrf_client = APIClient(enforce_csrf_checks=True)
- self.username = 'john'
- self.email = 'lennon@thebeatles.com'
- self.password = 'password'
- self.user = User.objects.create_user(self.username, self.email, self.password)
-
- self.CLIENT_ID = 'client_key'
- self.CLIENT_SECRET = 'client_secret'
- self.ACCESS_TOKEN = "access_token"
- self.REFRESH_TOKEN = "refresh_token"
-
- self.oauth2_client = oauth2_provider.oauth2.models.Client.objects.create(
- client_id=self.CLIENT_ID,
- client_secret=self.CLIENT_SECRET,
- redirect_uri='',
- client_type=0,
- name='example',
- user=None,
- )
-
- self.access_token = oauth2_provider.oauth2.models.AccessToken.objects.create(
- token=self.ACCESS_TOKEN,
- client=self.oauth2_client,
- user=self.user,
- )
- self.refresh_token = oauth2_provider.oauth2.models.RefreshToken.objects.create(
- user=self.user,
- access_token=self.access_token,
- client=self.oauth2_client
- )
-
- def _create_authorization_header(self, token=None):
- return "Bearer {0}".format(token or self.access_token.token)
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_get_form_with_wrong_authorization_header_token_type_failing(self):
- """Ensure that a wrong token type lead to the correct HTTP error status code"""
- auth = "Wrong token-type-obsviously"
- response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 401)
- response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 401)
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_get_form_with_wrong_authorization_header_token_format_failing(self):
- """Ensure that a wrong token format lead to the correct HTTP error status code"""
- auth = "Bearer wrong token format"
- response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 401)
- response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 401)
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_get_form_with_wrong_authorization_header_token_failing(self):
- """Ensure that a wrong token lead to the correct HTTP error status code"""
- auth = "Bearer wrong-token"
- response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 401)
- response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 401)
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_get_form_with_wrong_authorization_header_token_missing(self):
- """Ensure that a missing token lead to the correct HTTP error status code"""
- auth = "Bearer"
- response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 401)
- response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 401)
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_get_form_passing_auth(self):
- """Ensure GETing form over OAuth with correct client credentials succeed"""
- auth = self._create_authorization_header()
- response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_post_form_passing_auth_url_transport(self):
- """Ensure GETing form over OAuth with correct client credentials in form data succeed"""
- response = self.csrf_client.post(
- '/oauth2-test/',
- data={'access_token': self.access_token.token}
- )
- self.assertEqual(response.status_code, 200)
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_get_form_passing_auth_url_transport(self):
- """Ensure GETing form over OAuth with correct client credentials in query succeed when DEBUG is True"""
- query = urlencode({'access_token': self.access_token.token})
- response = self.csrf_client.get('/oauth2-test-debug/?%s' % query)
- self.assertEqual(response.status_code, 200)
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_get_form_failing_auth_url_transport(self):
- """Ensure GETing form over OAuth with correct client credentials in query fails when DEBUG is False"""
- query = urlencode({'access_token': self.access_token.token})
- response = self.csrf_client.get('/oauth2-test/?%s' % query)
- self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_post_form_passing_auth(self):
- """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
- auth = self._create_authorization_header()
- response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_post_form_token_removed_failing_auth(self):
- """Ensure POSTing when there is no OAuth access token in db fails"""
- self.access_token.delete()
- auth = self._create_authorization_header()
- response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
- self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_post_form_with_refresh_token_failing_auth(self):
- """Ensure POSTing with refresh token instead of access token fails"""
- auth = self._create_authorization_header(token=self.refresh_token.token)
- response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
- self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_post_form_with_expired_access_token_failing_auth(self):
- """Ensure POSTing with expired access token fails with an 'Invalid token' error"""
- self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10) # 10 seconds late
- self.access_token.save()
- auth = self._create_authorization_header()
- response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
- self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
- self.assertIn('Invalid token', response.content)
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_post_form_with_invalid_scope_failing_auth(self):
- """Ensure POSTing with a readonly scope instead of a write scope fails"""
- read_only_access_token = self.access_token
- read_only_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['read']
- read_only_access_token.save()
- auth = self._create_authorization_header(token=read_only_access_token.token)
- response = self.csrf_client.get('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
- response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_post_form_with_valid_scope_passing_auth(self):
- """Ensure POSTing with a write scope succeed"""
- read_write_access_token = self.access_token
- read_write_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['write']
- read_write_access_token.save()
- auth = self._create_authorization_header(token=read_write_access_token.token)
- response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
-
-
class FailingAuthAccessedInRenderer(TestCase):
def setUp(self):
class AuthAccessingRenderer(renderers.BaseRenderer):
diff --git a/tests/test_fields.py b/tests/test_fields.py
index 6744cf64..48ada780 100644
--- a/tests/test_fields.py
+++ b/tests/test_fields.py
@@ -347,7 +347,7 @@ class TestBooleanField(FieldValues):
False: False,
}
invalid_inputs = {
- 'foo': ['`foo` is not a valid boolean.'],
+ 'foo': ['"foo" is not a valid boolean.'],
None: ['This field may not be null.']
}
outputs = {
@@ -377,7 +377,7 @@ class TestNullBooleanField(FieldValues):
None: None
}
invalid_inputs = {
- 'foo': ['`foo` is not a valid boolean.'],
+ 'foo': ['"foo" is not a valid boolean.'],
}
outputs = {
'true': True,
@@ -448,7 +448,7 @@ class TestSlugField(FieldValues):
'slug-99': 'slug-99',
}
invalid_inputs = {
- 'slug 99': ["Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens."]
+ 'slug 99': ['Enter a valid "slug" consisting of letters, numbers, underscores or hyphens.']
}
outputs = {}
field = serializers.SlugField()
@@ -666,8 +666,8 @@ class TestDateField(FieldValues):
datetime.date(2001, 1, 1): datetime.date(2001, 1, 1),
}
invalid_inputs = {
- 'abc': ['Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]'],
- '2001-99-99': ['Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]'],
+ 'abc': ['Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]].'],
+ '2001-99-99': ['Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]].'],
datetime.datetime(2001, 1, 1, 12, 00): ['Expected a date but got a datetime.'],
}
outputs = {
@@ -684,7 +684,7 @@ class TestCustomInputFormatDateField(FieldValues):
'1 Jan 2001': datetime.date(2001, 1, 1),
}
invalid_inputs = {
- '2001-01-01': ['Date has wrong format. Use one of these formats instead: DD [Jan-Dec] YYYY']
+ '2001-01-01': ['Date has wrong format. Use one of these formats instead: DD [Jan-Dec] YYYY.']
}
outputs = {}
field = serializers.DateField(input_formats=['%d %b %Y'])
@@ -728,8 +728,8 @@ class TestDateTimeField(FieldValues):
'2001-01-01T14:00+01:00' if (django.VERSION > (1, 4)) else '2001-01-01T13:00Z': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC())
}
invalid_inputs = {
- 'abc': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]'],
- '2001-99-99T99:00': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]'],
+ 'abc': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z].'],
+ '2001-99-99T99:00': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z].'],
datetime.date(2001, 1, 1): ['Expected a datetime but got a date.'],
}
outputs = {
@@ -747,7 +747,7 @@ class TestCustomInputFormatDateTimeField(FieldValues):
'1:35pm, 1 Jan 2001': datetime.datetime(2001, 1, 1, 13, 35, tzinfo=timezone.UTC()),
}
invalid_inputs = {
- '2001-01-01T20:50': ['Datetime has wrong format. Use one of these formats instead: hh:mm[AM|PM], DD [Jan-Dec] YYYY']
+ '2001-01-01T20:50': ['Datetime has wrong format. Use one of these formats instead: hh:mm[AM|PM], DD [Jan-Dec] YYYY.']
}
outputs = {}
field = serializers.DateTimeField(default_timezone=timezone.UTC(), input_formats=['%I:%M%p, %d %b %Y'])
@@ -799,8 +799,8 @@ class TestTimeField(FieldValues):
datetime.time(13, 00): datetime.time(13, 00),
}
invalid_inputs = {
- 'abc': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]]'],
- '99:99': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]]'],
+ 'abc': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]].'],
+ '99:99': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]].'],
}
outputs = {
datetime.time(13, 00): '13:00:00'
@@ -816,7 +816,7 @@ class TestCustomInputFormatTimeField(FieldValues):
'1:00pm': datetime.time(13, 00),
}
invalid_inputs = {
- '13:00': ['Time has wrong format. Use one of these formats instead: hh:mm[AM|PM]'],
+ '13:00': ['Time has wrong format. Use one of these formats instead: hh:mm[AM|PM].'],
}
outputs = {}
field = serializers.TimeField(input_formats=['%I:%M%p'])
@@ -858,7 +858,7 @@ class TestChoiceField(FieldValues):
'good': 'good',
}
invalid_inputs = {
- 'amazing': ['`amazing` is not a valid choice.']
+ 'amazing': ['"amazing" is not a valid choice.']
}
outputs = {
'good': 'good',
@@ -898,8 +898,8 @@ class TestChoiceFieldWithType(FieldValues):
3: 3,
}
invalid_inputs = {
- 5: ['`5` is not a valid choice.'],
- 'abc': ['`abc` is not a valid choice.']
+ 5: ['"5" is not a valid choice.'],
+ 'abc': ['"abc" is not a valid choice.']
}
outputs = {
'1': 1,
@@ -925,7 +925,7 @@ class TestChoiceFieldWithListChoices(FieldValues):
'good': 'good',
}
invalid_inputs = {
- 'awful': ['`awful` is not a valid choice.']
+ 'awful': ['"awful" is not a valid choice.']
}
outputs = {
'good': 'good'
@@ -943,8 +943,8 @@ class TestMultipleChoiceField(FieldValues):
('aircon', 'manual'): set(['aircon', 'manual']),
}
invalid_inputs = {
- 'abc': ['Expected a list of items but got type `str`.'],
- ('aircon', 'incorrect'): ['`incorrect` is not a valid choice.']
+ 'abc': ['Expected a list of items but got type "str".'],
+ ('aircon', 'incorrect'): ['"incorrect" is not a valid choice.']
}
outputs = [
(['aircon', 'manual'], set(['aircon', 'manual']))
@@ -1054,7 +1054,7 @@ class TestListField(FieldValues):
(['1', '2', '3'], [1, 2, 3])
]
invalid_inputs = [
- ('not a list', ['Expected a list of items but got type `str`']),
+ ('not a list', ['Expected a list of items but got type "str".']),
([1, 2, 'error'], ['A valid integer is required.'])
]
outputs = [
@@ -1072,7 +1072,7 @@ class TestUnvalidatedListField(FieldValues):
([1, '2', True, [4, 5, 6]], [1, '2', True, [4, 5, 6]]),
]
invalid_inputs = [
- ('not a list', ['Expected a list of items but got type `str`']),
+ ('not a list', ['Expected a list of items but got type "str".']),
]
outputs = [
([1, '2', True, [4, 5, 6]], [1, '2', True, [4, 5, 6]]),
@@ -1089,7 +1089,7 @@ class TestDictField(FieldValues):
]
invalid_inputs = [
({'a': 1, 'b': None}, ['This field may not be null.']),
- ('not a dict', ['Expected a dictionary of items but got type `str`']),
+ ('not a dict', ['Expected a dictionary of items but got type "str".']),
]
outputs = [
({'a': 1, 'b': '2', 3: 3}, {'a': '1', 'b': '2', '3': '3'}),
@@ -1105,7 +1105,7 @@ class TestUnvalidatedDictField(FieldValues):
({'a': 1, 'b': [4, 5, 6], 1: 123}, {'a': 1, 'b': [4, 5, 6], '1': 123}),
]
invalid_inputs = [
- ('not a dict', ['Expected a dictionary of items but got type `str`']),
+ ('not a dict', ['Expected a dictionary of items but got type "str".']),
]
outputs = [
({'a': 1, 'b': [4, 5, 6]}, {'a': 1, 'b': [4, 5, 6]}),
diff --git a/tests/test_generics.py b/tests/test_generics.py
index 94023c30..fba8718f 100644
--- a/tests/test_generics.py
+++ b/tests/test_generics.py
@@ -117,7 +117,7 @@ class TestRootView(TestCase):
with self.assertNumQueries(0):
response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
- self.assertEqual(response.data, {"detail": "Method 'PUT' not allowed."})
+ self.assertEqual(response.data, {"detail": 'Method "PUT" not allowed.'})
def test_delete_root_view(self):
"""
@@ -127,7 +127,7 @@ class TestRootView(TestCase):
with self.assertNumQueries(0):
response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
- self.assertEqual(response.data, {"detail": "Method 'DELETE' not allowed."})
+ self.assertEqual(response.data, {"detail": 'Method "DELETE" not allowed.'})
def test_post_cannot_set_id(self):
"""
@@ -181,7 +181,7 @@ class TestInstanceView(TestCase):
with self.assertNumQueries(0):
response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
- self.assertEqual(response.data, {"detail": "Method 'POST' not allowed."})
+ self.assertEqual(response.data, {"detail": 'Method "POST" not allowed.'})
def test_put_instance_view(self):
"""
diff --git a/tests/test_metadata.py b/tests/test_metadata.py
index 5ff59c72..5031c0f3 100644
--- a/tests/test_metadata.py
+++ b/tests/test_metadata.py
@@ -1,9 +1,8 @@
from __future__ import unicode_literals
-
-from rest_framework import exceptions, serializers, views
+from rest_framework import exceptions, serializers, status, views, versioning
from rest_framework.request import Request
+from rest_framework.renderers import BrowsableAPIRenderer
from rest_framework.test import APIRequestFactory
-import pytest
request = Request(APIRequestFactory().options('/'))
@@ -17,7 +16,8 @@ class TestMetadata:
"""Example view."""
pass
- response = ExampleView().options(request=request)
+ view = ExampleView.as_view()
+ response = view(request=request)
expected = {
'name': 'Example',
'description': 'Example view.',
@@ -31,7 +31,7 @@ class TestMetadata:
'multipart/form-data'
]
}
- assert response.status_code == 200
+ assert response.status_code == status.HTTP_200_OK
assert response.data == expected
def test_none_metadata(self):
@@ -42,8 +42,10 @@ class TestMetadata:
class ExampleView(views.APIView):
metadata_class = None
- with pytest.raises(exceptions.MethodNotAllowed):
- ExampleView().options(request=request)
+ view = ExampleView.as_view()
+ response = view(request=request)
+ assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
+ assert response.data == {'detail': 'Method "OPTIONS" not allowed.'}
def test_actions(self):
"""
@@ -63,7 +65,8 @@ class TestMetadata:
def get_serializer(self):
return ExampleSerializer()
- response = ExampleView().options(request=request)
+ view = ExampleView.as_view()
+ response = view(request=request)
expected = {
'name': 'Example',
'description': 'Example view.',
@@ -104,7 +107,7 @@ class TestMetadata:
}
}
}
- assert response.status_code == 200
+ assert response.status_code == status.HTTP_200_OK
assert response.data == expected
def test_global_permissions(self):
@@ -132,8 +135,9 @@ class TestMetadata:
if request.method == 'POST':
raise exceptions.PermissionDenied()
- response = ExampleView().options(request=request)
- assert response.status_code == 200
+ view = ExampleView.as_view()
+ response = view(request=request)
+ assert response.status_code == status.HTTP_200_OK
assert list(response.data['actions'].keys()) == ['PUT']
def test_object_permissions(self):
@@ -161,6 +165,36 @@ class TestMetadata:
if self.request.method == 'PUT':
raise exceptions.PermissionDenied()
- response = ExampleView().options(request=request)
- assert response.status_code == 200
+ view = ExampleView.as_view()
+ response = view(request=request)
+ assert response.status_code == status.HTTP_200_OK
assert list(response.data['actions'].keys()) == ['POST']
+
+ def test_bug_2455_clone_request(self):
+ class ExampleView(views.APIView):
+ renderer_classes = (BrowsableAPIRenderer,)
+
+ def post(self, request):
+ pass
+
+ def get_serializer(self):
+ assert hasattr(self.request, 'version')
+ return serializers.Serializer()
+
+ view = ExampleView.as_view()
+ view(request=request)
+
+ def test_bug_2477_clone_request(self):
+ class ExampleView(views.APIView):
+ renderer_classes = (BrowsableAPIRenderer,)
+
+ def post(self, request):
+ pass
+
+ def get_serializer(self):
+ assert hasattr(self.request, 'versioning_scheme')
+ return serializers.Serializer()
+
+ scheme = versioning.QueryParameterVersioning
+ view = ExampleView.as_view(versioning_class=scheme)
+ view(request=request)
diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py
index 247b309a..bce2008a 100644
--- a/tests/test_model_serializer.py
+++ b/tests/test_model_serializer.py
@@ -216,7 +216,7 @@ class TestRegularFieldMappings(TestCase):
with self.assertRaises(ImproperlyConfigured) as excinfo:
TestSerializer().fields
- expected = 'Field name `invalid` is not valid for model `ModelBase`.'
+ expected = 'Field name `invalid` is not valid for model `RegularFieldsModel`.'
assert str(excinfo.exception) == expected
def test_missing_field(self):
@@ -234,8 +234,8 @@ class TestRegularFieldMappings(TestCase):
with self.assertRaises(AssertionError) as excinfo:
TestSerializer().fields
expected = (
- 'Field `missing` has been declared on serializer '
- '`TestSerializer`, but is missing from `Meta.fields`.'
+ "The field 'missing' was declared on serializer TestSerializer, "
+ "but has not been included in the 'fields' option."
)
assert str(excinfo.exception) == expected
@@ -637,5 +637,5 @@ class TestSerializerMetaClass(TestCase):
exception = result.exception
self.assertEqual(
str(exception),
- "Cannot set both 'fields' and 'exclude'."
+ "Cannot set both 'fields' and 'exclude' options on serializer ExampleSerializer."
)
diff --git a/tests/test_pagination.py b/tests/test_pagination.py
index 1fd9cf9c..13bfb627 100644
--- a/tests/test_pagination.py
+++ b/tests/test_pagination.py
@@ -1,553 +1,671 @@
+# coding: utf-8
from __future__ import unicode_literals
-import datetime
-from decimal import Decimal
-from django.core.paginator import Paginator
-from django.test import TestCase
-from django.utils import unittest
-from rest_framework import generics, serializers, status, pagination, filters
-from rest_framework.compat import django_filters
+from rest_framework import exceptions, generics, pagination, serializers, status, filters
+from rest_framework.request import Request
+from rest_framework.pagination import PageLink, PAGE_BREAK
from rest_framework.test import APIRequestFactory
-from .models import BasicModel, FilterableItem
+import pytest
factory = APIRequestFactory()
-# Helper function to split arguments out of an url
-def split_arguments_from_url(url):
- if '?' not in url:
- return url
+class TestPaginationIntegration:
+ """
+ Integration tests.
+ """
- path, args = url.split('?')
- args = dict(r.split('=') for r in args.split('&'))
- return path, args
+ def setup(self):
+ class PassThroughSerializer(serializers.BaseSerializer):
+ def to_representation(self, item):
+ return item
+ class EvenItemsOnly(filters.BaseFilterBackend):
+ def filter_queryset(self, request, queryset, view):
+ return [item for item in queryset if item % 2 == 0]
+
+ class BasicPagination(pagination.PageNumberPagination):
+ paginate_by = 5
+ paginate_by_param = 'page_size'
+ max_paginate_by = 20
+
+ self.view = generics.ListAPIView.as_view(
+ serializer_class=PassThroughSerializer,
+ queryset=range(1, 101),
+ filter_backends=[EvenItemsOnly],
+ pagination_class=BasicPagination
+ )
-class BasicSerializer(serializers.ModelSerializer):
- class Meta:
- model = BasicModel
+ def test_filtered_items_are_paginated(self):
+ request = factory.get('/', {'page': 2})
+ response = self.view(request)
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == {
+ 'results': [12, 14, 16, 18, 20],
+ 'previous': 'http://testserver/',
+ 'next': 'http://testserver/?page=3',
+ 'count': 50
+ }
+ def test_setting_page_size(self):
+ """
+ When 'paginate_by_param' is set, the client may choose a page size.
+ """
+ request = factory.get('/', {'page_size': 10})
+ response = self.view(request)
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == {
+ 'results': [2, 4, 6, 8, 10, 12, 14, 16, 18, 20],
+ 'previous': None,
+ 'next': 'http://testserver/?page=2&page_size=10',
+ 'count': 50
+ }
-class FilterableItemSerializer(serializers.ModelSerializer):
- class Meta:
- model = FilterableItem
+ def test_setting_page_size_over_maximum(self):
+ """
+ When page_size parameter exceeds maxiumum allowable,
+ then it should be capped to the maxiumum.
+ """
+ request = factory.get('/', {'page_size': 1000})
+ response = self.view(request)
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == {
+ 'results': [
+ 2, 4, 6, 8, 10, 12, 14, 16, 18, 20,
+ 22, 24, 26, 28, 30, 32, 34, 36, 38, 40
+ ],
+ 'previous': None,
+ 'next': 'http://testserver/?page=2&page_size=1000',
+ 'count': 50
+ }
+ def test_setting_page_size_to_zero(self):
+ """
+ When page_size parameter is invalid it should return to the default.
+ """
+ request = factory.get('/', {'page_size': 0})
+ response = self.view(request)
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == {
+ 'results': [2, 4, 6, 8, 10],
+ 'previous': None,
+ 'next': 'http://testserver/?page=2&page_size=0',
+ 'count': 50
+ }
-class RootView(generics.ListCreateAPIView):
- """
- Example description for OPTIONS.
- """
- queryset = BasicModel.objects.all()
- serializer_class = BasicSerializer
- paginate_by = 10
+ def test_additional_query_params_are_preserved(self):
+ request = factory.get('/', {'page': 2, 'filter': 'even'})
+ response = self.view(request)
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == {
+ 'results': [12, 14, 16, 18, 20],
+ 'previous': 'http://testserver/?filter=even',
+ 'next': 'http://testserver/?filter=even&page=3',
+ 'count': 50
+ }
+ def test_404_not_found_for_zero_page(self):
+ request = factory.get('/', {'page': '0'})
+ response = self.view(request)
+ assert response.status_code == status.HTTP_404_NOT_FOUND
+ assert response.data == {
+ 'detail': 'Invalid page "0": That page number is less than 1.'
+ }
-class DefaultPageSizeKwargView(generics.ListAPIView):
- """
- View for testing default paginate_by_param usage
- """
- queryset = BasicModel.objects.all()
- serializer_class = BasicSerializer
+ def test_404_not_found_for_invalid_page(self):
+ request = factory.get('/', {'page': 'invalid'})
+ response = self.view(request)
+ assert response.status_code == status.HTTP_404_NOT_FOUND
+ assert response.data == {
+ 'detail': 'Invalid page "invalid": That page number is not an integer.'
+ }
-class PaginateByParamView(generics.ListAPIView):
+class TestPaginationDisabledIntegration:
"""
- View for testing custom paginate_by_param usage
+ Integration tests for disabled pagination.
"""
- queryset = BasicModel.objects.all()
- serializer_class = BasicSerializer
- paginate_by_param = 'page_size'
+ def setup(self):
+ class PassThroughSerializer(serializers.BaseSerializer):
+ def to_representation(self, item):
+ return item
-class MaxPaginateByView(generics.ListAPIView):
- """
- View for testing custom max_paginate_by usage
- """
- queryset = BasicModel.objects.all()
- serializer_class = BasicSerializer
- paginate_by = 3
- max_paginate_by = 5
- paginate_by_param = 'page_size'
+ self.view = generics.ListAPIView.as_view(
+ serializer_class=PassThroughSerializer,
+ queryset=range(1, 101),
+ pagination_class=None
+ )
+
+ def test_unpaginated_list(self):
+ request = factory.get('/', {'page': 2})
+ response = self.view(request)
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == list(range(1, 101))
-class IntegrationTestPagination(TestCase):
+class TestDeprecatedStylePagination:
"""
- Integration tests for paginated list views.
+ Integration tests for deprecated style of setting pagination
+ attributes on the view.
"""
- def setUp(self):
- """
- Create 26 BasicModel instances.
- """
- for char in 'abcdefghijklmnopqrstuvwxyz':
- BasicModel(text=char * 3).save()
- self.objects = BasicModel.objects
- self.data = [
- {'id': obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
- self.view = RootView.as_view()
-
- def test_get_paginated_root_view(self):
- """
- GET requests to paginated ListCreateAPIView should return paginated results.
- """
- request = factory.get('/')
- # Note: Database queries are a `SELECT COUNT`, and `SELECT <fields>`
- with self.assertNumQueries(2):
- response = self.view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 26)
- self.assertEqual(response.data['results'], self.data[:10])
- self.assertNotEqual(response.data['next'], None)
- self.assertEqual(response.data['previous'], None)
-
- request = factory.get(*split_arguments_from_url(response.data['next']))
- with self.assertNumQueries(2):
- response = self.view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 26)
- self.assertEqual(response.data['results'], self.data[10:20])
- self.assertNotEqual(response.data['next'], None)
- self.assertNotEqual(response.data['previous'], None)
-
- request = factory.get(*split_arguments_from_url(response.data['next']))
- with self.assertNumQueries(2):
- response = self.view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 26)
- self.assertEqual(response.data['results'], self.data[20:])
- self.assertEqual(response.data['next'], None)
- self.assertNotEqual(response.data['previous'], None)
-
-
-class IntegrationTestPaginationAndFiltering(TestCase):
-
- def setUp(self):
- """
- Create 50 FilterableItem instances.
- """
- base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
- for i in range(26):
- text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
- decimal = base_data[1] + i
- date = base_data[2] - datetime.timedelta(days=i * 2)
- FilterableItem(text=text, decimal=decimal, date=date).save()
-
- self.objects = FilterableItem.objects
- self.data = [
- {'id': obj.id, 'text': obj.text, 'decimal': str(obj.decimal), 'date': obj.date.isoformat()}
- for obj in self.objects.all()
- ]
-
- @unittest.skipUnless(django_filters, 'django-filter not installed')
- def test_get_django_filter_paginated_filtered_root_view(self):
- """
- GET requests to paginated filtered ListCreateAPIView should return
- paginated results. The next and previous links should preserve the
- filtered parameters.
- """
- class DecimalFilter(django_filters.FilterSet):
- decimal = django_filters.NumberFilter(lookup_type='lt')
-
- class Meta:
- model = FilterableItem
- fields = ['text', 'decimal', 'date']
-
- class FilterFieldsRootView(generics.ListCreateAPIView):
- queryset = FilterableItem.objects.all()
- serializer_class = FilterableItemSerializer
- paginate_by = 10
- filter_class = DecimalFilter
- filter_backends = (filters.DjangoFilterBackend,)
-
- view = FilterFieldsRootView.as_view()
-
- EXPECTED_NUM_QUERIES = 2
-
- request = factory.get('/', {'decimal': '15.20'})
- with self.assertNumQueries(EXPECTED_NUM_QUERIES):
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 15)
- self.assertEqual(response.data['results'], self.data[:10])
- self.assertNotEqual(response.data['next'], None)
- self.assertEqual(response.data['previous'], None)
-
- request = factory.get(*split_arguments_from_url(response.data['next']))
- with self.assertNumQueries(EXPECTED_NUM_QUERIES):
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 15)
- self.assertEqual(response.data['results'], self.data[10:15])
- self.assertEqual(response.data['next'], None)
- self.assertNotEqual(response.data['previous'], None)
-
- request = factory.get(*split_arguments_from_url(response.data['previous']))
- with self.assertNumQueries(EXPECTED_NUM_QUERIES):
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 15)
- self.assertEqual(response.data['results'], self.data[:10])
- self.assertNotEqual(response.data['next'], None)
- self.assertEqual(response.data['previous'], None)
-
- def test_get_basic_paginated_filtered_root_view(self):
- """
- Same as `test_get_django_filter_paginated_filtered_root_view`,
- except using a custom filter backend instead of the django-filter
- backend,
- """
+ def setup(self):
+ class PassThroughSerializer(serializers.BaseSerializer):
+ def to_representation(self, item):
+ return item
- class DecimalFilterBackend(filters.BaseFilterBackend):
- def filter_queryset(self, request, queryset, view):
- return queryset.filter(decimal__lt=Decimal(request.GET['decimal']))
-
- class BasicFilterFieldsRootView(generics.ListCreateAPIView):
- queryset = FilterableItem.objects.all()
- serializer_class = FilterableItemSerializer
- paginate_by = 10
- filter_backends = (DecimalFilterBackend,)
-
- view = BasicFilterFieldsRootView.as_view()
-
- request = factory.get('/', {'decimal': '15.20'})
- with self.assertNumQueries(2):
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 15)
- self.assertEqual(response.data['results'], self.data[:10])
- self.assertNotEqual(response.data['next'], None)
- self.assertEqual(response.data['previous'], None)
-
- request = factory.get(*split_arguments_from_url(response.data['next']))
- with self.assertNumQueries(2):
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 15)
- self.assertEqual(response.data['results'], self.data[10:15])
- self.assertEqual(response.data['next'], None)
- self.assertNotEqual(response.data['previous'], None)
-
- request = factory.get(*split_arguments_from_url(response.data['previous']))
- with self.assertNumQueries(2):
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 15)
- self.assertEqual(response.data['results'], self.data[:10])
- self.assertNotEqual(response.data['next'], None)
- self.assertEqual(response.data['previous'], None)
-
-
-class PassOnContextPaginationSerializer(pagination.PaginationSerializer):
- class Meta:
- object_serializer_class = serializers.Serializer
-
-
-class UnitTestPagination(TestCase):
- """
- Unit tests for pagination of primitive objects.
- """
+ class ExampleView(generics.ListAPIView):
+ serializer_class = PassThroughSerializer
+ queryset = range(1, 101)
+ pagination_class = pagination.PageNumberPagination
+ paginate_by = 20
+ page_query_param = 'page_number'
- def setUp(self):
- self.objects = [char * 3 for char in 'abcdefghijklmnopqrstuvwxyz']
- paginator = Paginator(self.objects, 10)
- self.first_page = paginator.page(1)
- self.last_page = paginator.page(3)
-
- def test_native_pagination(self):
- serializer = pagination.PaginationSerializer(self.first_page)
- self.assertEqual(serializer.data['count'], 26)
- self.assertEqual(serializer.data['next'], '?page=2')
- self.assertEqual(serializer.data['previous'], None)
- self.assertEqual(serializer.data['results'], self.objects[:10])
-
- serializer = pagination.PaginationSerializer(self.last_page)
- self.assertEqual(serializer.data['count'], 26)
- self.assertEqual(serializer.data['next'], None)
- self.assertEqual(serializer.data['previous'], '?page=2')
- self.assertEqual(serializer.data['results'], self.objects[20:])
-
- def test_context_available_in_result(self):
- """
- Ensure context gets passed through to the object serializer.
- """
- serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'})
- serializer.data
- results = serializer.fields[serializer.results_field]
- self.assertEqual(serializer.context, results.context)
+ self.view = ExampleView.as_view()
+
+ def test_paginate_by_attribute_on_view(self):
+ request = factory.get('/?page_number=2')
+ response = self.view(request)
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == {
+ 'results': [
+ 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
+ 31, 32, 33, 34, 35, 36, 37, 38, 39, 40
+ ],
+ 'previous': 'http://testserver/',
+ 'next': 'http://testserver/?page_number=3',
+ 'count': 100
+ }
-class TestUnpaginated(TestCase):
+class TestPageNumberPagination:
"""
- Tests for list views without pagination.
+ Unit tests for `pagination.PageNumberPagination`.
"""
- def setUp(self):
- """
- Create 13 BasicModel instances.
- """
- for i in range(13):
- BasicModel(text=i).save()
- self.objects = BasicModel.objects
- self.data = [
- {'id': obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
- self.view = DefaultPageSizeKwargView.as_view()
-
- def test_unpaginated(self):
- """
- Tests the default page size for this view.
- no page size --> no limit --> no meta data
- """
- request = factory.get('/')
- response = self.view(request)
- self.assertEqual(response.data, self.data)
+ def setup(self):
+ class ExamplePagination(pagination.PageNumberPagination):
+ paginate_by = 5
+ self.pagination = ExamplePagination()
+ self.queryset = range(1, 101)
+
+ def paginate_queryset(self, request):
+ return list(self.pagination.paginate_queryset(self.queryset, request))
+
+ def get_paginated_content(self, queryset):
+ response = self.pagination.get_paginated_response(queryset)
+ return response.data
+
+ def get_html_context(self):
+ return self.pagination.get_html_context()
+
+ def test_no_page_number(self):
+ request = Request(factory.get('/'))
+ queryset = self.paginate_queryset(request)
+ content = self.get_paginated_content(queryset)
+ context = self.get_html_context()
+ assert queryset == [1, 2, 3, 4, 5]
+ assert content == {
+ 'results': [1, 2, 3, 4, 5],
+ 'previous': None,
+ 'next': 'http://testserver/?page=2',
+ 'count': 100
+ }
+ assert context == {
+ 'previous_url': None,
+ 'next_url': 'http://testserver/?page=2',
+ 'page_links': [
+ PageLink('http://testserver/', 1, True, False),
+ PageLink('http://testserver/?page=2', 2, False, False),
+ PageLink('http://testserver/?page=3', 3, False, False),
+ PAGE_BREAK,
+ PageLink('http://testserver/?page=20', 20, False, False),
+ ]
+ }
+ assert self.pagination.display_page_controls
+ assert isinstance(self.pagination.to_html(), type(''))
+
+ def test_second_page(self):
+ request = Request(factory.get('/', {'page': 2}))
+ queryset = self.paginate_queryset(request)
+ content = self.get_paginated_content(queryset)
+ context = self.get_html_context()
+ assert queryset == [6, 7, 8, 9, 10]
+ assert content == {
+ 'results': [6, 7, 8, 9, 10],
+ 'previous': 'http://testserver/',
+ 'next': 'http://testserver/?page=3',
+ 'count': 100
+ }
+ assert context == {
+ 'previous_url': 'http://testserver/',
+ 'next_url': 'http://testserver/?page=3',
+ 'page_links': [
+ PageLink('http://testserver/', 1, False, False),
+ PageLink('http://testserver/?page=2', 2, True, False),
+ PageLink('http://testserver/?page=3', 3, False, False),
+ PAGE_BREAK,
+ PageLink('http://testserver/?page=20', 20, False, False),
+ ]
+ }
+
+ def test_last_page(self):
+ request = Request(factory.get('/', {'page': 'last'}))
+ queryset = self.paginate_queryset(request)
+ content = self.get_paginated_content(queryset)
+ context = self.get_html_context()
+ assert queryset == [96, 97, 98, 99, 100]
+ assert content == {
+ 'results': [96, 97, 98, 99, 100],
+ 'previous': 'http://testserver/?page=19',
+ 'next': None,
+ 'count': 100
+ }
+ assert context == {
+ 'previous_url': 'http://testserver/?page=19',
+ 'next_url': None,
+ 'page_links': [
+ PageLink('http://testserver/', 1, False, False),
+ PAGE_BREAK,
+ PageLink('http://testserver/?page=18', 18, False, False),
+ PageLink('http://testserver/?page=19', 19, False, False),
+ PageLink('http://testserver/?page=20', 20, True, False),
+ ]
+ }
+
+ def test_invalid_page(self):
+ request = Request(factory.get('/', {'page': 'invalid'}))
+ with pytest.raises(exceptions.NotFound):
+ self.paginate_queryset(request)
-class TestCustomPaginateByParam(TestCase):
+class TestLimitOffset:
"""
- Tests for list views with default page size kwarg
+ Unit tests for `pagination.LimitOffsetPagination`.
"""
- def setUp(self):
+ def setup(self):
+ class ExamplePagination(pagination.LimitOffsetPagination):
+ default_limit = 10
+ self.pagination = ExamplePagination()
+ self.queryset = range(1, 101)
+
+ def paginate_queryset(self, request):
+ return list(self.pagination.paginate_queryset(self.queryset, request))
+
+ def get_paginated_content(self, queryset):
+ response = self.pagination.get_paginated_response(queryset)
+ return response.data
+
+ def get_html_context(self):
+ return self.pagination.get_html_context()
+
+ def test_no_offset(self):
+ request = Request(factory.get('/', {'limit': 5}))
+ queryset = self.paginate_queryset(request)
+ content = self.get_paginated_content(queryset)
+ context = self.get_html_context()
+ assert queryset == [1, 2, 3, 4, 5]
+ assert content == {
+ 'results': [1, 2, 3, 4, 5],
+ 'previous': None,
+ 'next': 'http://testserver/?limit=5&offset=5',
+ 'count': 100
+ }
+ assert context == {
+ 'previous_url': None,
+ 'next_url': 'http://testserver/?limit=5&offset=5',
+ 'page_links': [
+ PageLink('http://testserver/?limit=5', 1, True, False),
+ PageLink('http://testserver/?limit=5&offset=5', 2, False, False),
+ PageLink('http://testserver/?limit=5&offset=10', 3, False, False),
+ PAGE_BREAK,
+ PageLink('http://testserver/?limit=5&offset=95', 20, False, False),
+ ]
+ }
+ assert self.pagination.display_page_controls
+ assert isinstance(self.pagination.to_html(), type(''))
+
+ def test_single_offset(self):
"""
- Create 13 BasicModel instances.
+ When the offset is not a multiple of the limit we get some edge cases:
+ * The first page should still be offset zero.
+ * We may end up displaying an extra page in the pagination control.
"""
- for i in range(13):
- BasicModel(text=i).save()
- self.objects = BasicModel.objects
- self.data = [
- {'id': obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
- self.view = PaginateByParamView.as_view()
-
- def test_default_page_size(self):
+ request = Request(factory.get('/', {'limit': 5, 'offset': 1}))
+ queryset = self.paginate_queryset(request)
+ content = self.get_paginated_content(queryset)
+ context = self.get_html_context()
+ assert queryset == [2, 3, 4, 5, 6]
+ assert content == {
+ 'results': [2, 3, 4, 5, 6],
+ 'previous': 'http://testserver/?limit=5',
+ 'next': 'http://testserver/?limit=5&offset=6',
+ 'count': 100
+ }
+ assert context == {
+ 'previous_url': 'http://testserver/?limit=5',
+ 'next_url': 'http://testserver/?limit=5&offset=6',
+ 'page_links': [
+ PageLink('http://testserver/?limit=5', 1, False, False),
+ PageLink('http://testserver/?limit=5&offset=1', 2, True, False),
+ PageLink('http://testserver/?limit=5&offset=6', 3, False, False),
+ PAGE_BREAK,
+ PageLink('http://testserver/?limit=5&offset=96', 21, False, False),
+ ]
+ }
+
+ def test_first_offset(self):
+ request = Request(factory.get('/', {'limit': 5, 'offset': 5}))
+ queryset = self.paginate_queryset(request)
+ content = self.get_paginated_content(queryset)
+ context = self.get_html_context()
+ assert queryset == [6, 7, 8, 9, 10]
+ assert content == {
+ 'results': [6, 7, 8, 9, 10],
+ 'previous': 'http://testserver/?limit=5',
+ 'next': 'http://testserver/?limit=5&offset=10',
+ 'count': 100
+ }
+ assert context == {
+ 'previous_url': 'http://testserver/?limit=5',
+ 'next_url': 'http://testserver/?limit=5&offset=10',
+ 'page_links': [
+ PageLink('http://testserver/?limit=5', 1, False, False),
+ PageLink('http://testserver/?limit=5&offset=5', 2, True, False),
+ PageLink('http://testserver/?limit=5&offset=10', 3, False, False),
+ PAGE_BREAK,
+ PageLink('http://testserver/?limit=5&offset=95', 20, False, False),
+ ]
+ }
+
+ def test_middle_offset(self):
+ request = Request(factory.get('/', {'limit': 5, 'offset': 10}))
+ queryset = self.paginate_queryset(request)
+ content = self.get_paginated_content(queryset)
+ context = self.get_html_context()
+ assert queryset == [11, 12, 13, 14, 15]
+ assert content == {
+ 'results': [11, 12, 13, 14, 15],
+ 'previous': 'http://testserver/?limit=5&offset=5',
+ 'next': 'http://testserver/?limit=5&offset=15',
+ 'count': 100
+ }
+ assert context == {
+ 'previous_url': 'http://testserver/?limit=5&offset=5',
+ 'next_url': 'http://testserver/?limit=5&offset=15',
+ 'page_links': [
+ PageLink('http://testserver/?limit=5', 1, False, False),
+ PageLink('http://testserver/?limit=5&offset=5', 2, False, False),
+ PageLink('http://testserver/?limit=5&offset=10', 3, True, False),
+ PageLink('http://testserver/?limit=5&offset=15', 4, False, False),
+ PAGE_BREAK,
+ PageLink('http://testserver/?limit=5&offset=95', 20, False, False),
+ ]
+ }
+
+ def test_ending_offset(self):
+ request = Request(factory.get('/', {'limit': 5, 'offset': 95}))
+ queryset = self.paginate_queryset(request)
+ content = self.get_paginated_content(queryset)
+ context = self.get_html_context()
+ assert queryset == [96, 97, 98, 99, 100]
+ assert content == {
+ 'results': [96, 97, 98, 99, 100],
+ 'previous': 'http://testserver/?limit=5&offset=90',
+ 'next': None,
+ 'count': 100
+ }
+ assert context == {
+ 'previous_url': 'http://testserver/?limit=5&offset=90',
+ 'next_url': None,
+ 'page_links': [
+ PageLink('http://testserver/?limit=5', 1, False, False),
+ PAGE_BREAK,
+ PageLink('http://testserver/?limit=5&offset=85', 18, False, False),
+ PageLink('http://testserver/?limit=5&offset=90', 19, False, False),
+ PageLink('http://testserver/?limit=5&offset=95', 20, True, False),
+ ]
+ }
+
+ def test_invalid_offset(self):
"""
- Tests the default page size for this view.
- no page size --> no limit --> no meta data
+ An invalid offset query param should be treated as 0.
"""
- request = factory.get('/')
- response = self.view(request).render()
- self.assertEqual(response.data, self.data)
+ request = Request(factory.get('/', {'limit': 5, 'offset': 'invalid'}))
+ queryset = self.paginate_queryset(request)
+ assert queryset == [1, 2, 3, 4, 5]
- def test_paginate_by_param(self):
+ def test_invalid_limit(self):
"""
- If paginate_by_param is set, the new kwarg should limit per view requests.
+ An invalid limit query param should be ignored in favor of the default.
"""
- request = factory.get('/', {'page_size': 5})
- response = self.view(request).render()
- self.assertEqual(response.data['count'], 13)
- self.assertEqual(response.data['results'], self.data[:5])
+ request = Request(factory.get('/', {'limit': 'invalid', 'offset': 0}))
+ queryset = self.paginate_queryset(request)
+ assert queryset == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
-class TestMaxPaginateByParam(TestCase):
+class TestCursorPagination:
"""
- Tests for list views with max_paginate_by kwarg
+ Unit tests for `pagination.CursorPagination`.
"""
- def setUp(self):
+ def setup(self):
+ class MockObject(object):
+ def __init__(self, idx):
+ self.created = idx
+
+ class MockQuerySet(object):
+ def __init__(self, items):
+ self.items = items
+
+ def filter(self, created__gt=None, created__lt=None):
+ if created__gt is not None:
+ return MockQuerySet([
+ item for item in self.items
+ if item.created > int(created__gt)
+ ])
+
+ assert created__lt is not None
+ return MockQuerySet([
+ item for item in self.items
+ if item.created < int(created__lt)
+ ])
+
+ def order_by(self, *ordering):
+ if ordering[0].startswith('-'):
+ return MockQuerySet(list(reversed(self.items)))
+ return self
+
+ def __getitem__(self, sliced):
+ return self.items[sliced]
+
+ class ExamplePagination(pagination.CursorPagination):
+ page_size = 5
+ ordering = 'created'
+
+ self.pagination = ExamplePagination()
+ self.queryset = MockQuerySet([
+ MockObject(idx) for idx in [
+ 1, 1, 1, 1, 1,
+ 1, 2, 3, 4, 4,
+ 4, 4, 5, 6, 7,
+ 7, 7, 7, 7, 7,
+ 7, 7, 7, 8, 9,
+ 9, 9, 9, 9, 9
+ ]
+ ])
+
+ def get_pages(self, url):
"""
- Create 13 BasicModel instances.
- """
- for i in range(13):
- BasicModel(text=i).save()
- self.objects = BasicModel.objects
- self.data = [
- {'id': obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
- self.view = MaxPaginateByView.as_view()
-
- def test_max_paginate_by(self):
- """
- If max_paginate_by is set, it should limit page size for the view.
- """
- request = factory.get('/', data={'page_size': 10})
- response = self.view(request).render()
- self.assertEqual(response.data['count'], 13)
- self.assertEqual(response.data['results'], self.data[:5])
+ Given a URL return a tuple of:
- def test_max_paginate_by_without_page_size_param(self):
+ (previous page, current page, next page, previous url, next url)
"""
- If max_paginate_by is set, but client does not specifiy page_size,
- standard `paginate_by` behavior should be used.
- """
- request = factory.get('/')
- response = self.view(request).render()
- self.assertEqual(response.data['results'], self.data[:3])
-
-
-# Tests for context in pagination serializers
+ request = Request(factory.get(url))
+ queryset = self.pagination.paginate_queryset(self.queryset, request)
+ current = [item.created for item in queryset]
-class CustomField(serializers.ReadOnlyField):
- def to_native(self, value):
- if 'view' not in self.context:
- raise RuntimeError("context isn't getting passed into custom field")
- return "value"
+ next_url = self.pagination.get_next_link()
+ previous_url = self.pagination.get_previous_link()
+ if next_url is not None:
+ request = Request(factory.get(next_url))
+ queryset = self.pagination.paginate_queryset(self.queryset, request)
+ next = [item.created for item in queryset]
+ else:
+ next = None
-class BasicModelSerializer(serializers.Serializer):
- text = CustomField()
-
- def to_native(self, value):
- if 'view' not in self.context:
- raise RuntimeError("context isn't getting passed into serializer")
- return super(BasicSerializer, self).to_native(value)
+ if previous_url is not None:
+ request = Request(factory.get(previous_url))
+ queryset = self.pagination.paginate_queryset(self.queryset, request)
+ previous = [item.created for item in queryset]
+ else:
+ previous = None
+ return (previous, current, next, previous_url, next_url)
-class TestContextPassedToCustomField(TestCase):
- def setUp(self):
- BasicModel.objects.create(text='ala ma kota')
+ def test_invalid_cursor(self):
+ request = Request(factory.get('/', {'cursor': '123'}))
+ with pytest.raises(exceptions.NotFound):
+ self.pagination.paginate_queryset(self.queryset, request)
- def test_with_pagination(self):
- class ListView(generics.ListCreateAPIView):
- queryset = BasicModel.objects.all()
- serializer_class = BasicModelSerializer
- paginate_by = 1
+ def test_use_with_ordering_filter(self):
+ class MockView:
+ filter_backends = (filters.OrderingFilter,)
+ ordering_fields = ['username', 'created']
+ ordering = 'created'
- self.view = ListView.as_view()
- request = factory.get('/')
- response = self.view(request).render()
+ request = Request(factory.get('/', {'ordering': 'username'}))
+ ordering = self.pagination.get_ordering(request, [], MockView())
+ assert ordering == ('username',)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
+ request = Request(factory.get('/', {'ordering': '-username'}))
+ ordering = self.pagination.get_ordering(request, [], MockView())
+ assert ordering == ('-username',)
+ request = Request(factory.get('/', {'ordering': 'invalid'}))
+ ordering = self.pagination.get_ordering(request, [], MockView())
+ assert ordering == ('created',)
-# Tests for custom pagination serializers
+ def test_cursor_pagination(self):
+ (previous, current, next, previous_url, next_url) = self.get_pages('/')
-class LinksSerializer(serializers.Serializer):
- next = pagination.NextPageField(source='*')
- prev = pagination.PreviousPageField(source='*')
+ assert previous is None
+ assert current == [1, 1, 1, 1, 1]
+ assert next == [1, 2, 3, 4, 4]
+ (previous, current, next, previous_url, next_url) = self.get_pages(next_url)
-class CustomPaginationSerializer(pagination.BasePaginationSerializer):
- links = LinksSerializer(source='*') # Takes the page object as the source
- total_results = serializers.ReadOnlyField(source='paginator.count')
+ assert previous == [1, 1, 1, 1, 1]
+ assert current == [1, 2, 3, 4, 4]
+ assert next == [4, 4, 5, 6, 7]
- results_field = 'objects'
+ (previous, current, next, previous_url, next_url) = self.get_pages(next_url)
+ assert previous == [1, 2, 3, 4, 4]
+ assert current == [4, 4, 5, 6, 7]
+ assert next == [7, 7, 7, 7, 7]
-class CustomFooSerializer(serializers.Serializer):
- foo = serializers.CharField()
+ (previous, current, next, previous_url, next_url) = self.get_pages(next_url)
+ assert previous == [4, 4, 4, 5, 6] # Paging artifact
+ assert current == [7, 7, 7, 7, 7]
+ assert next == [7, 7, 7, 8, 9]
-class CustomFooPaginationSerializer(pagination.PaginationSerializer):
- class Meta:
- object_serializer_class = CustomFooSerializer
+ (previous, current, next, previous_url, next_url) = self.get_pages(next_url)
+ assert previous == [7, 7, 7, 7, 7]
+ assert current == [7, 7, 7, 8, 9]
+ assert next == [9, 9, 9, 9, 9]
-class TestCustomPaginationSerializer(TestCase):
- def setUp(self):
- objects = ['john', 'paul', 'george', 'ringo']
- paginator = Paginator(objects, 2)
- self.page = paginator.page(1)
+ (previous, current, next, previous_url, next_url) = self.get_pages(next_url)
- def test_custom_pagination_serializer(self):
- request = APIRequestFactory().get('/foobar')
- serializer = CustomPaginationSerializer(
- instance=self.page,
- context={'request': request}
- )
- expected = {
- 'links': {
- 'next': 'http://testserver/foobar?page=2',
- 'prev': None
- },
- 'total_results': 4,
- 'objects': ['john', 'paul']
- }
- self.assertEqual(serializer.data, expected)
+ assert previous == [7, 7, 7, 8, 9]
+ assert current == [9, 9, 9, 9, 9]
+ assert next is None
- def test_custom_pagination_serializer_with_custom_object_serializer(self):
- objects = [
- {'foo': 'bar'},
- {'foo': 'spam'}
- ]
- paginator = Paginator(objects, 1)
- page = paginator.page(1)
- serializer = CustomFooPaginationSerializer(page)
- serializer.data
+ (previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
+ assert previous == [7, 7, 7, 7, 7]
+ assert current == [7, 7, 7, 8, 9]
+ assert next == [9, 9, 9, 9, 9]
-class NonIntegerPage(object):
+ (previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
- def __init__(self, paginator, object_list, prev_token, token, next_token):
- self.paginator = paginator
- self.object_list = object_list
- self.prev_token = prev_token
- self.token = token
- self.next_token = next_token
+ assert previous == [4, 4, 5, 6, 7]
+ assert current == [7, 7, 7, 7, 7]
+ assert next == [8, 9, 9, 9, 9] # Paging artifact
- def has_next(self):
- return not not self.next_token
+ (previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
- def next_page_number(self):
- return self.next_token
+ assert previous == [1, 2, 3, 4, 4]
+ assert current == [4, 4, 5, 6, 7]
+ assert next == [7, 7, 7, 7, 7]
- def has_previous(self):
- return not not self.prev_token
+ (previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
- def previous_page_number(self):
- return self.prev_token
+ assert previous == [1, 1, 1, 1, 1]
+ assert current == [1, 2, 3, 4, 4]
+ assert next == [4, 4, 5, 6, 7]
+ (previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
-class NonIntegerPaginator(object):
+ assert previous is None
+ assert current == [1, 1, 1, 1, 1]
+ assert next == [1, 2, 3, 4, 4]
- def __init__(self, object_list, per_page):
- self.object_list = object_list
- self.per_page = per_page
+ assert isinstance(self.pagination.to_html(), type(''))
- def count(self):
- # pretend like we don't know how many pages we have
- return None
- def page(self, token=None):
- if token:
- try:
- first = self.object_list.index(token)
- except ValueError:
- first = 0
- else:
- first = 0
- n = len(self.object_list)
- last = min(first + self.per_page, n)
- prev_token = self.object_list[last - (2 * self.per_page)] if first else None
- next_token = self.object_list[last] if last < n else None
- return NonIntegerPage(self, self.object_list[first:last], prev_token, token, next_token)
-
-
-class TestNonIntegerPagination(TestCase):
- def test_custom_pagination_serializer(self):
- objects = ['john', 'paul', 'george', 'ringo']
- paginator = NonIntegerPaginator(objects, 2)
-
- request = APIRequestFactory().get('/foobar')
- serializer = CustomPaginationSerializer(
- instance=paginator.page(),
- context={'request': request}
- )
- expected = {
- 'links': {
- 'next': 'http://testserver/foobar?page={0}'.format(objects[2]),
- 'prev': None
- },
- 'total_results': None,
- 'objects': objects[:2]
- }
- self.assertEqual(serializer.data, expected)
+def test_get_displayed_page_numbers():
+ """
+ Test our contextual page display function.
- request = APIRequestFactory().get('/foobar')
- serializer = CustomPaginationSerializer(
- instance=paginator.page('george'),
- context={'request': request}
- )
- expected = {
- 'links': {
- 'next': None,
- 'prev': 'http://testserver/foobar?page={0}'.format(objects[0]),
- },
- 'total_results': None,
- 'objects': objects[2:]
- }
- self.assertEqual(serializer.data, expected)
+ This determines which pages to display in a pagination control,
+ given the current page and the last page.
+ """
+ displayed_page_numbers = pagination._get_displayed_page_numbers
+
+ # At five pages or less, all pages are displayed, always.
+ assert displayed_page_numbers(1, 5) == [1, 2, 3, 4, 5]
+ assert displayed_page_numbers(2, 5) == [1, 2, 3, 4, 5]
+ assert displayed_page_numbers(3, 5) == [1, 2, 3, 4, 5]
+ assert displayed_page_numbers(4, 5) == [1, 2, 3, 4, 5]
+ assert displayed_page_numbers(5, 5) == [1, 2, 3, 4, 5]
+
+ # Between six and either pages we may have a single page break.
+ assert displayed_page_numbers(1, 6) == [1, 2, 3, None, 6]
+ assert displayed_page_numbers(2, 6) == [1, 2, 3, None, 6]
+ assert displayed_page_numbers(3, 6) == [1, 2, 3, 4, 5, 6]
+ assert displayed_page_numbers(4, 6) == [1, 2, 3, 4, 5, 6]
+ assert displayed_page_numbers(5, 6) == [1, None, 4, 5, 6]
+ assert displayed_page_numbers(6, 6) == [1, None, 4, 5, 6]
+
+ assert displayed_page_numbers(1, 7) == [1, 2, 3, None, 7]
+ assert displayed_page_numbers(2, 7) == [1, 2, 3, None, 7]
+ assert displayed_page_numbers(3, 7) == [1, 2, 3, 4, None, 7]
+ assert displayed_page_numbers(4, 7) == [1, 2, 3, 4, 5, 6, 7]
+ assert displayed_page_numbers(5, 7) == [1, None, 4, 5, 6, 7]
+ assert displayed_page_numbers(6, 7) == [1, None, 5, 6, 7]
+ assert displayed_page_numbers(7, 7) == [1, None, 5, 6, 7]
+
+ assert displayed_page_numbers(1, 8) == [1, 2, 3, None, 8]
+ assert displayed_page_numbers(2, 8) == [1, 2, 3, None, 8]
+ assert displayed_page_numbers(3, 8) == [1, 2, 3, 4, None, 8]
+ assert displayed_page_numbers(4, 8) == [1, 2, 3, 4, 5, None, 8]
+ assert displayed_page_numbers(5, 8) == [1, None, 4, 5, 6, 7, 8]
+ assert displayed_page_numbers(6, 8) == [1, None, 5, 6, 7, 8]
+ assert displayed_page_numbers(7, 8) == [1, None, 6, 7, 8]
+ assert displayed_page_numbers(8, 8) == [1, None, 6, 7, 8]
+
+ # At nine or more pages we may have two page breaks, one on each side.
+ assert displayed_page_numbers(1, 9) == [1, 2, 3, None, 9]
+ assert displayed_page_numbers(2, 9) == [1, 2, 3, None, 9]
+ assert displayed_page_numbers(3, 9) == [1, 2, 3, 4, None, 9]
+ assert displayed_page_numbers(4, 9) == [1, 2, 3, 4, 5, None, 9]
+ assert displayed_page_numbers(5, 9) == [1, None, 4, 5, 6, None, 9]
+ assert displayed_page_numbers(6, 9) == [1, None, 5, 6, 7, 8, 9]
+ assert displayed_page_numbers(7, 9) == [1, None, 6, 7, 8, 9]
+ assert displayed_page_numbers(8, 9) == [1, None, 7, 8, 9]
+ assert displayed_page_numbers(9, 9) == [1, None, 7, 8, 9]
diff --git a/tests/test_parsers.py b/tests/test_parsers.py
index 1d2054ac..8816065a 100644
--- a/tests/test_parsers.py
+++ b/tests/test_parsers.py
@@ -4,13 +4,9 @@ from __future__ import unicode_literals
from django import forms
from django.core.files.uploadhandler import MemoryFileUploadHandler
from django.test import TestCase
-from django.utils import unittest
from django.utils.six.moves import StringIO
-from rest_framework.compat import etree
from rest_framework.exceptions import ParseError
from rest_framework.parsers import FormParser, FileUploadParser
-from rest_framework.parsers import XMLParser
-import datetime
class Form(forms.Form):
@@ -32,62 +28,6 @@ class TestFormParser(TestCase):
self.assertEqual(Form(data).is_valid(), True)
-class TestXMLParser(TestCase):
- def setUp(self):
- self._input = StringIO(
- '<?xml version="1.0" encoding="utf-8"?>'
- '<root>'
- '<field_a>121.0</field_a>'
- '<field_b>dasd</field_b>'
- '<field_c></field_c>'
- '<field_d>2011-12-25 12:45:00</field_d>'
- '</root>'
- )
- self._data = {
- 'field_a': 121,
- 'field_b': 'dasd',
- 'field_c': None,
- 'field_d': datetime.datetime(2011, 12, 25, 12, 45, 00)
- }
- self._complex_data_input = StringIO(
- '<?xml version="1.0" encoding="utf-8"?>'
- '<root>'
- '<creation_date>2011-12-25 12:45:00</creation_date>'
- '<sub_data_list>'
- '<list-item><sub_id>1</sub_id><sub_name>first</sub_name></list-item>'
- '<list-item><sub_id>2</sub_id><sub_name>second</sub_name></list-item>'
- '</sub_data_list>'
- '<name>name</name>'
- '</root>'
- )
- self._complex_data = {
- "creation_date": datetime.datetime(2011, 12, 25, 12, 45, 00),
- "name": "name",
- "sub_data_list": [
- {
- "sub_id": 1,
- "sub_name": "first"
- },
- {
- "sub_id": 2,
- "sub_name": "second"
- }
- ]
- }
-
- @unittest.skipUnless(etree, 'defusedxml not installed')
- def test_parse(self):
- parser = XMLParser()
- data = parser.parse(self._input)
- self.assertEqual(data, self._data)
-
- @unittest.skipUnless(etree, 'defusedxml not installed')
- def test_complex_data_parse(self):
- parser = XMLParser()
- data = parser.parse(self._complex_data_input)
- self.assertEqual(data, self._complex_data)
-
-
class TestFileUploadParser(TestCase):
def setUp(self):
class MockRequest(object):
diff --git a/tests/test_relations.py b/tests/test_relations.py
index d478d855..fbe176e2 100644
--- a/tests/test_relations.py
+++ b/tests/test_relations.py
@@ -35,7 +35,7 @@ class TestPrimaryKeyRelatedField(APISimpleTestCase):
with pytest.raises(serializers.ValidationError) as excinfo:
self.field.to_internal_value(4)
msg = excinfo.value.detail[0]
- assert msg == "Invalid pk '4' - object does not exist."
+ assert msg == 'Invalid pk "4" - object does not exist.'
def test_pk_related_lookup_invalid_type(self):
with pytest.raises(serializers.ValidationError) as excinfo:
diff --git a/tests/test_renderers.py b/tests/test_renderers.py
index 54eea8ce..f68405f0 100644
--- a/tests/test_renderers.py
+++ b/tests/test_renderers.py
@@ -1,26 +1,19 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
-
-from decimal import Decimal
from django.conf.urls import patterns, url, include
from django.core.cache import cache
from django.db import models
from django.test import TestCase
-from django.utils import six, unittest
-from django.utils.six import BytesIO
-from django.utils.six.moves import StringIO
+from django.utils import six
from django.utils.translation import ugettext_lazy as _
from rest_framework import status, permissions
-from rest_framework.compat import yaml, etree
+from rest_framework.compat import OrderedDict
from rest_framework.response import Response
from rest_framework.views import APIView
-from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \
- XMLRenderer, JSONPRenderer, BrowsableAPIRenderer
-from rest_framework.parsers import YAMLParser, XMLParser
+from rest_framework.renderers import BaseRenderer, JSONRenderer, BrowsableAPIRenderer
from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory
from collections import MutableMapping
-import datetime
import json
import re
@@ -107,8 +100,6 @@ urlpatterns = patterns(
url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
url(r'^cache$', MockGETView.as_view()),
- url(r'^jsonp/jsonrenderer$', MockGETView.as_view(renderer_classes=[JSONRenderer, JSONPRenderer])),
- url(r'^jsonp/nojsonrenderer$', MockGETView.as_view(renderer_classes=[JSONPRenderer])),
url(r'^parseerror$', MockPOSTView.as_view(renderer_classes=[JSONRenderer, BrowsableAPIRenderer])),
url(r'^html$', HTMLView.as_view()),
url(r'^html1$', HTMLView1.as_view()),
@@ -408,207 +399,6 @@ class AsciiJSONRendererTests(TestCase):
self.assertEqual(content, '{"countries":["United Kingdom","France","Espa\\u00f1a"]}'.encode('utf-8'))
-class JSONPRendererTests(TestCase):
- """
- Tests specific to the JSONP Renderer
- """
-
- urls = 'tests.test_renderers'
-
- def test_without_callback_with_json_renderer(self):
- """
- Test JSONP rendering with View JSON Renderer.
- """
- resp = self.client.get(
- '/jsonp/jsonrenderer',
- HTTP_ACCEPT='application/javascript'
- )
- self.assertEqual(resp.status_code, status.HTTP_200_OK)
- self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8')
- self.assertEqual(
- resp.content,
- ('callback(%s);' % _flat_repr).encode('ascii')
- )
-
- def test_without_callback_without_json_renderer(self):
- """
- Test JSONP rendering without View JSON Renderer.
- """
- resp = self.client.get(
- '/jsonp/nojsonrenderer',
- HTTP_ACCEPT='application/javascript'
- )
- self.assertEqual(resp.status_code, status.HTTP_200_OK)
- self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8')
- self.assertEqual(
- resp.content,
- ('callback(%s);' % _flat_repr).encode('ascii')
- )
-
- def test_with_callback(self):
- """
- Test JSONP rendering with callback function name.
- """
- callback_func = 'myjsonpcallback'
- resp = self.client.get(
- '/jsonp/nojsonrenderer?callback=' + callback_func,
- HTTP_ACCEPT='application/javascript'
- )
- self.assertEqual(resp.status_code, status.HTTP_200_OK)
- self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8')
- self.assertEqual(
- resp.content,
- ('%s(%s);' % (callback_func, _flat_repr)).encode('ascii')
- )
-
-
-if yaml:
- _yaml_repr = 'foo: [bar, baz]\n'
-
- class YAMLRendererTests(TestCase):
- """
- Tests specific to the YAML Renderer
- """
-
- def test_render(self):
- """
- Test basic YAML rendering.
- """
- obj = {'foo': ['bar', 'baz']}
- renderer = YAMLRenderer()
- content = renderer.render(obj, 'application/yaml')
- self.assertEqual(content.decode('utf-8'), _yaml_repr)
-
- def test_render_and_parse(self):
- """
- Test rendering and then parsing returns the original object.
- IE obj -> render -> parse -> obj.
- """
- obj = {'foo': ['bar', 'baz']}
-
- renderer = YAMLRenderer()
- parser = YAMLParser()
-
- content = renderer.render(obj, 'application/yaml')
- data = parser.parse(BytesIO(content))
- self.assertEqual(obj, data)
-
- def test_render_decimal(self):
- """
- Test YAML decimal rendering.
- """
- renderer = YAMLRenderer()
- content = renderer.render({'field': Decimal('111.2')}, 'application/yaml')
- self.assertYAMLContains(content.decode('utf-8'), "field: '111.2'")
-
- def assertYAMLContains(self, content, string):
- self.assertTrue(string in content, '%r not in %r' % (string, content))
-
- def test_proper_encoding(self):
- obj = {'countries': ['United Kingdom', 'France', 'España']}
- renderer = YAMLRenderer()
- content = renderer.render(obj, 'application/yaml')
- self.assertEqual(content.strip(), 'countries: [United Kingdom, France, España]'.encode('utf-8'))
-
-
-class XMLRendererTestCase(TestCase):
- """
- Tests specific to the XML Renderer
- """
-
- _complex_data = {
- "creation_date": datetime.datetime(2011, 12, 25, 12, 45, 00),
- "name": "name",
- "sub_data_list": [
- {
- "sub_id": 1,
- "sub_name": "first"
- },
- {
- "sub_id": 2,
- "sub_name": "second"
- }
- ]
- }
-
- def test_render_string(self):
- """
- Test XML rendering.
- """
- renderer = XMLRenderer()
- content = renderer.render({'field': 'astring'}, 'application/xml')
- self.assertXMLContains(content, '<field>astring</field>')
-
- def test_render_integer(self):
- """
- Test XML rendering.
- """
- renderer = XMLRenderer()
- content = renderer.render({'field': 111}, 'application/xml')
- self.assertXMLContains(content, '<field>111</field>')
-
- def test_render_datetime(self):
- """
- Test XML rendering.
- """
- renderer = XMLRenderer()
- content = renderer.render({
- 'field': datetime.datetime(2011, 12, 25, 12, 45, 00)
- }, 'application/xml')
- self.assertXMLContains(content, '<field>2011-12-25 12:45:00</field>')
-
- def test_render_float(self):
- """
- Test XML rendering.
- """
- renderer = XMLRenderer()
- content = renderer.render({'field': 123.4}, 'application/xml')
- self.assertXMLContains(content, '<field>123.4</field>')
-
- def test_render_decimal(self):
- """
- Test XML rendering.
- """
- renderer = XMLRenderer()
- content = renderer.render({'field': Decimal('111.2')}, 'application/xml')
- self.assertXMLContains(content, '<field>111.2</field>')
-
- def test_render_none(self):
- """
- Test XML rendering.
- """
- renderer = XMLRenderer()
- content = renderer.render({'field': None}, 'application/xml')
- self.assertXMLContains(content, '<field></field>')
-
- def test_render_complex_data(self):
- """
- Test XML rendering.
- """
- renderer = XMLRenderer()
- content = renderer.render(self._complex_data, 'application/xml')
- self.assertXMLContains(content, '<sub_name>first</sub_name>')
- self.assertXMLContains(content, '<sub_name>second</sub_name>')
-
- @unittest.skipUnless(etree, 'defusedxml not installed')
- def test_render_and_parse_complex_data(self):
- """
- Test XML rendering.
- """
- renderer = XMLRenderer()
- content = StringIO(renderer.render(self._complex_data, 'application/xml'))
-
- parser = XMLParser()
- complex_data_out = parser.parse(content)
- error_msg = "complex data differs!IN:\n %s \n\n OUT:\n %s" % (repr(self._complex_data), repr(complex_data_out))
- self.assertEqual(self._complex_data, complex_data_out, error_msg)
-
- def assertXMLContains(self, xml, string):
- self.assertTrue(xml.startswith('<?xml version="1.0" encoding="utf-8"?>\n<root>'))
- self.assertTrue(xml.endswith('</root>'))
- self.assertTrue(string in xml, '%r not in %r' % (string, xml))
-
-
# Tests for caching issue, #346
class CacheRenderTest(TestCase):
"""
@@ -638,3 +428,25 @@ class CacheRenderTest(TestCase):
assert isinstance(cached_response, Response)
assert cached_response.content == response.content
assert cached_response.status_code == response.status_code
+
+
+class TestJSONIndentationStyles:
+ def test_indented(self):
+ renderer = JSONRenderer()
+ data = OrderedDict([('a', 1), ('b', 2)])
+ assert renderer.render(data) == b'{"a":1,"b":2}'
+
+ def test_compact(self):
+ renderer = JSONRenderer()
+ data = OrderedDict([('a', 1), ('b', 2)])
+ context = {'indent': 4}
+ assert (
+ renderer.render(data, renderer_context=context) ==
+ b'{\n "a": 1,\n "b": 2\n}'
+ )
+
+ def test_long_form(self):
+ renderer = JSONRenderer()
+ renderer.compact = False
+ data = OrderedDict([('a', 1), ('b', 2)])
+ assert renderer.render(data) == b'{"a": 1, "b": 2}'
diff --git a/tests/test_serializer_bulk_update.py b/tests/test_serializer_bulk_update.py
index fb881a75..bc955b2e 100644
--- a/tests/test_serializer_bulk_update.py
+++ b/tests/test_serializer_bulk_update.py
@@ -101,7 +101,7 @@ class BulkCreateSerializerTests(TestCase):
serializer = self.BookSerializer(data=data, many=True)
self.assertEqual(serializer.is_valid(), False)
- expected_errors = {'non_field_errors': ['Expected a list of items but got type `int`.']}
+ expected_errors = {'non_field_errors': ['Expected a list of items but got type "int".']}
self.assertEqual(serializer.errors, expected_errors)
@@ -118,6 +118,6 @@ class BulkCreateSerializerTests(TestCase):
serializer = self.BookSerializer(data=data, many=True)
self.assertEqual(serializer.is_valid(), False)
- expected_errors = {'non_field_errors': ['Expected a list of items but got type `dict`.']}
+ expected_errors = {'non_field_errors': ['Expected a list of items but got type "dict".']}
self.assertEqual(serializer.errors, expected_errors)
diff --git a/tests/test_templatetags.py b/tests/test_templatetags.py
index b04a937e..0cee91f1 100644
--- a/tests/test_templatetags.py
+++ b/tests/test_templatetags.py
@@ -54,7 +54,7 @@ class Issue1386Tests(TestCase):
class URLizerTests(TestCase):
"""
- Test if both JSON and YAML URLs are transformed into links well
+ Test if JSON URLs are transformed into links well
"""
def _urlize_dict_check(self, data):
"""
@@ -73,14 +73,3 @@ class URLizerTests(TestCase):
data['"foo_set": [\n "http://api/foos/1/"\n], '] = \
'&quot;foo_set&quot;: [\n &quot;<a href="http://api/foos/1/">http://api/foos/1/</a>&quot;\n], '
self._urlize_dict_check(data)
-
- def test_yaml_with_url(self):
- """
- Test if YAML URLs are transformed into links well
- """
- data = {}
- data['''{users: 'http://api/users/'}'''] = \
- '''{users: &#39;<a href="http://api/users/">http://api/users/</a>&#39;}'''
- data['''foo_set: ['http://api/foos/1/']'''] = \
- '''foo_set: [&#39;<a href="http://api/foos/1/">http://api/foos/1/</a>&#39;]'''
- self._urlize_dict_check(data)
diff --git a/tests/test_versioning.py b/tests/test_versioning.py
new file mode 100644
index 00000000..c44f727d
--- /dev/null
+++ b/tests/test_versioning.py
@@ -0,0 +1,223 @@
+from django.conf.urls import include, url
+from rest_framework import status, versioning
+from rest_framework.decorators import APIView
+from rest_framework.response import Response
+from rest_framework.reverse import reverse
+from rest_framework.test import APIRequestFactory, APITestCase
+
+
+class RequestVersionView(APIView):
+ def get(self, request, *args, **kwargs):
+ return Response({'version': request.version})
+
+
+class ReverseView(APIView):
+ def get(self, request, *args, **kwargs):
+ return Response({'url': reverse('another', request=request)})
+
+
+class RequestInvalidVersionView(APIView):
+ def determine_version(self, request, *args, **kwargs):
+ scheme = self.versioning_class()
+ scheme.allowed_versions = ('v1', 'v2')
+ return (scheme.determine_version(request, *args, **kwargs), scheme)
+
+ def get(self, request, *args, **kwargs):
+ return Response({'version': request.version})
+
+
+factory = APIRequestFactory()
+
+mock_view = lambda request: None
+
+included_patterns = [
+ url(r'^namespaced/$', mock_view, name='another'),
+]
+
+urlpatterns = [
+ url(r'^v1/', include(included_patterns, namespace='v1')),
+ url(r'^another/$', mock_view, name='another'),
+ url(r'^(?P<version>[^/]+)/another/$', mock_view, name='another')
+]
+
+
+class TestRequestVersion:
+ def test_unversioned(self):
+ view = RequestVersionView.as_view()
+
+ request = factory.get('/endpoint/')
+ response = view(request)
+ assert response.data == {'version': None}
+
+ def test_query_param_versioning(self):
+ scheme = versioning.QueryParameterVersioning
+ view = RequestVersionView.as_view(versioning_class=scheme)
+
+ request = factory.get('/endpoint/?version=1.2.3')
+ response = view(request)
+ assert response.data == {'version': '1.2.3'}
+
+ request = factory.get('/endpoint/')
+ response = view(request)
+ assert response.data == {'version': None}
+
+ def test_host_name_versioning(self):
+ scheme = versioning.HostNameVersioning
+ view = RequestVersionView.as_view(versioning_class=scheme)
+
+ request = factory.get('/endpoint/', HTTP_HOST='v1.example.org')
+ response = view(request)
+ assert response.data == {'version': 'v1'}
+
+ request = factory.get('/endpoint/')
+ response = view(request)
+ assert response.data == {'version': None}
+
+ def test_accept_header_versioning(self):
+ scheme = versioning.AcceptHeaderVersioning
+ view = RequestVersionView.as_view(versioning_class=scheme)
+
+ request = factory.get('/endpoint/', HTTP_ACCEPT='application/json; version=1.2.3')
+ response = view(request)
+ assert response.data == {'version': '1.2.3'}
+
+ request = factory.get('/endpoint/', HTTP_ACCEPT='application/json')
+ response = view(request)
+ assert response.data == {'version': None}
+
+ def test_url_path_versioning(self):
+ scheme = versioning.URLPathVersioning
+ view = RequestVersionView.as_view(versioning_class=scheme)
+
+ request = factory.get('/1.2.3/endpoint/')
+ response = view(request, version='1.2.3')
+ assert response.data == {'version': '1.2.3'}
+
+ request = factory.get('/endpoint/')
+ response = view(request)
+ assert response.data == {'version': None}
+
+ def test_namespace_versioning(self):
+ class FakeResolverMatch:
+ namespace = 'v1'
+
+ scheme = versioning.NamespaceVersioning
+ view = RequestVersionView.as_view(versioning_class=scheme)
+
+ request = factory.get('/v1/endpoint/')
+ request.resolver_match = FakeResolverMatch
+ response = view(request, version='v1')
+ assert response.data == {'version': 'v1'}
+
+ request = factory.get('/endpoint/')
+ response = view(request)
+ assert response.data == {'version': None}
+
+
+class TestURLReversing(APITestCase):
+ urls = 'tests.test_versioning'
+
+ def test_reverse_unversioned(self):
+ view = ReverseView.as_view()
+
+ request = factory.get('/endpoint/')
+ response = view(request)
+ assert response.data == {'url': 'http://testserver/another/'}
+
+ def test_reverse_query_param_versioning(self):
+ scheme = versioning.QueryParameterVersioning
+ view = ReverseView.as_view(versioning_class=scheme)
+
+ request = factory.get('/endpoint/?version=v1')
+ response = view(request)
+ assert response.data == {'url': 'http://testserver/another/?version=v1'}
+
+ request = factory.get('/endpoint/')
+ response = view(request)
+ assert response.data == {'url': 'http://testserver/another/'}
+
+ def test_reverse_host_name_versioning(self):
+ scheme = versioning.HostNameVersioning
+ view = ReverseView.as_view(versioning_class=scheme)
+
+ request = factory.get('/endpoint/', HTTP_HOST='v1.example.org')
+ response = view(request)
+ assert response.data == {'url': 'http://v1.example.org/another/'}
+
+ request = factory.get('/endpoint/')
+ response = view(request)
+ assert response.data == {'url': 'http://testserver/another/'}
+
+ def test_reverse_url_path_versioning(self):
+ scheme = versioning.URLPathVersioning
+ view = ReverseView.as_view(versioning_class=scheme)
+
+ request = factory.get('/v1/endpoint/')
+ response = view(request, version='v1')
+ assert response.data == {'url': 'http://testserver/v1/another/'}
+
+ request = factory.get('/endpoint/')
+ response = view(request)
+ assert response.data == {'url': 'http://testserver/another/'}
+
+ def test_reverse_namespace_versioning(self):
+ class FakeResolverMatch:
+ namespace = 'v1'
+
+ scheme = versioning.NamespaceVersioning
+ view = ReverseView.as_view(versioning_class=scheme)
+
+ request = factory.get('/v1/endpoint/')
+ request.resolver_match = FakeResolverMatch
+ response = view(request, version='v1')
+ assert response.data == {'url': 'http://testserver/v1/namespaced/'}
+
+ request = factory.get('/endpoint/')
+ response = view(request)
+ assert response.data == {'url': 'http://testserver/another/'}
+
+
+class TestInvalidVersion:
+ def test_invalid_query_param_versioning(self):
+ scheme = versioning.QueryParameterVersioning
+ view = RequestInvalidVersionView.as_view(versioning_class=scheme)
+
+ request = factory.get('/endpoint/?version=v3')
+ response = view(request)
+ assert response.status_code == status.HTTP_404_NOT_FOUND
+
+ def test_invalid_host_name_versioning(self):
+ scheme = versioning.HostNameVersioning
+ view = RequestInvalidVersionView.as_view(versioning_class=scheme)
+
+ request = factory.get('/endpoint/', HTTP_HOST='v3.example.org')
+ response = view(request)
+ assert response.status_code == status.HTTP_404_NOT_FOUND
+
+ def test_invalid_accept_header_versioning(self):
+ scheme = versioning.AcceptHeaderVersioning
+ view = RequestInvalidVersionView.as_view(versioning_class=scheme)
+
+ request = factory.get('/endpoint/', HTTP_ACCEPT='application/json; version=v3')
+ response = view(request)
+ assert response.status_code == status.HTTP_406_NOT_ACCEPTABLE
+
+ def test_invalid_url_path_versioning(self):
+ scheme = versioning.URLPathVersioning
+ view = RequestInvalidVersionView.as_view(versioning_class=scheme)
+
+ request = factory.get('/v3/endpoint/')
+ response = view(request, version='v3')
+ assert response.status_code == status.HTTP_404_NOT_FOUND
+
+ def test_invalid_namespace_versioning(self):
+ class FakeResolverMatch:
+ namespace = 'v3'
+
+ scheme = versioning.NamespaceVersioning
+ view = RequestInvalidVersionView.as_view(versioning_class=scheme)
+
+ request = factory.get('/v3/endpoint/')
+ request.resolver_match = FakeResolverMatch
+ response = view(request, version='v3')
+ assert response.status_code == status.HTTP_404_NOT_FOUND