diff options
| author | Tom Christie | 2015-02-13 13:38:44 +0000 | 
|---|---|---|
| committer | Tom Christie | 2015-02-13 13:38:44 +0000 | 
| commit | 4248a8d3fc725d9ae3fe7aaaad7ee12479ab07ab (patch) | |
| tree | c38485aec717a35de8691c3d55bd50ba3e4aae6d /tests | |
| parent | 84260b5dd66cc31858898ff11d5300a73083cca1 (diff) | |
| parent | ad32e14360a23ee3e93ff54ca206c64009d184c9 (diff) | |
| download | django-rest-framework-4248a8d3fc725d9ae3fe7aaaad7ee12479ab07ab.tar.bz2 | |
Merge pull request #2198 from tomchristie/version-3.1
Version 3.1
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/browsable_api/auth_urls.py | 1 | ||||
| -rw-r--r-- | tests/conftest.py | 20 | ||||
| -rw-r--r-- | tests/test_authentication.py | 430 | ||||
| -rw-r--r-- | tests/test_fields.py | 52 | ||||
| -rw-r--r-- | tests/test_generics.py | 8 | ||||
| -rw-r--r-- | tests/test_metadata.py | 60 | ||||
| -rw-r--r-- | tests/test_model_serializer.py | 8 | ||||
| -rw-r--r-- | tests/test_pagination.py | 1048 | ||||
| -rw-r--r-- | tests/test_parsers.py | 60 | ||||
| -rw-r--r-- | tests/test_relations.py | 2 | ||||
| -rw-r--r-- | tests/test_relations_hyperlink.py | 7 | ||||
| -rw-r--r-- | tests/test_renderers.py | 238 | ||||
| -rw-r--r-- | tests/test_serializer_bulk_update.py | 4 | ||||
| -rw-r--r-- | tests/test_templatetags.py | 13 | ||||
| -rw-r--r-- | tests/test_versioning.py | 264 | ||||
| -rw-r--r-- | tests/utils.py | 24 | 
16 files changed, 990 insertions, 1249 deletions
| diff --git a/tests/browsable_api/auth_urls.py b/tests/browsable_api/auth_urls.py index bce7dcf9..97bc1036 100644 --- a/tests/browsable_api/auth_urls.py +++ b/tests/browsable_api/auth_urls.py @@ -3,6 +3,7 @@ from django.conf.urls import patterns, url, include  from .views import MockView +  urlpatterns = patterns(      '',      (r'^$', MockView.as_view()), diff --git a/tests/conftest.py b/tests/conftest.py index 31142eaf..44ed070b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -44,26 +44,6 @@ def pytest_configure():          ),      ) -    try: -        import oauth_provider  # NOQA -        import oauth2  # NOQA -    except ImportError: -        pass -    else: -        settings.INSTALLED_APPS += ( -            'oauth_provider', -        ) - -    try: -        import provider  # NOQA -    except ImportError: -        pass -    else: -        settings.INSTALLED_APPS += ( -            'provider', -            'provider.oauth2', -        ) -      # guardian is optional      try:          import guardian  # NOQA diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 19fe6043..91e49f9d 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -3,8 +3,7 @@ from django.conf.urls import patterns, url, include  from django.contrib.auth.models import User  from django.http import HttpResponse  from django.test import TestCase -from django.utils import six, unittest -from django.utils.http import urlencode +from django.utils import six  from rest_framework import HTTP_HEADER_ENCODING  from rest_framework import exceptions  from rest_framework import permissions @@ -16,17 +15,11 @@ from rest_framework.authentication import (      TokenAuthentication,      BasicAuthentication,      SessionAuthentication, -    OAuthAuthentication, -    OAuth2Authentication  )  from rest_framework.authtoken.models import Token -from rest_framework.compat import oauth2_provider, oauth2_provider_scope -from rest_framework.compat import oauth, oauth_provider  from rest_framework.test import APIRequestFactory, APIClient  from rest_framework.views import APIView  import base64 -import time -import datetime  factory = APIRequestFactory() @@ -50,37 +43,10 @@ urlpatterns = patterns(      (r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])),      (r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),      (r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'), -    (r'^oauth/$', MockView.as_view(authentication_classes=[OAuthAuthentication])), -    ( -        r'^oauth-with-scope/$', -        MockView.as_view( -            authentication_classes=[OAuthAuthentication], -            permission_classes=[permissions.TokenHasReadWriteScope] -        ) -    ),      url(r'^auth/', include('rest_framework.urls', namespace='rest_framework'))  ) -class OAuth2AuthenticationDebug(OAuth2Authentication): -    allow_query_params_token = True - -if oauth2_provider is not None: -    urlpatterns += patterns( -        '', -        url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')), -        url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])), -        url(r'^oauth2-test-debug/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationDebug])), -        url( -            r'^oauth2-with-scope-test/$', -            MockView.as_view( -                authentication_classes=[OAuth2Authentication], -                permission_classes=[permissions.TokenHasReadWriteScope] -            ) -        ) -    ) - -  class BasicAuthTests(TestCase):      """Basic authentication"""      urls = 'tests.test_authentication' @@ -285,400 +251,6 @@ class IncorrectCredentialsTests(TestCase):          self.assertEqual(response.data, {'detail': 'Bad credentials'}) -class OAuthTests(TestCase): -    """OAuth 1.0a authentication""" -    urls = 'tests.test_authentication' - -    def setUp(self): -        # these imports are here because oauth is optional and hiding them in try..except block or compat -        # could obscure problems if something breaks -        from oauth_provider.models import Consumer, Scope -        from oauth_provider.models import Token as OAuthToken -        from oauth_provider import consts - -        self.consts = consts - -        self.csrf_client = APIClient(enforce_csrf_checks=True) -        self.username = 'john' -        self.email = 'lennon@thebeatles.com' -        self.password = 'password' -        self.user = User.objects.create_user(self.username, self.email, self.password) - -        self.CONSUMER_KEY = 'consumer_key' -        self.CONSUMER_SECRET = 'consumer_secret' -        self.TOKEN_KEY = "token_key" -        self.TOKEN_SECRET = "token_secret" - -        self.consumer = Consumer.objects.create( -            key=self.CONSUMER_KEY, secret=self.CONSUMER_SECRET, -            name='example', user=self.user, status=self.consts.ACCEPTED -        ) - -        self.scope = Scope.objects.create(name="resource name", url="api/") -        self.token = OAuthToken.objects.create( -            user=self.user, consumer=self.consumer, scope=self.scope, -            token_type=OAuthToken.ACCESS, key=self.TOKEN_KEY, secret=self.TOKEN_SECRET, -            is_approved=True -        ) - -    def _create_authorization_header(self): -        params = { -            'oauth_version': "1.0", -            'oauth_nonce': oauth.generate_nonce(), -            'oauth_timestamp': int(time.time()), -            'oauth_token': self.token.key, -            'oauth_consumer_key': self.consumer.key -        } - -        req = oauth.Request(method="GET", url="http://example.com", parameters=params) - -        signature_method = oauth.SignatureMethod_PLAINTEXT() -        req.sign_request(signature_method, self.consumer, self.token) - -        return req.to_header()["Authorization"] - -    def _create_authorization_url_parameters(self): -        params = { -            'oauth_version': "1.0", -            'oauth_nonce': oauth.generate_nonce(), -            'oauth_timestamp': int(time.time()), -            'oauth_token': self.token.key, -            'oauth_consumer_key': self.consumer.key -        } - -        req = oauth.Request(method="GET", url="http://example.com", parameters=params) - -        signature_method = oauth.SignatureMethod_PLAINTEXT() -        req.sign_request(signature_method, self.consumer, self.token) -        return dict(req) - -    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') -    @unittest.skipUnless(oauth, 'oauth2 not installed') -    def test_post_form_passing_oauth(self): -        """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF""" -        auth = self._create_authorization_header() -        response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 200) - -    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') -    @unittest.skipUnless(oauth, 'oauth2 not installed') -    def test_post_form_repeated_nonce_failing_oauth(self): -        """Ensure POSTing form over OAuth with repeated auth (same nonces and timestamp) credentials fails""" -        auth = self._create_authorization_header() -        response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 200) - -        # simulate reply attack auth header containes already used (nonce, timestamp) pair -        response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) -        self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) - -    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') -    @unittest.skipUnless(oauth, 'oauth2 not installed') -    def test_post_form_token_removed_failing_oauth(self): -        """Ensure POSTing when there is no OAuth access token in db fails""" -        self.token.delete() -        auth = self._create_authorization_header() -        response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) -        self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) - -    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') -    @unittest.skipUnless(oauth, 'oauth2 not installed') -    def test_post_form_consumer_status_not_accepted_failing_oauth(self): -        """Ensure POSTing when consumer status is anything other than ACCEPTED fails""" -        for consumer_status in (self.consts.CANCELED, self.consts.PENDING, self.consts.REJECTED): -            self.consumer.status = consumer_status -            self.consumer.save() - -            auth = self._create_authorization_header() -            response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) -            self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) - -    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') -    @unittest.skipUnless(oauth, 'oauth2 not installed') -    def test_post_form_with_request_token_failing_oauth(self): -        """Ensure POSTing with unauthorized request token instead of access token fails""" -        self.token.token_type = self.token.REQUEST -        self.token.save() - -        auth = self._create_authorization_header() -        response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) -        self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) - -    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') -    @unittest.skipUnless(oauth, 'oauth2 not installed') -    def test_post_form_with_urlencoded_parameters(self): -        """Ensure POSTing with x-www-form-urlencoded auth parameters passes""" -        params = self._create_authorization_url_parameters() -        auth = self._create_authorization_header() -        response = self.csrf_client.post('/oauth/', params, HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 200) - -    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') -    @unittest.skipUnless(oauth, 'oauth2 not installed') -    def test_get_form_with_url_parameters(self): -        """Ensure GETing with auth in url parameters passes""" -        params = self._create_authorization_url_parameters() -        response = self.csrf_client.get('/oauth/', params) -        self.assertEqual(response.status_code, 200) - -    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') -    @unittest.skipUnless(oauth, 'oauth2 not installed') -    def test_post_hmac_sha1_signature_passes(self): -        """Ensure POSTing using HMAC_SHA1 signature method passes""" -        params = { -            'oauth_version': "1.0", -            'oauth_nonce': oauth.generate_nonce(), -            'oauth_timestamp': int(time.time()), -            'oauth_token': self.token.key, -            'oauth_consumer_key': self.consumer.key -        } - -        req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params) - -        signature_method = oauth.SignatureMethod_HMAC_SHA1() -        req.sign_request(signature_method, self.consumer, self.token) -        auth = req.to_header()["Authorization"] - -        response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 200) - -    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') -    @unittest.skipUnless(oauth, 'oauth2 not installed') -    def test_get_form_with_readonly_resource_passing_auth(self): -        """Ensure POSTing with a readonly scope instead of a write scope fails""" -        read_only_access_token = self.token -        read_only_access_token.scope.is_readonly = True -        read_only_access_token.scope.save() -        params = self._create_authorization_url_parameters() -        response = self.csrf_client.get('/oauth-with-scope/', params) -        self.assertEqual(response.status_code, 200) - -    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') -    @unittest.skipUnless(oauth, 'oauth2 not installed') -    def test_post_form_with_readonly_resource_failing_auth(self): -        """Ensure POSTing with a readonly resource instead of a write scope fails""" -        read_only_access_token = self.token -        read_only_access_token.scope.is_readonly = True -        read_only_access_token.scope.save() -        params = self._create_authorization_url_parameters() -        response = self.csrf_client.post('/oauth-with-scope/', params) -        self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) - -    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') -    @unittest.skipUnless(oauth, 'oauth2 not installed') -    def test_post_form_with_write_resource_passing_auth(self): -        """Ensure POSTing with a write resource succeed""" -        read_write_access_token = self.token -        read_write_access_token.scope.is_readonly = False -        read_write_access_token.scope.save() -        params = self._create_authorization_url_parameters() -        auth = self._create_authorization_header() -        response = self.csrf_client.post('/oauth-with-scope/', params, HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 200) - -    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') -    @unittest.skipUnless(oauth, 'oauth2 not installed') -    def test_bad_consumer_key(self): -        """Ensure POSTing using HMAC_SHA1 signature method passes""" -        params = { -            'oauth_version': "1.0", -            'oauth_nonce': oauth.generate_nonce(), -            'oauth_timestamp': int(time.time()), -            'oauth_token': self.token.key, -            'oauth_consumer_key': 'badconsumerkey' -        } - -        req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params) - -        signature_method = oauth.SignatureMethod_HMAC_SHA1() -        req.sign_request(signature_method, self.consumer, self.token) -        auth = req.to_header()["Authorization"] - -        response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 401) - -    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') -    @unittest.skipUnless(oauth, 'oauth2 not installed') -    def test_bad_token_key(self): -        """Ensure POSTing using HMAC_SHA1 signature method passes""" -        params = { -            'oauth_version': "1.0", -            'oauth_nonce': oauth.generate_nonce(), -            'oauth_timestamp': int(time.time()), -            'oauth_token': 'badtokenkey', -            'oauth_consumer_key': self.consumer.key -        } - -        req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params) - -        signature_method = oauth.SignatureMethod_HMAC_SHA1() -        req.sign_request(signature_method, self.consumer, self.token) -        auth = req.to_header()["Authorization"] - -        response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 401) - - -class OAuth2Tests(TestCase): -    """OAuth 2.0 authentication""" -    urls = 'tests.test_authentication' - -    def setUp(self): -        self.csrf_client = APIClient(enforce_csrf_checks=True) -        self.username = 'john' -        self.email = 'lennon@thebeatles.com' -        self.password = 'password' -        self.user = User.objects.create_user(self.username, self.email, self.password) - -        self.CLIENT_ID = 'client_key' -        self.CLIENT_SECRET = 'client_secret' -        self.ACCESS_TOKEN = "access_token" -        self.REFRESH_TOKEN = "refresh_token" - -        self.oauth2_client = oauth2_provider.oauth2.models.Client.objects.create( -            client_id=self.CLIENT_ID, -            client_secret=self.CLIENT_SECRET, -            redirect_uri='', -            client_type=0, -            name='example', -            user=None, -        ) - -        self.access_token = oauth2_provider.oauth2.models.AccessToken.objects.create( -            token=self.ACCESS_TOKEN, -            client=self.oauth2_client, -            user=self.user, -        ) -        self.refresh_token = oauth2_provider.oauth2.models.RefreshToken.objects.create( -            user=self.user, -            access_token=self.access_token, -            client=self.oauth2_client -        ) - -    def _create_authorization_header(self, token=None): -        return "Bearer {0}".format(token or self.access_token.token) - -    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') -    def test_get_form_with_wrong_authorization_header_token_type_failing(self): -        """Ensure that a wrong token type lead to the correct HTTP error status code""" -        auth = "Wrong token-type-obsviously" -        response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 401) -        response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 401) - -    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') -    def test_get_form_with_wrong_authorization_header_token_format_failing(self): -        """Ensure that a wrong token format lead to the correct HTTP error status code""" -        auth = "Bearer wrong token format" -        response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 401) -        response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 401) - -    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') -    def test_get_form_with_wrong_authorization_header_token_failing(self): -        """Ensure that a wrong token lead to the correct HTTP error status code""" -        auth = "Bearer wrong-token" -        response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 401) -        response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 401) - -    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') -    def test_get_form_with_wrong_authorization_header_token_missing(self): -        """Ensure that a missing token lead to the correct HTTP error status code""" -        auth = "Bearer" -        response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 401) -        response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 401) - -    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') -    def test_get_form_passing_auth(self): -        """Ensure GETing form over OAuth with correct client credentials succeed""" -        auth = self._create_authorization_header() -        response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 200) - -    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') -    def test_post_form_passing_auth_url_transport(self): -        """Ensure GETing form over OAuth with correct client credentials in form data succeed""" -        response = self.csrf_client.post( -            '/oauth2-test/', -            data={'access_token': self.access_token.token} -        ) -        self.assertEqual(response.status_code, 200) - -    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') -    def test_get_form_passing_auth_url_transport(self): -        """Ensure GETing form over OAuth with correct client credentials in query succeed when DEBUG is True""" -        query = urlencode({'access_token': self.access_token.token}) -        response = self.csrf_client.get('/oauth2-test-debug/?%s' % query) -        self.assertEqual(response.status_code, 200) - -    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') -    def test_get_form_failing_auth_url_transport(self): -        """Ensure GETing form over OAuth with correct client credentials in query fails when DEBUG is False""" -        query = urlencode({'access_token': self.access_token.token}) -        response = self.csrf_client.get('/oauth2-test/?%s' % query) -        self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) - -    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') -    def test_post_form_passing_auth(self): -        """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF""" -        auth = self._create_authorization_header() -        response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 200) - -    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') -    def test_post_form_token_removed_failing_auth(self): -        """Ensure POSTing when there is no OAuth access token in db fails""" -        self.access_token.delete() -        auth = self._create_authorization_header() -        response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) -        self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) - -    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') -    def test_post_form_with_refresh_token_failing_auth(self): -        """Ensure POSTing with refresh token instead of access token fails""" -        auth = self._create_authorization_header(token=self.refresh_token.token) -        response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) -        self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) - -    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') -    def test_post_form_with_expired_access_token_failing_auth(self): -        """Ensure POSTing with expired access token fails with an 'Invalid token' error""" -        self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10)  # 10 seconds late -        self.access_token.save() -        auth = self._create_authorization_header() -        response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) -        self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) -        self.assertIn('Invalid token', response.content) - -    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') -    def test_post_form_with_invalid_scope_failing_auth(self): -        """Ensure POSTing with a readonly scope instead of a write scope fails""" -        read_only_access_token = self.access_token -        read_only_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['read'] -        read_only_access_token.save() -        auth = self._create_authorization_header(token=read_only_access_token.token) -        response = self.csrf_client.get('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 200) -        response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - -    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') -    def test_post_form_with_valid_scope_passing_auth(self): -        """Ensure POSTing with a write scope succeed""" -        read_write_access_token = self.access_token -        read_write_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['write'] -        read_write_access_token.save() -        auth = self._create_authorization_header(token=read_write_access_token.token) -        response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 200) - -  class FailingAuthAccessedInRenderer(TestCase):      def setUp(self):          class AuthAccessingRenderer(renderers.BaseRenderer): diff --git a/tests/test_fields.py b/tests/test_fields.py index 6744cf64..ab3418bd 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -347,7 +347,7 @@ class TestBooleanField(FieldValues):          False: False,      }      invalid_inputs = { -        'foo': ['`foo` is not a valid boolean.'], +        'foo': ['"foo" is not a valid boolean.'],          None: ['This field may not be null.']      }      outputs = { @@ -377,7 +377,7 @@ class TestNullBooleanField(FieldValues):          None: None      }      invalid_inputs = { -        'foo': ['`foo` is not a valid boolean.'], +        'foo': ['"foo" is not a valid boolean.'],      }      outputs = {          'true': True, @@ -410,6 +410,14 @@ class TestCharField(FieldValues):      }      field = serializers.CharField() +    def test_trim_whitespace_default(self): +        field = serializers.CharField() +        assert field.to_internal_value(' abc ') == 'abc' + +    def test_trim_whitespace_disabled(self): +        field = serializers.CharField(trim_whitespace=False) +        assert field.to_internal_value(' abc ') == ' abc ' +  class TestEmailField(FieldValues):      """ @@ -448,7 +456,7 @@ class TestSlugField(FieldValues):          'slug-99': 'slug-99',      }      invalid_inputs = { -        'slug 99': ["Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens."] +        'slug 99': ['Enter a valid "slug" consisting of letters, numbers, underscores or hyphens.']      }      outputs = {}      field = serializers.SlugField() @@ -666,8 +674,8 @@ class TestDateField(FieldValues):          datetime.date(2001, 1, 1): datetime.date(2001, 1, 1),      }      invalid_inputs = { -        'abc': ['Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]'], -        '2001-99-99': ['Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]'], +        'abc': ['Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]].'], +        '2001-99-99': ['Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]].'],          datetime.datetime(2001, 1, 1, 12, 00): ['Expected a date but got a datetime.'],      }      outputs = { @@ -684,7 +692,7 @@ class TestCustomInputFormatDateField(FieldValues):          '1 Jan 2001': datetime.date(2001, 1, 1),      }      invalid_inputs = { -        '2001-01-01': ['Date has wrong format. Use one of these formats instead: DD [Jan-Dec] YYYY'] +        '2001-01-01': ['Date has wrong format. Use one of these formats instead: DD [Jan-Dec] YYYY.']      }      outputs = {}      field = serializers.DateField(input_formats=['%d %b %Y']) @@ -728,8 +736,8 @@ class TestDateTimeField(FieldValues):          '2001-01-01T14:00+01:00' if (django.VERSION > (1, 4)) else '2001-01-01T13:00Z': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC())      }      invalid_inputs = { -        'abc': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]'], -        '2001-99-99T99:00': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]'], +        'abc': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z].'], +        '2001-99-99T99:00': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z].'],          datetime.date(2001, 1, 1): ['Expected a datetime but got a date.'],      }      outputs = { @@ -747,7 +755,7 @@ class TestCustomInputFormatDateTimeField(FieldValues):          '1:35pm, 1 Jan 2001': datetime.datetime(2001, 1, 1, 13, 35, tzinfo=timezone.UTC()),      }      invalid_inputs = { -        '2001-01-01T20:50': ['Datetime has wrong format. Use one of these formats instead: hh:mm[AM|PM], DD [Jan-Dec] YYYY'] +        '2001-01-01T20:50': ['Datetime has wrong format. Use one of these formats instead: hh:mm[AM|PM], DD [Jan-Dec] YYYY.']      }      outputs = {}      field = serializers.DateTimeField(default_timezone=timezone.UTC(), input_formats=['%I:%M%p, %d %b %Y']) @@ -799,8 +807,8 @@ class TestTimeField(FieldValues):          datetime.time(13, 00): datetime.time(13, 00),      }      invalid_inputs = { -        'abc': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]]'], -        '99:99': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]]'], +        'abc': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]].'], +        '99:99': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]].'],      }      outputs = {          datetime.time(13, 00): '13:00:00' @@ -816,7 +824,7 @@ class TestCustomInputFormatTimeField(FieldValues):          '1:00pm': datetime.time(13, 00),      }      invalid_inputs = { -        '13:00': ['Time has wrong format. Use one of these formats instead: hh:mm[AM|PM]'], +        '13:00': ['Time has wrong format. Use one of these formats instead: hh:mm[AM|PM].'],      }      outputs = {}      field = serializers.TimeField(input_formats=['%I:%M%p']) @@ -858,7 +866,7 @@ class TestChoiceField(FieldValues):          'good': 'good',      }      invalid_inputs = { -        'amazing': ['`amazing` is not a valid choice.'] +        'amazing': ['"amazing" is not a valid choice.']      }      outputs = {          'good': 'good', @@ -898,8 +906,8 @@ class TestChoiceFieldWithType(FieldValues):          3: 3,      }      invalid_inputs = { -        5: ['`5` is not a valid choice.'], -        'abc': ['`abc` is not a valid choice.'] +        5: ['"5" is not a valid choice.'], +        'abc': ['"abc" is not a valid choice.']      }      outputs = {          '1': 1, @@ -925,7 +933,7 @@ class TestChoiceFieldWithListChoices(FieldValues):          'good': 'good',      }      invalid_inputs = { -        'awful': ['`awful` is not a valid choice.'] +        'awful': ['"awful" is not a valid choice.']      }      outputs = {          'good': 'good' @@ -943,8 +951,8 @@ class TestMultipleChoiceField(FieldValues):          ('aircon', 'manual'): set(['aircon', 'manual']),      }      invalid_inputs = { -        'abc': ['Expected a list of items but got type `str`.'], -        ('aircon', 'incorrect'): ['`incorrect` is not a valid choice.'] +        'abc': ['Expected a list of items but got type "str".'], +        ('aircon', 'incorrect'): ['"incorrect" is not a valid choice.']      }      outputs = [          (['aircon', 'manual'], set(['aircon', 'manual'])) @@ -1054,7 +1062,7 @@ class TestListField(FieldValues):          (['1', '2', '3'], [1, 2, 3])      ]      invalid_inputs = [ -        ('not a list', ['Expected a list of items but got type `str`']), +        ('not a list', ['Expected a list of items but got type "str".']),          ([1, 2, 'error'], ['A valid integer is required.'])      ]      outputs = [ @@ -1072,7 +1080,7 @@ class TestUnvalidatedListField(FieldValues):          ([1, '2', True, [4, 5, 6]], [1, '2', True, [4, 5, 6]]),      ]      invalid_inputs = [ -        ('not a list', ['Expected a list of items but got type `str`']), +        ('not a list', ['Expected a list of items but got type "str".']),      ]      outputs = [          ([1, '2', True, [4, 5, 6]], [1, '2', True, [4, 5, 6]]), @@ -1089,7 +1097,7 @@ class TestDictField(FieldValues):      ]      invalid_inputs = [          ({'a': 1, 'b': None}, ['This field may not be null.']), -        ('not a dict', ['Expected a dictionary of items but got type `str`']), +        ('not a dict', ['Expected a dictionary of items but got type "str".']),      ]      outputs = [          ({'a': 1, 'b': '2', 3: 3}, {'a': '1', 'b': '2', '3': '3'}), @@ -1105,7 +1113,7 @@ class TestUnvalidatedDictField(FieldValues):          ({'a': 1, 'b': [4, 5, 6], 1: 123}, {'a': 1, 'b': [4, 5, 6], '1': 123}),      ]      invalid_inputs = [ -        ('not a dict', ['Expected a dictionary of items but got type `str`']), +        ('not a dict', ['Expected a dictionary of items but got type "str".']),      ]      outputs = [          ({'a': 1, 'b': [4, 5, 6]}, {'a': 1, 'b': [4, 5, 6]}), diff --git a/tests/test_generics.py b/tests/test_generics.py index 94023c30..88e792ce 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -117,7 +117,7 @@ class TestRootView(TestCase):          with self.assertNumQueries(0):              response = self.view(request).render()          self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) -        self.assertEqual(response.data, {"detail": "Method 'PUT' not allowed."}) +        self.assertEqual(response.data, {"detail": 'Method "PUT" not allowed.'})      def test_delete_root_view(self):          """ @@ -127,7 +127,7 @@ class TestRootView(TestCase):          with self.assertNumQueries(0):              response = self.view(request).render()          self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) -        self.assertEqual(response.data, {"detail": "Method 'DELETE' not allowed."}) +        self.assertEqual(response.data, {"detail": 'Method "DELETE" not allowed.'})      def test_post_cannot_set_id(self):          """ @@ -181,7 +181,7 @@ class TestInstanceView(TestCase):          with self.assertNumQueries(0):              response = self.view(request).render()          self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) -        self.assertEqual(response.data, {"detail": "Method 'POST' not allowed."}) +        self.assertEqual(response.data, {"detail": 'Method "POST" not allowed.'})      def test_put_instance_view(self):          """ @@ -483,7 +483,7 @@ class TestFilterBackendAppliedToViews(TestCase):          request = factory.get('/1')          response = instance_view(request, pk=1).render()          self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) -        self.assertEqual(response.data, {'detail': 'Not found'}) +        self.assertEqual(response.data, {'detail': 'Not found.'})      def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self):          """ diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 5ff59c72..5031c0f3 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -1,9 +1,8 @@  from __future__ import unicode_literals - -from rest_framework import exceptions, serializers, views +from rest_framework import exceptions, serializers, status, views, versioning  from rest_framework.request import Request +from rest_framework.renderers import BrowsableAPIRenderer  from rest_framework.test import APIRequestFactory -import pytest  request = Request(APIRequestFactory().options('/')) @@ -17,7 +16,8 @@ class TestMetadata:              """Example view."""              pass -        response = ExampleView().options(request=request) +        view = ExampleView.as_view() +        response = view(request=request)          expected = {              'name': 'Example',              'description': 'Example view.', @@ -31,7 +31,7 @@ class TestMetadata:                  'multipart/form-data'              ]          } -        assert response.status_code == 200 +        assert response.status_code == status.HTTP_200_OK          assert response.data == expected      def test_none_metadata(self): @@ -42,8 +42,10 @@ class TestMetadata:          class ExampleView(views.APIView):              metadata_class = None -        with pytest.raises(exceptions.MethodNotAllowed): -            ExampleView().options(request=request) +        view = ExampleView.as_view() +        response = view(request=request) +        assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED +        assert response.data == {'detail': 'Method "OPTIONS" not allowed.'}      def test_actions(self):          """ @@ -63,7 +65,8 @@ class TestMetadata:              def get_serializer(self):                  return ExampleSerializer() -        response = ExampleView().options(request=request) +        view = ExampleView.as_view() +        response = view(request=request)          expected = {              'name': 'Example',              'description': 'Example view.', @@ -104,7 +107,7 @@ class TestMetadata:                  }              }          } -        assert response.status_code == 200 +        assert response.status_code == status.HTTP_200_OK          assert response.data == expected      def test_global_permissions(self): @@ -132,8 +135,9 @@ class TestMetadata:                  if request.method == 'POST':                      raise exceptions.PermissionDenied() -        response = ExampleView().options(request=request) -        assert response.status_code == 200 +        view = ExampleView.as_view() +        response = view(request=request) +        assert response.status_code == status.HTTP_200_OK          assert list(response.data['actions'].keys()) == ['PUT']      def test_object_permissions(self): @@ -161,6 +165,36 @@ class TestMetadata:                  if self.request.method == 'PUT':                      raise exceptions.PermissionDenied() -        response = ExampleView().options(request=request) -        assert response.status_code == 200 +        view = ExampleView.as_view() +        response = view(request=request) +        assert response.status_code == status.HTTP_200_OK          assert list(response.data['actions'].keys()) == ['POST'] + +    def test_bug_2455_clone_request(self): +        class ExampleView(views.APIView): +            renderer_classes = (BrowsableAPIRenderer,) + +            def post(self, request): +                pass + +            def get_serializer(self): +                assert hasattr(self.request, 'version') +                return serializers.Serializer() + +        view = ExampleView.as_view() +        view(request=request) + +    def test_bug_2477_clone_request(self): +        class ExampleView(views.APIView): +            renderer_classes = (BrowsableAPIRenderer,) + +            def post(self, request): +                pass + +            def get_serializer(self): +                assert hasattr(self.request, 'versioning_scheme') +                return serializers.Serializer() + +        scheme = versioning.QueryParameterVersioning +        view = ExampleView.as_view(versioning_class=scheme) +        view(request=request) diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py index 247b309a..bce2008a 100644 --- a/tests/test_model_serializer.py +++ b/tests/test_model_serializer.py @@ -216,7 +216,7 @@ class TestRegularFieldMappings(TestCase):          with self.assertRaises(ImproperlyConfigured) as excinfo:              TestSerializer().fields -        expected = 'Field name `invalid` is not valid for model `ModelBase`.' +        expected = 'Field name `invalid` is not valid for model `RegularFieldsModel`.'          assert str(excinfo.exception) == expected      def test_missing_field(self): @@ -234,8 +234,8 @@ class TestRegularFieldMappings(TestCase):          with self.assertRaises(AssertionError) as excinfo:              TestSerializer().fields          expected = ( -            'Field `missing` has been declared on serializer ' -            '`TestSerializer`, but is missing from `Meta.fields`.' +            "The field 'missing' was declared on serializer TestSerializer, " +            "but has not been included in the 'fields' option."          )          assert str(excinfo.exception) == expected @@ -637,5 +637,5 @@ class TestSerializerMetaClass(TestCase):          exception = result.exception          self.assertEqual(              str(exception), -            "Cannot set both 'fields' and 'exclude'." +            "Cannot set both 'fields' and 'exclude' options on serializer ExampleSerializer."          ) diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 1fd9cf9c..13bfb627 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -1,553 +1,671 @@ +# coding: utf-8  from __future__ import unicode_literals -import datetime -from decimal import Decimal -from django.core.paginator import Paginator -from django.test import TestCase -from django.utils import unittest -from rest_framework import generics, serializers, status, pagination, filters -from rest_framework.compat import django_filters +from rest_framework import exceptions, generics, pagination, serializers, status, filters +from rest_framework.request import Request +from rest_framework.pagination import PageLink, PAGE_BREAK  from rest_framework.test import APIRequestFactory -from .models import BasicModel, FilterableItem +import pytest  factory = APIRequestFactory() -# Helper function to split arguments out of an url -def split_arguments_from_url(url): -    if '?' not in url: -        return url +class TestPaginationIntegration: +    """ +    Integration tests. +    """ -    path, args = url.split('?') -    args = dict(r.split('=') for r in args.split('&')) -    return path, args +    def setup(self): +        class PassThroughSerializer(serializers.BaseSerializer): +            def to_representation(self, item): +                return item +        class EvenItemsOnly(filters.BaseFilterBackend): +            def filter_queryset(self, request, queryset, view): +                return [item for item in queryset if item % 2 == 0] + +        class BasicPagination(pagination.PageNumberPagination): +            paginate_by = 5 +            paginate_by_param = 'page_size' +            max_paginate_by = 20 + +        self.view = generics.ListAPIView.as_view( +            serializer_class=PassThroughSerializer, +            queryset=range(1, 101), +            filter_backends=[EvenItemsOnly], +            pagination_class=BasicPagination +        ) -class BasicSerializer(serializers.ModelSerializer): -    class Meta: -        model = BasicModel +    def test_filtered_items_are_paginated(self): +        request = factory.get('/', {'page': 2}) +        response = self.view(request) +        assert response.status_code == status.HTTP_200_OK +        assert response.data == { +            'results': [12, 14, 16, 18, 20], +            'previous': 'http://testserver/', +            'next': 'http://testserver/?page=3', +            'count': 50 +        } +    def test_setting_page_size(self): +        """ +        When 'paginate_by_param' is set, the client may choose a page size. +        """ +        request = factory.get('/', {'page_size': 10}) +        response = self.view(request) +        assert response.status_code == status.HTTP_200_OK +        assert response.data == { +            'results': [2, 4, 6, 8, 10, 12, 14, 16, 18, 20], +            'previous': None, +            'next': 'http://testserver/?page=2&page_size=10', +            'count': 50 +        } -class FilterableItemSerializer(serializers.ModelSerializer): -    class Meta: -        model = FilterableItem +    def test_setting_page_size_over_maximum(self): +        """ +        When page_size parameter exceeds maxiumum allowable, +        then it should be capped to the maxiumum. +        """ +        request = factory.get('/', {'page_size': 1000}) +        response = self.view(request) +        assert response.status_code == status.HTTP_200_OK +        assert response.data == { +            'results': [ +                2, 4, 6, 8, 10, 12, 14, 16, 18, 20, +                22, 24, 26, 28, 30, 32, 34, 36, 38, 40 +            ], +            'previous': None, +            'next': 'http://testserver/?page=2&page_size=1000', +            'count': 50 +        } +    def test_setting_page_size_to_zero(self): +        """ +        When page_size parameter is invalid it should return to the default. +        """ +        request = factory.get('/', {'page_size': 0}) +        response = self.view(request) +        assert response.status_code == status.HTTP_200_OK +        assert response.data == { +            'results': [2, 4, 6, 8, 10], +            'previous': None, +            'next': 'http://testserver/?page=2&page_size=0', +            'count': 50 +        } -class RootView(generics.ListCreateAPIView): -    """ -    Example description for OPTIONS. -    """ -    queryset = BasicModel.objects.all() -    serializer_class = BasicSerializer -    paginate_by = 10 +    def test_additional_query_params_are_preserved(self): +        request = factory.get('/', {'page': 2, 'filter': 'even'}) +        response = self.view(request) +        assert response.status_code == status.HTTP_200_OK +        assert response.data == { +            'results': [12, 14, 16, 18, 20], +            'previous': 'http://testserver/?filter=even', +            'next': 'http://testserver/?filter=even&page=3', +            'count': 50 +        } +    def test_404_not_found_for_zero_page(self): +        request = factory.get('/', {'page': '0'}) +        response = self.view(request) +        assert response.status_code == status.HTTP_404_NOT_FOUND +        assert response.data == { +            'detail': 'Invalid page "0": That page number is less than 1.' +        } -class DefaultPageSizeKwargView(generics.ListAPIView): -    """ -    View for testing default paginate_by_param usage -    """ -    queryset = BasicModel.objects.all() -    serializer_class = BasicSerializer +    def test_404_not_found_for_invalid_page(self): +        request = factory.get('/', {'page': 'invalid'}) +        response = self.view(request) +        assert response.status_code == status.HTTP_404_NOT_FOUND +        assert response.data == { +            'detail': 'Invalid page "invalid": That page number is not an integer.' +        } -class PaginateByParamView(generics.ListAPIView): +class TestPaginationDisabledIntegration:      """ -    View for testing custom paginate_by_param usage +    Integration tests for disabled pagination.      """ -    queryset = BasicModel.objects.all() -    serializer_class = BasicSerializer -    paginate_by_param = 'page_size' +    def setup(self): +        class PassThroughSerializer(serializers.BaseSerializer): +            def to_representation(self, item): +                return item -class MaxPaginateByView(generics.ListAPIView): -    """ -    View for testing custom max_paginate_by usage -    """ -    queryset = BasicModel.objects.all() -    serializer_class = BasicSerializer -    paginate_by = 3 -    max_paginate_by = 5 -    paginate_by_param = 'page_size' +        self.view = generics.ListAPIView.as_view( +            serializer_class=PassThroughSerializer, +            queryset=range(1, 101), +            pagination_class=None +        ) + +    def test_unpaginated_list(self): +        request = factory.get('/', {'page': 2}) +        response = self.view(request) +        assert response.status_code == status.HTTP_200_OK +        assert response.data == list(range(1, 101)) -class IntegrationTestPagination(TestCase): +class TestDeprecatedStylePagination:      """ -    Integration tests for paginated list views. +    Integration tests for deprecated style of setting pagination +    attributes on the view.      """ -    def setUp(self): -        """ -        Create 26 BasicModel instances. -        """ -        for char in 'abcdefghijklmnopqrstuvwxyz': -            BasicModel(text=char * 3).save() -        self.objects = BasicModel.objects -        self.data = [ -            {'id': obj.id, 'text': obj.text} -            for obj in self.objects.all() -        ] -        self.view = RootView.as_view() - -    def test_get_paginated_root_view(self): -        """ -        GET requests to paginated ListCreateAPIView should return paginated results. -        """ -        request = factory.get('/') -        # Note: Database queries are a `SELECT COUNT`, and `SELECT <fields>` -        with self.assertNumQueries(2): -            response = self.view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 26) -        self.assertEqual(response.data['results'], self.data[:10]) -        self.assertNotEqual(response.data['next'], None) -        self.assertEqual(response.data['previous'], None) - -        request = factory.get(*split_arguments_from_url(response.data['next'])) -        with self.assertNumQueries(2): -            response = self.view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 26) -        self.assertEqual(response.data['results'], self.data[10:20]) -        self.assertNotEqual(response.data['next'], None) -        self.assertNotEqual(response.data['previous'], None) - -        request = factory.get(*split_arguments_from_url(response.data['next'])) -        with self.assertNumQueries(2): -            response = self.view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 26) -        self.assertEqual(response.data['results'], self.data[20:]) -        self.assertEqual(response.data['next'], None) -        self.assertNotEqual(response.data['previous'], None) - - -class IntegrationTestPaginationAndFiltering(TestCase): - -    def setUp(self): -        """ -        Create 50 FilterableItem instances. -        """ -        base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8)) -        for i in range(26): -            text = chr(i + ord(base_data[0])) * 3  # Produces string 'aaa', 'bbb', etc. -            decimal = base_data[1] + i -            date = base_data[2] - datetime.timedelta(days=i * 2) -            FilterableItem(text=text, decimal=decimal, date=date).save() - -        self.objects = FilterableItem.objects -        self.data = [ -            {'id': obj.id, 'text': obj.text, 'decimal': str(obj.decimal), 'date': obj.date.isoformat()} -            for obj in self.objects.all() -        ] - -    @unittest.skipUnless(django_filters, 'django-filter not installed') -    def test_get_django_filter_paginated_filtered_root_view(self): -        """ -        GET requests to paginated filtered ListCreateAPIView should return -        paginated results. The next and previous links should preserve the -        filtered parameters. -        """ -        class DecimalFilter(django_filters.FilterSet): -            decimal = django_filters.NumberFilter(lookup_type='lt') - -            class Meta: -                model = FilterableItem -                fields = ['text', 'decimal', 'date'] - -        class FilterFieldsRootView(generics.ListCreateAPIView): -            queryset = FilterableItem.objects.all() -            serializer_class = FilterableItemSerializer -            paginate_by = 10 -            filter_class = DecimalFilter -            filter_backends = (filters.DjangoFilterBackend,) - -        view = FilterFieldsRootView.as_view() - -        EXPECTED_NUM_QUERIES = 2 - -        request = factory.get('/', {'decimal': '15.20'}) -        with self.assertNumQueries(EXPECTED_NUM_QUERIES): -            response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 15) -        self.assertEqual(response.data['results'], self.data[:10]) -        self.assertNotEqual(response.data['next'], None) -        self.assertEqual(response.data['previous'], None) - -        request = factory.get(*split_arguments_from_url(response.data['next'])) -        with self.assertNumQueries(EXPECTED_NUM_QUERIES): -            response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 15) -        self.assertEqual(response.data['results'], self.data[10:15]) -        self.assertEqual(response.data['next'], None) -        self.assertNotEqual(response.data['previous'], None) - -        request = factory.get(*split_arguments_from_url(response.data['previous'])) -        with self.assertNumQueries(EXPECTED_NUM_QUERIES): -            response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 15) -        self.assertEqual(response.data['results'], self.data[:10]) -        self.assertNotEqual(response.data['next'], None) -        self.assertEqual(response.data['previous'], None) - -    def test_get_basic_paginated_filtered_root_view(self): -        """ -        Same as `test_get_django_filter_paginated_filtered_root_view`, -        except using a custom filter backend instead of the django-filter -        backend, -        """ +    def setup(self): +        class PassThroughSerializer(serializers.BaseSerializer): +            def to_representation(self, item): +                return item -        class DecimalFilterBackend(filters.BaseFilterBackend): -            def filter_queryset(self, request, queryset, view): -                return queryset.filter(decimal__lt=Decimal(request.GET['decimal'])) - -        class BasicFilterFieldsRootView(generics.ListCreateAPIView): -            queryset = FilterableItem.objects.all() -            serializer_class = FilterableItemSerializer -            paginate_by = 10 -            filter_backends = (DecimalFilterBackend,) - -        view = BasicFilterFieldsRootView.as_view() - -        request = factory.get('/', {'decimal': '15.20'}) -        with self.assertNumQueries(2): -            response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 15) -        self.assertEqual(response.data['results'], self.data[:10]) -        self.assertNotEqual(response.data['next'], None) -        self.assertEqual(response.data['previous'], None) - -        request = factory.get(*split_arguments_from_url(response.data['next'])) -        with self.assertNumQueries(2): -            response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 15) -        self.assertEqual(response.data['results'], self.data[10:15]) -        self.assertEqual(response.data['next'], None) -        self.assertNotEqual(response.data['previous'], None) - -        request = factory.get(*split_arguments_from_url(response.data['previous'])) -        with self.assertNumQueries(2): -            response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 15) -        self.assertEqual(response.data['results'], self.data[:10]) -        self.assertNotEqual(response.data['next'], None) -        self.assertEqual(response.data['previous'], None) - - -class PassOnContextPaginationSerializer(pagination.PaginationSerializer): -    class Meta: -        object_serializer_class = serializers.Serializer - - -class UnitTestPagination(TestCase): -    """ -    Unit tests for pagination of primitive objects. -    """ +        class ExampleView(generics.ListAPIView): +            serializer_class = PassThroughSerializer +            queryset = range(1, 101) +            pagination_class = pagination.PageNumberPagination +            paginate_by = 20 +            page_query_param = 'page_number' -    def setUp(self): -        self.objects = [char * 3 for char in 'abcdefghijklmnopqrstuvwxyz'] -        paginator = Paginator(self.objects, 10) -        self.first_page = paginator.page(1) -        self.last_page = paginator.page(3) - -    def test_native_pagination(self): -        serializer = pagination.PaginationSerializer(self.first_page) -        self.assertEqual(serializer.data['count'], 26) -        self.assertEqual(serializer.data['next'], '?page=2') -        self.assertEqual(serializer.data['previous'], None) -        self.assertEqual(serializer.data['results'], self.objects[:10]) - -        serializer = pagination.PaginationSerializer(self.last_page) -        self.assertEqual(serializer.data['count'], 26) -        self.assertEqual(serializer.data['next'], None) -        self.assertEqual(serializer.data['previous'], '?page=2') -        self.assertEqual(serializer.data['results'], self.objects[20:]) - -    def test_context_available_in_result(self): -        """ -        Ensure context gets passed through to the object serializer. -        """ -        serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'}) -        serializer.data -        results = serializer.fields[serializer.results_field] -        self.assertEqual(serializer.context, results.context) +        self.view = ExampleView.as_view() + +    def test_paginate_by_attribute_on_view(self): +        request = factory.get('/?page_number=2') +        response = self.view(request) +        assert response.status_code == status.HTTP_200_OK +        assert response.data == { +            'results': [ +                21, 22, 23, 24, 25, 26, 27, 28, 29, 30, +                31, 32, 33, 34, 35, 36, 37, 38, 39, 40 +            ], +            'previous': 'http://testserver/', +            'next': 'http://testserver/?page_number=3', +            'count': 100 +        } -class TestUnpaginated(TestCase): +class TestPageNumberPagination:      """ -    Tests for list views without pagination. +    Unit tests for `pagination.PageNumberPagination`.      """ -    def setUp(self): -        """ -        Create 13 BasicModel instances. -        """ -        for i in range(13): -            BasicModel(text=i).save() -        self.objects = BasicModel.objects -        self.data = [ -            {'id': obj.id, 'text': obj.text} -            for obj in self.objects.all() -        ] -        self.view = DefaultPageSizeKwargView.as_view() - -    def test_unpaginated(self): -        """ -        Tests the default page size for this view. -        no page size --> no limit --> no meta data -        """ -        request = factory.get('/') -        response = self.view(request) -        self.assertEqual(response.data, self.data) +    def setup(self): +        class ExamplePagination(pagination.PageNumberPagination): +            paginate_by = 5 +        self.pagination = ExamplePagination() +        self.queryset = range(1, 101) + +    def paginate_queryset(self, request): +        return list(self.pagination.paginate_queryset(self.queryset, request)) + +    def get_paginated_content(self, queryset): +        response = self.pagination.get_paginated_response(queryset) +        return response.data + +    def get_html_context(self): +        return self.pagination.get_html_context() + +    def test_no_page_number(self): +        request = Request(factory.get('/')) +        queryset = self.paginate_queryset(request) +        content = self.get_paginated_content(queryset) +        context = self.get_html_context() +        assert queryset == [1, 2, 3, 4, 5] +        assert content == { +            'results': [1, 2, 3, 4, 5], +            'previous': None, +            'next': 'http://testserver/?page=2', +            'count': 100 +        } +        assert context == { +            'previous_url': None, +            'next_url': 'http://testserver/?page=2', +            'page_links': [ +                PageLink('http://testserver/', 1, True, False), +                PageLink('http://testserver/?page=2', 2, False, False), +                PageLink('http://testserver/?page=3', 3, False, False), +                PAGE_BREAK, +                PageLink('http://testserver/?page=20', 20, False, False), +            ] +        } +        assert self.pagination.display_page_controls +        assert isinstance(self.pagination.to_html(), type('')) + +    def test_second_page(self): +        request = Request(factory.get('/', {'page': 2})) +        queryset = self.paginate_queryset(request) +        content = self.get_paginated_content(queryset) +        context = self.get_html_context() +        assert queryset == [6, 7, 8, 9, 10] +        assert content == { +            'results': [6, 7, 8, 9, 10], +            'previous': 'http://testserver/', +            'next': 'http://testserver/?page=3', +            'count': 100 +        } +        assert context == { +            'previous_url': 'http://testserver/', +            'next_url': 'http://testserver/?page=3', +            'page_links': [ +                PageLink('http://testserver/', 1, False, False), +                PageLink('http://testserver/?page=2', 2, True, False), +                PageLink('http://testserver/?page=3', 3, False, False), +                PAGE_BREAK, +                PageLink('http://testserver/?page=20', 20, False, False), +            ] +        } + +    def test_last_page(self): +        request = Request(factory.get('/', {'page': 'last'})) +        queryset = self.paginate_queryset(request) +        content = self.get_paginated_content(queryset) +        context = self.get_html_context() +        assert queryset == [96, 97, 98, 99, 100] +        assert content == { +            'results': [96, 97, 98, 99, 100], +            'previous': 'http://testserver/?page=19', +            'next': None, +            'count': 100 +        } +        assert context == { +            'previous_url': 'http://testserver/?page=19', +            'next_url': None, +            'page_links': [ +                PageLink('http://testserver/', 1, False, False), +                PAGE_BREAK, +                PageLink('http://testserver/?page=18', 18, False, False), +                PageLink('http://testserver/?page=19', 19, False, False), +                PageLink('http://testserver/?page=20', 20, True, False), +            ] +        } + +    def test_invalid_page(self): +        request = Request(factory.get('/', {'page': 'invalid'})) +        with pytest.raises(exceptions.NotFound): +            self.paginate_queryset(request) -class TestCustomPaginateByParam(TestCase): +class TestLimitOffset:      """ -    Tests for list views with default page size kwarg +    Unit tests for `pagination.LimitOffsetPagination`.      """ -    def setUp(self): +    def setup(self): +        class ExamplePagination(pagination.LimitOffsetPagination): +            default_limit = 10 +        self.pagination = ExamplePagination() +        self.queryset = range(1, 101) + +    def paginate_queryset(self, request): +        return list(self.pagination.paginate_queryset(self.queryset, request)) + +    def get_paginated_content(self, queryset): +        response = self.pagination.get_paginated_response(queryset) +        return response.data + +    def get_html_context(self): +        return self.pagination.get_html_context() + +    def test_no_offset(self): +        request = Request(factory.get('/', {'limit': 5})) +        queryset = self.paginate_queryset(request) +        content = self.get_paginated_content(queryset) +        context = self.get_html_context() +        assert queryset == [1, 2, 3, 4, 5] +        assert content == { +            'results': [1, 2, 3, 4, 5], +            'previous': None, +            'next': 'http://testserver/?limit=5&offset=5', +            'count': 100 +        } +        assert context == { +            'previous_url': None, +            'next_url': 'http://testserver/?limit=5&offset=5', +            'page_links': [ +                PageLink('http://testserver/?limit=5', 1, True, False), +                PageLink('http://testserver/?limit=5&offset=5', 2, False, False), +                PageLink('http://testserver/?limit=5&offset=10', 3, False, False), +                PAGE_BREAK, +                PageLink('http://testserver/?limit=5&offset=95', 20, False, False), +            ] +        } +        assert self.pagination.display_page_controls +        assert isinstance(self.pagination.to_html(), type('')) + +    def test_single_offset(self):          """ -        Create 13 BasicModel instances. +        When the offset is not a multiple of the limit we get some edge cases: +        * The first page should still be offset zero. +        * We may end up displaying an extra page in the pagination control.          """ -        for i in range(13): -            BasicModel(text=i).save() -        self.objects = BasicModel.objects -        self.data = [ -            {'id': obj.id, 'text': obj.text} -            for obj in self.objects.all() -        ] -        self.view = PaginateByParamView.as_view() - -    def test_default_page_size(self): +        request = Request(factory.get('/', {'limit': 5, 'offset': 1})) +        queryset = self.paginate_queryset(request) +        content = self.get_paginated_content(queryset) +        context = self.get_html_context() +        assert queryset == [2, 3, 4, 5, 6] +        assert content == { +            'results': [2, 3, 4, 5, 6], +            'previous': 'http://testserver/?limit=5', +            'next': 'http://testserver/?limit=5&offset=6', +            'count': 100 +        } +        assert context == { +            'previous_url': 'http://testserver/?limit=5', +            'next_url': 'http://testserver/?limit=5&offset=6', +            'page_links': [ +                PageLink('http://testserver/?limit=5', 1, False, False), +                PageLink('http://testserver/?limit=5&offset=1', 2, True, False), +                PageLink('http://testserver/?limit=5&offset=6', 3, False, False), +                PAGE_BREAK, +                PageLink('http://testserver/?limit=5&offset=96', 21, False, False), +            ] +        } + +    def test_first_offset(self): +        request = Request(factory.get('/', {'limit': 5, 'offset': 5})) +        queryset = self.paginate_queryset(request) +        content = self.get_paginated_content(queryset) +        context = self.get_html_context() +        assert queryset == [6, 7, 8, 9, 10] +        assert content == { +            'results': [6, 7, 8, 9, 10], +            'previous': 'http://testserver/?limit=5', +            'next': 'http://testserver/?limit=5&offset=10', +            'count': 100 +        } +        assert context == { +            'previous_url': 'http://testserver/?limit=5', +            'next_url': 'http://testserver/?limit=5&offset=10', +            'page_links': [ +                PageLink('http://testserver/?limit=5', 1, False, False), +                PageLink('http://testserver/?limit=5&offset=5', 2, True, False), +                PageLink('http://testserver/?limit=5&offset=10', 3, False, False), +                PAGE_BREAK, +                PageLink('http://testserver/?limit=5&offset=95', 20, False, False), +            ] +        } + +    def test_middle_offset(self): +        request = Request(factory.get('/', {'limit': 5, 'offset': 10})) +        queryset = self.paginate_queryset(request) +        content = self.get_paginated_content(queryset) +        context = self.get_html_context() +        assert queryset == [11, 12, 13, 14, 15] +        assert content == { +            'results': [11, 12, 13, 14, 15], +            'previous': 'http://testserver/?limit=5&offset=5', +            'next': 'http://testserver/?limit=5&offset=15', +            'count': 100 +        } +        assert context == { +            'previous_url': 'http://testserver/?limit=5&offset=5', +            'next_url': 'http://testserver/?limit=5&offset=15', +            'page_links': [ +                PageLink('http://testserver/?limit=5', 1, False, False), +                PageLink('http://testserver/?limit=5&offset=5', 2, False, False), +                PageLink('http://testserver/?limit=5&offset=10', 3, True, False), +                PageLink('http://testserver/?limit=5&offset=15', 4, False, False), +                PAGE_BREAK, +                PageLink('http://testserver/?limit=5&offset=95', 20, False, False), +            ] +        } + +    def test_ending_offset(self): +        request = Request(factory.get('/', {'limit': 5, 'offset': 95})) +        queryset = self.paginate_queryset(request) +        content = self.get_paginated_content(queryset) +        context = self.get_html_context() +        assert queryset == [96, 97, 98, 99, 100] +        assert content == { +            'results': [96, 97, 98, 99, 100], +            'previous': 'http://testserver/?limit=5&offset=90', +            'next': None, +            'count': 100 +        } +        assert context == { +            'previous_url': 'http://testserver/?limit=5&offset=90', +            'next_url': None, +            'page_links': [ +                PageLink('http://testserver/?limit=5', 1, False, False), +                PAGE_BREAK, +                PageLink('http://testserver/?limit=5&offset=85', 18, False, False), +                PageLink('http://testserver/?limit=5&offset=90', 19, False, False), +                PageLink('http://testserver/?limit=5&offset=95', 20, True, False), +            ] +        } + +    def test_invalid_offset(self):          """ -        Tests the default page size for this view. -        no page size --> no limit --> no meta data +        An invalid offset query param should be treated as 0.          """ -        request = factory.get('/') -        response = self.view(request).render() -        self.assertEqual(response.data, self.data) +        request = Request(factory.get('/', {'limit': 5, 'offset': 'invalid'})) +        queryset = self.paginate_queryset(request) +        assert queryset == [1, 2, 3, 4, 5] -    def test_paginate_by_param(self): +    def test_invalid_limit(self):          """ -        If paginate_by_param is set, the new kwarg should limit per view requests. +        An invalid limit query param should be ignored in favor of the default.          """ -        request = factory.get('/', {'page_size': 5}) -        response = self.view(request).render() -        self.assertEqual(response.data['count'], 13) -        self.assertEqual(response.data['results'], self.data[:5]) +        request = Request(factory.get('/', {'limit': 'invalid', 'offset': 0})) +        queryset = self.paginate_queryset(request) +        assert queryset == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] -class TestMaxPaginateByParam(TestCase): +class TestCursorPagination:      """ -    Tests for list views with max_paginate_by kwarg +    Unit tests for `pagination.CursorPagination`.      """ -    def setUp(self): +    def setup(self): +        class MockObject(object): +            def __init__(self, idx): +                self.created = idx + +        class MockQuerySet(object): +            def __init__(self, items): +                self.items = items + +            def filter(self, created__gt=None, created__lt=None): +                if created__gt is not None: +                    return MockQuerySet([ +                        item for item in self.items +                        if item.created > int(created__gt) +                    ]) + +                assert created__lt is not None +                return MockQuerySet([ +                    item for item in self.items +                    if item.created < int(created__lt) +                ]) + +            def order_by(self, *ordering): +                if ordering[0].startswith('-'): +                    return MockQuerySet(list(reversed(self.items))) +                return self + +            def __getitem__(self, sliced): +                return self.items[sliced] + +        class ExamplePagination(pagination.CursorPagination): +            page_size = 5 +            ordering = 'created' + +        self.pagination = ExamplePagination() +        self.queryset = MockQuerySet([ +            MockObject(idx) for idx in [ +                1, 1, 1, 1, 1, +                1, 2, 3, 4, 4, +                4, 4, 5, 6, 7, +                7, 7, 7, 7, 7, +                7, 7, 7, 8, 9, +                9, 9, 9, 9, 9 +            ] +        ]) + +    def get_pages(self, url):          """ -        Create 13 BasicModel instances. -        """ -        for i in range(13): -            BasicModel(text=i).save() -        self.objects = BasicModel.objects -        self.data = [ -            {'id': obj.id, 'text': obj.text} -            for obj in self.objects.all() -        ] -        self.view = MaxPaginateByView.as_view() - -    def test_max_paginate_by(self): -        """ -        If max_paginate_by is set, it should limit page size for the view. -        """ -        request = factory.get('/', data={'page_size': 10}) -        response = self.view(request).render() -        self.assertEqual(response.data['count'], 13) -        self.assertEqual(response.data['results'], self.data[:5]) +        Given a URL return a tuple of: -    def test_max_paginate_by_without_page_size_param(self): +        (previous page, current page, next page, previous url, next url)          """ -        If max_paginate_by is set, but client does not specifiy page_size, -        standard `paginate_by` behavior should be used. -        """ -        request = factory.get('/') -        response = self.view(request).render() -        self.assertEqual(response.data['results'], self.data[:3]) - - -# Tests for context in pagination serializers +        request = Request(factory.get(url)) +        queryset = self.pagination.paginate_queryset(self.queryset, request) +        current = [item.created for item in queryset] -class CustomField(serializers.ReadOnlyField): -    def to_native(self, value): -        if 'view' not in self.context: -            raise RuntimeError("context isn't getting passed into custom field") -        return "value" +        next_url = self.pagination.get_next_link() +        previous_url = self.pagination.get_previous_link() +        if next_url is not None: +            request = Request(factory.get(next_url)) +            queryset = self.pagination.paginate_queryset(self.queryset, request) +            next = [item.created for item in queryset] +        else: +            next = None -class BasicModelSerializer(serializers.Serializer): -    text = CustomField() - -    def to_native(self, value): -        if 'view' not in self.context: -            raise RuntimeError("context isn't getting passed into serializer") -        return super(BasicSerializer, self).to_native(value) +        if previous_url is not None: +            request = Request(factory.get(previous_url)) +            queryset = self.pagination.paginate_queryset(self.queryset, request) +            previous = [item.created for item in queryset] +        else: +            previous = None +        return (previous, current, next, previous_url, next_url) -class TestContextPassedToCustomField(TestCase): -    def setUp(self): -        BasicModel.objects.create(text='ala ma kota') +    def test_invalid_cursor(self): +        request = Request(factory.get('/', {'cursor': '123'})) +        with pytest.raises(exceptions.NotFound): +            self.pagination.paginate_queryset(self.queryset, request) -    def test_with_pagination(self): -        class ListView(generics.ListCreateAPIView): -            queryset = BasicModel.objects.all() -            serializer_class = BasicModelSerializer -            paginate_by = 1 +    def test_use_with_ordering_filter(self): +        class MockView: +            filter_backends = (filters.OrderingFilter,) +            ordering_fields = ['username', 'created'] +            ordering = 'created' -        self.view = ListView.as_view() -        request = factory.get('/') -        response = self.view(request).render() +        request = Request(factory.get('/', {'ordering': 'username'})) +        ordering = self.pagination.get_ordering(request, [], MockView()) +        assert ordering == ('username',) -        self.assertEqual(response.status_code, status.HTTP_200_OK) +        request = Request(factory.get('/', {'ordering': '-username'})) +        ordering = self.pagination.get_ordering(request, [], MockView()) +        assert ordering == ('-username',) +        request = Request(factory.get('/', {'ordering': 'invalid'})) +        ordering = self.pagination.get_ordering(request, [], MockView()) +        assert ordering == ('created',) -# Tests for custom pagination serializers +    def test_cursor_pagination(self): +        (previous, current, next, previous_url, next_url) = self.get_pages('/') -class LinksSerializer(serializers.Serializer): -    next = pagination.NextPageField(source='*') -    prev = pagination.PreviousPageField(source='*') +        assert previous is None +        assert current == [1, 1, 1, 1, 1] +        assert next == [1, 2, 3, 4, 4] +        (previous, current, next, previous_url, next_url) = self.get_pages(next_url) -class CustomPaginationSerializer(pagination.BasePaginationSerializer): -    links = LinksSerializer(source='*')  # Takes the page object as the source -    total_results = serializers.ReadOnlyField(source='paginator.count') +        assert previous == [1, 1, 1, 1, 1] +        assert current == [1, 2, 3, 4, 4] +        assert next == [4, 4, 5, 6, 7] -    results_field = 'objects' +        (previous, current, next, previous_url, next_url) = self.get_pages(next_url) +        assert previous == [1, 2, 3, 4, 4] +        assert current == [4, 4, 5, 6, 7] +        assert next == [7, 7, 7, 7, 7] -class CustomFooSerializer(serializers.Serializer): -    foo = serializers.CharField() +        (previous, current, next, previous_url, next_url) = self.get_pages(next_url) +        assert previous == [4, 4, 4, 5, 6]  # Paging artifact +        assert current == [7, 7, 7, 7, 7] +        assert next == [7, 7, 7, 8, 9] -class CustomFooPaginationSerializer(pagination.PaginationSerializer): -    class Meta: -        object_serializer_class = CustomFooSerializer +        (previous, current, next, previous_url, next_url) = self.get_pages(next_url) +        assert previous == [7, 7, 7, 7, 7] +        assert current == [7, 7, 7, 8, 9] +        assert next == [9, 9, 9, 9, 9] -class TestCustomPaginationSerializer(TestCase): -    def setUp(self): -        objects = ['john', 'paul', 'george', 'ringo'] -        paginator = Paginator(objects, 2) -        self.page = paginator.page(1) +        (previous, current, next, previous_url, next_url) = self.get_pages(next_url) -    def test_custom_pagination_serializer(self): -        request = APIRequestFactory().get('/foobar') -        serializer = CustomPaginationSerializer( -            instance=self.page, -            context={'request': request} -        ) -        expected = { -            'links': { -                'next': 'http://testserver/foobar?page=2', -                'prev': None -            }, -            'total_results': 4, -            'objects': ['john', 'paul'] -        } -        self.assertEqual(serializer.data, expected) +        assert previous == [7, 7, 7, 8, 9] +        assert current == [9, 9, 9, 9, 9] +        assert next is None -    def test_custom_pagination_serializer_with_custom_object_serializer(self): -        objects = [ -            {'foo': 'bar'}, -            {'foo': 'spam'} -        ] -        paginator = Paginator(objects, 1) -        page = paginator.page(1) -        serializer = CustomFooPaginationSerializer(page) -        serializer.data +        (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) +        assert previous == [7, 7, 7, 7, 7] +        assert current == [7, 7, 7, 8, 9] +        assert next == [9, 9, 9, 9, 9] -class NonIntegerPage(object): +        (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) -    def __init__(self, paginator, object_list, prev_token, token, next_token): -        self.paginator = paginator -        self.object_list = object_list -        self.prev_token = prev_token -        self.token = token -        self.next_token = next_token +        assert previous == [4, 4, 5, 6, 7] +        assert current == [7, 7, 7, 7, 7] +        assert next == [8, 9, 9, 9, 9]  # Paging artifact -    def has_next(self): -        return not not self.next_token +        (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) -    def next_page_number(self): -        return self.next_token +        assert previous == [1, 2, 3, 4, 4] +        assert current == [4, 4, 5, 6, 7] +        assert next == [7, 7, 7, 7, 7] -    def has_previous(self): -        return not not self.prev_token +        (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) -    def previous_page_number(self): -        return self.prev_token +        assert previous == [1, 1, 1, 1, 1] +        assert current == [1, 2, 3, 4, 4] +        assert next == [4, 4, 5, 6, 7] +        (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) -class NonIntegerPaginator(object): +        assert previous is None +        assert current == [1, 1, 1, 1, 1] +        assert next == [1, 2, 3, 4, 4] -    def __init__(self, object_list, per_page): -        self.object_list = object_list -        self.per_page = per_page +        assert isinstance(self.pagination.to_html(), type('')) -    def count(self): -        # pretend like we don't know how many pages we have -        return None -    def page(self, token=None): -        if token: -            try: -                first = self.object_list.index(token) -            except ValueError: -                first = 0 -        else: -            first = 0 -        n = len(self.object_list) -        last = min(first + self.per_page, n) -        prev_token = self.object_list[last - (2 * self.per_page)] if first else None -        next_token = self.object_list[last] if last < n else None -        return NonIntegerPage(self, self.object_list[first:last], prev_token, token, next_token) - - -class TestNonIntegerPagination(TestCase): -    def test_custom_pagination_serializer(self): -        objects = ['john', 'paul', 'george', 'ringo'] -        paginator = NonIntegerPaginator(objects, 2) - -        request = APIRequestFactory().get('/foobar') -        serializer = CustomPaginationSerializer( -            instance=paginator.page(), -            context={'request': request} -        ) -        expected = { -            'links': { -                'next': 'http://testserver/foobar?page={0}'.format(objects[2]), -                'prev': None -            }, -            'total_results': None, -            'objects': objects[:2] -        } -        self.assertEqual(serializer.data, expected) +def test_get_displayed_page_numbers(): +    """ +    Test our contextual page display function. -        request = APIRequestFactory().get('/foobar') -        serializer = CustomPaginationSerializer( -            instance=paginator.page('george'), -            context={'request': request} -        ) -        expected = { -            'links': { -                'next': None, -                'prev': 'http://testserver/foobar?page={0}'.format(objects[0]), -            }, -            'total_results': None, -            'objects': objects[2:] -        } -        self.assertEqual(serializer.data, expected) +    This determines which pages to display in a pagination control, +    given the current page and the last page. +    """ +    displayed_page_numbers = pagination._get_displayed_page_numbers + +    # At five pages or less, all pages are displayed, always. +    assert displayed_page_numbers(1, 5) == [1, 2, 3, 4, 5] +    assert displayed_page_numbers(2, 5) == [1, 2, 3, 4, 5] +    assert displayed_page_numbers(3, 5) == [1, 2, 3, 4, 5] +    assert displayed_page_numbers(4, 5) == [1, 2, 3, 4, 5] +    assert displayed_page_numbers(5, 5) == [1, 2, 3, 4, 5] + +    # Between six and either pages we may have a single page break. +    assert displayed_page_numbers(1, 6) == [1, 2, 3, None, 6] +    assert displayed_page_numbers(2, 6) == [1, 2, 3, None, 6] +    assert displayed_page_numbers(3, 6) == [1, 2, 3, 4, 5, 6] +    assert displayed_page_numbers(4, 6) == [1, 2, 3, 4, 5, 6] +    assert displayed_page_numbers(5, 6) == [1, None, 4, 5, 6] +    assert displayed_page_numbers(6, 6) == [1, None, 4, 5, 6] + +    assert displayed_page_numbers(1, 7) == [1, 2, 3, None, 7] +    assert displayed_page_numbers(2, 7) == [1, 2, 3, None, 7] +    assert displayed_page_numbers(3, 7) == [1, 2, 3, 4, None, 7] +    assert displayed_page_numbers(4, 7) == [1, 2, 3, 4, 5, 6, 7] +    assert displayed_page_numbers(5, 7) == [1, None, 4, 5, 6, 7] +    assert displayed_page_numbers(6, 7) == [1, None, 5, 6, 7] +    assert displayed_page_numbers(7, 7) == [1, None, 5, 6, 7] + +    assert displayed_page_numbers(1, 8) == [1, 2, 3, None, 8] +    assert displayed_page_numbers(2, 8) == [1, 2, 3, None, 8] +    assert displayed_page_numbers(3, 8) == [1, 2, 3, 4, None, 8] +    assert displayed_page_numbers(4, 8) == [1, 2, 3, 4, 5, None, 8] +    assert displayed_page_numbers(5, 8) == [1, None, 4, 5, 6, 7, 8] +    assert displayed_page_numbers(6, 8) == [1, None, 5, 6, 7, 8] +    assert displayed_page_numbers(7, 8) == [1, None, 6, 7, 8] +    assert displayed_page_numbers(8, 8) == [1, None, 6, 7, 8] + +    # At nine or more pages we may have two page breaks, one on each side. +    assert displayed_page_numbers(1, 9) == [1, 2, 3, None, 9] +    assert displayed_page_numbers(2, 9) == [1, 2, 3, None, 9] +    assert displayed_page_numbers(3, 9) == [1, 2, 3, 4, None, 9] +    assert displayed_page_numbers(4, 9) == [1, 2, 3, 4, 5, None, 9] +    assert displayed_page_numbers(5, 9) == [1, None, 4, 5, 6, None, 9] +    assert displayed_page_numbers(6, 9) == [1, None, 5, 6, 7, 8, 9] +    assert displayed_page_numbers(7, 9) == [1, None, 6, 7, 8, 9] +    assert displayed_page_numbers(8, 9) == [1, None, 7, 8, 9] +    assert displayed_page_numbers(9, 9) == [1, None, 7, 8, 9] diff --git a/tests/test_parsers.py b/tests/test_parsers.py index 1d2054ac..8816065a 100644 --- a/tests/test_parsers.py +++ b/tests/test_parsers.py @@ -4,13 +4,9 @@ from __future__ import unicode_literals  from django import forms  from django.core.files.uploadhandler import MemoryFileUploadHandler  from django.test import TestCase -from django.utils import unittest  from django.utils.six.moves import StringIO -from rest_framework.compat import etree  from rest_framework.exceptions import ParseError  from rest_framework.parsers import FormParser, FileUploadParser -from rest_framework.parsers import XMLParser -import datetime  class Form(forms.Form): @@ -32,62 +28,6 @@ class TestFormParser(TestCase):          self.assertEqual(Form(data).is_valid(), True) -class TestXMLParser(TestCase): -    def setUp(self): -        self._input = StringIO( -            '<?xml version="1.0" encoding="utf-8"?>' -            '<root>' -            '<field_a>121.0</field_a>' -            '<field_b>dasd</field_b>' -            '<field_c></field_c>' -            '<field_d>2011-12-25 12:45:00</field_d>' -            '</root>' -        ) -        self._data = { -            'field_a': 121, -            'field_b': 'dasd', -            'field_c': None, -            'field_d': datetime.datetime(2011, 12, 25, 12, 45, 00) -        } -        self._complex_data_input = StringIO( -            '<?xml version="1.0" encoding="utf-8"?>' -            '<root>' -            '<creation_date>2011-12-25 12:45:00</creation_date>' -            '<sub_data_list>' -            '<list-item><sub_id>1</sub_id><sub_name>first</sub_name></list-item>' -            '<list-item><sub_id>2</sub_id><sub_name>second</sub_name></list-item>' -            '</sub_data_list>' -            '<name>name</name>' -            '</root>' -        ) -        self._complex_data = { -            "creation_date": datetime.datetime(2011, 12, 25, 12, 45, 00), -            "name": "name", -            "sub_data_list": [ -                { -                    "sub_id": 1, -                    "sub_name": "first" -                }, -                { -                    "sub_id": 2, -                    "sub_name": "second" -                } -            ] -        } - -    @unittest.skipUnless(etree, 'defusedxml not installed') -    def test_parse(self): -        parser = XMLParser() -        data = parser.parse(self._input) -        self.assertEqual(data, self._data) - -    @unittest.skipUnless(etree, 'defusedxml not installed') -    def test_complex_data_parse(self): -        parser = XMLParser() -        data = parser.parse(self._complex_data_input) -        self.assertEqual(data, self._complex_data) - -  class TestFileUploadParser(TestCase):      def setUp(self):          class MockRequest(object): diff --git a/tests/test_relations.py b/tests/test_relations.py index d478d855..fbe176e2 100644 --- a/tests/test_relations.py +++ b/tests/test_relations.py @@ -35,7 +35,7 @@ class TestPrimaryKeyRelatedField(APISimpleTestCase):          with pytest.raises(serializers.ValidationError) as excinfo:              self.field.to_internal_value(4)          msg = excinfo.value.detail[0] -        assert msg == "Invalid pk '4' - object does not exist." +        assert msg == 'Invalid pk "4" - object does not exist.'      def test_pk_related_lookup_invalid_type(self):          with pytest.raises(serializers.ValidationError) as excinfo: diff --git a/tests/test_relations_hyperlink.py b/tests/test_relations_hyperlink.py index 2230c275..33b09713 100644 --- a/tests/test_relations_hyperlink.py +++ b/tests/test_relations_hyperlink.py @@ -1,5 +1,5 @@  from __future__ import unicode_literals -from django.conf.urls import patterns, url +from django.conf.urls import url  from django.test import TestCase  from rest_framework import serializers  from rest_framework.test import APIRequestFactory @@ -16,8 +16,7 @@ def dummy_view(request, pk):      pass -urlpatterns = patterns( -    '', +urlpatterns = [      url(r'^dummyurl/(?P<pk>[0-9]+)/$', dummy_view, name='dummy-url'),      url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'),      url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'), @@ -26,7 +25,7 @@ urlpatterns = patterns(      url(r'^nullableforeignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableforeignkeysource-detail'),      url(r'^onetoonetarget/(?P<pk>[0-9]+)/$', dummy_view, name='onetoonetarget-detail'),      url(r'^nullableonetoonesource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableonetoonesource-detail'), -) +]  # ManyToMany diff --git a/tests/test_renderers.py b/tests/test_renderers.py index 4f41144e..60a08225 100644 --- a/tests/test_renderers.py +++ b/tests/test_renderers.py @@ -1,26 +1,19 @@  # -*- coding: utf-8 -*-  from __future__ import unicode_literals - -from decimal import Decimal  from django.conf.urls import patterns, url, include  from django.core.cache import cache  from django.db import models  from django.test import TestCase -from django.utils import six, unittest -from django.utils.six import BytesIO -from django.utils.six.moves import StringIO +from django.utils import six  from django.utils.translation import ugettext_lazy as _  from rest_framework import status, permissions -from rest_framework.compat import yaml, etree +from rest_framework.compat import OrderedDict  from rest_framework.response import Response  from rest_framework.views import APIView -from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \ -    XMLRenderer, JSONPRenderer, BrowsableAPIRenderer -from rest_framework.parsers import YAMLParser, XMLParser +from rest_framework.renderers import BaseRenderer, JSONRenderer, BrowsableAPIRenderer  from rest_framework.settings import api_settings  from rest_framework.test import APIRequestFactory  from collections import MutableMapping -import datetime  import json  import re @@ -112,8 +105,6 @@ urlpatterns = patterns(      url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])),      url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])),      url(r'^cache$', MockGETView.as_view()), -    url(r'^jsonp/jsonrenderer$', MockGETView.as_view(renderer_classes=[JSONRenderer, JSONPRenderer])), -    url(r'^jsonp/nojsonrenderer$', MockGETView.as_view(renderer_classes=[JSONPRenderer])),      url(r'^parseerror$', MockPOSTView.as_view(renderer_classes=[JSONRenderer, BrowsableAPIRenderer])),      url(r'^html$', HTMLView.as_view()),      url(r'^html1$', HTMLView1.as_view()), @@ -413,207 +404,6 @@ class AsciiJSONRendererTests(TestCase):          self.assertEqual(content, '{"countries":["United Kingdom","France","Espa\\u00f1a"]}'.encode('utf-8')) -class JSONPRendererTests(TestCase): -    """ -    Tests specific to the JSONP Renderer -    """ - -    urls = 'tests.test_renderers' - -    def test_without_callback_with_json_renderer(self): -        """ -        Test JSONP rendering with View JSON Renderer. -        """ -        resp = self.client.get( -            '/jsonp/jsonrenderer', -            HTTP_ACCEPT='application/javascript' -        ) -        self.assertEqual(resp.status_code, status.HTTP_200_OK) -        self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8') -        self.assertEqual( -            resp.content, -            ('callback(%s);' % _flat_repr).encode('ascii') -        ) - -    def test_without_callback_without_json_renderer(self): -        """ -        Test JSONP rendering without View JSON Renderer. -        """ -        resp = self.client.get( -            '/jsonp/nojsonrenderer', -            HTTP_ACCEPT='application/javascript' -        ) -        self.assertEqual(resp.status_code, status.HTTP_200_OK) -        self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8') -        self.assertEqual( -            resp.content, -            ('callback(%s);' % _flat_repr).encode('ascii') -        ) - -    def test_with_callback(self): -        """ -        Test JSONP rendering with callback function name. -        """ -        callback_func = 'myjsonpcallback' -        resp = self.client.get( -            '/jsonp/nojsonrenderer?callback=' + callback_func, -            HTTP_ACCEPT='application/javascript' -        ) -        self.assertEqual(resp.status_code, status.HTTP_200_OK) -        self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8') -        self.assertEqual( -            resp.content, -            ('%s(%s);' % (callback_func, _flat_repr)).encode('ascii') -        ) - - -if yaml: -    _yaml_repr = 'foo: [bar, baz]\n' - -    class YAMLRendererTests(TestCase): -        """ -        Tests specific to the YAML Renderer -        """ - -        def test_render(self): -            """ -            Test basic YAML rendering. -            """ -            obj = {'foo': ['bar', 'baz']} -            renderer = YAMLRenderer() -            content = renderer.render(obj, 'application/yaml') -            self.assertEqual(content.decode('utf-8'), _yaml_repr) - -        def test_render_and_parse(self): -            """ -            Test rendering and then parsing returns the original object. -            IE obj -> render -> parse -> obj. -            """ -            obj = {'foo': ['bar', 'baz']} - -            renderer = YAMLRenderer() -            parser = YAMLParser() - -            content = renderer.render(obj, 'application/yaml') -            data = parser.parse(BytesIO(content)) -            self.assertEqual(obj, data) - -        def test_render_decimal(self): -            """ -            Test YAML decimal rendering. -            """ -            renderer = YAMLRenderer() -            content = renderer.render({'field': Decimal('111.2')}, 'application/yaml') -            self.assertYAMLContains(content.decode('utf-8'), "field: '111.2'") - -        def assertYAMLContains(self, content, string): -            self.assertTrue(string in content, '%r not in %r' % (string, content)) - -        def test_proper_encoding(self): -            obj = {'countries': ['United Kingdom', 'France', 'España']} -            renderer = YAMLRenderer() -            content = renderer.render(obj, 'application/yaml') -            self.assertEqual(content.strip(), 'countries: [United Kingdom, France, España]'.encode('utf-8')) - - -class XMLRendererTestCase(TestCase): -    """ -    Tests specific to the XML Renderer -    """ - -    _complex_data = { -        "creation_date": datetime.datetime(2011, 12, 25, 12, 45, 00), -        "name": "name", -        "sub_data_list": [ -            { -                "sub_id": 1, -                "sub_name": "first" -            }, -            { -                "sub_id": 2, -                "sub_name": "second" -            } -        ] -    } - -    def test_render_string(self): -        """ -        Test XML rendering. -        """ -        renderer = XMLRenderer() -        content = renderer.render({'field': 'astring'}, 'application/xml') -        self.assertXMLContains(content, '<field>astring</field>') - -    def test_render_integer(self): -        """ -        Test XML rendering. -        """ -        renderer = XMLRenderer() -        content = renderer.render({'field': 111}, 'application/xml') -        self.assertXMLContains(content, '<field>111</field>') - -    def test_render_datetime(self): -        """ -        Test XML rendering. -        """ -        renderer = XMLRenderer() -        content = renderer.render({ -            'field': datetime.datetime(2011, 12, 25, 12, 45, 00) -        }, 'application/xml') -        self.assertXMLContains(content, '<field>2011-12-25 12:45:00</field>') - -    def test_render_float(self): -        """ -        Test XML rendering. -        """ -        renderer = XMLRenderer() -        content = renderer.render({'field': 123.4}, 'application/xml') -        self.assertXMLContains(content, '<field>123.4</field>') - -    def test_render_decimal(self): -        """ -        Test XML rendering. -        """ -        renderer = XMLRenderer() -        content = renderer.render({'field': Decimal('111.2')}, 'application/xml') -        self.assertXMLContains(content, '<field>111.2</field>') - -    def test_render_none(self): -        """ -        Test XML rendering. -        """ -        renderer = XMLRenderer() -        content = renderer.render({'field': None}, 'application/xml') -        self.assertXMLContains(content, '<field></field>') - -    def test_render_complex_data(self): -        """ -        Test XML rendering. -        """ -        renderer = XMLRenderer() -        content = renderer.render(self._complex_data, 'application/xml') -        self.assertXMLContains(content, '<sub_name>first</sub_name>') -        self.assertXMLContains(content, '<sub_name>second</sub_name>') - -    @unittest.skipUnless(etree, 'defusedxml not installed') -    def test_render_and_parse_complex_data(self): -        """ -        Test XML rendering. -        """ -        renderer = XMLRenderer() -        content = StringIO(renderer.render(self._complex_data, 'application/xml')) - -        parser = XMLParser() -        complex_data_out = parser.parse(content) -        error_msg = "complex data differs!IN:\n %s \n\n OUT:\n %s" % (repr(self._complex_data), repr(complex_data_out)) -        self.assertEqual(self._complex_data, complex_data_out, error_msg) - -    def assertXMLContains(self, xml, string): -        self.assertTrue(xml.startswith('<?xml version="1.0" encoding="utf-8"?>\n<root>')) -        self.assertTrue(xml.endswith('</root>')) -        self.assertTrue(string in xml, '%r not in %r' % (string, xml)) - -  # Tests for caching issue, #346  class CacheRenderTest(TestCase):      """ @@ -643,3 +433,25 @@ class CacheRenderTest(TestCase):          assert isinstance(cached_response, Response)          assert cached_response.content == response.content          assert cached_response.status_code == response.status_code + + +class TestJSONIndentationStyles: +    def test_indented(self): +        renderer = JSONRenderer() +        data = OrderedDict([('a', 1), ('b', 2)]) +        assert renderer.render(data) == b'{"a":1,"b":2}' + +    def test_compact(self): +        renderer = JSONRenderer() +        data = OrderedDict([('a', 1), ('b', 2)]) +        context = {'indent': 4} +        assert ( +            renderer.render(data, renderer_context=context) == +            b'{\n    "a": 1,\n    "b": 2\n}' +        ) + +    def test_long_form(self): +        renderer = JSONRenderer() +        renderer.compact = False +        data = OrderedDict([('a', 1), ('b', 2)]) +        assert renderer.render(data) == b'{"a": 1, "b": 2}' diff --git a/tests/test_serializer_bulk_update.py b/tests/test_serializer_bulk_update.py index fb881a75..bc955b2e 100644 --- a/tests/test_serializer_bulk_update.py +++ b/tests/test_serializer_bulk_update.py @@ -101,7 +101,7 @@ class BulkCreateSerializerTests(TestCase):          serializer = self.BookSerializer(data=data, many=True)          self.assertEqual(serializer.is_valid(), False) -        expected_errors = {'non_field_errors': ['Expected a list of items but got type `int`.']} +        expected_errors = {'non_field_errors': ['Expected a list of items but got type "int".']}          self.assertEqual(serializer.errors, expected_errors) @@ -118,6 +118,6 @@ class BulkCreateSerializerTests(TestCase):          serializer = self.BookSerializer(data=data, many=True)          self.assertEqual(serializer.is_valid(), False) -        expected_errors = {'non_field_errors': ['Expected a list of items but got type `dict`.']} +        expected_errors = {'non_field_errors': ['Expected a list of items but got type "dict".']}          self.assertEqual(serializer.errors, expected_errors) diff --git a/tests/test_templatetags.py b/tests/test_templatetags.py index b04a937e..0cee91f1 100644 --- a/tests/test_templatetags.py +++ b/tests/test_templatetags.py @@ -54,7 +54,7 @@ class Issue1386Tests(TestCase):  class URLizerTests(TestCase):      """ -    Test if both JSON and YAML URLs are transformed into links well +    Test if JSON URLs are transformed into links well      """      def _urlize_dict_check(self, data):          """ @@ -73,14 +73,3 @@ class URLizerTests(TestCase):          data['"foo_set": [\n    "http://api/foos/1/"\n], '] = \              '"foo_set": [\n    "<a href="http://api/foos/1/">http://api/foos/1/</a>"\n], '          self._urlize_dict_check(data) - -    def test_yaml_with_url(self): -        """ -        Test if YAML URLs are transformed into links well -        """ -        data = {} -        data['''{users: 'http://api/users/'}'''] = \ -            '''{users: '<a href="http://api/users/">http://api/users/</a>'}''' -        data['''foo_set: ['http://api/foos/1/']'''] = \ -            '''foo_set: ['<a href="http://api/foos/1/">http://api/foos/1/</a>']''' -        self._urlize_dict_check(data) diff --git a/tests/test_versioning.py b/tests/test_versioning.py new file mode 100644 index 00000000..90ad8afd --- /dev/null +++ b/tests/test_versioning.py @@ -0,0 +1,264 @@ +from .utils import UsingURLPatterns +from django.conf.urls import include, url +from rest_framework import serializers +from rest_framework import status, versioning +from rest_framework.decorators import APIView +from rest_framework.response import Response +from rest_framework.reverse import reverse +from rest_framework.test import APIRequestFactory, APITestCase +from rest_framework.versioning import NamespaceVersioning +import pytest + + +class RequestVersionView(APIView): +    def get(self, request, *args, **kwargs): +        return Response({'version': request.version}) + + +class ReverseView(APIView): +    def get(self, request, *args, **kwargs): +        return Response({'url': reverse('another', request=request)}) + + +class RequestInvalidVersionView(APIView): +    def determine_version(self, request, *args, **kwargs): +        scheme = self.versioning_class() +        scheme.allowed_versions = ('v1', 'v2') +        return (scheme.determine_version(request, *args, **kwargs), scheme) + +    def get(self, request, *args, **kwargs): +        return Response({'version': request.version}) + + +factory = APIRequestFactory() + + +def dummy_view(request): +    pass + + +def dummy_pk_view(request, pk): +    pass + + +class TestRequestVersion: +    def test_unversioned(self): +        view = RequestVersionView.as_view() + +        request = factory.get('/endpoint/') +        response = view(request) +        assert response.data == {'version': None} + +    def test_query_param_versioning(self): +        scheme = versioning.QueryParameterVersioning +        view = RequestVersionView.as_view(versioning_class=scheme) + +        request = factory.get('/endpoint/?version=1.2.3') +        response = view(request) +        assert response.data == {'version': '1.2.3'} + +        request = factory.get('/endpoint/') +        response = view(request) +        assert response.data == {'version': None} + +    def test_host_name_versioning(self): +        scheme = versioning.HostNameVersioning +        view = RequestVersionView.as_view(versioning_class=scheme) + +        request = factory.get('/endpoint/', HTTP_HOST='v1.example.org') +        response = view(request) +        assert response.data == {'version': 'v1'} + +        request = factory.get('/endpoint/') +        response = view(request) +        assert response.data == {'version': None} + +    def test_accept_header_versioning(self): +        scheme = versioning.AcceptHeaderVersioning +        view = RequestVersionView.as_view(versioning_class=scheme) + +        request = factory.get('/endpoint/', HTTP_ACCEPT='application/json; version=1.2.3') +        response = view(request) +        assert response.data == {'version': '1.2.3'} + +        request = factory.get('/endpoint/', HTTP_ACCEPT='application/json') +        response = view(request) +        assert response.data == {'version': None} + +    def test_url_path_versioning(self): +        scheme = versioning.URLPathVersioning +        view = RequestVersionView.as_view(versioning_class=scheme) + +        request = factory.get('/1.2.3/endpoint/') +        response = view(request, version='1.2.3') +        assert response.data == {'version': '1.2.3'} + +        request = factory.get('/endpoint/') +        response = view(request) +        assert response.data == {'version': None} + +    def test_namespace_versioning(self): +        class FakeResolverMatch: +            namespace = 'v1' + +        scheme = versioning.NamespaceVersioning +        view = RequestVersionView.as_view(versioning_class=scheme) + +        request = factory.get('/v1/endpoint/') +        request.resolver_match = FakeResolverMatch +        response = view(request, version='v1') +        assert response.data == {'version': 'v1'} + +        request = factory.get('/endpoint/') +        response = view(request) +        assert response.data == {'version': None} + + +class TestURLReversing(UsingURLPatterns, APITestCase): +    included = [ +        url(r'^namespaced/$', dummy_view, name='another'), +        url(r'^example/(?P<pk>\d+)/$', dummy_pk_view, name='example-detail') +    ] + +    urlpatterns = [ +        url(r'^v1/', include(included, namespace='v1')), +        url(r'^another/$', dummy_view, name='another'), +        url(r'^(?P<version>[^/]+)/another/$', dummy_view, name='another'), +    ] + +    def test_reverse_unversioned(self): +        view = ReverseView.as_view() + +        request = factory.get('/endpoint/') +        response = view(request) +        assert response.data == {'url': 'http://testserver/another/'} + +    def test_reverse_query_param_versioning(self): +        scheme = versioning.QueryParameterVersioning +        view = ReverseView.as_view(versioning_class=scheme) + +        request = factory.get('/endpoint/?version=v1') +        response = view(request) +        assert response.data == {'url': 'http://testserver/another/?version=v1'} + +        request = factory.get('/endpoint/') +        response = view(request) +        assert response.data == {'url': 'http://testserver/another/'} + +    def test_reverse_host_name_versioning(self): +        scheme = versioning.HostNameVersioning +        view = ReverseView.as_view(versioning_class=scheme) + +        request = factory.get('/endpoint/', HTTP_HOST='v1.example.org') +        response = view(request) +        assert response.data == {'url': 'http://v1.example.org/another/'} + +        request = factory.get('/endpoint/') +        response = view(request) +        assert response.data == {'url': 'http://testserver/another/'} + +    def test_reverse_url_path_versioning(self): +        scheme = versioning.URLPathVersioning +        view = ReverseView.as_view(versioning_class=scheme) + +        request = factory.get('/v1/endpoint/') +        response = view(request, version='v1') +        assert response.data == {'url': 'http://testserver/v1/another/'} + +        request = factory.get('/endpoint/') +        response = view(request) +        assert response.data == {'url': 'http://testserver/another/'} + +    def test_reverse_namespace_versioning(self): +        class FakeResolverMatch: +            namespace = 'v1' + +        scheme = versioning.NamespaceVersioning +        view = ReverseView.as_view(versioning_class=scheme) + +        request = factory.get('/v1/endpoint/') +        request.resolver_match = FakeResolverMatch +        response = view(request, version='v1') +        assert response.data == {'url': 'http://testserver/v1/namespaced/'} + +        request = factory.get('/endpoint/') +        response = view(request) +        assert response.data == {'url': 'http://testserver/another/'} + + +class TestInvalidVersion: +    def test_invalid_query_param_versioning(self): +        scheme = versioning.QueryParameterVersioning +        view = RequestInvalidVersionView.as_view(versioning_class=scheme) + +        request = factory.get('/endpoint/?version=v3') +        response = view(request) +        assert response.status_code == status.HTTP_404_NOT_FOUND + +    def test_invalid_host_name_versioning(self): +        scheme = versioning.HostNameVersioning +        view = RequestInvalidVersionView.as_view(versioning_class=scheme) + +        request = factory.get('/endpoint/', HTTP_HOST='v3.example.org') +        response = view(request) +        assert response.status_code == status.HTTP_404_NOT_FOUND + +    def test_invalid_accept_header_versioning(self): +        scheme = versioning.AcceptHeaderVersioning +        view = RequestInvalidVersionView.as_view(versioning_class=scheme) + +        request = factory.get('/endpoint/', HTTP_ACCEPT='application/json; version=v3') +        response = view(request) +        assert response.status_code == status.HTTP_406_NOT_ACCEPTABLE + +    def test_invalid_url_path_versioning(self): +        scheme = versioning.URLPathVersioning +        view = RequestInvalidVersionView.as_view(versioning_class=scheme) + +        request = factory.get('/v3/endpoint/') +        response = view(request, version='v3') +        assert response.status_code == status.HTTP_404_NOT_FOUND + +    def test_invalid_namespace_versioning(self): +        class FakeResolverMatch: +            namespace = 'v3' + +        scheme = versioning.NamespaceVersioning +        view = RequestInvalidVersionView.as_view(versioning_class=scheme) + +        request = factory.get('/v3/endpoint/') +        request.resolver_match = FakeResolverMatch +        response = view(request, version='v3') +        assert response.status_code == status.HTTP_404_NOT_FOUND + + +class TestHyperlinkedRelatedField(UsingURLPatterns, APITestCase): +    included = [ +        url(r'^namespaced/(?P<pk>\d+)/$', dummy_view, name='namespaced'), +    ] + +    urlpatterns = [ +        url(r'^v1/', include(included, namespace='v1')), +        url(r'^v2/', include(included, namespace='v2')) +    ] + +    def setUp(self): +        super(TestHyperlinkedRelatedField, self).setUp() + +        class MockQueryset(object): +            def get(self, pk): +                return 'object %s' % pk + +        self.field = serializers.HyperlinkedRelatedField( +            view_name='namespaced', +            queryset=MockQueryset() +        ) +        request = factory.get('/') +        request.versioning_scheme = NamespaceVersioning() +        request.version = 'v1' +        self.field._context = {'request': request} + +    def test_bug_2489(self): +        assert self.field.to_internal_value('/v1/namespaced/3/') == 'object 3' +        with pytest.raises(serializers.ValidationError): +            self.field.to_internal_value('/v2/namespaced/3/') diff --git a/tests/utils.py b/tests/utils.py index 5b2d7586..b9034996 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,6 +2,30 @@ from django.core.exceptions import ObjectDoesNotExist  from django.core.urlresolvers import NoReverseMatch +class UsingURLPatterns(object): +    """ +    Isolates URL patterns used during testing on the test class itself. +    For example: + +    class MyTestCase(UsingURLPatterns, TestCase): +        urlpatterns = [ +            ... +        ] + +        def test_something(self): +            ... +    """ +    urls = __name__ + +    def setUp(self): +        global urlpatterns +        urlpatterns = self.urlpatterns + +    def tearDown(self): +        global urlpatterns +        urlpatterns = [] + +  class MockObject(object):      def __init__(self, **kwargs):          self._kwargs = kwargs | 
