diff options
Diffstat (limited to 'tests')
54 files changed, 11322 insertions, 0 deletions
| diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/tests/__init__.py diff --git a/tests/browsable_api/__init__.py b/tests/browsable_api/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/tests/browsable_api/__init__.py diff --git a/tests/browsable_api/auth_urls.py b/tests/browsable_api/auth_urls.py new file mode 100644 index 00000000..97bc1036 --- /dev/null +++ b/tests/browsable_api/auth_urls.py @@ -0,0 +1,11 @@ +from __future__ import unicode_literals +from django.conf.urls import patterns, url, include + +from .views import MockView + + +urlpatterns = patterns( +    '', +    (r'^$', MockView.as_view()), +    url(r'^auth/', include('rest_framework.urls', namespace='rest_framework')), +) diff --git a/tests/browsable_api/no_auth_urls.py b/tests/browsable_api/no_auth_urls.py new file mode 100644 index 00000000..5e3604a6 --- /dev/null +++ b/tests/browsable_api/no_auth_urls.py @@ -0,0 +1,9 @@ +from __future__ import unicode_literals +from django.conf.urls import patterns + +from .views import MockView + +urlpatterns = patterns( +    '', +    (r'^$', MockView.as_view()), +) diff --git a/tests/browsable_api/test_browsable_api.py b/tests/browsable_api/test_browsable_api.py new file mode 100644 index 00000000..5f264783 --- /dev/null +++ b/tests/browsable_api/test_browsable_api.py @@ -0,0 +1,65 @@ +from __future__ import unicode_literals +from django.contrib.auth.models import User +from django.test import TestCase + +from rest_framework.test import APIClient + + +class DropdownWithAuthTests(TestCase): +    """Tests correct dropdown behaviour with Auth views enabled.""" + +    urls = 'tests.browsable_api.auth_urls' + +    def setUp(self): +        self.client = APIClient(enforce_csrf_checks=True) +        self.username = 'john' +        self.email = 'lennon@thebeatles.com' +        self.password = 'password' +        self.user = User.objects.create_user(self.username, self.email, self.password) + +    def tearDown(self): +        self.client.logout() + +    def test_name_shown_when_logged_in(self): +        self.client.login(username=self.username, password=self.password) +        response = self.client.get('/') +        self.assertContains(response, 'john') + +    def test_logout_shown_when_logged_in(self): +        self.client.login(username=self.username, password=self.password) +        response = self.client.get('/') +        self.assertContains(response, '>Log out<') + +    def test_login_shown_when_logged_out(self): +        response = self.client.get('/') +        self.assertContains(response, '>Log in<') + + +class NoDropdownWithoutAuthTests(TestCase): +    """Tests correct dropdown behaviour with Auth views NOT enabled.""" + +    urls = 'tests.browsable_api.no_auth_urls' + +    def setUp(self): +        self.client = APIClient(enforce_csrf_checks=True) +        self.username = 'john' +        self.email = 'lennon@thebeatles.com' +        self.password = 'password' +        self.user = User.objects.create_user(self.username, self.email, self.password) + +    def tearDown(self): +        self.client.logout() + +    def test_name_shown_when_logged_in(self): +        self.client.login(username=self.username, password=self.password) +        response = self.client.get('/') +        self.assertContains(response, 'john') + +    def test_dropdown_not_shown_when_logged_in(self): +        self.client.login(username=self.username, password=self.password) +        response = self.client.get('/') +        self.assertNotContains(response, '<li class="dropdown">') + +    def test_dropdown_not_shown_when_logged_out(self): +        response = self.client.get('/') +        self.assertNotContains(response, '<li class="dropdown">') diff --git a/tests/browsable_api/views.py b/tests/browsable_api/views.py new file mode 100644 index 00000000..000f4e80 --- /dev/null +++ b/tests/browsable_api/views.py @@ -0,0 +1,15 @@ +from __future__ import unicode_literals + +from rest_framework.views import APIView +from rest_framework import authentication +from rest_framework import renderers +from rest_framework.response import Response + + +class MockView(APIView): + +    authentication_classes = (authentication.SessionAuthentication,) +    renderer_classes = (renderers.BrowsableAPIRenderer,) + +    def get(self, request): +        return Response({'a': 1, 'b': 2, 'c': 3}) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..44ed070b --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,66 @@ +def pytest_configure(): +    from django.conf import settings + +    settings.configure( +        DEBUG_PROPAGATE_EXCEPTIONS=True, +        DATABASES={'default': {'ENGINE': 'django.db.backends.sqlite3', +                               'NAME': ':memory:'}}, +        SITE_ID=1, +        SECRET_KEY='not very secret in tests', +        USE_I18N=True, +        USE_L10N=True, +        STATIC_URL='/static/', +        ROOT_URLCONF='tests.urls', +        TEMPLATE_LOADERS=( +            'django.template.loaders.filesystem.Loader', +            'django.template.loaders.app_directories.Loader', +        ), +        MIDDLEWARE_CLASSES=( +            'django.middleware.common.CommonMiddleware', +            'django.contrib.sessions.middleware.SessionMiddleware', +            'django.middleware.csrf.CsrfViewMiddleware', +            'django.contrib.auth.middleware.AuthenticationMiddleware', +            'django.contrib.messages.middleware.MessageMiddleware', +        ), +        INSTALLED_APPS=( +            'django.contrib.auth', +            'django.contrib.contenttypes', +            'django.contrib.sessions', +            'django.contrib.sites', +            'django.contrib.messages', +            'django.contrib.staticfiles', + +            'rest_framework', +            'rest_framework.authtoken', +            'tests', +        ), +        PASSWORD_HASHERS=( +            'django.contrib.auth.hashers.SHA1PasswordHasher', +            'django.contrib.auth.hashers.PBKDF2PasswordHasher', +            'django.contrib.auth.hashers.PBKDF2SHA1PasswordHasher', +            'django.contrib.auth.hashers.BCryptPasswordHasher', +            'django.contrib.auth.hashers.MD5PasswordHasher', +            'django.contrib.auth.hashers.CryptPasswordHasher', +        ), +    ) + +    # guardian is optional +    try: +        import guardian  # NOQA +    except ImportError: +        pass +    else: +        settings.ANONYMOUS_USER_ID = -1 +        settings.AUTHENTICATION_BACKENDS = ( +            'django.contrib.auth.backends.ModelBackend', +            'guardian.backends.ObjectPermissionBackend', +        ) +        settings.INSTALLED_APPS += ( +            'guardian', +        ) + +    try: +        import django +        django.setup() +    except AttributeError: +        pass diff --git a/tests/description.py b/tests/description.py new file mode 100644 index 00000000..b46d7f54 --- /dev/null +++ b/tests/description.py @@ -0,0 +1,26 @@ +# -- coding: utf-8 -- + +# Apparently there is a python 2.6 issue where docstrings of imported view classes +# do not retain their encoding information even if a module has a proper +# encoding declaration at the top of its source file. Therefore for tests +# to catch unicode related errors, a mock view has to be declared in a separate +# module. + +from rest_framework.views import APIView + + +# test strings snatched from http://www.columbia.edu/~fdc/utf8/, +# http://winrus.com/utf8-jap.htm and memory +UTF8_TEST_DOCSTRING = ( +    'zażółć gęślą jaźń' +    'Sîne klâwen durh die wolken sint geslagen' +    'Τη γλώσσα μου έδωσαν ελληνική' +    'யாமறிந்த மொழிகளிலே தமிழ்மொழி' +    'На берегу пустынных волн' +    'てすと' +    'アイウエオカキクケコサシスセソタチツテ' +) + + +class ViewWithNonASCIICharactersInDocstring(APIView): +    __doc__ = UTF8_TEST_DOCSTRING diff --git a/tests/models.py b/tests/models.py new file mode 100644 index 00000000..456b0a0b --- /dev/null +++ b/tests/models.py @@ -0,0 +1,70 @@ +from __future__ import unicode_literals +from django.db import models +from django.utils.translation import ugettext_lazy as _ + + +class RESTFrameworkModel(models.Model): +    """ +    Base for test models that sets app_label, so they play nicely. +    """ + +    class Meta: +        app_label = 'tests' +        abstract = True + + +class BasicModel(RESTFrameworkModel): +    text = models.CharField(max_length=100, verbose_name=_("Text comes here"), help_text=_("Text description.")) + + +class BaseFilterableItem(RESTFrameworkModel): +    text = models.CharField(max_length=100) + +    class Meta: +        abstract = True + + +class FilterableItem(BaseFilterableItem): +    decimal = models.DecimalField(max_digits=4, decimal_places=2) +    date = models.DateField() + + +# Models for relations tests +# ManyToMany +class ManyToManyTarget(RESTFrameworkModel): +    name = models.CharField(max_length=100) + + +class ManyToManySource(RESTFrameworkModel): +    name = models.CharField(max_length=100) +    targets = models.ManyToManyField(ManyToManyTarget, related_name='sources') + + +# ForeignKey +class ForeignKeyTarget(RESTFrameworkModel): +    name = models.CharField(max_length=100) + + +class ForeignKeySource(RESTFrameworkModel): +    name = models.CharField(max_length=100) +    target = models.ForeignKey(ForeignKeyTarget, related_name='sources', +                               help_text='Target', verbose_name='Target') + + +# Nullable ForeignKey +class NullableForeignKeySource(RESTFrameworkModel): +    name = models.CharField(max_length=100) +    target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True, +                               related_name='nullable_sources', +                               verbose_name='Optional target object') + + +# OneToOne +class OneToOneTarget(RESTFrameworkModel): +    name = models.CharField(max_length=100) + + +class NullableOneToOneSource(RESTFrameworkModel): +    name = models.CharField(max_length=100) +    target = models.OneToOneField(OneToOneTarget, null=True, blank=True, +                                  related_name='nullable_source') diff --git a/tests/test_authentication.py b/tests/test_authentication.py new file mode 100644 index 00000000..91e49f9d --- /dev/null +++ b/tests/test_authentication.py @@ -0,0 +1,288 @@ +from __future__ import unicode_literals +from django.conf.urls import patterns, url, include +from django.contrib.auth.models import User +from django.http import HttpResponse +from django.test import TestCase +from django.utils import six +from rest_framework import HTTP_HEADER_ENCODING +from rest_framework import exceptions +from rest_framework import permissions +from rest_framework import renderers +from rest_framework.response import Response +from rest_framework import status +from rest_framework.authentication import ( +    BaseAuthentication, +    TokenAuthentication, +    BasicAuthentication, +    SessionAuthentication, +) +from rest_framework.authtoken.models import Token +from rest_framework.test import APIRequestFactory, APIClient +from rest_framework.views import APIView +import base64 + +factory = APIRequestFactory() + + +class MockView(APIView): +    permission_classes = (permissions.IsAuthenticated,) + +    def get(self, request): +        return HttpResponse({'a': 1, 'b': 2, 'c': 3}) + +    def post(self, request): +        return HttpResponse({'a': 1, 'b': 2, 'c': 3}) + +    def put(self, request): +        return HttpResponse({'a': 1, 'b': 2, 'c': 3}) + + +urlpatterns = patterns( +    '', +    (r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])), +    (r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])), +    (r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])), +    (r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'), +    url(r'^auth/', include('rest_framework.urls', namespace='rest_framework')) +) + + +class BasicAuthTests(TestCase): +    """Basic authentication""" +    urls = 'tests.test_authentication' + +    def setUp(self): +        self.csrf_client = APIClient(enforce_csrf_checks=True) +        self.username = 'john' +        self.email = 'lennon@thebeatles.com' +        self.password = 'password' +        self.user = User.objects.create_user(self.username, self.email, self.password) + +    def test_post_form_passing_basic_auth(self): +        """Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF""" +        credentials = ('%s:%s' % (self.username, self.password)) +        base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING) +        auth = 'Basic %s' % base64_credentials +        response = self.csrf_client.post('/basic/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) +        self.assertEqual(response.status_code, status.HTTP_200_OK) + +    def test_post_json_passing_basic_auth(self): +        """Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF""" +        credentials = ('%s:%s' % (self.username, self.password)) +        base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING) +        auth = 'Basic %s' % base64_credentials +        response = self.csrf_client.post('/basic/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth) +        self.assertEqual(response.status_code, status.HTTP_200_OK) + +    def test_post_form_failing_basic_auth(self): +        """Ensure POSTing form over basic auth without correct credentials fails""" +        response = self.csrf_client.post('/basic/', {'example': 'example'}) +        self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + +    def test_post_json_failing_basic_auth(self): +        """Ensure POSTing json over basic auth without correct credentials fails""" +        response = self.csrf_client.post('/basic/', {'example': 'example'}, format='json') +        self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) +        self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"') + + +class SessionAuthTests(TestCase): +    """User session authentication""" +    urls = 'tests.test_authentication' + +    def setUp(self): +        self.csrf_client = APIClient(enforce_csrf_checks=True) +        self.non_csrf_client = APIClient(enforce_csrf_checks=False) +        self.username = 'john' +        self.email = 'lennon@thebeatles.com' +        self.password = 'password' +        self.user = User.objects.create_user(self.username, self.email, self.password) + +    def tearDown(self): +        self.csrf_client.logout() + +    def test_login_view_renders_on_get(self): +        """ +        Ensure the login template renders for a basic GET. + +        cf. [#1810](https://github.com/tomchristie/django-rest-framework/pull/1810) +        """ +        response = self.csrf_client.get('/auth/login/') +        self.assertContains(response, '<label for="id_username">Username:</label>') + +    def test_post_form_session_auth_failing_csrf(self): +        """ +        Ensure POSTing form over session authentication without CSRF token fails. +        """ +        self.csrf_client.login(username=self.username, password=self.password) +        response = self.csrf_client.post('/session/', {'example': 'example'}) +        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + +    def test_post_form_session_auth_passing(self): +        """ +        Ensure POSTing form over session authentication with logged in user and CSRF token passes. +        """ +        self.non_csrf_client.login(username=self.username, password=self.password) +        response = self.non_csrf_client.post('/session/', {'example': 'example'}) +        self.assertEqual(response.status_code, status.HTTP_200_OK) + +    def test_put_form_session_auth_passing(self): +        """ +        Ensure PUTting form over session authentication with logged in user and CSRF token passes. +        """ +        self.non_csrf_client.login(username=self.username, password=self.password) +        response = self.non_csrf_client.put('/session/', {'example': 'example'}) +        self.assertEqual(response.status_code, status.HTTP_200_OK) + +    def test_post_form_session_auth_failing(self): +        """ +        Ensure POSTing form over session authentication without logged in user fails. +        """ +        response = self.csrf_client.post('/session/', {'example': 'example'}) +        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + +class TokenAuthTests(TestCase): +    """Token authentication""" +    urls = 'tests.test_authentication' + +    def setUp(self): +        self.csrf_client = APIClient(enforce_csrf_checks=True) +        self.username = 'john' +        self.email = 'lennon@thebeatles.com' +        self.password = 'password' +        self.user = User.objects.create_user(self.username, self.email, self.password) + +        self.key = 'abcd1234' +        self.token = Token.objects.create(key=self.key, user=self.user) + +    def test_post_form_passing_token_auth(self): +        """Ensure POSTing json over token auth with correct credentials passes and does not require CSRF""" +        auth = 'Token ' + self.key +        response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) +        self.assertEqual(response.status_code, status.HTTP_200_OK) + +    def test_post_json_passing_token_auth(self): +        """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF""" +        auth = "Token " + self.key +        response = self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth) +        self.assertEqual(response.status_code, status.HTTP_200_OK) + +    def test_post_json_makes_one_db_query(self): +        """Ensure that authenticating a user using a token performs only one DB query""" +        auth = "Token " + self.key + +        def func_to_test(): +            return self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth) + +        self.assertNumQueries(1, func_to_test) + +    def test_post_form_failing_token_auth(self): +        """Ensure POSTing form over token auth without correct credentials fails""" +        response = self.csrf_client.post('/token/', {'example': 'example'}) +        self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + +    def test_post_json_failing_token_auth(self): +        """Ensure POSTing json over token auth without correct credentials fails""" +        response = self.csrf_client.post('/token/', {'example': 'example'}, format='json') +        self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + +    def test_token_has_auto_assigned_key_if_none_provided(self): +        """Ensure creating a token with no key will auto-assign a key""" +        self.token.delete() +        token = Token.objects.create(user=self.user) +        self.assertTrue(bool(token.key)) + +    def test_generate_key_returns_string(self): +        """Ensure generate_key returns a string""" +        token = Token() +        key = token.generate_key() +        self.assertTrue(isinstance(key, six.string_types)) + +    def test_token_login_json(self): +        """Ensure token login view using JSON POST works.""" +        client = APIClient(enforce_csrf_checks=True) +        response = client.post('/auth-token/', +                               {'username': self.username, 'password': self.password}, format='json') +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data['token'], self.key) + +    def test_token_login_json_bad_creds(self): +        """Ensure token login view using JSON POST fails if bad credentials are used.""" +        client = APIClient(enforce_csrf_checks=True) +        response = client.post('/auth-token/', +                               {'username': self.username, 'password': "badpass"}, format='json') +        self.assertEqual(response.status_code, 400) + +    def test_token_login_json_missing_fields(self): +        """Ensure token login view using JSON POST fails if missing fields.""" +        client = APIClient(enforce_csrf_checks=True) +        response = client.post('/auth-token/', +                               {'username': self.username}, format='json') +        self.assertEqual(response.status_code, 400) + +    def test_token_login_form(self): +        """Ensure token login view using form POST works.""" +        client = APIClient(enforce_csrf_checks=True) +        response = client.post('/auth-token/', +                               {'username': self.username, 'password': self.password}) +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data['token'], self.key) + + +class IncorrectCredentialsTests(TestCase): +    def test_incorrect_credentials(self): +        """ +        If a request contains bad authentication credentials, then +        authentication should run and error, even if no permissions +        are set on the view. +        """ +        class IncorrectCredentialsAuth(BaseAuthentication): +            def authenticate(self, request): +                raise exceptions.AuthenticationFailed('Bad credentials') + +        request = factory.get('/') +        view = MockView.as_view( +            authentication_classes=(IncorrectCredentialsAuth,), +            permission_classes=() +        ) +        response = view(request) +        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) +        self.assertEqual(response.data, {'detail': 'Bad credentials'}) + + +class FailingAuthAccessedInRenderer(TestCase): +    def setUp(self): +        class AuthAccessingRenderer(renderers.BaseRenderer): +            media_type = 'text/plain' +            format = 'txt' + +            def render(self, data, media_type=None, renderer_context=None): +                request = renderer_context['request'] +                if request.user.is_authenticated(): +                    return b'authenticated' +                return b'not authenticated' + +        class FailingAuth(BaseAuthentication): +            def authenticate(self, request): +                raise exceptions.AuthenticationFailed('authentication failed') + +        class ExampleView(APIView): +            authentication_classes = (FailingAuth,) +            renderer_classes = (AuthAccessingRenderer,) + +            def get(self, request): +                return Response({'foo': 'bar'}) + +        self.view = ExampleView.as_view() + +    def test_failing_auth_accessed_in_renderer(self): +        """ +        When authentication fails the renderer should still be able to access +        `request.user` without raising an exception. Particularly relevant +        to HTML responses that might reasonably access `request.user`. +        """ +        request = factory.get('/') +        response = self.view(request) +        content = response.render().content +        self.assertEqual(content, b'not authenticated') diff --git a/tests/test_bound_fields.py b/tests/test_bound_fields.py new file mode 100644 index 00000000..bfc54b23 --- /dev/null +++ b/tests/test_bound_fields.py @@ -0,0 +1,69 @@ +from rest_framework import serializers + + +class TestSimpleBoundField: +    def test_empty_bound_field(self): +        class ExampleSerializer(serializers.Serializer): +            text = serializers.CharField(max_length=100) +            amount = serializers.IntegerField() + +        serializer = ExampleSerializer() + +        assert serializer['text'].value == '' +        assert serializer['text'].errors is None +        assert serializer['text'].name == 'text' +        assert serializer['amount'].value is None +        assert serializer['amount'].errors is None +        assert serializer['amount'].name == 'amount' + +    def test_populated_bound_field(self): +        class ExampleSerializer(serializers.Serializer): +            text = serializers.CharField(max_length=100) +            amount = serializers.IntegerField() + +        serializer = ExampleSerializer(data={'text': 'abc', 'amount': 123}) +        assert serializer.is_valid() +        assert serializer['text'].value == 'abc' +        assert serializer['text'].errors is None +        assert serializer['text'].name == 'text' +        assert serializer['amount'].value is 123 +        assert serializer['amount'].errors is None +        assert serializer['amount'].name == 'amount' + +    def test_error_bound_field(self): +        class ExampleSerializer(serializers.Serializer): +            text = serializers.CharField(max_length=100) +            amount = serializers.IntegerField() + +        serializer = ExampleSerializer(data={'text': 'x' * 1000, 'amount': 123}) +        serializer.is_valid() + +        assert serializer['text'].value == 'x' * 1000 +        assert serializer['text'].errors == ['Ensure this field has no more than 100 characters.'] +        assert serializer['text'].name == 'text' +        assert serializer['amount'].value is 123 +        assert serializer['amount'].errors is None +        assert serializer['amount'].name == 'amount' + + +class TestNestedBoundField: +    def test_nested_empty_bound_field(self): +        class Nested(serializers.Serializer): +            more_text = serializers.CharField(max_length=100) +            amount = serializers.IntegerField() + +        class ExampleSerializer(serializers.Serializer): +            text = serializers.CharField(max_length=100) +            nested = Nested() + +        serializer = ExampleSerializer() + +        assert serializer['text'].value == '' +        assert serializer['text'].errors is None +        assert serializer['text'].name == 'text' +        assert serializer['nested']['more_text'].value == '' +        assert serializer['nested']['more_text'].errors is None +        assert serializer['nested']['more_text'].name == 'nested.more_text' +        assert serializer['nested']['amount'].value is None +        assert serializer['nested']['amount'].errors is None +        assert serializer['nested']['amount'].name == 'nested.amount' diff --git a/tests/test_decorators.py b/tests/test_decorators.py new file mode 100644 index 00000000..195f0ba3 --- /dev/null +++ b/tests/test_decorators.py @@ -0,0 +1,157 @@ +from __future__ import unicode_literals +from django.test import TestCase +from rest_framework import status +from rest_framework.authentication import BasicAuthentication +from rest_framework.parsers import JSONParser +from rest_framework.permissions import IsAuthenticated +from rest_framework.response import Response +from rest_framework.renderers import JSONRenderer +from rest_framework.test import APIRequestFactory +from rest_framework.throttling import UserRateThrottle +from rest_framework.views import APIView +from rest_framework.decorators import ( +    api_view, +    renderer_classes, +    parser_classes, +    authentication_classes, +    throttle_classes, +    permission_classes, +) + + +class DecoratorTestCase(TestCase): + +    def setUp(self): +        self.factory = APIRequestFactory() + +    def _finalize_response(self, request, response, *args, **kwargs): +        response.request = request +        return APIView.finalize_response(self, request, response, *args, **kwargs) + +    def test_api_view_incorrect(self): +        """ +        If @api_view is not applied correct, we should raise an assertion. +        """ + +        @api_view +        def view(request): +            return Response() + +        request = self.factory.get('/') +        self.assertRaises(AssertionError, view, request) + +    def test_api_view_incorrect_arguments(self): +        """ +        If @api_view is missing arguments, we should raise an assertion. +        """ + +        with self.assertRaises(AssertionError): +            @api_view('GET') +            def view(request): +                return Response() + +    def test_calling_method(self): + +        @api_view(['GET']) +        def view(request): +            return Response({}) + +        request = self.factory.get('/') +        response = view(request) +        self.assertEqual(response.status_code, status.HTTP_200_OK) + +        request = self.factory.post('/') +        response = view(request) +        self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) + +    def test_calling_put_method(self): + +        @api_view(['GET', 'PUT']) +        def view(request): +            return Response({}) + +        request = self.factory.put('/') +        response = view(request) +        self.assertEqual(response.status_code, status.HTTP_200_OK) + +        request = self.factory.post('/') +        response = view(request) +        self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) + +    def test_calling_patch_method(self): + +        @api_view(['GET', 'PATCH']) +        def view(request): +            return Response({}) + +        request = self.factory.patch('/') +        response = view(request) +        self.assertEqual(response.status_code, status.HTTP_200_OK) + +        request = self.factory.post('/') +        response = view(request) +        self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) + +    def test_renderer_classes(self): + +        @api_view(['GET']) +        @renderer_classes([JSONRenderer]) +        def view(request): +            return Response({}) + +        request = self.factory.get('/') +        response = view(request) +        self.assertTrue(isinstance(response.accepted_renderer, JSONRenderer)) + +    def test_parser_classes(self): + +        @api_view(['GET']) +        @parser_classes([JSONParser]) +        def view(request): +            self.assertEqual(len(request.parsers), 1) +            self.assertTrue(isinstance(request.parsers[0], +                                       JSONParser)) +            return Response({}) + +        request = self.factory.get('/') +        view(request) + +    def test_authentication_classes(self): + +        @api_view(['GET']) +        @authentication_classes([BasicAuthentication]) +        def view(request): +            self.assertEqual(len(request.authenticators), 1) +            self.assertTrue(isinstance(request.authenticators[0], +                                       BasicAuthentication)) +            return Response({}) + +        request = self.factory.get('/') +        view(request) + +    def test_permission_classes(self): + +        @api_view(['GET']) +        @permission_classes([IsAuthenticated]) +        def view(request): +            return Response({}) + +        request = self.factory.get('/') +        response = view(request) +        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + +    def test_throttle_classes(self): +        class OncePerDayUserThrottle(UserRateThrottle): +            rate = '1/day' + +        @api_view(['GET']) +        @throttle_classes([OncePerDayUserThrottle]) +        def view(request): +            return Response({}) + +        request = self.factory.get('/') +        response = view(request) +        self.assertEqual(response.status_code, status.HTTP_200_OK) + +        response = view(request) +        self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) diff --git a/tests/test_description.py b/tests/test_description.py new file mode 100644 index 00000000..78ce2350 --- /dev/null +++ b/tests/test_description.py @@ -0,0 +1,131 @@ +# -- coding: utf-8 -- + +from __future__ import unicode_literals +from django.test import TestCase +from django.utils.encoding import python_2_unicode_compatible, smart_text +from rest_framework.compat import apply_markdown +from rest_framework.views import APIView +from .description import ViewWithNonASCIICharactersInDocstring +from .description import UTF8_TEST_DOCSTRING + +# We check that docstrings get nicely un-indented. +DESCRIPTION = """an example docstring +==================== + +* list +* list + +another header +-------------- + +    code block + +indented + +# hash style header #""" + +# If markdown is installed we also test it's working +# (and that our wrapped forces '=' to h2 and '-' to h3) + +# We support markdown < 2.1 and markdown >= 2.1 +MARKED_DOWN_lt_21 = """<h2>an example docstring</h2> +<ul> +<li>list</li> +<li>list</li> +</ul> +<h3>another header</h3> +<pre><code>code block +</code></pre> +<p>indented</p> +<h2 id="hash_style_header">hash style header</h2>""" + +MARKED_DOWN_gte_21 = """<h2 id="an-example-docstring">an example docstring</h2> +<ul> +<li>list</li> +<li>list</li> +</ul> +<h3 id="another-header">another header</h3> +<pre><code>code block +</code></pre> +<p>indented</p> +<h2 id="hash-style-header">hash style header</h2>""" + + +class TestViewNamesAndDescriptions(TestCase): +    def test_view_name_uses_class_name(self): +        """ +        Ensure view names are based on the class name. +        """ +        class MockView(APIView): +            pass +        self.assertEqual(MockView().get_view_name(), 'Mock') + +    def test_view_description_uses_docstring(self): +        """Ensure view descriptions are based on the docstring.""" +        class MockView(APIView): +            """an example docstring +            ==================== + +            * list +            * list + +            another header +            -------------- + +                code block + +            indented + +            # hash style header #""" + +        self.assertEqual(MockView().get_view_description(), DESCRIPTION) + +    def test_view_description_supports_unicode(self): +        """ +        Unicode in docstrings should be respected. +        """ + +        self.assertEqual( +            ViewWithNonASCIICharactersInDocstring().get_view_description(), +            smart_text(UTF8_TEST_DOCSTRING) +        ) + +    def test_view_description_can_be_empty(self): +        """ +        Ensure that if a view has no docstring, +        then it's description is the empty string. +        """ +        class MockView(APIView): +            pass +        self.assertEqual(MockView().get_view_description(), '') + +    def test_view_description_can_be_promise(self): +        """ +        Ensure a view may have a docstring that is actually a lazily evaluated +        class that can be converted to a string. + +        See: https://github.com/tomchristie/django-rest-framework/issues/1708 +        """ +        # use a mock object instead of gettext_lazy to ensure that we can't end +        # up with a test case string in our l10n catalog +        @python_2_unicode_compatible +        class MockLazyStr(object): +            def __init__(self, string): +                self.s = string + +            def __str__(self): +                return self.s + +        class MockView(APIView): +            __doc__ = MockLazyStr("a gettext string") + +        self.assertEqual(MockView().get_view_description(), 'a gettext string') + +    def test_markdown(self): +        """ +        Ensure markdown to HTML works as expected. +        """ +        if apply_markdown: +            gte_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_gte_21 +            lt_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_lt_21 +            self.assertTrue(gte_21_match or lt_21_match) diff --git a/tests/test_fields.py b/tests/test_fields.py new file mode 100644 index 00000000..1aa528da --- /dev/null +++ b/tests/test_fields.py @@ -0,0 +1,1212 @@ +from decimal import Decimal +from django.utils import timezone +from rest_framework import serializers +import datetime +import django +import pytest +import uuid + + +# Tests for field keyword arguments and core functionality. +# --------------------------------------------------------- + +class TestEmpty: +    """ +    Tests for `required`, `allow_null`, `allow_blank`, `default`. +    """ +    def test_required(self): +        """ +        By default a field must be included in the input. +        """ +        field = serializers.IntegerField() +        with pytest.raises(serializers.ValidationError) as exc_info: +            field.run_validation() +        assert exc_info.value.detail == ['This field is required.'] + +    def test_not_required(self): +        """ +        If `required=False` then a field may be omitted from the input. +        """ +        field = serializers.IntegerField(required=False) +        with pytest.raises(serializers.SkipField): +            field.run_validation() + +    def test_disallow_null(self): +        """ +        By default `None` is not a valid input. +        """ +        field = serializers.IntegerField() +        with pytest.raises(serializers.ValidationError) as exc_info: +            field.run_validation(None) +        assert exc_info.value.detail == ['This field may not be null.'] + +    def test_allow_null(self): +        """ +        If `allow_null=True` then `None` is a valid input. +        """ +        field = serializers.IntegerField(allow_null=True) +        output = field.run_validation(None) +        assert output is None + +    def test_disallow_blank(self): +        """ +        By default '' is not a valid input. +        """ +        field = serializers.CharField() +        with pytest.raises(serializers.ValidationError) as exc_info: +            field.run_validation('') +        assert exc_info.value.detail == ['This field may not be blank.'] + +    def test_allow_blank(self): +        """ +        If `allow_blank=True` then '' is a valid input. +        """ +        field = serializers.CharField(allow_blank=True) +        output = field.run_validation('') +        assert output == '' + +    def test_default(self): +        """ +        If `default` is set, then omitted values get the default input. +        """ +        field = serializers.IntegerField(default=123) +        output = field.run_validation() +        assert output is 123 + + +class TestSource: +    def test_source(self): +        class ExampleSerializer(serializers.Serializer): +            example_field = serializers.CharField(source='other') +        serializer = ExampleSerializer(data={'example_field': 'abc'}) +        assert serializer.is_valid() +        assert serializer.validated_data == {'other': 'abc'} + +    def test_redundant_source(self): +        class ExampleSerializer(serializers.Serializer): +            example_field = serializers.CharField(source='example_field') +        with pytest.raises(AssertionError) as exc_info: +            ExampleSerializer().fields +        assert str(exc_info.value) == ( +            "It is redundant to specify `source='example_field'` on field " +            "'CharField' in serializer 'ExampleSerializer', because it is the " +            "same as the field name. Remove the `source` keyword argument." +        ) + +    def test_callable_source(self): +        class ExampleSerializer(serializers.Serializer): +            example_field = serializers.CharField(source='example_callable') + +        class ExampleInstance(object): +            def example_callable(self): +                return 'example callable value' + +        serializer = ExampleSerializer(ExampleInstance()) +        assert serializer.data['example_field'] == 'example callable value' + +    def test_callable_source_raises(self): +        class ExampleSerializer(serializers.Serializer): +            example_field = serializers.CharField(source='example_callable', read_only=True) + +        class ExampleInstance(object): +            def example_callable(self): +                raise AttributeError('method call failed') + +        with pytest.raises(ValueError) as exc_info: +            serializer = ExampleSerializer(ExampleInstance()) +            serializer.data.items() + +        assert 'method call failed' in str(exc_info.value) + + +class TestReadOnly: +    def setup(self): +        class TestSerializer(serializers.Serializer): +            read_only = serializers.ReadOnlyField() +            writable = serializers.IntegerField() +        self.Serializer = TestSerializer + +    def test_validate_read_only(self): +        """ +        Read-only serializers.should not be included in validation. +        """ +        data = {'read_only': 123, 'writable': 456} +        serializer = self.Serializer(data=data) +        assert serializer.is_valid() +        assert serializer.validated_data == {'writable': 456} + +    def test_serialize_read_only(self): +        """ +        Read-only serializers.should be serialized. +        """ +        instance = {'read_only': 123, 'writable': 456} +        serializer = self.Serializer(instance) +        assert serializer.data == {'read_only': 123, 'writable': 456} + + +class TestWriteOnly: +    def setup(self): +        class TestSerializer(serializers.Serializer): +            write_only = serializers.IntegerField(write_only=True) +            readable = serializers.IntegerField() +        self.Serializer = TestSerializer + +    def test_validate_write_only(self): +        """ +        Write-only serializers.should be included in validation. +        """ +        data = {'write_only': 123, 'readable': 456} +        serializer = self.Serializer(data=data) +        assert serializer.is_valid() +        assert serializer.validated_data == {'write_only': 123, 'readable': 456} + +    def test_serialize_write_only(self): +        """ +        Write-only serializers.should not be serialized. +        """ +        instance = {'write_only': 123, 'readable': 456} +        serializer = self.Serializer(instance) +        assert serializer.data == {'readable': 456} + + +class TestInitial: +    def setup(self): +        class TestSerializer(serializers.Serializer): +            initial_field = serializers.IntegerField(initial=123) +            blank_field = serializers.IntegerField() +        self.serializer = TestSerializer() + +    def test_initial(self): +        """ +        Initial values should be included when serializing a new representation. +        """ +        assert self.serializer.data == { +            'initial_field': 123, +            'blank_field': None +        } + + +class TestLabel: +    def setup(self): +        class TestSerializer(serializers.Serializer): +            labeled = serializers.IntegerField(label='My label') +        self.serializer = TestSerializer() + +    def test_label(self): +        """ +        A field's label may be set with the `label` argument. +        """ +        fields = self.serializer.fields +        assert fields['labeled'].label == 'My label' + + +class TestInvalidErrorKey: +    def setup(self): +        class ExampleField(serializers.Field): +            def to_native(self, data): +                self.fail('incorrect') +        self.field = ExampleField() + +    def test_invalid_error_key(self): +        """ +        If a field raises a validation error, but does not have a corresponding +        error message, then raise an appropriate assertion error. +        """ +        with pytest.raises(AssertionError) as exc_info: +            self.field.to_native(123) +        expected = ( +            'ValidationError raised by `ExampleField`, but error key ' +            '`incorrect` does not exist in the `error_messages` dictionary.' +        ) +        assert str(exc_info.value) == expected + + +class TestBooleanHTMLInput: +    def setup(self): +        class TestSerializer(serializers.Serializer): +            archived = serializers.BooleanField() +        self.Serializer = TestSerializer + +    def test_empty_html_checkbox(self): +        """ +        HTML checkboxes do not send any value, but should be treated +        as `False` by BooleanField. +        """ +        # This class mocks up a dictionary like object, that behaves +        # as if it was returned for multipart or urlencoded data. +        class MockHTMLDict(dict): +            getlist = None +        serializer = self.Serializer(data=MockHTMLDict()) +        assert serializer.is_valid() +        assert serializer.validated_data == {'archived': False} + + +class MockHTMLDict(dict): +    """ +    This class mocks up a dictionary like object, that behaves +    as if it was returned for multipart or urlencoded data. +    """ +    getlist = None + + +class TestHTMLInput: +    def test_empty_html_charfield(self): +        class TestSerializer(serializers.Serializer): +            message = serializers.CharField(default='happy') + +        serializer = TestSerializer(data=MockHTMLDict()) +        assert serializer.is_valid() +        assert serializer.validated_data == {'message': 'happy'} + +    def test_empty_html_charfield_allow_null(self): +        class TestSerializer(serializers.Serializer): +            message = serializers.CharField(allow_null=True) + +        serializer = TestSerializer(data=MockHTMLDict({'message': ''})) +        assert serializer.is_valid() +        assert serializer.validated_data == {'message': None} + +    def test_empty_html_datefield_allow_null(self): +        class TestSerializer(serializers.Serializer): +            expiry = serializers.DateField(allow_null=True) + +        serializer = TestSerializer(data=MockHTMLDict({'expiry': ''})) +        assert serializer.is_valid() +        assert serializer.validated_data == {'expiry': None} + +    def test_empty_html_charfield_allow_null_allow_blank(self): +        class TestSerializer(serializers.Serializer): +            message = serializers.CharField(allow_null=True, allow_blank=True) + +        serializer = TestSerializer(data=MockHTMLDict({'message': ''})) +        assert serializer.is_valid() +        assert serializer.validated_data == {'message': ''} + +    def test_empty_html_charfield_required_false(self): +        class TestSerializer(serializers.Serializer): +            message = serializers.CharField(required=False) + +        serializer = TestSerializer(data=MockHTMLDict()) +        assert serializer.is_valid() +        assert serializer.validated_data == {} + + +class TestCreateOnlyDefault: +    def setup(self): +        default = serializers.CreateOnlyDefault('2001-01-01') + +        class TestSerializer(serializers.Serializer): +            published = serializers.HiddenField(default=default) +            text = serializers.CharField() +        self.Serializer = TestSerializer + +    def test_create_only_default_is_provided(self): +        serializer = self.Serializer(data={'text': 'example'}) +        assert serializer.is_valid() +        assert serializer.validated_data == { +            'text': 'example', 'published': '2001-01-01' +        } + +    def test_create_only_default_is_not_provided_on_update(self): +        instance = { +            'text': 'example', 'published': '2001-01-01' +        } +        serializer = self.Serializer(instance, data={'text': 'example'}) +        assert serializer.is_valid() +        assert serializer.validated_data == { +            'text': 'example', +        } + +    def test_create_only_default_callable_sets_context(self): +        """ +        CreateOnlyDefault instances with a callable default should set_context +        on the callable if possible +        """ +        class TestCallableDefault: +            def set_context(self, serializer_field): +                self.field = serializer_field + +            def __call__(self): +                return "success" if hasattr(self, 'field') else "failure" + +        class TestSerializer(serializers.Serializer): +            context_set = serializers.CharField(default=serializers.CreateOnlyDefault(TestCallableDefault())) + +        serializer = TestSerializer(data={}) +        assert serializer.is_valid() +        assert serializer.validated_data['context_set'] == 'success' + + +# Tests for field input and output values. +# ---------------------------------------- + +def get_items(mapping_or_list_of_two_tuples): +    # Tests accept either lists of two tuples, or dictionaries. +    if isinstance(mapping_or_list_of_two_tuples, dict): +        # {value: expected} +        return mapping_or_list_of_two_tuples.items() +    # [(value, expected), ...] +    return mapping_or_list_of_two_tuples + + +class FieldValues: +    """ +    Base class for testing valid and invalid input values. +    """ +    def test_valid_inputs(self): +        """ +        Ensure that valid values return the expected validated data. +        """ +        for input_value, expected_output in get_items(self.valid_inputs): +            assert self.field.run_validation(input_value) == expected_output + +    def test_invalid_inputs(self): +        """ +        Ensure that invalid values raise the expected validation error. +        """ +        for input_value, expected_failure in get_items(self.invalid_inputs): +            with pytest.raises(serializers.ValidationError) as exc_info: +                self.field.run_validation(input_value) +            assert exc_info.value.detail == expected_failure + +    def test_outputs(self): +        for output_value, expected_output in get_items(self.outputs): +            assert self.field.to_representation(output_value) == expected_output + + +# Boolean types... + +class TestBooleanField(FieldValues): +    """ +    Valid and invalid values for `BooleanField`. +    """ +    valid_inputs = { +        'true': True, +        'false': False, +        '1': True, +        '0': False, +        1: True, +        0: False, +        True: True, +        False: False, +    } +    invalid_inputs = { +        'foo': ['"foo" is not a valid boolean.'], +        None: ['This field may not be null.'] +    } +    outputs = { +        'true': True, +        'false': False, +        '1': True, +        '0': False, +        1: True, +        0: False, +        True: True, +        False: False, +        'other': True +    } +    field = serializers.BooleanField() + + +class TestNullBooleanField(FieldValues): +    """ +    Valid and invalid values for `BooleanField`. +    """ +    valid_inputs = { +        'true': True, +        'false': False, +        'null': None, +        True: True, +        False: False, +        None: None +    } +    invalid_inputs = { +        'foo': ['"foo" is not a valid boolean.'], +    } +    outputs = { +        'true': True, +        'false': False, +        'null': None, +        True: True, +        False: False, +        None: None, +        'other': True +    } +    field = serializers.NullBooleanField() + + +# String types... + +class TestCharField(FieldValues): +    """ +    Valid and invalid values for `CharField`. +    """ +    valid_inputs = { +        1: '1', +        'abc': 'abc' +    } +    invalid_inputs = { +        '': ['This field may not be blank.'] +    } +    outputs = { +        1: '1', +        'abc': 'abc' +    } +    field = serializers.CharField() + +    def test_trim_whitespace_default(self): +        field = serializers.CharField() +        assert field.to_internal_value(' abc ') == 'abc' + +    def test_trim_whitespace_disabled(self): +        field = serializers.CharField(trim_whitespace=False) +        assert field.to_internal_value(' abc ') == ' abc ' + + +class TestEmailField(FieldValues): +    """ +    Valid and invalid values for `EmailField`. +    """ +    valid_inputs = { +        'example@example.com': 'example@example.com', +        ' example@example.com ': 'example@example.com', +    } +    invalid_inputs = { +        'examplecom': ['Enter a valid email address.'] +    } +    outputs = {} +    field = serializers.EmailField() + + +class TestRegexField(FieldValues): +    """ +    Valid and invalid values for `RegexField`. +    """ +    valid_inputs = { +        'a9': 'a9', +    } +    invalid_inputs = { +        'A9': ["This value does not match the required pattern."] +    } +    outputs = {} +    field = serializers.RegexField(regex='[a-z][0-9]') + + +class TestSlugField(FieldValues): +    """ +    Valid and invalid values for `SlugField`. +    """ +    valid_inputs = { +        'slug-99': 'slug-99', +    } +    invalid_inputs = { +        'slug 99': ['Enter a valid "slug" consisting of letters, numbers, underscores or hyphens.'] +    } +    outputs = {} +    field = serializers.SlugField() + + +class TestURLField(FieldValues): +    """ +    Valid and invalid values for `URLField`. +    """ +    valid_inputs = { +        'http://example.com': 'http://example.com', +    } +    invalid_inputs = { +        'example.com': ['Enter a valid URL.'] +    } +    outputs = {} +    field = serializers.URLField() + + +class TestUUIDField(FieldValues): +    """ +    Valid and invalid values for `UUIDField`. +    """ +    valid_inputs = { +        '825d7aeb-05a9-45b5-a5b7-05df87923cda': uuid.UUID('825d7aeb-05a9-45b5-a5b7-05df87923cda'), +        '825d7aeb05a945b5a5b705df87923cda': uuid.UUID('825d7aeb-05a9-45b5-a5b7-05df87923cda') +    } +    invalid_inputs = { +        '825d7aeb-05a9-45b5-a5b7': ['"825d7aeb-05a9-45b5-a5b7" is not a valid UUID.'] +    } +    outputs = { +        uuid.UUID('825d7aeb-05a9-45b5-a5b7-05df87923cda'): '825d7aeb-05a9-45b5-a5b7-05df87923cda' +    } +    field = serializers.UUIDField() + + +# Number types... + +class TestIntegerField(FieldValues): +    """ +    Valid and invalid values for `IntegerField`. +    """ +    valid_inputs = { +        '1': 1, +        '0': 0, +        1: 1, +        0: 0, +        1.0: 1, +        0.0: 0 +    } +    invalid_inputs = { +        'abc': ['A valid integer is required.'] +    } +    outputs = { +        '1': 1, +        '0': 0, +        1: 1, +        0: 0, +        1.0: 1, +        0.0: 0 +    } +    field = serializers.IntegerField() + + +class TestMinMaxIntegerField(FieldValues): +    """ +    Valid and invalid values for `IntegerField` with min and max limits. +    """ +    valid_inputs = { +        '1': 1, +        '3': 3, +        1: 1, +        3: 3, +    } +    invalid_inputs = { +        0: ['Ensure this value is greater than or equal to 1.'], +        4: ['Ensure this value is less than or equal to 3.'], +        '0': ['Ensure this value is greater than or equal to 1.'], +        '4': ['Ensure this value is less than or equal to 3.'], +    } +    outputs = {} +    field = serializers.IntegerField(min_value=1, max_value=3) + + +class TestFloatField(FieldValues): +    """ +    Valid and invalid values for `FloatField`. +    """ +    valid_inputs = { +        '1': 1.0, +        '0': 0.0, +        1: 1.0, +        0: 0.0, +        1.0: 1.0, +        0.0: 0.0, +    } +    invalid_inputs = { +        'abc': ["A valid number is required."] +    } +    outputs = { +        '1': 1.0, +        '0': 0.0, +        1: 1.0, +        0: 0.0, +        1.0: 1.0, +        0.0: 0.0, +    } +    field = serializers.FloatField() + + +class TestMinMaxFloatField(FieldValues): +    """ +    Valid and invalid values for `FloatField` with min and max limits. +    """ +    valid_inputs = { +        '1': 1, +        '3': 3, +        1: 1, +        3: 3, +        1.0: 1.0, +        3.0: 3.0, +    } +    invalid_inputs = { +        0.9: ['Ensure this value is greater than or equal to 1.'], +        3.1: ['Ensure this value is less than or equal to 3.'], +        '0.0': ['Ensure this value is greater than or equal to 1.'], +        '3.1': ['Ensure this value is less than or equal to 3.'], +    } +    outputs = {} +    field = serializers.FloatField(min_value=1, max_value=3) + + +class TestDecimalField(FieldValues): +    """ +    Valid and invalid values for `DecimalField`. +    """ +    valid_inputs = { +        '12.3': Decimal('12.3'), +        '0.1': Decimal('0.1'), +        10: Decimal('10'), +        0: Decimal('0'), +        12.3: Decimal('12.3'), +        0.1: Decimal('0.1'), +    } +    invalid_inputs = ( +        ('abc', ["A valid number is required."]), +        (Decimal('Nan'), ["A valid number is required."]), +        (Decimal('Inf'), ["A valid number is required."]), +        ('12.345', ["Ensure that there are no more than 3 digits in total."]), +        ('0.01', ["Ensure that there are no more than 1 decimal places."]), +        (123, ["Ensure that there are no more than 2 digits before the decimal point."]) +    ) +    outputs = { +        '1': '1.0', +        '0': '0.0', +        '1.09': '1.1', +        '0.04': '0.0', +        1: '1.0', +        0: '0.0', +        Decimal('1.0'): '1.0', +        Decimal('0.0'): '0.0', +        Decimal('1.09'): '1.1', +        Decimal('0.04'): '0.0' +    } +    field = serializers.DecimalField(max_digits=3, decimal_places=1) + + +class TestMinMaxDecimalField(FieldValues): +    """ +    Valid and invalid values for `DecimalField` with min and max limits. +    """ +    valid_inputs = { +        '10.0': Decimal('10.0'), +        '20.0': Decimal('20.0'), +    } +    invalid_inputs = { +        '9.9': ['Ensure this value is greater than or equal to 10.'], +        '20.1': ['Ensure this value is less than or equal to 20.'], +    } +    outputs = {} +    field = serializers.DecimalField( +        max_digits=3, decimal_places=1, +        min_value=10, max_value=20 +    ) + + +class TestNoStringCoercionDecimalField(FieldValues): +    """ +    Output values for `DecimalField` with `coerce_to_string=False`. +    """ +    valid_inputs = {} +    invalid_inputs = {} +    outputs = { +        1.09: Decimal('1.1'), +        0.04: Decimal('0.0'), +        '1.09': Decimal('1.1'), +        '0.04': Decimal('0.0'), +        Decimal('1.09'): Decimal('1.1'), +        Decimal('0.04'): Decimal('0.0'), +    } +    field = serializers.DecimalField( +        max_digits=3, decimal_places=1, +        coerce_to_string=False +    ) + + +# Date & time serializers... + +class TestDateField(FieldValues): +    """ +    Valid and invalid values for `DateField`. +    """ +    valid_inputs = { +        '2001-01-01': datetime.date(2001, 1, 1), +        datetime.date(2001, 1, 1): datetime.date(2001, 1, 1), +    } +    invalid_inputs = { +        'abc': ['Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]].'], +        '2001-99-99': ['Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]].'], +        datetime.datetime(2001, 1, 1, 12, 00): ['Expected a date but got a datetime.'], +    } +    outputs = { +        datetime.date(2001, 1, 1): '2001-01-01' +    } +    field = serializers.DateField() + + +class TestCustomInputFormatDateField(FieldValues): +    """ +    Valid and invalid values for `DateField` with a cutom input format. +    """ +    valid_inputs = { +        '1 Jan 2001': datetime.date(2001, 1, 1), +    } +    invalid_inputs = { +        '2001-01-01': ['Date has wrong format. Use one of these formats instead: DD [Jan-Dec] YYYY.'] +    } +    outputs = {} +    field = serializers.DateField(input_formats=['%d %b %Y']) + + +class TestCustomOutputFormatDateField(FieldValues): +    """ +    Values for `DateField` with a custom output format. +    """ +    valid_inputs = {} +    invalid_inputs = {} +    outputs = { +        datetime.date(2001, 1, 1): '01 Jan 2001' +    } +    field = serializers.DateField(format='%d %b %Y') + + +class TestNoOutputFormatDateField(FieldValues): +    """ +    Values for `DateField` with no output format. +    """ +    valid_inputs = {} +    invalid_inputs = {} +    outputs = { +        datetime.date(2001, 1, 1): datetime.date(2001, 1, 1) +    } +    field = serializers.DateField(format=None) + + +class TestDateTimeField(FieldValues): +    """ +    Valid and invalid values for `DateTimeField`. +    """ +    valid_inputs = { +        '2001-01-01 13:00': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()), +        '2001-01-01T13:00': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()), +        '2001-01-01T13:00Z': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()), +        datetime.datetime(2001, 1, 1, 13, 00): datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()), +        datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()): datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()), +        # Django 1.4 does not support timezone string parsing. +        '2001-01-01T14:00+01:00' if (django.VERSION > (1, 4)) else '2001-01-01T13:00Z': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()) +    } +    invalid_inputs = { +        'abc': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z].'], +        '2001-99-99T99:00': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z].'], +        datetime.date(2001, 1, 1): ['Expected a datetime but got a date.'], +    } +    outputs = { +        datetime.datetime(2001, 1, 1, 13, 00): '2001-01-01T13:00:00', +        datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()): '2001-01-01T13:00:00Z' +    } +    field = serializers.DateTimeField(default_timezone=timezone.UTC()) + + +class TestCustomInputFormatDateTimeField(FieldValues): +    """ +    Valid and invalid values for `DateTimeField` with a cutom input format. +    """ +    valid_inputs = { +        '1:35pm, 1 Jan 2001': datetime.datetime(2001, 1, 1, 13, 35, tzinfo=timezone.UTC()), +    } +    invalid_inputs = { +        '2001-01-01T20:50': ['Datetime has wrong format. Use one of these formats instead: hh:mm[AM|PM], DD [Jan-Dec] YYYY.'] +    } +    outputs = {} +    field = serializers.DateTimeField(default_timezone=timezone.UTC(), input_formats=['%I:%M%p, %d %b %Y']) + + +class TestCustomOutputFormatDateTimeField(FieldValues): +    """ +    Values for `DateTimeField` with a custom output format. +    """ +    valid_inputs = {} +    invalid_inputs = {} +    outputs = { +        datetime.datetime(2001, 1, 1, 13, 00): '01:00PM, 01 Jan 2001', +    } +    field = serializers.DateTimeField(format='%I:%M%p, %d %b %Y') + + +class TestNoOutputFormatDateTimeField(FieldValues): +    """ +    Values for `DateTimeField` with no output format. +    """ +    valid_inputs = {} +    invalid_inputs = {} +    outputs = { +        datetime.datetime(2001, 1, 1, 13, 00): datetime.datetime(2001, 1, 1, 13, 00), +    } +    field = serializers.DateTimeField(format=None) + + +class TestNaiveDateTimeField(FieldValues): +    """ +    Valid and invalid values for `DateTimeField` with naive datetimes. +    """ +    valid_inputs = { +        datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()): datetime.datetime(2001, 1, 1, 13, 00), +        '2001-01-01 13:00': datetime.datetime(2001, 1, 1, 13, 00), +    } +    invalid_inputs = {} +    outputs = {} +    field = serializers.DateTimeField(default_timezone=None) + + +class TestTimeField(FieldValues): +    """ +    Valid and invalid values for `TimeField`. +    """ +    valid_inputs = { +        '13:00': datetime.time(13, 00), +        datetime.time(13, 00): datetime.time(13, 00), +    } +    invalid_inputs = { +        'abc': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]].'], +        '99:99': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]].'], +    } +    outputs = { +        datetime.time(13, 00): '13:00:00' +    } +    field = serializers.TimeField() + + +class TestCustomInputFormatTimeField(FieldValues): +    """ +    Valid and invalid values for `TimeField` with a custom input format. +    """ +    valid_inputs = { +        '1:00pm': datetime.time(13, 00), +    } +    invalid_inputs = { +        '13:00': ['Time has wrong format. Use one of these formats instead: hh:mm[AM|PM].'], +    } +    outputs = {} +    field = serializers.TimeField(input_formats=['%I:%M%p']) + + +class TestCustomOutputFormatTimeField(FieldValues): +    """ +    Values for `TimeField` with a custom output format. +    """ +    valid_inputs = {} +    invalid_inputs = {} +    outputs = { +        datetime.time(13, 00): '01:00PM' +    } +    field = serializers.TimeField(format='%I:%M%p') + + +class TestNoOutputFormatTimeField(FieldValues): +    """ +    Values for `TimeField` with a no output format. +    """ +    valid_inputs = {} +    invalid_inputs = {} +    outputs = { +        datetime.time(13, 00): datetime.time(13, 00) +    } +    field = serializers.TimeField(format=None) + + +# Choice types... + +class TestChoiceField(FieldValues): +    """ +    Valid and invalid values for `ChoiceField`. +    """ +    valid_inputs = { +        'poor': 'poor', +        'medium': 'medium', +        'good': 'good', +    } +    invalid_inputs = { +        'amazing': ['"amazing" is not a valid choice.'] +    } +    outputs = { +        'good': 'good', +        '': '' +    } +    field = serializers.ChoiceField( +        choices=[ +            ('poor', 'Poor quality'), +            ('medium', 'Medium quality'), +            ('good', 'Good quality'), +        ] +    ) + +    def test_allow_blank(self): +        """ +        If `allow_blank=True` then '' is a valid input. +        """ +        field = serializers.ChoiceField( +            allow_blank=True, +            choices=[ +                ('poor', 'Poor quality'), +                ('medium', 'Medium quality'), +                ('good', 'Good quality'), +            ] +        ) +        output = field.run_validation('') +        assert output == '' + + +class TestChoiceFieldWithType(FieldValues): +    """ +    Valid and invalid values for a `Choice` field that uses an integer type, +    instead of a char type. +    """ +    valid_inputs = { +        '1': 1, +        3: 3, +    } +    invalid_inputs = { +        5: ['"5" is not a valid choice.'], +        'abc': ['"abc" is not a valid choice.'] +    } +    outputs = { +        '1': 1, +        1: 1 +    } +    field = serializers.ChoiceField( +        choices=[ +            (1, 'Poor quality'), +            (2, 'Medium quality'), +            (3, 'Good quality'), +        ] +    ) + + +class TestChoiceFieldWithListChoices(FieldValues): +    """ +    Valid and invalid values for a `Choice` field that uses a flat list for the +    choices, rather than a list of pairs of (`value`, `description`). +    """ +    valid_inputs = { +        'poor': 'poor', +        'medium': 'medium', +        'good': 'good', +    } +    invalid_inputs = { +        'awful': ['"awful" is not a valid choice.'] +    } +    outputs = { +        'good': 'good' +    } +    field = serializers.ChoiceField(choices=('poor', 'medium', 'good')) + + +class TestMultipleChoiceField(FieldValues): +    """ +    Valid and invalid values for `MultipleChoiceField`. +    """ +    valid_inputs = { +        (): set(), +        ('aircon',): set(['aircon']), +        ('aircon', 'manual'): set(['aircon', 'manual']), +    } +    invalid_inputs = { +        'abc': ['Expected a list of items but got type "str".'], +        ('aircon', 'incorrect'): ['"incorrect" is not a valid choice.'] +    } +    outputs = [ +        (['aircon', 'manual'], set(['aircon', 'manual'])) +    ] +    field = serializers.MultipleChoiceField( +        choices=[ +            ('aircon', 'AirCon'), +            ('manual', 'Manual drive'), +            ('diesel', 'Diesel'), +        ] +    ) + + +# File serializers... + +class MockFile: +    def __init__(self, name='', size=0, url=''): +        self.name = name +        self.size = size +        self.url = url + +    def __eq__(self, other): +        return ( +            isinstance(other, MockFile) and +            self.name == other.name and +            self.size == other.size and +            self.url == other.url +        ) + + +class TestFileField(FieldValues): +    """ +    Values for `FileField`. +    """ +    valid_inputs = [ +        (MockFile(name='example', size=10), MockFile(name='example', size=10)) +    ] +    invalid_inputs = [ +        ('invalid', ['The submitted data was not a file. Check the encoding type on the form.']), +        (MockFile(name='example.txt', size=0), ['The submitted file is empty.']), +        (MockFile(name='', size=10), ['No filename could be determined.']), +        (MockFile(name='x' * 100, size=10), ['Ensure this filename has at most 10 characters (it has 100).']) +    ] +    outputs = [ +        (MockFile(name='example.txt', url='/example.txt'), '/example.txt'), +        ('', None) +    ] +    field = serializers.FileField(max_length=10) + + +class TestFieldFieldWithName(FieldValues): +    """ +    Values for `FileField` with a filename output instead of URLs. +    """ +    valid_inputs = {} +    invalid_inputs = {} +    outputs = [ +        (MockFile(name='example.txt', url='/example.txt'), 'example.txt') +    ] +    field = serializers.FileField(use_url=False) + + +# Stub out mock Django `forms.ImageField` class so we don't *actually* +# call into it's regular validation, or require PIL for testing. +class FailImageValidation(object): +    def to_python(self, value): +        raise serializers.ValidationError(self.error_messages['invalid_image']) + + +class PassImageValidation(object): +    def to_python(self, value): +        return value + + +class TestInvalidImageField(FieldValues): +    """ +    Values for an invalid `ImageField`. +    """ +    valid_inputs = {} +    invalid_inputs = [ +        (MockFile(name='example.txt', size=10), ['Upload a valid image. The file you uploaded was either not an image or a corrupted image.']) +    ] +    outputs = {} +    field = serializers.ImageField(_DjangoImageField=FailImageValidation) + + +class TestValidImageField(FieldValues): +    """ +    Values for an valid `ImageField`. +    """ +    valid_inputs = [ +        (MockFile(name='example.txt', size=10), MockFile(name='example.txt', size=10)) +    ] +    invalid_inputs = {} +    outputs = {} +    field = serializers.ImageField(_DjangoImageField=PassImageValidation) + + +# Composite serializers... + +class TestListField(FieldValues): +    """ +    Values for `ListField` with IntegerField as child. +    """ +    valid_inputs = [ +        ([1, 2, 3], [1, 2, 3]), +        (['1', '2', '3'], [1, 2, 3]) +    ] +    invalid_inputs = [ +        ('not a list', ['Expected a list of items but got type "str".']), +        ([1, 2, 'error'], ['A valid integer is required.']) +    ] +    outputs = [ +        ([1, 2, 3], [1, 2, 3]), +        (['1', '2', '3'], [1, 2, 3]) +    ] +    field = serializers.ListField(child=serializers.IntegerField()) + + +class TestUnvalidatedListField(FieldValues): +    """ +    Values for `ListField` with no `child` argument. +    """ +    valid_inputs = [ +        ([1, '2', True, [4, 5, 6]], [1, '2', True, [4, 5, 6]]), +    ] +    invalid_inputs = [ +        ('not a list', ['Expected a list of items but got type "str".']), +    ] +    outputs = [ +        ([1, '2', True, [4, 5, 6]], [1, '2', True, [4, 5, 6]]), +    ] +    field = serializers.ListField() + + +class TestDictField(FieldValues): +    """ +    Values for `ListField` with CharField as child. +    """ +    valid_inputs = [ +        ({'a': 1, 'b': '2', 3: 3}, {'a': '1', 'b': '2', '3': '3'}), +    ] +    invalid_inputs = [ +        ({'a': 1, 'b': None}, ['This field may not be null.']), +        ('not a dict', ['Expected a dictionary of items but got type "str".']), +    ] +    outputs = [ +        ({'a': 1, 'b': '2', 3: 3}, {'a': '1', 'b': '2', '3': '3'}), +    ] +    field = serializers.DictField(child=serializers.CharField()) + + +class TestUnvalidatedDictField(FieldValues): +    """ +    Values for `ListField` with no `child` argument. +    """ +    valid_inputs = [ +        ({'a': 1, 'b': [4, 5, 6], 1: 123}, {'a': 1, 'b': [4, 5, 6], '1': 123}), +    ] +    invalid_inputs = [ +        ('not a dict', ['Expected a dictionary of items but got type "str".']), +    ] +    outputs = [ +        ({'a': 1, 'b': [4, 5, 6]}, {'a': 1, 'b': [4, 5, 6]}), +    ] +    field = serializers.DictField() + + +# Tests for FieldField. +# --------------------- + +class MockRequest: +    def build_absolute_uri(self, value): +        return 'http://example.com' + value + + +class TestFileFieldContext: +    def test_fully_qualified_when_request_in_context(self): +        field = serializers.FileField(max_length=10) +        field._context = {'request': MockRequest()} +        obj = MockFile(name='example.txt', url='/example.txt') +        value = field.to_representation(obj) +        assert value == 'http://example.com/example.txt' + + +# Tests for SerializerMethodField. +# -------------------------------- + +class TestSerializerMethodField: +    def test_serializer_method_field(self): +        class ExampleSerializer(serializers.Serializer): +            example_field = serializers.SerializerMethodField() + +            def get_example_field(self, obj): +                return 'ran get_example_field(%d)' % obj['example_field'] + +        serializer = ExampleSerializer({'example_field': 123}) +        assert serializer.data == { +            'example_field': 'ran get_example_field(123)' +        } + +    def test_redundant_method_name(self): +        class ExampleSerializer(serializers.Serializer): +            example_field = serializers.SerializerMethodField('get_example_field') + +        with pytest.raises(AssertionError) as exc_info: +            ExampleSerializer().fields +        assert str(exc_info.value) == ( +            "It is redundant to specify `get_example_field` on " +            "SerializerMethodField 'example_field' in serializer " +            "'ExampleSerializer', because it is the same as the default " +            "method name. Remove the `method_name` argument." +        ) diff --git a/tests/test_filters.py b/tests/test_filters.py new file mode 100644 index 00000000..e7cb0c79 --- /dev/null +++ b/tests/test_filters.py @@ -0,0 +1,823 @@ +from __future__ import unicode_literals +import datetime +from decimal import Decimal +from django.db import models +from django.conf.urls import patterns, url +from django.core.urlresolvers import reverse +from django.test import TestCase +from django.test.utils import override_settings +from django.utils import unittest +from django.utils.dateparse import parse_date +from django.utils.six.moves import reload_module +from rest_framework import generics, serializers, status, filters +from rest_framework.compat import django_filters +from rest_framework.test import APIRequestFactory +from .models import BaseFilterableItem, FilterableItem, BasicModel + + +factory = APIRequestFactory() + + +if django_filters: +    class FilterableItemSerializer(serializers.ModelSerializer): +        class Meta: +            model = FilterableItem + +    # Basic filter on a list view. +    class FilterFieldsRootView(generics.ListCreateAPIView): +        queryset = FilterableItem.objects.all() +        serializer_class = FilterableItemSerializer +        filter_fields = ['decimal', 'date'] +        filter_backends = (filters.DjangoFilterBackend,) + +    # These class are used to test a filter class. +    class SeveralFieldsFilter(django_filters.FilterSet): +        text = django_filters.CharFilter(lookup_type='icontains') +        decimal = django_filters.NumberFilter(lookup_type='lt') +        date = django_filters.DateFilter(lookup_type='gt') + +        class Meta: +            model = FilterableItem +            fields = ['text', 'decimal', 'date'] + +    class FilterClassRootView(generics.ListCreateAPIView): +        queryset = FilterableItem.objects.all() +        serializer_class = FilterableItemSerializer +        filter_class = SeveralFieldsFilter +        filter_backends = (filters.DjangoFilterBackend,) + +    # These classes are used to test a misconfigured filter class. +    class MisconfiguredFilter(django_filters.FilterSet): +        text = django_filters.CharFilter(lookup_type='icontains') + +        class Meta: +            model = BasicModel +            fields = ['text'] + +    class IncorrectlyConfiguredRootView(generics.ListCreateAPIView): +        queryset = FilterableItem.objects.all() +        serializer_class = FilterableItemSerializer +        filter_class = MisconfiguredFilter +        filter_backends = (filters.DjangoFilterBackend,) + +    class FilterClassDetailView(generics.RetrieveAPIView): +        queryset = FilterableItem.objects.all() +        serializer_class = FilterableItemSerializer +        filter_class = SeveralFieldsFilter +        filter_backends = (filters.DjangoFilterBackend,) + +    # These classes are used to test base model filter support +    class BaseFilterableItemFilter(django_filters.FilterSet): +        text = django_filters.CharFilter() + +        class Meta: +            model = BaseFilterableItem + +    class BaseFilterableItemFilterRootView(generics.ListCreateAPIView): +        queryset = FilterableItem.objects.all() +        serializer_class = FilterableItemSerializer +        filter_class = BaseFilterableItemFilter +        filter_backends = (filters.DjangoFilterBackend,) + +    # Regression test for #814 +    class FilterFieldsQuerysetView(generics.ListCreateAPIView): +        queryset = FilterableItem.objects.all() +        serializer_class = FilterableItemSerializer +        filter_fields = ['decimal', 'date'] +        filter_backends = (filters.DjangoFilterBackend,) + +    class GetQuerysetView(generics.ListCreateAPIView): +        serializer_class = FilterableItemSerializer +        filter_class = SeveralFieldsFilter +        filter_backends = (filters.DjangoFilterBackend,) + +        def get_queryset(self): +            return FilterableItem.objects.all() + +    urlpatterns = patterns( +        '', +        url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'), +        url(r'^$', FilterClassRootView.as_view(), name='root-view'), +        url(r'^get-queryset/$', GetQuerysetView.as_view(), +            name='get-queryset-view'), +    ) + + +class CommonFilteringTestCase(TestCase): +    def _serialize_object(self, obj): +        return {'id': obj.id, 'text': obj.text, 'decimal': str(obj.decimal), 'date': obj.date.isoformat()} + +    def setUp(self): +        """ +        Create 10 FilterableItem instances. +        """ +        base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8)) +        for i in range(10): +            text = chr(i + ord(base_data[0])) * 3  # Produces string 'aaa', 'bbb', etc. +            decimal = base_data[1] + i +            date = base_data[2] - datetime.timedelta(days=i * 2) +            FilterableItem(text=text, decimal=decimal, date=date).save() + +        self.objects = FilterableItem.objects +        self.data = [ +            self._serialize_object(obj) +            for obj in self.objects.all() +        ] + + +class IntegrationTestFiltering(CommonFilteringTestCase): +    """ +    Integration tests for filtered list views. +    """ + +    @unittest.skipUnless(django_filters, 'django-filter not installed') +    def test_get_filtered_fields_root_view(self): +        """ +        GET requests to paginated ListCreateAPIView should return paginated results. +        """ +        view = FilterFieldsRootView.as_view() + +        # Basic test with no filter. +        request = factory.get('/') +        response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, self.data) + +        # Tests that the decimal filter works. +        search_decimal = Decimal('2.25') +        request = factory.get('/', {'decimal': '%s' % search_decimal}) +        response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        expected_data = [f for f in self.data if Decimal(f['decimal']) == search_decimal] +        self.assertEqual(response.data, expected_data) + +        # Tests that the date filter works. +        search_date = datetime.date(2012, 9, 22) +        request = factory.get('/', {'date': '%s' % search_date})  # search_date str: '2012-09-22' +        response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        expected_data = [f for f in self.data if parse_date(f['date']) == search_date] +        self.assertEqual(response.data, expected_data) + +    @unittest.skipUnless(django_filters, 'django-filter not installed') +    def test_filter_with_queryset(self): +        """ +        Regression test for #814. +        """ +        view = FilterFieldsQuerysetView.as_view() + +        # Tests that the decimal filter works. +        search_decimal = Decimal('2.25') +        request = factory.get('/', {'decimal': '%s' % search_decimal}) +        response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        expected_data = [f for f in self.data if Decimal(f['decimal']) == search_decimal] +        self.assertEqual(response.data, expected_data) + +    @unittest.skipUnless(django_filters, 'django-filter not installed') +    def test_filter_with_get_queryset_only(self): +        """ +        Regression test for #834. +        """ +        view = GetQuerysetView.as_view() +        request = factory.get('/get-queryset/') +        view(request).render() +        # Used to raise "issubclass() arg 2 must be a class or tuple of classes" +        # here when neither `model' nor `queryset' was specified. + +    @unittest.skipUnless(django_filters, 'django-filter not installed') +    def test_get_filtered_class_root_view(self): +        """ +        GET requests to filtered ListCreateAPIView that have a filter_class set +        should return filtered results. +        """ +        view = FilterClassRootView.as_view() + +        # Basic test with no filter. +        request = factory.get('/') +        response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, self.data) + +        # Tests that the decimal filter set with 'lt' in the filter class works. +        search_decimal = Decimal('4.25') +        request = factory.get('/', {'decimal': '%s' % search_decimal}) +        response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        expected_data = [f for f in self.data if Decimal(f['decimal']) < search_decimal] +        self.assertEqual(response.data, expected_data) + +        # Tests that the date filter set with 'gt' in the filter class works. +        search_date = datetime.date(2012, 10, 2) +        request = factory.get('/', {'date': '%s' % search_date})  # search_date str: '2012-10-02' +        response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        expected_data = [f for f in self.data if parse_date(f['date']) > search_date] +        self.assertEqual(response.data, expected_data) + +        # Tests that the text filter set with 'icontains' in the filter class works. +        search_text = 'ff' +        request = factory.get('/', {'text': '%s' % search_text}) +        response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        expected_data = [f for f in self.data if search_text in f['text'].lower()] +        self.assertEqual(response.data, expected_data) + +        # Tests that multiple filters works. +        search_decimal = Decimal('5.25') +        search_date = datetime.date(2012, 10, 2) +        request = factory.get('/', { +            'decimal': '%s' % (search_decimal,), +            'date': '%s' % (search_date,) +        }) +        response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        expected_data = [f for f in self.data if parse_date(f['date']) > search_date and +                         Decimal(f['decimal']) < search_decimal] +        self.assertEqual(response.data, expected_data) + +    @unittest.skipUnless(django_filters, 'django-filter not installed') +    def test_incorrectly_configured_filter(self): +        """ +        An error should be displayed when the filter class is misconfigured. +        """ +        view = IncorrectlyConfiguredRootView.as_view() + +        request = factory.get('/') +        self.assertRaises(AssertionError, view, request) + +    @unittest.skipUnless(django_filters, 'django-filter not installed') +    def test_base_model_filter(self): +        """ +        The `get_filter_class` model checks should allow base model filters. +        """ +        view = BaseFilterableItemFilterRootView.as_view() + +        request = factory.get('/?text=aaa') +        response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(len(response.data), 1) + +    @unittest.skipUnless(django_filters, 'django-filter not installed') +    def test_unknown_filter(self): +        """ +        GET requests with filters that aren't configured should return 200. +        """ +        view = FilterFieldsRootView.as_view() + +        search_integer = 10 +        request = factory.get('/', {'integer': '%s' % search_integer}) +        response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) + + +class IntegrationTestDetailFiltering(CommonFilteringTestCase): +    """ +    Integration tests for filtered detail views. +    """ +    urls = 'tests.test_filters' + +    def _get_url(self, item): +        return reverse('detail-view', kwargs=dict(pk=item.pk)) + +    @unittest.skipUnless(django_filters, 'django-filter not installed') +    def test_get_filtered_detail_view(self): +        """ +        GET requests to filtered RetrieveAPIView that have a filter_class set +        should return filtered results. +        """ +        item = self.objects.all()[0] +        data = self._serialize_object(item) + +        # Basic test with no filter. +        response = self.client.get(self._get_url(item)) +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, data) + +        # Tests that the decimal filter set that should fail. +        search_decimal = Decimal('4.25') +        high_item = self.objects.filter(decimal__gt=search_decimal)[0] +        response = self.client.get( +            '{url}'.format(url=self._get_url(high_item)), +            {'decimal': '{param}'.format(param=search_decimal)}) +        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + +        # Tests that the decimal filter set that should succeed. +        search_decimal = Decimal('4.25') +        low_item = self.objects.filter(decimal__lt=search_decimal)[0] +        low_item_data = self._serialize_object(low_item) +        response = self.client.get( +            '{url}'.format(url=self._get_url(low_item)), +            {'decimal': '{param}'.format(param=search_decimal)}) +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, low_item_data) + +        # Tests that multiple filters works. +        search_decimal = Decimal('5.25') +        search_date = datetime.date(2012, 10, 2) +        valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0] +        valid_item_data = self._serialize_object(valid_item) +        response = self.client.get( +            '{url}'.format(url=self._get_url(valid_item)), { +                'decimal': '{decimal}'.format(decimal=search_decimal), +                'date': '{date}'.format(date=search_date) +            }) +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, valid_item_data) + + +class SearchFilterModel(models.Model): +    title = models.CharField(max_length=20) +    text = models.CharField(max_length=100) + + +class SearchFilterSerializer(serializers.ModelSerializer): +    class Meta: +        model = SearchFilterModel + + +class SearchFilterTests(TestCase): +    def setUp(self): +        # Sequence of title/text is: +        # +        # z   abc +        # zz  bcd +        # zzz cde +        # ... +        for idx in range(10): +            title = 'z' * (idx + 1) +            text = ( +                chr(idx + ord('a')) + +                chr(idx + ord('b')) + +                chr(idx + ord('c')) +            ) +            SearchFilterModel(title=title, text=text).save() + +    def test_search(self): +        class SearchListView(generics.ListAPIView): +            queryset = SearchFilterModel.objects.all() +            serializer_class = SearchFilterSerializer +            filter_backends = (filters.SearchFilter,) +            search_fields = ('title', 'text') + +        view = SearchListView.as_view() +        request = factory.get('/', {'search': 'b'}) +        response = view(request) +        self.assertEqual( +            response.data, +            [ +                {'id': 1, 'title': 'z', 'text': 'abc'}, +                {'id': 2, 'title': 'zz', 'text': 'bcd'} +            ] +        ) + +    def test_exact_search(self): +        class SearchListView(generics.ListAPIView): +            queryset = SearchFilterModel.objects.all() +            serializer_class = SearchFilterSerializer +            filter_backends = (filters.SearchFilter,) +            search_fields = ('=title', 'text') + +        view = SearchListView.as_view() +        request = factory.get('/', {'search': 'zzz'}) +        response = view(request) +        self.assertEqual( +            response.data, +            [ +                {'id': 3, 'title': 'zzz', 'text': 'cde'} +            ] +        ) + +    def test_startswith_search(self): +        class SearchListView(generics.ListAPIView): +            queryset = SearchFilterModel.objects.all() +            serializer_class = SearchFilterSerializer +            filter_backends = (filters.SearchFilter,) +            search_fields = ('title', '^text') + +        view = SearchListView.as_view() +        request = factory.get('/', {'search': 'b'}) +        response = view(request) +        self.assertEqual( +            response.data, +            [ +                {'id': 2, 'title': 'zz', 'text': 'bcd'} +            ] +        ) + +    def test_search_with_nonstandard_search_param(self): +        with override_settings(REST_FRAMEWORK={'SEARCH_PARAM': 'query'}): +            reload_module(filters) + +            class SearchListView(generics.ListAPIView): +                queryset = SearchFilterModel.objects.all() +                serializer_class = SearchFilterSerializer +                filter_backends = (filters.SearchFilter,) +                search_fields = ('title', 'text') + +            view = SearchListView.as_view() +            request = factory.get('/', {'query': 'b'}) +            response = view(request) +            self.assertEqual( +                response.data, +                [ +                    {'id': 1, 'title': 'z', 'text': 'abc'}, +                    {'id': 2, 'title': 'zz', 'text': 'bcd'} +                ] +            ) + +        reload_module(filters) + + +class AttributeModel(models.Model): +    label = models.CharField(max_length=32) + + +class SearchFilterModelM2M(models.Model): +    title = models.CharField(max_length=20) +    text = models.CharField(max_length=100) +    attributes = models.ManyToManyField(AttributeModel) + + +class SearchFilterM2MSerializer(serializers.ModelSerializer): +    class Meta: +        model = SearchFilterModelM2M + + +class SearchFilterM2MTests(TestCase): +    def setUp(self): +        # Sequence of title/text/attributes is: +        # +        # z   abc [1, 2, 3] +        # zz  bcd [1, 2, 3] +        # zzz cde [1, 2, 3] +        # ... +        for idx in range(3): +            label = 'w' * (idx + 1) +            AttributeModel(label=label) + +        for idx in range(10): +            title = 'z' * (idx + 1) +            text = ( +                chr(idx + ord('a')) + +                chr(idx + ord('b')) + +                chr(idx + ord('c')) +            ) +            SearchFilterModelM2M(title=title, text=text).save() +        SearchFilterModelM2M.objects.get(title='zz').attributes.add(1, 2, 3) + +    def test_m2m_search(self): +        class SearchListView(generics.ListAPIView): +            queryset = SearchFilterModelM2M.objects.all() +            serializer_class = SearchFilterM2MSerializer +            filter_backends = (filters.SearchFilter,) +            search_fields = ('=title', 'text', 'attributes__label') + +        view = SearchListView.as_view() +        request = factory.get('/', {'search': 'zz'}) +        response = view(request) +        self.assertEqual(len(response.data), 1) + + +class OrderingFilterModel(models.Model): +    title = models.CharField(max_length=20) +    text = models.CharField(max_length=100) + + +class OrderingFilterRelatedModel(models.Model): +    related_object = models.ForeignKey(OrderingFilterModel, +                                       related_name="relateds") + + +class OrderingFilterSerializer(serializers.ModelSerializer): +    class Meta: +        model = OrderingFilterModel + + +class DjangoFilterOrderingModel(models.Model): +    date = models.DateField() +    text = models.CharField(max_length=10) + +    class Meta: +        ordering = ['-date'] + + +class DjangoFilterOrderingSerializer(serializers.ModelSerializer): +    class Meta: +        model = DjangoFilterOrderingModel + + +class DjangoFilterOrderingTests(TestCase): +    def setUp(self): +        data = [{ +            'date': datetime.date(2012, 10, 8), +            'text': 'abc' +        }, { +            'date': datetime.date(2013, 10, 8), +            'text': 'bcd' +        }, { +            'date': datetime.date(2014, 10, 8), +            'text': 'cde' +        }] + +        for d in data: +            DjangoFilterOrderingModel.objects.create(**d) + +    @unittest.skipUnless(django_filters, 'django-filter not installed') +    def test_default_ordering(self): +        class DjangoFilterOrderingView(generics.ListAPIView): +            serializer_class = DjangoFilterOrderingSerializer +            queryset = DjangoFilterOrderingModel.objects.all() +            filter_backends = (filters.DjangoFilterBackend,) +            filter_fields = ['text'] +            ordering = ('-date',) + +        view = DjangoFilterOrderingView.as_view() +        request = factory.get('/') +        response = view(request) + +        self.assertEqual( +            response.data, +            [ +                {'id': 3, 'date': '2014-10-08', 'text': 'cde'}, +                {'id': 2, 'date': '2013-10-08', 'text': 'bcd'}, +                {'id': 1, 'date': '2012-10-08', 'text': 'abc'} +            ] +        ) + + +class OrderingFilterTests(TestCase): +    def setUp(self): +        # Sequence of title/text is: +        # +        # zyx abc +        # yxw bcd +        # xwv cde +        for idx in range(3): +            title = ( +                chr(ord('z') - idx) + +                chr(ord('y') - idx) + +                chr(ord('x') - idx) +            ) +            text = ( +                chr(idx + ord('a')) + +                chr(idx + ord('b')) + +                chr(idx + ord('c')) +            ) +            OrderingFilterModel(title=title, text=text).save() + +    def test_ordering(self): +        class OrderingListView(generics.ListAPIView): +            queryset = OrderingFilterModel.objects.all() +            serializer_class = OrderingFilterSerializer +            filter_backends = (filters.OrderingFilter,) +            ordering = ('title',) +            ordering_fields = ('text',) + +        view = OrderingListView.as_view() +        request = factory.get('/', {'ordering': 'text'}) +        response = view(request) +        self.assertEqual( +            response.data, +            [ +                {'id': 1, 'title': 'zyx', 'text': 'abc'}, +                {'id': 2, 'title': 'yxw', 'text': 'bcd'}, +                {'id': 3, 'title': 'xwv', 'text': 'cde'}, +            ] +        ) + +    def test_reverse_ordering(self): +        class OrderingListView(generics.ListAPIView): +            queryset = OrderingFilterModel.objects.all() +            serializer_class = OrderingFilterSerializer +            filter_backends = (filters.OrderingFilter,) +            ordering = ('title',) +            ordering_fields = ('text',) + +        view = OrderingListView.as_view() +        request = factory.get('/', {'ordering': '-text'}) +        response = view(request) +        self.assertEqual( +            response.data, +            [ +                {'id': 3, 'title': 'xwv', 'text': 'cde'}, +                {'id': 2, 'title': 'yxw', 'text': 'bcd'}, +                {'id': 1, 'title': 'zyx', 'text': 'abc'}, +            ] +        ) + +    def test_incorrectfield_ordering(self): +        class OrderingListView(generics.ListAPIView): +            queryset = OrderingFilterModel.objects.all() +            serializer_class = OrderingFilterSerializer +            filter_backends = (filters.OrderingFilter,) +            ordering = ('title',) +            ordering_fields = ('text',) + +        view = OrderingListView.as_view() +        request = factory.get('/', {'ordering': 'foobar'}) +        response = view(request) +        self.assertEqual( +            response.data, +            [ +                {'id': 3, 'title': 'xwv', 'text': 'cde'}, +                {'id': 2, 'title': 'yxw', 'text': 'bcd'}, +                {'id': 1, 'title': 'zyx', 'text': 'abc'}, +            ] +        ) + +    def test_default_ordering(self): +        class OrderingListView(generics.ListAPIView): +            queryset = OrderingFilterModel.objects.all() +            serializer_class = OrderingFilterSerializer +            filter_backends = (filters.OrderingFilter,) +            ordering = ('title',) +            oredering_fields = ('text',) + +        view = OrderingListView.as_view() +        request = factory.get('') +        response = view(request) +        self.assertEqual( +            response.data, +            [ +                {'id': 3, 'title': 'xwv', 'text': 'cde'}, +                {'id': 2, 'title': 'yxw', 'text': 'bcd'}, +                {'id': 1, 'title': 'zyx', 'text': 'abc'}, +            ] +        ) + +    def test_default_ordering_using_string(self): +        class OrderingListView(generics.ListAPIView): +            queryset = OrderingFilterModel.objects.all() +            serializer_class = OrderingFilterSerializer +            filter_backends = (filters.OrderingFilter,) +            ordering = 'title' +            ordering_fields = ('text',) + +        view = OrderingListView.as_view() +        request = factory.get('') +        response = view(request) +        self.assertEqual( +            response.data, +            [ +                {'id': 3, 'title': 'xwv', 'text': 'cde'}, +                {'id': 2, 'title': 'yxw', 'text': 'bcd'}, +                {'id': 1, 'title': 'zyx', 'text': 'abc'}, +            ] +        ) + +    def test_ordering_by_aggregate_field(self): +        # create some related models to aggregate order by +        num_objs = [2, 5, 3] +        for obj, num_relateds in zip(OrderingFilterModel.objects.all(), +                                     num_objs): +            for _ in range(num_relateds): +                new_related = OrderingFilterRelatedModel( +                    related_object=obj +                ) +                new_related.save() + +        class OrderingListView(generics.ListAPIView): +            serializer_class = OrderingFilterSerializer +            filter_backends = (filters.OrderingFilter,) +            ordering = 'title' +            ordering_fields = '__all__' +            queryset = OrderingFilterModel.objects.all().annotate( +                models.Count("relateds")) + +        view = OrderingListView.as_view() +        request = factory.get('/', {'ordering': 'relateds__count'}) +        response = view(request) +        self.assertEqual( +            response.data, +            [ +                {'id': 1, 'title': 'zyx', 'text': 'abc'}, +                {'id': 3, 'title': 'xwv', 'text': 'cde'}, +                {'id': 2, 'title': 'yxw', 'text': 'bcd'}, +            ] +        ) + +    def test_ordering_with_nonstandard_ordering_param(self): +        with override_settings(REST_FRAMEWORK={'ORDERING_PARAM': 'order'}): +            reload_module(filters) + +            class OrderingListView(generics.ListAPIView): +                queryset = OrderingFilterModel.objects.all() +                serializer_class = OrderingFilterSerializer +                filter_backends = (filters.OrderingFilter,) +                ordering = ('title',) +                ordering_fields = ('text',) + +            view = OrderingListView.as_view() +            request = factory.get('/', {'order': 'text'}) +            response = view(request) +            self.assertEqual( +                response.data, +                [ +                    {'id': 1, 'title': 'zyx', 'text': 'abc'}, +                    {'id': 2, 'title': 'yxw', 'text': 'bcd'}, +                    {'id': 3, 'title': 'xwv', 'text': 'cde'}, +                ] +            ) + +        reload_module(filters) + + +class SensitiveOrderingFilterModel(models.Model): +    username = models.CharField(max_length=20) +    password = models.CharField(max_length=100) + + +# Three different styles of serializer. +# All should allow ordering by username, but not by password. +class SensitiveDataSerializer1(serializers.ModelSerializer): +    username = serializers.CharField() + +    class Meta: +        model = SensitiveOrderingFilterModel +        fields = ('id', 'username') + + +class SensitiveDataSerializer2(serializers.ModelSerializer): +    username = serializers.CharField() +    password = serializers.CharField(write_only=True) + +    class Meta: +        model = SensitiveOrderingFilterModel +        fields = ('id', 'username', 'password') + + +class SensitiveDataSerializer3(serializers.ModelSerializer): +    user = serializers.CharField(source='username') + +    class Meta: +        model = SensitiveOrderingFilterModel +        fields = ('id', 'user') + + +class SensitiveOrderingFilterTests(TestCase): +    def setUp(self): +        for idx in range(3): +            username = {0: 'userA', 1: 'userB', 2: 'userC'}[idx] +            password = {0: 'passA', 1: 'passC', 2: 'passB'}[idx] +            SensitiveOrderingFilterModel(username=username, password=password).save() + +    def test_order_by_serializer_fields(self): +        for serializer_cls in [ +            SensitiveDataSerializer1, +            SensitiveDataSerializer2, +            SensitiveDataSerializer3 +        ]: +            class OrderingListView(generics.ListAPIView): +                queryset = SensitiveOrderingFilterModel.objects.all().order_by('username') +                filter_backends = (filters.OrderingFilter,) +                serializer_class = serializer_cls + +            view = OrderingListView.as_view() +            request = factory.get('/', {'ordering': '-username'}) +            response = view(request) + +            if serializer_cls == SensitiveDataSerializer3: +                username_field = 'user' +            else: +                username_field = 'username' + +            # Note: Inverse username ordering correctly applied. +            self.assertEqual( +                response.data, +                [ +                    {'id': 3, username_field: 'userC'}, +                    {'id': 2, username_field: 'userB'}, +                    {'id': 1, username_field: 'userA'}, +                ] +            ) + +    def test_cannot_order_by_non_serializer_fields(self): +        for serializer_cls in [ +            SensitiveDataSerializer1, +            SensitiveDataSerializer2, +            SensitiveDataSerializer3 +        ]: +            class OrderingListView(generics.ListAPIView): +                queryset = SensitiveOrderingFilterModel.objects.all().order_by('username') +                filter_backends = (filters.OrderingFilter,) +                serializer_class = serializer_cls + +            view = OrderingListView.as_view() +            request = factory.get('/', {'ordering': 'password'}) +            response = view(request) + +            if serializer_cls == SensitiveDataSerializer3: +                username_field = 'user' +            else: +                username_field = 'username' + +            # Note: The passwords are not in order.  Default ordering is used. +            self.assertEqual( +                response.data, +                [ +                    {'id': 1, username_field: 'userA'},  # PassB +                    {'id': 2, username_field: 'userB'},  # PassC +                    {'id': 3, username_field: 'userC'},  # PassA +                ] +            ) diff --git a/tests/test_generics.py b/tests/test_generics.py new file mode 100644 index 00000000..88e792ce --- /dev/null +++ b/tests/test_generics.py @@ -0,0 +1,506 @@ +from __future__ import unicode_literals +import django +from django.db import models +from django.shortcuts import get_object_or_404 +from django.test import TestCase +from django.utils import six +from rest_framework import generics, renderers, serializers, status +from rest_framework.test import APIRequestFactory +from tests.models import BasicModel, RESTFrameworkModel +from tests.models import ForeignKeySource, ForeignKeyTarget + +factory = APIRequestFactory() + + +# Models +class SlugBasedModel(RESTFrameworkModel): +    text = models.CharField(max_length=100) +    slug = models.SlugField(max_length=32) + + +# Model for regression test for #285 +class Comment(RESTFrameworkModel): +    email = models.EmailField() +    content = models.CharField(max_length=200) +    created = models.DateTimeField(auto_now_add=True) + + +# Serializers +class BasicSerializer(serializers.ModelSerializer): +    class Meta: +        model = BasicModel + + +class ForeignKeySerializer(serializers.ModelSerializer): +    class Meta: +        model = ForeignKeySource + + +class SlugSerializer(serializers.ModelSerializer): +    slug = serializers.ReadOnlyField() + +    class Meta: +        model = SlugBasedModel +        fields = ('text', 'slug') + + +# Views +class RootView(generics.ListCreateAPIView): +    queryset = BasicModel.objects.all() +    serializer_class = BasicSerializer + + +class InstanceView(generics.RetrieveUpdateDestroyAPIView): +    queryset = BasicModel.objects.exclude(text='filtered out') +    serializer_class = BasicSerializer + + +class FKInstanceView(generics.RetrieveUpdateDestroyAPIView): +    queryset = ForeignKeySource.objects.all() +    serializer_class = ForeignKeySerializer + + +class SlugBasedInstanceView(InstanceView): +    """ +    A model with a slug-field. +    """ +    queryset = SlugBasedModel.objects.all() +    serializer_class = SlugSerializer +    lookup_field = 'slug' + + +# Tests +class TestRootView(TestCase): +    def setUp(self): +        """ +        Create 3 BasicModel instances. +        """ +        items = ['foo', 'bar', 'baz'] +        for item in items: +            BasicModel(text=item).save() +        self.objects = BasicModel.objects +        self.data = [ +            {'id': obj.id, 'text': obj.text} +            for obj in self.objects.all() +        ] +        self.view = RootView.as_view() + +    def test_get_root_view(self): +        """ +        GET requests to ListCreateAPIView should return list of objects. +        """ +        request = factory.get('/') +        with self.assertNumQueries(1): +            response = self.view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, self.data) + +    def test_post_root_view(self): +        """ +        POST requests to ListCreateAPIView should create a new object. +        """ +        data = {'text': 'foobar'} +        request = factory.post('/', data, format='json') +        with self.assertNumQueries(1): +            response = self.view(request).render() +        self.assertEqual(response.status_code, status.HTTP_201_CREATED) +        self.assertEqual(response.data, {'id': 4, 'text': 'foobar'}) +        created = self.objects.get(id=4) +        self.assertEqual(created.text, 'foobar') + +    def test_put_root_view(self): +        """ +        PUT requests to ListCreateAPIView should not be allowed +        """ +        data = {'text': 'foobar'} +        request = factory.put('/', data, format='json') +        with self.assertNumQueries(0): +            response = self.view(request).render() +        self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) +        self.assertEqual(response.data, {"detail": 'Method "PUT" not allowed.'}) + +    def test_delete_root_view(self): +        """ +        DELETE requests to ListCreateAPIView should not be allowed +        """ +        request = factory.delete('/') +        with self.assertNumQueries(0): +            response = self.view(request).render() +        self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) +        self.assertEqual(response.data, {"detail": 'Method "DELETE" not allowed.'}) + +    def test_post_cannot_set_id(self): +        """ +        POST requests to create a new object should not be able to set the id. +        """ +        data = {'id': 999, 'text': 'foobar'} +        request = factory.post('/', data, format='json') +        with self.assertNumQueries(1): +            response = self.view(request).render() +        self.assertEqual(response.status_code, status.HTTP_201_CREATED) +        self.assertEqual(response.data, {'id': 4, 'text': 'foobar'}) +        created = self.objects.get(id=4) +        self.assertEqual(created.text, 'foobar') + + +EXPECTED_QUERIES_FOR_PUT = 3 if django.VERSION < (1, 6) else 2 + + +class TestInstanceView(TestCase): +    def setUp(self): +        """ +        Create 3 BasicModel instances. +        """ +        items = ['foo', 'bar', 'baz', 'filtered out'] +        for item in items: +            BasicModel(text=item).save() +        self.objects = BasicModel.objects.exclude(text='filtered out') +        self.data = [ +            {'id': obj.id, 'text': obj.text} +            for obj in self.objects.all() +        ] +        self.view = InstanceView.as_view() +        self.slug_based_view = SlugBasedInstanceView.as_view() + +    def test_get_instance_view(self): +        """ +        GET requests to RetrieveUpdateDestroyAPIView should return a single object. +        """ +        request = factory.get('/1') +        with self.assertNumQueries(1): +            response = self.view(request, pk=1).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, self.data[0]) + +    def test_post_instance_view(self): +        """ +        POST requests to RetrieveUpdateDestroyAPIView should not be allowed +        """ +        data = {'text': 'foobar'} +        request = factory.post('/', data, format='json') +        with self.assertNumQueries(0): +            response = self.view(request).render() +        self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) +        self.assertEqual(response.data, {"detail": 'Method "POST" not allowed.'}) + +    def test_put_instance_view(self): +        """ +        PUT requests to RetrieveUpdateDestroyAPIView should update an object. +        """ +        data = {'text': 'foobar'} +        request = factory.put('/1', data, format='json') +        with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT): +            response = self.view(request, pk='1').render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(dict(response.data), {'id': 1, 'text': 'foobar'}) +        updated = self.objects.get(id=1) +        self.assertEqual(updated.text, 'foobar') + +    def test_patch_instance_view(self): +        """ +        PATCH requests to RetrieveUpdateDestroyAPIView should update an object. +        """ +        data = {'text': 'foobar'} +        request = factory.patch('/1', data, format='json') + +        with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT): +            response = self.view(request, pk=1).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) +        updated = self.objects.get(id=1) +        self.assertEqual(updated.text, 'foobar') + +    def test_delete_instance_view(self): +        """ +        DELETE requests to RetrieveUpdateDestroyAPIView should delete an object. +        """ +        request = factory.delete('/1') +        with self.assertNumQueries(2): +            response = self.view(request, pk=1).render() +        self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) +        self.assertEqual(response.content, six.b('')) +        ids = [obj.id for obj in self.objects.all()] +        self.assertEqual(ids, [2, 3]) + +    def test_get_instance_view_incorrect_arg(self): +        """ +        GET requests with an incorrect pk type, should raise 404, not 500. +        Regression test for #890. +        """ +        request = factory.get('/a') +        with self.assertNumQueries(0): +            response = self.view(request, pk='a').render() +        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + +    def test_put_cannot_set_id(self): +        """ +        PUT requests to create a new object should not be able to set the id. +        """ +        data = {'id': 999, 'text': 'foobar'} +        request = factory.put('/1', data, format='json') +        with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT): +            response = self.view(request, pk=1).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) +        updated = self.objects.get(id=1) +        self.assertEqual(updated.text, 'foobar') + +    def test_put_to_deleted_instance(self): +        """ +        PUT requests to RetrieveUpdateDestroyAPIView should return 404 if +        an object does not currently exist. +        """ +        self.objects.get(id=1).delete() +        data = {'text': 'foobar'} +        request = factory.put('/1', data, format='json') +        with self.assertNumQueries(1): +            response = self.view(request, pk=1).render() +        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + +    def test_put_to_filtered_out_instance(self): +        """ +        PUT requests to an URL of instance which is filtered out should not be +        able to create new objects. +        """ +        data = {'text': 'foo'} +        filtered_out_pk = BasicModel.objects.filter(text='filtered out')[0].pk +        request = factory.put('/{0}'.format(filtered_out_pk), data, format='json') +        response = self.view(request, pk=filtered_out_pk).render() +        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + +    def test_patch_cannot_create_an_object(self): +        """ +        PATCH requests should not be able to create objects. +        """ +        data = {'text': 'foobar'} +        request = factory.patch('/999', data, format='json') +        with self.assertNumQueries(1): +            response = self.view(request, pk=999).render() +        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) +        self.assertFalse(self.objects.filter(id=999).exists()) + + +class TestFKInstanceView(TestCase): +    def setUp(self): +        """ +        Create 3 BasicModel instances. +        """ +        items = ['foo', 'bar', 'baz'] +        for item in items: +            t = ForeignKeyTarget(name=item) +            t.save() +            ForeignKeySource(name='source_' + item, target=t).save() + +        self.objects = ForeignKeySource.objects +        self.data = [ +            {'id': obj.id, 'name': obj.name} +            for obj in self.objects.all() +        ] +        self.view = FKInstanceView.as_view() + + +class TestOverriddenGetObject(TestCase): +    """ +    Test cases for a RetrieveUpdateDestroyAPIView that does NOT use the +    queryset/model mechanism but instead overrides get_object() +    """ + +    def setUp(self): +        """ +        Create 3 BasicModel instances. +        """ +        items = ['foo', 'bar', 'baz'] +        for item in items: +            BasicModel(text=item).save() +        self.objects = BasicModel.objects +        self.data = [ +            {'id': obj.id, 'text': obj.text} +            for obj in self.objects.all() +        ] + +        class OverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView): +            """ +            Example detail view for override of get_object(). +            """ +            serializer_class = BasicSerializer + +            def get_object(self): +                pk = int(self.kwargs['pk']) +                return get_object_or_404(BasicModel.objects.all(), id=pk) + +        self.view = OverriddenGetObjectView.as_view() + +    def test_overridden_get_object_view(self): +        """ +        GET requests to RetrieveUpdateDestroyAPIView should return a single object. +        """ +        request = factory.get('/1') +        with self.assertNumQueries(1): +            response = self.view(request, pk=1).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, self.data[0]) + + +# Regression test for #285 + +class CommentSerializer(serializers.ModelSerializer): +    class Meta: +        model = Comment +        exclude = ('created',) + + +class CommentView(generics.ListCreateAPIView): +    serializer_class = CommentSerializer +    model = Comment + + +class TestCreateModelWithAutoNowAddField(TestCase): +    def setUp(self): +        self.objects = Comment.objects +        self.view = CommentView.as_view() + +    def test_create_model_with_auto_now_add_field(self): +        """ +        Regression test for #285 + +        https://github.com/tomchristie/django-rest-framework/issues/285 +        """ +        data = {'email': 'foobar@example.com', 'content': 'foobar'} +        request = factory.post('/', data, format='json') +        response = self.view(request).render() +        self.assertEqual(response.status_code, status.HTTP_201_CREATED) +        created = self.objects.get(id=1) +        self.assertEqual(created.content, 'foobar') + + +# Test for particularly ugly regression with m2m in browsable API +class ClassB(models.Model): +    name = models.CharField(max_length=255) + + +class ClassA(models.Model): +    name = models.CharField(max_length=255) +    children = models.ManyToManyField(ClassB, blank=True, null=True) + + +class ClassASerializer(serializers.ModelSerializer): +    children = serializers.PrimaryKeyRelatedField( +        many=True, queryset=ClassB.objects.all() +    ) + +    class Meta: +        model = ClassA + + +class ExampleView(generics.ListCreateAPIView): +    serializer_class = ClassASerializer +    queryset = ClassA.objects.all() + + +class TestM2MBrowsableAPI(TestCase): +    def test_m2m_in_browsable_api(self): +        """ +        Test for particularly ugly regression with m2m in browsable API +        """ +        request = factory.get('/', HTTP_ACCEPT='text/html') +        view = ExampleView().as_view() +        response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) + + +class InclusiveFilterBackend(object): +    def filter_queryset(self, request, queryset, view): +        return queryset.filter(text='foo') + + +class ExclusiveFilterBackend(object): +    def filter_queryset(self, request, queryset, view): +        return queryset.filter(text='other') + + +class TwoFieldModel(models.Model): +    field_a = models.CharField(max_length=100) +    field_b = models.CharField(max_length=100) + + +class DynamicSerializerView(generics.ListCreateAPIView): +    queryset = TwoFieldModel.objects.all() +    renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer) + +    def get_serializer_class(self): +        if self.request.method == 'POST': +            class DynamicSerializer(serializers.ModelSerializer): +                class Meta: +                    model = TwoFieldModel +                    fields = ('field_b',) +        else: +            class DynamicSerializer(serializers.ModelSerializer): +                class Meta: +                    model = TwoFieldModel +        return DynamicSerializer + + +class TestFilterBackendAppliedToViews(TestCase): +    def setUp(self): +        """ +        Create 3 BasicModel instances to filter on. +        """ +        items = ['foo', 'bar', 'baz'] +        for item in items: +            BasicModel(text=item).save() +        self.objects = BasicModel.objects +        self.data = [ +            {'id': obj.id, 'text': obj.text} +            for obj in self.objects.all() +        ] + +    def test_get_root_view_filters_by_name_with_filter_backend(self): +        """ +        GET requests to ListCreateAPIView should return filtered list. +        """ +        root_view = RootView.as_view(filter_backends=(InclusiveFilterBackend,)) +        request = factory.get('/') +        response = root_view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(len(response.data), 1) +        self.assertEqual(response.data, [{'id': 1, 'text': 'foo'}]) + +    def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(self): +        """ +        GET requests to ListCreateAPIView should return empty list when all models are filtered out. +        """ +        root_view = RootView.as_view(filter_backends=(ExclusiveFilterBackend,)) +        request = factory.get('/') +        response = root_view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, []) + +    def test_get_instance_view_filters_out_name_with_filter_backend(self): +        """ +        GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out. +        """ +        instance_view = InstanceView.as_view(filter_backends=(ExclusiveFilterBackend,)) +        request = factory.get('/1') +        response = instance_view(request, pk=1).render() +        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) +        self.assertEqual(response.data, {'detail': 'Not found.'}) + +    def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self): +        """ +        GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded +        """ +        instance_view = InstanceView.as_view(filter_backends=(InclusiveFilterBackend,)) +        request = factory.get('/1') +        response = instance_view(request, pk=1).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, {'id': 1, 'text': 'foo'}) + +    def test_dynamic_serializer_form_in_browsable_api(self): +        """ +        GET requests to ListCreateAPIView should return filtered list. +        """ +        view = DynamicSerializerView.as_view() +        request = factory.get('/') +        response = view(request).render() +        self.assertContains(response, 'field_b') +        self.assertNotContains(response, 'field_a') diff --git a/tests/test_htmlrenderer.py b/tests/test_htmlrenderer.py new file mode 100644 index 00000000..a33b832f --- /dev/null +++ b/tests/test_htmlrenderer.py @@ -0,0 +1,127 @@ +from __future__ import unicode_literals +from django.core.exceptions import PermissionDenied +from django.conf.urls import patterns, url +from django.http import Http404 +from django.test import TestCase +from django.template import TemplateDoesNotExist, Template +from django.utils import six +from rest_framework import status +from rest_framework.decorators import api_view, renderer_classes +from rest_framework.renderers import TemplateHTMLRenderer +from rest_framework.response import Response +import django.template.loader + + +@api_view(('GET',)) +@renderer_classes((TemplateHTMLRenderer,)) +def example(request): +    """ +    A view that can returns an HTML representation. +    """ +    data = {'object': 'foobar'} +    return Response(data, template_name='example.html') + + +@api_view(('GET',)) +@renderer_classes((TemplateHTMLRenderer,)) +def permission_denied(request): +    raise PermissionDenied() + + +@api_view(('GET',)) +@renderer_classes((TemplateHTMLRenderer,)) +def not_found(request): +    raise Http404() + + +urlpatterns = patterns( +    '', +    url(r'^$', example), +    url(r'^permission_denied$', permission_denied), +    url(r'^not_found$', not_found), +) + + +class TemplateHTMLRendererTests(TestCase): +    urls = 'tests.test_htmlrenderer' + +    def setUp(self): +        """ +        Monkeypatch get_template +        """ +        self.get_template = django.template.loader.get_template + +        def get_template(template_name, dirs=None): +            if template_name == 'example.html': +                return Template("example: {{ object }}") +            raise TemplateDoesNotExist(template_name) + +        def select_template(template_name_list, dirs=None, using=None): +            if template_name_list == ['example.html']: +                return Template("example: {{ object }}") +            raise TemplateDoesNotExist(template_name_list[0]) + +        django.template.loader.get_template = get_template +        django.template.loader.select_template = select_template + +    def tearDown(self): +        """ +        Revert monkeypatching +        """ +        django.template.loader.get_template = self.get_template + +    def test_simple_html_view(self): +        response = self.client.get('/') +        self.assertContains(response, "example: foobar") +        self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') + +    def test_not_found_html_view(self): +        response = self.client.get('/not_found') +        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) +        self.assertEqual(response.content, six.b("404 Not Found")) +        self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') + +    def test_permission_denied_html_view(self): +        response = self.client.get('/permission_denied') +        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) +        self.assertEqual(response.content, six.b("403 Forbidden")) +        self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') + + +class TemplateHTMLRendererExceptionTests(TestCase): +    urls = 'tests.test_htmlrenderer' + +    def setUp(self): +        """ +        Monkeypatch get_template +        """ +        self.get_template = django.template.loader.get_template + +        def get_template(template_name): +            if template_name == '404.html': +                return Template("404: {{ detail }}") +            if template_name == '403.html': +                return Template("403: {{ detail }}") +            raise TemplateDoesNotExist(template_name) + +        django.template.loader.get_template = get_template + +    def tearDown(self): +        """ +        Revert monkeypatching +        """ +        django.template.loader.get_template = self.get_template + +    def test_not_found_html_view_with_template(self): +        response = self.client.get('/not_found') +        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) +        self.assertTrue(response.content in ( +            six.b("404: Not found"), six.b("404 Not Found"))) +        self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') + +    def test_permission_denied_html_view_with_template(self): +        response = self.client.get('/permission_denied') +        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) +        self.assertTrue(response.content in ( +            six.b("403: Permission denied"), six.b("403 Forbidden"))) +        self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') diff --git a/tests/test_metadata.py b/tests/test_metadata.py new file mode 100644 index 00000000..3a435f02 --- /dev/null +++ b/tests/test_metadata.py @@ -0,0 +1,209 @@ +from __future__ import unicode_literals +from rest_framework import exceptions, serializers, status, views, versioning +from rest_framework.request import Request +from rest_framework.renderers import BrowsableAPIRenderer +from rest_framework.test import APIRequestFactory + +request = Request(APIRequestFactory().options('/')) + + +class TestMetadata: +    def test_metadata(self): +        """ +        OPTIONS requests to views should return a valid 200 response. +        """ +        class ExampleView(views.APIView): +            """Example view.""" +            pass + +        view = ExampleView.as_view() +        response = view(request=request) +        expected = { +            'name': 'Example', +            'description': 'Example view.', +            'renders': [ +                'application/json', +                'text/html' +            ], +            'parses': [ +                'application/json', +                'application/x-www-form-urlencoded', +                'multipart/form-data' +            ] +        } +        assert response.status_code == status.HTTP_200_OK +        assert response.data == expected + +    def test_none_metadata(self): +        """ +        OPTIONS requests to views where `metadata_class = None` should raise +        a MethodNotAllowed exception, which will result in an HTTP 405 response. +        """ +        class ExampleView(views.APIView): +            metadata_class = None + +        view = ExampleView.as_view() +        response = view(request=request) +        assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED +        assert response.data == {'detail': 'Method "OPTIONS" not allowed.'} + +    def test_actions(self): +        """ +        On generic views OPTIONS should return an 'actions' key with metadata +        on the fields that may be supplied to PUT and POST requests. +        """ +        class ExampleSerializer(serializers.Serializer): +            choice_field = serializers.ChoiceField(['red', 'green', 'blue']) +            integer_field = serializers.IntegerField( +                min_value=1, max_value=1000 +            ) +            char_field = serializers.CharField( +                required=False, min_length=3, max_length=40 +            ) + +        class ExampleView(views.APIView): +            """Example view.""" +            def post(self, request): +                pass + +            def get_serializer(self): +                return ExampleSerializer() + +        view = ExampleView.as_view() +        response = view(request=request) +        expected = { +            'name': 'Example', +            'description': 'Example view.', +            'renders': [ +                'application/json', +                'text/html' +            ], +            'parses': [ +                'application/json', +                'application/x-www-form-urlencoded', +                'multipart/form-data' +            ], +            'actions': { +                'POST': { +                    'choice_field': { +                        'type': 'choice', +                        'required': True, +                        'read_only': False, +                        'label': 'Choice field', +                        'choices': [ +                            {'display_name': 'red', 'value': 'red'}, +                            {'display_name': 'green', 'value': 'green'}, +                            {'display_name': 'blue', 'value': 'blue'} +                        ] +                    }, +                    'integer_field': { +                        'type': 'integer', +                        'required': True, +                        'read_only': False, +                        'label': 'Integer field', +                        'min_value': 1, +                        'max_value': 1000, + +                    }, +                    'char_field': { +                        'type': 'string', +                        'required': False, +                        'read_only': False, +                        'label': 'Char field', +                        'min_length': 3, +                        'max_length': 40 +                    } +                } +            } +        } +        assert response.status_code == status.HTTP_200_OK +        assert response.data == expected + +    def test_global_permissions(self): +        """ +        If a user does not have global permissions on an action, then any +        metadata associated with it should not be included in OPTION responses. +        """ +        class ExampleSerializer(serializers.Serializer): +            choice_field = serializers.ChoiceField(['red', 'green', 'blue']) +            integer_field = serializers.IntegerField(max_value=10) +            char_field = serializers.CharField(required=False) + +        class ExampleView(views.APIView): +            """Example view.""" +            def post(self, request): +                pass + +            def put(self, request): +                pass + +            def get_serializer(self): +                return ExampleSerializer() + +            def check_permissions(self, request): +                if request.method == 'POST': +                    raise exceptions.PermissionDenied() + +        view = ExampleView.as_view() +        response = view(request=request) +        assert response.status_code == status.HTTP_200_OK +        assert list(response.data['actions'].keys()) == ['PUT'] + +    def test_object_permissions(self): +        """ +        If a user does not have object permissions on an action, then any +        metadata associated with it should not be included in OPTION responses. +        """ +        class ExampleSerializer(serializers.Serializer): +            choice_field = serializers.ChoiceField(['red', 'green', 'blue']) +            integer_field = serializers.IntegerField(max_value=10) +            char_field = serializers.CharField(required=False) + +        class ExampleView(views.APIView): +            """Example view.""" +            def post(self, request): +                pass + +            def put(self, request): +                pass + +            def get_serializer(self): +                return ExampleSerializer() + +            def get_object(self): +                if self.request.method == 'PUT': +                    raise exceptions.PermissionDenied() + +        view = ExampleView.as_view() +        response = view(request=request) +        assert response.status_code == status.HTTP_200_OK +        assert list(response.data['actions'].keys()) == ['POST'] + +    def test_bug_2455_clone_request(self): +        class ExampleView(views.APIView): +            renderer_classes = (BrowsableAPIRenderer,) + +            def post(self, request): +                pass + +            def get_serializer(self): +                assert hasattr(self.request, 'version') +                return serializers.Serializer() + +        view = ExampleView.as_view() +        view(request=request) + +    def test_bug_2477_clone_request(self): +        class ExampleView(views.APIView): +            renderer_classes = (BrowsableAPIRenderer,) + +            def post(self, request): +                pass + +            def get_serializer(self): +                assert hasattr(self.request, 'versioning_scheme') +                return serializers.Serializer() + +        scheme = versioning.QueryParameterVersioning +        view = ExampleView.as_view(versioning_class=scheme) +        view(request=request) diff --git a/tests/test_middleware.py b/tests/test_middleware.py new file mode 100644 index 00000000..4c099fca --- /dev/null +++ b/tests/test_middleware.py @@ -0,0 +1,37 @@ + +from django.conf.urls import patterns, url +from django.contrib.auth.models import User +from rest_framework.authentication import TokenAuthentication +from rest_framework.authtoken.models import Token +from rest_framework.test import APITestCase +from rest_framework.views import APIView + + +urlpatterns = patterns( +    '', +    url(r'^$', APIView.as_view(authentication_classes=(TokenAuthentication,))), +) + + +class MyMiddleware(object): + +    def process_response(self, request, response): +        assert hasattr(request, 'user'), '`user` is not set on request' +        assert request.user.is_authenticated(), '`user` is not authenticated' +        return response + + +class TestMiddleware(APITestCase): + +    urls = 'tests.test_middleware' + +    def test_middleware_can_access_user_when_processing_response(self): +        user = User.objects.create_user('john', 'john@example.com', 'password') +        key = 'abcd1234' +        Token.objects.create(key=key, user=user) + +        with self.settings( +            MIDDLEWARE_CLASSES=('tests.test_middleware.MyMiddleware',) +        ): +            auth = 'Token ' + key +            self.client.get('/', HTTP_AUTHORIZATION=auth) diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py new file mode 100644 index 00000000..bce2008a --- /dev/null +++ b/tests/test_model_serializer.py @@ -0,0 +1,641 @@ +""" +The `ModelSerializer` and `HyperlinkedModelSerializer` classes are essentially +shortcuts for automatically creating serializers based on a given model class. + +These tests deal with ensuring that we correctly map the model fields onto +an appropriate set of serializer fields for each case. +""" +from __future__ import unicode_literals +from django.core.exceptions import ImproperlyConfigured +from django.core.validators import MaxValueValidator, MinValueValidator, MinLengthValidator +from django.db import models +from django.test import TestCase +from django.utils import six +from rest_framework import serializers +from rest_framework.compat import unicode_repr + + +def dedent(blocktext): +    return '\n'.join([line[12:] for line in blocktext.splitlines()[1:-1]]) + + +# Tests for regular field mappings. +# --------------------------------- + +class CustomField(models.Field): +    """ +    A custom model field simply for testing purposes. +    """ +    pass + + +class OneFieldModel(models.Model): +    char_field = models.CharField(max_length=100) + + +class RegularFieldsModel(models.Model): +    """ +    A model class for testing regular flat fields. +    """ +    auto_field = models.AutoField(primary_key=True) +    big_integer_field = models.BigIntegerField() +    boolean_field = models.BooleanField(default=False) +    char_field = models.CharField(max_length=100) +    comma_separated_integer_field = models.CommaSeparatedIntegerField(max_length=100) +    date_field = models.DateField() +    datetime_field = models.DateTimeField() +    decimal_field = models.DecimalField(max_digits=3, decimal_places=1) +    email_field = models.EmailField(max_length=100) +    float_field = models.FloatField() +    integer_field = models.IntegerField() +    null_boolean_field = models.NullBooleanField() +    positive_integer_field = models.PositiveIntegerField() +    positive_small_integer_field = models.PositiveSmallIntegerField() +    slug_field = models.SlugField(max_length=100) +    small_integer_field = models.SmallIntegerField() +    text_field = models.TextField() +    time_field = models.TimeField() +    url_field = models.URLField(max_length=100) +    custom_field = CustomField() + +    def method(self): +        return 'method' + + +COLOR_CHOICES = (('red', 'Red'), ('blue', 'Blue'), ('green', 'Green')) + + +class FieldOptionsModel(models.Model): +    value_limit_field = models.IntegerField(validators=[MinValueValidator(1), MaxValueValidator(10)]) +    length_limit_field = models.CharField(validators=[MinLengthValidator(3)], max_length=12) +    blank_field = models.CharField(blank=True, max_length=10) +    null_field = models.IntegerField(null=True) +    default_field = models.IntegerField(default=0) +    descriptive_field = models.IntegerField(help_text='Some help text', verbose_name='A label') +    choices_field = models.CharField(max_length=100, choices=COLOR_CHOICES) + + +class TestModelSerializer(TestCase): +    def test_create_method(self): +        class TestSerializer(serializers.ModelSerializer): +            non_model_field = serializers.CharField() + +            class Meta: +                model = OneFieldModel +                fields = ('char_field', 'non_model_field') + +        serializer = TestSerializer(data={ +            'char_field': 'foo', +            'non_model_field': 'bar', +        }) +        serializer.is_valid() +        with self.assertRaises(TypeError) as excinfo: +            serializer.save() +        msginitial = 'Got a `TypeError` when calling `OneFieldModel.objects.create()`.' +        assert str(excinfo.exception).startswith(msginitial) + + +class TestRegularFieldMappings(TestCase): +    def test_regular_fields(self): +        """ +        Model fields should map to their equivelent serializer fields. +        """ +        class TestSerializer(serializers.ModelSerializer): +            class Meta: +                model = RegularFieldsModel + +        expected = dedent(""" +            TestSerializer(): +                auto_field = IntegerField(read_only=True) +                big_integer_field = IntegerField() +                boolean_field = BooleanField(required=False) +                char_field = CharField(max_length=100) +                comma_separated_integer_field = CharField(max_length=100, validators=[<django.core.validators.RegexValidator object>]) +                date_field = DateField() +                datetime_field = DateTimeField() +                decimal_field = DecimalField(decimal_places=1, max_digits=3) +                email_field = EmailField(max_length=100) +                float_field = FloatField() +                integer_field = IntegerField() +                null_boolean_field = NullBooleanField(required=False) +                positive_integer_field = IntegerField() +                positive_small_integer_field = IntegerField() +                slug_field = SlugField(max_length=100) +                small_integer_field = IntegerField() +                text_field = CharField(style={'base_template': 'textarea.html'}) +                time_field = TimeField() +                url_field = URLField(max_length=100) +                custom_field = ModelField(model_field=<tests.test_model_serializer.CustomField: custom_field>) +        """) +        self.assertEqual(unicode_repr(TestSerializer()), expected) + +    def test_field_options(self): +        class TestSerializer(serializers.ModelSerializer): +            class Meta: +                model = FieldOptionsModel + +        expected = dedent(""" +            TestSerializer(): +                id = IntegerField(label='ID', read_only=True) +                value_limit_field = IntegerField(max_value=10, min_value=1) +                length_limit_field = CharField(max_length=12, min_length=3) +                blank_field = CharField(allow_blank=True, max_length=10, required=False) +                null_field = IntegerField(allow_null=True, required=False) +                default_field = IntegerField(required=False) +                descriptive_field = IntegerField(help_text='Some help text', label='A label') +                choices_field = ChoiceField(choices=[('red', 'Red'), ('blue', 'Blue'), ('green', 'Green')]) +        """) +        if six.PY2: +            # This particular case is too awkward to resolve fully across +            # both py2 and py3. +            expected = expected.replace( +                "('red', 'Red'), ('blue', 'Blue'), ('green', 'Green')", +                "(u'red', u'Red'), (u'blue', u'Blue'), (u'green', u'Green')" +            ) +        self.assertEqual(unicode_repr(TestSerializer()), expected) + +    def test_method_field(self): +        """ +        Properties and methods on the model should be allowed as `Meta.fields` +        values, and should map to `ReadOnlyField`. +        """ +        class TestSerializer(serializers.ModelSerializer): +            class Meta: +                model = RegularFieldsModel +                fields = ('auto_field', 'method') + +        expected = dedent(""" +            TestSerializer(): +                auto_field = IntegerField(read_only=True) +                method = ReadOnlyField() +        """) +        self.assertEqual(repr(TestSerializer()), expected) + +    def test_pk_fields(self): +        """ +        Both `pk` and the actual primary key name are valid in `Meta.fields`. +        """ +        class TestSerializer(serializers.ModelSerializer): +            class Meta: +                model = RegularFieldsModel +                fields = ('pk', 'auto_field') + +        expected = dedent(""" +            TestSerializer(): +                pk = IntegerField(label='Auto field', read_only=True) +                auto_field = IntegerField(read_only=True) +        """) +        self.assertEqual(repr(TestSerializer()), expected) + +    def test_extra_field_kwargs(self): +        """ +        Ensure `extra_kwargs` are passed to generated fields. +        """ +        class TestSerializer(serializers.ModelSerializer): +            class Meta: +                model = RegularFieldsModel +                fields = ('auto_field', 'char_field') +                extra_kwargs = {'char_field': {'default': 'extra'}} + +        expected = dedent(""" +            TestSerializer(): +                auto_field = IntegerField(read_only=True) +                char_field = CharField(default='extra', max_length=100) +        """) +        self.assertEqual(repr(TestSerializer()), expected) + +    def test_invalid_field(self): +        """ +        Field names that do not map to a model field or relationship should +        raise a configuration errror. +        """ +        class TestSerializer(serializers.ModelSerializer): +            class Meta: +                model = RegularFieldsModel +                fields = ('auto_field', 'invalid') + +        with self.assertRaises(ImproperlyConfigured) as excinfo: +            TestSerializer().fields +        expected = 'Field name `invalid` is not valid for model `RegularFieldsModel`.' +        assert str(excinfo.exception) == expected + +    def test_missing_field(self): +        """ +        Fields that have been declared on the serializer class must be included +        in the `Meta.fields` if it exists. +        """ +        class TestSerializer(serializers.ModelSerializer): +            missing = serializers.ReadOnlyField() + +            class Meta: +                model = RegularFieldsModel +                fields = ('auto_field',) + +        with self.assertRaises(AssertionError) as excinfo: +            TestSerializer().fields +        expected = ( +            "The field 'missing' was declared on serializer TestSerializer, " +            "but has not been included in the 'fields' option." +        ) +        assert str(excinfo.exception) == expected + +    def test_missing_superclass_field(self): +        """ +        Fields that have been declared on a parent of the serializer class may +        be excluded from the `Meta.fields` option. +        """ +        class TestSerializer(serializers.ModelSerializer): +            missing = serializers.ReadOnlyField() + +            class Meta: +                model = RegularFieldsModel + +        class ChildSerializer(TestSerializer): +            missing = serializers.ReadOnlyField() + +            class Meta: +                model = RegularFieldsModel +                fields = ('auto_field',) + +        ChildSerializer().fields + + +# Tests for relational field mappings. +# ------------------------------------ + +class ForeignKeyTargetModel(models.Model): +    name = models.CharField(max_length=100) + + +class ManyToManyTargetModel(models.Model): +    name = models.CharField(max_length=100) + + +class OneToOneTargetModel(models.Model): +    name = models.CharField(max_length=100) + + +class ThroughTargetModel(models.Model): +    name = models.CharField(max_length=100) + + +class Supplementary(models.Model): +    extra = models.IntegerField() +    forwards = models.ForeignKey('ThroughTargetModel') +    backwards = models.ForeignKey('RelationalModel') + + +class RelationalModel(models.Model): +    foreign_key = models.ForeignKey(ForeignKeyTargetModel, related_name='reverse_foreign_key') +    many_to_many = models.ManyToManyField(ManyToManyTargetModel, related_name='reverse_many_to_many') +    one_to_one = models.OneToOneField(OneToOneTargetModel, related_name='reverse_one_to_one') +    through = models.ManyToManyField(ThroughTargetModel, through=Supplementary, related_name='reverse_through') + + +class TestRelationalFieldMappings(TestCase): +    def test_pk_relations(self): +        class TestSerializer(serializers.ModelSerializer): +            class Meta: +                model = RelationalModel + +        expected = dedent(""" +            TestSerializer(): +                id = IntegerField(label='ID', read_only=True) +                foreign_key = PrimaryKeyRelatedField(queryset=ForeignKeyTargetModel.objects.all()) +                one_to_one = PrimaryKeyRelatedField(queryset=OneToOneTargetModel.objects.all(), validators=[<UniqueValidator(queryset=RelationalModel.objects.all())>]) +                many_to_many = PrimaryKeyRelatedField(many=True, queryset=ManyToManyTargetModel.objects.all()) +                through = PrimaryKeyRelatedField(many=True, read_only=True) +        """) +        self.assertEqual(unicode_repr(TestSerializer()), expected) + +    def test_nested_relations(self): +        class TestSerializer(serializers.ModelSerializer): +            class Meta: +                model = RelationalModel +                depth = 1 + +        expected = dedent(""" +            TestSerializer(): +                id = IntegerField(label='ID', read_only=True) +                foreign_key = NestedSerializer(read_only=True): +                    id = IntegerField(label='ID', read_only=True) +                    name = CharField(max_length=100) +                one_to_one = NestedSerializer(read_only=True): +                    id = IntegerField(label='ID', read_only=True) +                    name = CharField(max_length=100) +                many_to_many = NestedSerializer(many=True, read_only=True): +                    id = IntegerField(label='ID', read_only=True) +                    name = CharField(max_length=100) +                through = NestedSerializer(many=True, read_only=True): +                    id = IntegerField(label='ID', read_only=True) +                    name = CharField(max_length=100) +        """) +        self.assertEqual(unicode_repr(TestSerializer()), expected) + +    def test_hyperlinked_relations(self): +        class TestSerializer(serializers.HyperlinkedModelSerializer): +            class Meta: +                model = RelationalModel + +        expected = dedent(""" +            TestSerializer(): +                url = HyperlinkedIdentityField(view_name='relationalmodel-detail') +                foreign_key = HyperlinkedRelatedField(queryset=ForeignKeyTargetModel.objects.all(), view_name='foreignkeytargetmodel-detail') +                one_to_one = HyperlinkedRelatedField(queryset=OneToOneTargetModel.objects.all(), validators=[<UniqueValidator(queryset=RelationalModel.objects.all())>], view_name='onetoonetargetmodel-detail') +                many_to_many = HyperlinkedRelatedField(many=True, queryset=ManyToManyTargetModel.objects.all(), view_name='manytomanytargetmodel-detail') +                through = HyperlinkedRelatedField(many=True, read_only=True, view_name='throughtargetmodel-detail') +        """) +        self.assertEqual(unicode_repr(TestSerializer()), expected) + +    def test_nested_hyperlinked_relations(self): +        class TestSerializer(serializers.HyperlinkedModelSerializer): +            class Meta: +                model = RelationalModel +                depth = 1 + +        expected = dedent(""" +            TestSerializer(): +                url = HyperlinkedIdentityField(view_name='relationalmodel-detail') +                foreign_key = NestedSerializer(read_only=True): +                    url = HyperlinkedIdentityField(view_name='foreignkeytargetmodel-detail') +                    name = CharField(max_length=100) +                one_to_one = NestedSerializer(read_only=True): +                    url = HyperlinkedIdentityField(view_name='onetoonetargetmodel-detail') +                    name = CharField(max_length=100) +                many_to_many = NestedSerializer(many=True, read_only=True): +                    url = HyperlinkedIdentityField(view_name='manytomanytargetmodel-detail') +                    name = CharField(max_length=100) +                through = NestedSerializer(many=True, read_only=True): +                    url = HyperlinkedIdentityField(view_name='throughtargetmodel-detail') +                    name = CharField(max_length=100) +        """) +        self.assertEqual(unicode_repr(TestSerializer()), expected) + +    def test_pk_reverse_foreign_key(self): +        class TestSerializer(serializers.ModelSerializer): +            class Meta: +                model = ForeignKeyTargetModel +                fields = ('id', 'name', 'reverse_foreign_key') + +        expected = dedent(""" +            TestSerializer(): +                id = IntegerField(label='ID', read_only=True) +                name = CharField(max_length=100) +                reverse_foreign_key = PrimaryKeyRelatedField(many=True, queryset=RelationalModel.objects.all()) +        """) +        self.assertEqual(unicode_repr(TestSerializer()), expected) + +    def test_pk_reverse_one_to_one(self): +        class TestSerializer(serializers.ModelSerializer): +            class Meta: +                model = OneToOneTargetModel +                fields = ('id', 'name', 'reverse_one_to_one') + +        expected = dedent(""" +            TestSerializer(): +                id = IntegerField(label='ID', read_only=True) +                name = CharField(max_length=100) +                reverse_one_to_one = PrimaryKeyRelatedField(queryset=RelationalModel.objects.all()) +        """) +        self.assertEqual(unicode_repr(TestSerializer()), expected) + +    def test_pk_reverse_many_to_many(self): +        class TestSerializer(serializers.ModelSerializer): +            class Meta: +                model = ManyToManyTargetModel +                fields = ('id', 'name', 'reverse_many_to_many') + +        expected = dedent(""" +            TestSerializer(): +                id = IntegerField(label='ID', read_only=True) +                name = CharField(max_length=100) +                reverse_many_to_many = PrimaryKeyRelatedField(many=True, queryset=RelationalModel.objects.all()) +        """) +        self.assertEqual(unicode_repr(TestSerializer()), expected) + +    def test_pk_reverse_through(self): +        class TestSerializer(serializers.ModelSerializer): +            class Meta: +                model = ThroughTargetModel +                fields = ('id', 'name', 'reverse_through') + +        expected = dedent(""" +            TestSerializer(): +                id = IntegerField(label='ID', read_only=True) +                name = CharField(max_length=100) +                reverse_through = PrimaryKeyRelatedField(many=True, read_only=True) +        """) +        self.assertEqual(unicode_repr(TestSerializer()), expected) + + +class TestIntegration(TestCase): +    def setUp(self): +        self.foreign_key_target = ForeignKeyTargetModel.objects.create( +            name='foreign_key' +        ) +        self.one_to_one_target = OneToOneTargetModel.objects.create( +            name='one_to_one' +        ) +        self.many_to_many_targets = [ +            ManyToManyTargetModel.objects.create( +                name='many_to_many (%d)' % idx +            ) for idx in range(3) +        ] +        self.instance = RelationalModel.objects.create( +            foreign_key=self.foreign_key_target, +            one_to_one=self.one_to_one_target, +        ) +        self.instance.many_to_many = self.many_to_many_targets +        self.instance.save() + +    def test_pk_retrival(self): +        class TestSerializer(serializers.ModelSerializer): +            class Meta: +                model = RelationalModel + +        serializer = TestSerializer(self.instance) +        expected = { +            'id': self.instance.pk, +            'foreign_key': self.foreign_key_target.pk, +            'one_to_one': self.one_to_one_target.pk, +            'many_to_many': [item.pk for item in self.many_to_many_targets], +            'through': [] +        } +        self.assertEqual(serializer.data, expected) + +    def test_pk_create(self): +        class TestSerializer(serializers.ModelSerializer): +            class Meta: +                model = RelationalModel + +        new_foreign_key = ForeignKeyTargetModel.objects.create( +            name='foreign_key' +        ) +        new_one_to_one = OneToOneTargetModel.objects.create( +            name='one_to_one' +        ) +        new_many_to_many = [ +            ManyToManyTargetModel.objects.create( +                name='new many_to_many (%d)' % idx +            ) for idx in range(3) +        ] +        data = { +            'foreign_key': new_foreign_key.pk, +            'one_to_one': new_one_to_one.pk, +            'many_to_many': [item.pk for item in new_many_to_many], +        } + +        # Serializer should validate okay. +        serializer = TestSerializer(data=data) +        assert serializer.is_valid() + +        # Creating the instance, relationship attributes should be set. +        instance = serializer.save() +        assert instance.foreign_key.pk == new_foreign_key.pk +        assert instance.one_to_one.pk == new_one_to_one.pk +        assert [ +            item.pk for item in instance.many_to_many.all() +        ] == [ +            item.pk for item in new_many_to_many +        ] +        assert list(instance.through.all()) == [] + +        # Representation should be correct. +        expected = { +            'id': instance.pk, +            'foreign_key': new_foreign_key.pk, +            'one_to_one': new_one_to_one.pk, +            'many_to_many': [item.pk for item in new_many_to_many], +            'through': [] +        } +        self.assertEqual(serializer.data, expected) + +    def test_pk_update(self): +        class TestSerializer(serializers.ModelSerializer): +            class Meta: +                model = RelationalModel + +        new_foreign_key = ForeignKeyTargetModel.objects.create( +            name='foreign_key' +        ) +        new_one_to_one = OneToOneTargetModel.objects.create( +            name='one_to_one' +        ) +        new_many_to_many = [ +            ManyToManyTargetModel.objects.create( +                name='new many_to_many (%d)' % idx +            ) for idx in range(3) +        ] +        data = { +            'foreign_key': new_foreign_key.pk, +            'one_to_one': new_one_to_one.pk, +            'many_to_many': [item.pk for item in new_many_to_many], +        } + +        # Serializer should validate okay. +        serializer = TestSerializer(self.instance, data=data) +        assert serializer.is_valid() + +        # Creating the instance, relationship attributes should be set. +        instance = serializer.save() +        assert instance.foreign_key.pk == new_foreign_key.pk +        assert instance.one_to_one.pk == new_one_to_one.pk +        assert [ +            item.pk for item in instance.many_to_many.all() +        ] == [ +            item.pk for item in new_many_to_many +        ] +        assert list(instance.through.all()) == [] + +        # Representation should be correct. +        expected = { +            'id': self.instance.pk, +            'foreign_key': new_foreign_key.pk, +            'one_to_one': new_one_to_one.pk, +            'many_to_many': [item.pk for item in new_many_to_many], +            'through': [] +        } +        self.assertEqual(serializer.data, expected) + + +# Tests for bulk create using `ListSerializer`. + +class BulkCreateModel(models.Model): +    name = models.CharField(max_length=10) + + +class TestBulkCreate(TestCase): +    def test_bulk_create(self): +        class BasicModelSerializer(serializers.ModelSerializer): +            class Meta: +                model = BulkCreateModel +                fields = ('name',) + +        class BulkCreateSerializer(serializers.ListSerializer): +            child = BasicModelSerializer() + +        data = [{'name': 'a'}, {'name': 'b'}, {'name': 'c'}] +        serializer = BulkCreateSerializer(data=data) +        assert serializer.is_valid() + +        # Objects are returned by save(). +        instances = serializer.save() +        assert len(instances) == 3 +        assert [item.name for item in instances] == ['a', 'b', 'c'] + +        # Objects have been created in the database. +        assert BulkCreateModel.objects.count() == 3 +        assert list(BulkCreateModel.objects.values_list('name', flat=True)) == ['a', 'b', 'c'] + +        # Serializer returns correct data. +        assert serializer.data == data + + +class TestMetaClassModel(models.Model): +    text = models.CharField(max_length=100) + + +class TestSerializerMetaClass(TestCase): +    def test_meta_class_fields_option(self): +        class ExampleSerializer(serializers.ModelSerializer): +            class Meta: +                model = TestMetaClassModel +                fields = 'text' + +        with self.assertRaises(TypeError) as result: +            ExampleSerializer().fields + +        exception = result.exception +        assert str(exception).startswith( +            "The `fields` option must be a list or tuple" +        ) + +    def test_meta_class_exclude_option(self): +        class ExampleSerializer(serializers.ModelSerializer): +            class Meta: +                model = TestMetaClassModel +                exclude = 'text' + +        with self.assertRaises(TypeError) as result: +            ExampleSerializer().fields + +        exception = result.exception +        assert str(exception).startswith( +            "The `exclude` option must be a list or tuple" +        ) + +    def test_meta_class_fields_and_exclude_options(self): +        class ExampleSerializer(serializers.ModelSerializer): +            class Meta: +                model = TestMetaClassModel +                fields = ('text',) +                exclude = ('text',) + +        with self.assertRaises(AssertionError) as result: +            ExampleSerializer().fields + +        exception = result.exception +        self.assertEqual( +            str(exception), +            "Cannot set both 'fields' and 'exclude' options on serializer ExampleSerializer." +        ) diff --git a/tests/test_multitable_inheritance.py b/tests/test_multitable_inheritance.py new file mode 100644 index 00000000..15627e1d --- /dev/null +++ b/tests/test_multitable_inheritance.py @@ -0,0 +1,67 @@ +from __future__ import unicode_literals +from django.db import models +from django.test import TestCase +from rest_framework import serializers +from tests.models import RESTFrameworkModel + + +# Models +class ParentModel(RESTFrameworkModel): +    name1 = models.CharField(max_length=100) + + +class ChildModel(ParentModel): +    name2 = models.CharField(max_length=100) + + +class AssociatedModel(RESTFrameworkModel): +    ref = models.OneToOneField(ParentModel, primary_key=True) +    name = models.CharField(max_length=100) + + +# Serializers +class DerivedModelSerializer(serializers.ModelSerializer): +    class Meta: +        model = ChildModel + + +class AssociatedModelSerializer(serializers.ModelSerializer): +    class Meta: +        model = AssociatedModel + + +# Tests +class InheritedModelSerializationTests(TestCase): + +    def test_multitable_inherited_model_fields_as_expected(self): +        """ +        Assert that the parent pointer field is not included in the fields +        serialized fields +        """ +        child = ChildModel(name1='parent name', name2='child name') +        serializer = DerivedModelSerializer(child) +        self.assertEqual(set(serializer.data.keys()), +                         set(['name1', 'name2', 'id'])) + +    def test_onetoone_primary_key_model_fields_as_expected(self): +        """ +        Assert that a model with a onetoone field that is the primary key is +        not treated like a derived model +        """ +        parent = ParentModel.objects.create(name1='parent name') +        associate = AssociatedModel.objects.create(name='hello', ref=parent) +        serializer = AssociatedModelSerializer(associate) +        self.assertEqual(set(serializer.data.keys()), +                         set(['name', 'ref'])) + +    def test_data_is_valid_without_parent_ptr(self): +        """ +        Assert that the pointer to the parent table is not a required field +        for input data +        """ +        data = { +            'name1': 'parent name', +            'name2': 'child name', +        } +        serializer = DerivedModelSerializer(data=data) +        self.assertEqual(serializer.is_valid(), True) diff --git a/tests/test_negotiation.py b/tests/test_negotiation.py new file mode 100644 index 00000000..04b89eb6 --- /dev/null +++ b/tests/test_negotiation.py @@ -0,0 +1,45 @@ +from __future__ import unicode_literals +from django.test import TestCase +from rest_framework.negotiation import DefaultContentNegotiation +from rest_framework.request import Request +from rest_framework.renderers import BaseRenderer +from rest_framework.test import APIRequestFactory + + +factory = APIRequestFactory() + + +class MockJSONRenderer(BaseRenderer): +    media_type = 'application/json' + + +class MockHTMLRenderer(BaseRenderer): +    media_type = 'text/html' + + +class NoCharsetSpecifiedRenderer(BaseRenderer): +    media_type = 'my/media' + + +class TestAcceptedMediaType(TestCase): +    def setUp(self): +        self.renderers = [MockJSONRenderer(), MockHTMLRenderer()] +        self.negotiator = DefaultContentNegotiation() + +    def select_renderer(self, request): +        return self.negotiator.select_renderer(request, self.renderers) + +    def test_client_without_accept_use_renderer(self): +        request = Request(factory.get('/')) +        accepted_renderer, accepted_media_type = self.select_renderer(request) +        self.assertEqual(accepted_media_type, 'application/json') + +    def test_client_underspecifies_accept_use_renderer(self): +        request = Request(factory.get('/', HTTP_ACCEPT='*/*')) +        accepted_renderer, accepted_media_type = self.select_renderer(request) +        self.assertEqual(accepted_media_type, 'application/json') + +    def test_client_overspecifies_accept_use_client(self): +        request = Request(factory.get('/', HTTP_ACCEPT='application/json; indent=8')) +        accepted_renderer, accepted_media_type = self.select_renderer(request) +        self.assertEqual(accepted_media_type, 'application/json; indent=8') diff --git a/tests/test_pagination.py b/tests/test_pagination.py new file mode 100644 index 00000000..6b39a6f2 --- /dev/null +++ b/tests/test_pagination.py @@ -0,0 +1,671 @@ +# coding: utf-8 +from __future__ import unicode_literals +from rest_framework import exceptions, generics, pagination, serializers, status, filters +from rest_framework.request import Request +from rest_framework.pagination import PageLink, PAGE_BREAK +from rest_framework.test import APIRequestFactory +import pytest + +factory = APIRequestFactory() + + +class TestPaginationIntegration: +    """ +    Integration tests. +    """ + +    def setup(self): +        class PassThroughSerializer(serializers.BaseSerializer): +            def to_representation(self, item): +                return item + +        class EvenItemsOnly(filters.BaseFilterBackend): +            def filter_queryset(self, request, queryset, view): +                return [item for item in queryset if item % 2 == 0] + +        class BasicPagination(pagination.PageNumberPagination): +            page_size = 5 +            page_size_query_param = 'page_size' +            max_page_size = 20 + +        self.view = generics.ListAPIView.as_view( +            serializer_class=PassThroughSerializer, +            queryset=range(1, 101), +            filter_backends=[EvenItemsOnly], +            pagination_class=BasicPagination +        ) + +    def test_filtered_items_are_paginated(self): +        request = factory.get('/', {'page': 2}) +        response = self.view(request) +        assert response.status_code == status.HTTP_200_OK +        assert response.data == { +            'results': [12, 14, 16, 18, 20], +            'previous': 'http://testserver/', +            'next': 'http://testserver/?page=3', +            'count': 50 +        } + +    def test_setting_page_size(self): +        """ +        When 'paginate_by_param' is set, the client may choose a page size. +        """ +        request = factory.get('/', {'page_size': 10}) +        response = self.view(request) +        assert response.status_code == status.HTTP_200_OK +        assert response.data == { +            'results': [2, 4, 6, 8, 10, 12, 14, 16, 18, 20], +            'previous': None, +            'next': 'http://testserver/?page=2&page_size=10', +            'count': 50 +        } + +    def test_setting_page_size_over_maximum(self): +        """ +        When page_size parameter exceeds maxiumum allowable, +        then it should be capped to the maxiumum. +        """ +        request = factory.get('/', {'page_size': 1000}) +        response = self.view(request) +        assert response.status_code == status.HTTP_200_OK +        assert response.data == { +            'results': [ +                2, 4, 6, 8, 10, 12, 14, 16, 18, 20, +                22, 24, 26, 28, 30, 32, 34, 36, 38, 40 +            ], +            'previous': None, +            'next': 'http://testserver/?page=2&page_size=1000', +            'count': 50 +        } + +    def test_setting_page_size_to_zero(self): +        """ +        When page_size parameter is invalid it should return to the default. +        """ +        request = factory.get('/', {'page_size': 0}) +        response = self.view(request) +        assert response.status_code == status.HTTP_200_OK +        assert response.data == { +            'results': [2, 4, 6, 8, 10], +            'previous': None, +            'next': 'http://testserver/?page=2&page_size=0', +            'count': 50 +        } + +    def test_additional_query_params_are_preserved(self): +        request = factory.get('/', {'page': 2, 'filter': 'even'}) +        response = self.view(request) +        assert response.status_code == status.HTTP_200_OK +        assert response.data == { +            'results': [12, 14, 16, 18, 20], +            'previous': 'http://testserver/?filter=even', +            'next': 'http://testserver/?filter=even&page=3', +            'count': 50 +        } + +    def test_404_not_found_for_zero_page(self): +        request = factory.get('/', {'page': '0'}) +        response = self.view(request) +        assert response.status_code == status.HTTP_404_NOT_FOUND +        assert response.data == { +            'detail': 'Invalid page "0": That page number is less than 1.' +        } + +    def test_404_not_found_for_invalid_page(self): +        request = factory.get('/', {'page': 'invalid'}) +        response = self.view(request) +        assert response.status_code == status.HTTP_404_NOT_FOUND +        assert response.data == { +            'detail': 'Invalid page "invalid": That page number is not an integer.' +        } + + +class TestPaginationDisabledIntegration: +    """ +    Integration tests for disabled pagination. +    """ + +    def setup(self): +        class PassThroughSerializer(serializers.BaseSerializer): +            def to_representation(self, item): +                return item + +        self.view = generics.ListAPIView.as_view( +            serializer_class=PassThroughSerializer, +            queryset=range(1, 101), +            pagination_class=None +        ) + +    def test_unpaginated_list(self): +        request = factory.get('/', {'page': 2}) +        response = self.view(request) +        assert response.status_code == status.HTTP_200_OK +        assert response.data == list(range(1, 101)) + + +class TestDeprecatedStylePagination: +    """ +    Integration tests for deprecated style of setting pagination +    attributes on the view. +    """ + +    def setup(self): +        class PassThroughSerializer(serializers.BaseSerializer): +            def to_representation(self, item): +                return item + +        class ExampleView(generics.ListAPIView): +            serializer_class = PassThroughSerializer +            queryset = range(1, 101) +            pagination_class = pagination.PageNumberPagination +            paginate_by = 20 +            page_query_param = 'page_number' + +        self.view = ExampleView.as_view() + +    def test_paginate_by_attribute_on_view(self): +        request = factory.get('/?page_number=2') +        response = self.view(request) +        assert response.status_code == status.HTTP_200_OK +        assert response.data == { +            'results': [ +                21, 22, 23, 24, 25, 26, 27, 28, 29, 30, +                31, 32, 33, 34, 35, 36, 37, 38, 39, 40 +            ], +            'previous': 'http://testserver/', +            'next': 'http://testserver/?page_number=3', +            'count': 100 +        } + + +class TestPageNumberPagination: +    """ +    Unit tests for `pagination.PageNumberPagination`. +    """ + +    def setup(self): +        class ExamplePagination(pagination.PageNumberPagination): +            page_size = 5 +        self.pagination = ExamplePagination() +        self.queryset = range(1, 101) + +    def paginate_queryset(self, request): +        return list(self.pagination.paginate_queryset(self.queryset, request)) + +    def get_paginated_content(self, queryset): +        response = self.pagination.get_paginated_response(queryset) +        return response.data + +    def get_html_context(self): +        return self.pagination.get_html_context() + +    def test_no_page_number(self): +        request = Request(factory.get('/')) +        queryset = self.paginate_queryset(request) +        content = self.get_paginated_content(queryset) +        context = self.get_html_context() +        assert queryset == [1, 2, 3, 4, 5] +        assert content == { +            'results': [1, 2, 3, 4, 5], +            'previous': None, +            'next': 'http://testserver/?page=2', +            'count': 100 +        } +        assert context == { +            'previous_url': None, +            'next_url': 'http://testserver/?page=2', +            'page_links': [ +                PageLink('http://testserver/', 1, True, False), +                PageLink('http://testserver/?page=2', 2, False, False), +                PageLink('http://testserver/?page=3', 3, False, False), +                PAGE_BREAK, +                PageLink('http://testserver/?page=20', 20, False, False), +            ] +        } +        assert self.pagination.display_page_controls +        assert isinstance(self.pagination.to_html(), type('')) + +    def test_second_page(self): +        request = Request(factory.get('/', {'page': 2})) +        queryset = self.paginate_queryset(request) +        content = self.get_paginated_content(queryset) +        context = self.get_html_context() +        assert queryset == [6, 7, 8, 9, 10] +        assert content == { +            'results': [6, 7, 8, 9, 10], +            'previous': 'http://testserver/', +            'next': 'http://testserver/?page=3', +            'count': 100 +        } +        assert context == { +            'previous_url': 'http://testserver/', +            'next_url': 'http://testserver/?page=3', +            'page_links': [ +                PageLink('http://testserver/', 1, False, False), +                PageLink('http://testserver/?page=2', 2, True, False), +                PageLink('http://testserver/?page=3', 3, False, False), +                PAGE_BREAK, +                PageLink('http://testserver/?page=20', 20, False, False), +            ] +        } + +    def test_last_page(self): +        request = Request(factory.get('/', {'page': 'last'})) +        queryset = self.paginate_queryset(request) +        content = self.get_paginated_content(queryset) +        context = self.get_html_context() +        assert queryset == [96, 97, 98, 99, 100] +        assert content == { +            'results': [96, 97, 98, 99, 100], +            'previous': 'http://testserver/?page=19', +            'next': None, +            'count': 100 +        } +        assert context == { +            'previous_url': 'http://testserver/?page=19', +            'next_url': None, +            'page_links': [ +                PageLink('http://testserver/', 1, False, False), +                PAGE_BREAK, +                PageLink('http://testserver/?page=18', 18, False, False), +                PageLink('http://testserver/?page=19', 19, False, False), +                PageLink('http://testserver/?page=20', 20, True, False), +            ] +        } + +    def test_invalid_page(self): +        request = Request(factory.get('/', {'page': 'invalid'})) +        with pytest.raises(exceptions.NotFound): +            self.paginate_queryset(request) + + +class TestLimitOffset: +    """ +    Unit tests for `pagination.LimitOffsetPagination`. +    """ + +    def setup(self): +        class ExamplePagination(pagination.LimitOffsetPagination): +            default_limit = 10 +        self.pagination = ExamplePagination() +        self.queryset = range(1, 101) + +    def paginate_queryset(self, request): +        return list(self.pagination.paginate_queryset(self.queryset, request)) + +    def get_paginated_content(self, queryset): +        response = self.pagination.get_paginated_response(queryset) +        return response.data + +    def get_html_context(self): +        return self.pagination.get_html_context() + +    def test_no_offset(self): +        request = Request(factory.get('/', {'limit': 5})) +        queryset = self.paginate_queryset(request) +        content = self.get_paginated_content(queryset) +        context = self.get_html_context() +        assert queryset == [1, 2, 3, 4, 5] +        assert content == { +            'results': [1, 2, 3, 4, 5], +            'previous': None, +            'next': 'http://testserver/?limit=5&offset=5', +            'count': 100 +        } +        assert context == { +            'previous_url': None, +            'next_url': 'http://testserver/?limit=5&offset=5', +            'page_links': [ +                PageLink('http://testserver/?limit=5', 1, True, False), +                PageLink('http://testserver/?limit=5&offset=5', 2, False, False), +                PageLink('http://testserver/?limit=5&offset=10', 3, False, False), +                PAGE_BREAK, +                PageLink('http://testserver/?limit=5&offset=95', 20, False, False), +            ] +        } +        assert self.pagination.display_page_controls +        assert isinstance(self.pagination.to_html(), type('')) + +    def test_single_offset(self): +        """ +        When the offset is not a multiple of the limit we get some edge cases: +        * The first page should still be offset zero. +        * We may end up displaying an extra page in the pagination control. +        """ +        request = Request(factory.get('/', {'limit': 5, 'offset': 1})) +        queryset = self.paginate_queryset(request) +        content = self.get_paginated_content(queryset) +        context = self.get_html_context() +        assert queryset == [2, 3, 4, 5, 6] +        assert content == { +            'results': [2, 3, 4, 5, 6], +            'previous': 'http://testserver/?limit=5', +            'next': 'http://testserver/?limit=5&offset=6', +            'count': 100 +        } +        assert context == { +            'previous_url': 'http://testserver/?limit=5', +            'next_url': 'http://testserver/?limit=5&offset=6', +            'page_links': [ +                PageLink('http://testserver/?limit=5', 1, False, False), +                PageLink('http://testserver/?limit=5&offset=1', 2, True, False), +                PageLink('http://testserver/?limit=5&offset=6', 3, False, False), +                PAGE_BREAK, +                PageLink('http://testserver/?limit=5&offset=96', 21, False, False), +            ] +        } + +    def test_first_offset(self): +        request = Request(factory.get('/', {'limit': 5, 'offset': 5})) +        queryset = self.paginate_queryset(request) +        content = self.get_paginated_content(queryset) +        context = self.get_html_context() +        assert queryset == [6, 7, 8, 9, 10] +        assert content == { +            'results': [6, 7, 8, 9, 10], +            'previous': 'http://testserver/?limit=5', +            'next': 'http://testserver/?limit=5&offset=10', +            'count': 100 +        } +        assert context == { +            'previous_url': 'http://testserver/?limit=5', +            'next_url': 'http://testserver/?limit=5&offset=10', +            'page_links': [ +                PageLink('http://testserver/?limit=5', 1, False, False), +                PageLink('http://testserver/?limit=5&offset=5', 2, True, False), +                PageLink('http://testserver/?limit=5&offset=10', 3, False, False), +                PAGE_BREAK, +                PageLink('http://testserver/?limit=5&offset=95', 20, False, False), +            ] +        } + +    def test_middle_offset(self): +        request = Request(factory.get('/', {'limit': 5, 'offset': 10})) +        queryset = self.paginate_queryset(request) +        content = self.get_paginated_content(queryset) +        context = self.get_html_context() +        assert queryset == [11, 12, 13, 14, 15] +        assert content == { +            'results': [11, 12, 13, 14, 15], +            'previous': 'http://testserver/?limit=5&offset=5', +            'next': 'http://testserver/?limit=5&offset=15', +            'count': 100 +        } +        assert context == { +            'previous_url': 'http://testserver/?limit=5&offset=5', +            'next_url': 'http://testserver/?limit=5&offset=15', +            'page_links': [ +                PageLink('http://testserver/?limit=5', 1, False, False), +                PageLink('http://testserver/?limit=5&offset=5', 2, False, False), +                PageLink('http://testserver/?limit=5&offset=10', 3, True, False), +                PageLink('http://testserver/?limit=5&offset=15', 4, False, False), +                PAGE_BREAK, +                PageLink('http://testserver/?limit=5&offset=95', 20, False, False), +            ] +        } + +    def test_ending_offset(self): +        request = Request(factory.get('/', {'limit': 5, 'offset': 95})) +        queryset = self.paginate_queryset(request) +        content = self.get_paginated_content(queryset) +        context = self.get_html_context() +        assert queryset == [96, 97, 98, 99, 100] +        assert content == { +            'results': [96, 97, 98, 99, 100], +            'previous': 'http://testserver/?limit=5&offset=90', +            'next': None, +            'count': 100 +        } +        assert context == { +            'previous_url': 'http://testserver/?limit=5&offset=90', +            'next_url': None, +            'page_links': [ +                PageLink('http://testserver/?limit=5', 1, False, False), +                PAGE_BREAK, +                PageLink('http://testserver/?limit=5&offset=85', 18, False, False), +                PageLink('http://testserver/?limit=5&offset=90', 19, False, False), +                PageLink('http://testserver/?limit=5&offset=95', 20, True, False), +            ] +        } + +    def test_invalid_offset(self): +        """ +        An invalid offset query param should be treated as 0. +        """ +        request = Request(factory.get('/', {'limit': 5, 'offset': 'invalid'})) +        queryset = self.paginate_queryset(request) +        assert queryset == [1, 2, 3, 4, 5] + +    def test_invalid_limit(self): +        """ +        An invalid limit query param should be ignored in favor of the default. +        """ +        request = Request(factory.get('/', {'limit': 'invalid', 'offset': 0})) +        queryset = self.paginate_queryset(request) +        assert queryset == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + + +class TestCursorPagination: +    """ +    Unit tests for `pagination.CursorPagination`. +    """ + +    def setup(self): +        class MockObject(object): +            def __init__(self, idx): +                self.created = idx + +        class MockQuerySet(object): +            def __init__(self, items): +                self.items = items + +            def filter(self, created__gt=None, created__lt=None): +                if created__gt is not None: +                    return MockQuerySet([ +                        item for item in self.items +                        if item.created > int(created__gt) +                    ]) + +                assert created__lt is not None +                return MockQuerySet([ +                    item for item in self.items +                    if item.created < int(created__lt) +                ]) + +            def order_by(self, *ordering): +                if ordering[0].startswith('-'): +                    return MockQuerySet(list(reversed(self.items))) +                return self + +            def __getitem__(self, sliced): +                return self.items[sliced] + +        class ExamplePagination(pagination.CursorPagination): +            page_size = 5 +            ordering = 'created' + +        self.pagination = ExamplePagination() +        self.queryset = MockQuerySet([ +            MockObject(idx) for idx in [ +                1, 1, 1, 1, 1, +                1, 2, 3, 4, 4, +                4, 4, 5, 6, 7, +                7, 7, 7, 7, 7, +                7, 7, 7, 8, 9, +                9, 9, 9, 9, 9 +            ] +        ]) + +    def get_pages(self, url): +        """ +        Given a URL return a tuple of: + +        (previous page, current page, next page, previous url, next url) +        """ +        request = Request(factory.get(url)) +        queryset = self.pagination.paginate_queryset(self.queryset, request) +        current = [item.created for item in queryset] + +        next_url = self.pagination.get_next_link() +        previous_url = self.pagination.get_previous_link() + +        if next_url is not None: +            request = Request(factory.get(next_url)) +            queryset = self.pagination.paginate_queryset(self.queryset, request) +            next = [item.created for item in queryset] +        else: +            next = None + +        if previous_url is not None: +            request = Request(factory.get(previous_url)) +            queryset = self.pagination.paginate_queryset(self.queryset, request) +            previous = [item.created for item in queryset] +        else: +            previous = None + +        return (previous, current, next, previous_url, next_url) + +    def test_invalid_cursor(self): +        request = Request(factory.get('/', {'cursor': '123'})) +        with pytest.raises(exceptions.NotFound): +            self.pagination.paginate_queryset(self.queryset, request) + +    def test_use_with_ordering_filter(self): +        class MockView: +            filter_backends = (filters.OrderingFilter,) +            ordering_fields = ['username', 'created'] +            ordering = 'created' + +        request = Request(factory.get('/', {'ordering': 'username'})) +        ordering = self.pagination.get_ordering(request, [], MockView()) +        assert ordering == ('username',) + +        request = Request(factory.get('/', {'ordering': '-username'})) +        ordering = self.pagination.get_ordering(request, [], MockView()) +        assert ordering == ('-username',) + +        request = Request(factory.get('/', {'ordering': 'invalid'})) +        ordering = self.pagination.get_ordering(request, [], MockView()) +        assert ordering == ('created',) + +    def test_cursor_pagination(self): +        (previous, current, next, previous_url, next_url) = self.get_pages('/') + +        assert previous is None +        assert current == [1, 1, 1, 1, 1] +        assert next == [1, 2, 3, 4, 4] + +        (previous, current, next, previous_url, next_url) = self.get_pages(next_url) + +        assert previous == [1, 1, 1, 1, 1] +        assert current == [1, 2, 3, 4, 4] +        assert next == [4, 4, 5, 6, 7] + +        (previous, current, next, previous_url, next_url) = self.get_pages(next_url) + +        assert previous == [1, 2, 3, 4, 4] +        assert current == [4, 4, 5, 6, 7] +        assert next == [7, 7, 7, 7, 7] + +        (previous, current, next, previous_url, next_url) = self.get_pages(next_url) + +        assert previous == [4, 4, 4, 5, 6]  # Paging artifact +        assert current == [7, 7, 7, 7, 7] +        assert next == [7, 7, 7, 8, 9] + +        (previous, current, next, previous_url, next_url) = self.get_pages(next_url) + +        assert previous == [7, 7, 7, 7, 7] +        assert current == [7, 7, 7, 8, 9] +        assert next == [9, 9, 9, 9, 9] + +        (previous, current, next, previous_url, next_url) = self.get_pages(next_url) + +        assert previous == [7, 7, 7, 8, 9] +        assert current == [9, 9, 9, 9, 9] +        assert next is None + +        (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) + +        assert previous == [7, 7, 7, 7, 7] +        assert current == [7, 7, 7, 8, 9] +        assert next == [9, 9, 9, 9, 9] + +        (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) + +        assert previous == [4, 4, 5, 6, 7] +        assert current == [7, 7, 7, 7, 7] +        assert next == [8, 9, 9, 9, 9]  # Paging artifact + +        (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) + +        assert previous == [1, 2, 3, 4, 4] +        assert current == [4, 4, 5, 6, 7] +        assert next == [7, 7, 7, 7, 7] + +        (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) + +        assert previous == [1, 1, 1, 1, 1] +        assert current == [1, 2, 3, 4, 4] +        assert next == [4, 4, 5, 6, 7] + +        (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) + +        assert previous is None +        assert current == [1, 1, 1, 1, 1] +        assert next == [1, 2, 3, 4, 4] + +        assert isinstance(self.pagination.to_html(), type('')) + + +def test_get_displayed_page_numbers(): +    """ +    Test our contextual page display function. + +    This determines which pages to display in a pagination control, +    given the current page and the last page. +    """ +    displayed_page_numbers = pagination._get_displayed_page_numbers + +    # At five pages or less, all pages are displayed, always. +    assert displayed_page_numbers(1, 5) == [1, 2, 3, 4, 5] +    assert displayed_page_numbers(2, 5) == [1, 2, 3, 4, 5] +    assert displayed_page_numbers(3, 5) == [1, 2, 3, 4, 5] +    assert displayed_page_numbers(4, 5) == [1, 2, 3, 4, 5] +    assert displayed_page_numbers(5, 5) == [1, 2, 3, 4, 5] + +    # Between six and either pages we may have a single page break. +    assert displayed_page_numbers(1, 6) == [1, 2, 3, None, 6] +    assert displayed_page_numbers(2, 6) == [1, 2, 3, None, 6] +    assert displayed_page_numbers(3, 6) == [1, 2, 3, 4, 5, 6] +    assert displayed_page_numbers(4, 6) == [1, 2, 3, 4, 5, 6] +    assert displayed_page_numbers(5, 6) == [1, None, 4, 5, 6] +    assert displayed_page_numbers(6, 6) == [1, None, 4, 5, 6] + +    assert displayed_page_numbers(1, 7) == [1, 2, 3, None, 7] +    assert displayed_page_numbers(2, 7) == [1, 2, 3, None, 7] +    assert displayed_page_numbers(3, 7) == [1, 2, 3, 4, None, 7] +    assert displayed_page_numbers(4, 7) == [1, 2, 3, 4, 5, 6, 7] +    assert displayed_page_numbers(5, 7) == [1, None, 4, 5, 6, 7] +    assert displayed_page_numbers(6, 7) == [1, None, 5, 6, 7] +    assert displayed_page_numbers(7, 7) == [1, None, 5, 6, 7] + +    assert displayed_page_numbers(1, 8) == [1, 2, 3, None, 8] +    assert displayed_page_numbers(2, 8) == [1, 2, 3, None, 8] +    assert displayed_page_numbers(3, 8) == [1, 2, 3, 4, None, 8] +    assert displayed_page_numbers(4, 8) == [1, 2, 3, 4, 5, None, 8] +    assert displayed_page_numbers(5, 8) == [1, None, 4, 5, 6, 7, 8] +    assert displayed_page_numbers(6, 8) == [1, None, 5, 6, 7, 8] +    assert displayed_page_numbers(7, 8) == [1, None, 6, 7, 8] +    assert displayed_page_numbers(8, 8) == [1, None, 6, 7, 8] + +    # At nine or more pages we may have two page breaks, one on each side. +    assert displayed_page_numbers(1, 9) == [1, 2, 3, None, 9] +    assert displayed_page_numbers(2, 9) == [1, 2, 3, None, 9] +    assert displayed_page_numbers(3, 9) == [1, 2, 3, 4, None, 9] +    assert displayed_page_numbers(4, 9) == [1, 2, 3, 4, 5, None, 9] +    assert displayed_page_numbers(5, 9) == [1, None, 4, 5, 6, None, 9] +    assert displayed_page_numbers(6, 9) == [1, None, 5, 6, 7, 8, 9] +    assert displayed_page_numbers(7, 9) == [1, None, 6, 7, 8, 9] +    assert displayed_page_numbers(8, 9) == [1, None, 7, 8, 9] +    assert displayed_page_numbers(9, 9) == [1, None, 7, 8, 9] diff --git a/tests/test_parsers.py b/tests/test_parsers.py new file mode 100644 index 00000000..fe6aec19 --- /dev/null +++ b/tests/test_parsers.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- + +from __future__ import unicode_literals +from django import forms +from django.core.files.uploadhandler import MemoryFileUploadHandler +from django.test import TestCase +from django.utils.six.moves import StringIO +from rest_framework.exceptions import ParseError +from rest_framework.parsers import FormParser, FileUploadParser + + +class Form(forms.Form): +    field1 = forms.CharField(max_length=3) +    field2 = forms.CharField() + + +class TestFormParser(TestCase): +    def setUp(self): +        self.string = "field1=abc&field2=defghijk" + +    def test_parse(self): +        """ Make sure the `QueryDict` works OK """ +        parser = FormParser() + +        stream = StringIO(self.string) +        data = parser.parse(stream) + +        self.assertEqual(Form(data).is_valid(), True) + + +class TestFileUploadParser(TestCase): +    def setUp(self): +        class MockRequest(object): +            pass +        from io import BytesIO +        self.stream = BytesIO( +            "Test text file".encode('utf-8') +        ) +        request = MockRequest() +        request.upload_handlers = (MemoryFileUploadHandler(),) +        request.META = { +            'HTTP_CONTENT_DISPOSITION': 'Content-Disposition: inline; filename=file.txt', +            'HTTP_CONTENT_LENGTH': 14, +        } +        self.parser_context = {'request': request, 'kwargs': {}} + +    def test_parse(self): +        """ +        Parse raw file upload. +        """ +        parser = FileUploadParser() +        self.stream.seek(0) +        data_and_files = parser.parse(self.stream, None, self.parser_context) +        file_obj = data_and_files.files['file'] +        self.assertEqual(file_obj._size, 14) + +    def test_parse_missing_filename(self): +        """ +        Parse raw file upload when filename is missing. +        """ +        parser = FileUploadParser() +        self.stream.seek(0) +        self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = '' +        with self.assertRaises(ParseError): +            parser.parse(self.stream, None, self.parser_context) + +    def test_parse_missing_filename_multiple_upload_handlers(self): +        """ +        Parse raw file upload with multiple handlers when filename is missing. +        Regression test for #2109. +        """ +        parser = FileUploadParser() +        self.stream.seek(0) +        self.parser_context['request'].upload_handlers = ( +            MemoryFileUploadHandler(), +            MemoryFileUploadHandler() +        ) +        self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = '' +        with self.assertRaises(ParseError): +            parser.parse(self.stream, None, self.parser_context) + +    def test_get_filename(self): +        parser = FileUploadParser() +        filename = parser.get_filename(self.stream, None, self.parser_context) +        self.assertEqual(filename, 'file.txt') + +    def test_get_encoded_filename(self): +        parser = FileUploadParser() + +        self.__replace_content_disposition('inline; filename*=utf-8\'\'ÀĥƦ.txt') +        filename = parser.get_filename(self.stream, None, self.parser_context) +        self.assertEqual(filename, 'ÀĥƦ.txt') + +        self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'\'ÀĥƦ.txt') +        filename = parser.get_filename(self.stream, None, self.parser_context) +        self.assertEqual(filename, 'ÀĥƦ.txt') + +        self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'en-us\'ÀĥƦ.txt') +        filename = parser.get_filename(self.stream, None, self.parser_context) +        self.assertEqual(filename, 'ÀĥƦ.txt') + +    def __replace_content_disposition(self, disposition): +        self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = disposition diff --git a/tests/test_permissions.py b/tests/test_permissions.py new file mode 100644 index 00000000..97bac33d --- /dev/null +++ b/tests/test_permissions.py @@ -0,0 +1,312 @@ +from __future__ import unicode_literals +from django.contrib.auth.models import User, Permission, Group +from django.db import models +from django.test import TestCase +from django.utils import unittest +from rest_framework import generics, serializers, status, permissions, authentication, HTTP_HEADER_ENCODING +from rest_framework.compat import guardian, get_model_name +from rest_framework.filters import DjangoObjectPermissionsFilter +from rest_framework.test import APIRequestFactory +from tests.models import BasicModel +import base64 + +factory = APIRequestFactory() + + +class BasicSerializer(serializers.ModelSerializer): +    class Meta: +        model = BasicModel + + +class RootView(generics.ListCreateAPIView): +    queryset = BasicModel.objects.all() +    serializer_class = BasicSerializer +    authentication_classes = [authentication.BasicAuthentication] +    permission_classes = [permissions.DjangoModelPermissions] + + +class InstanceView(generics.RetrieveUpdateDestroyAPIView): +    queryset = BasicModel.objects.all() +    serializer_class = BasicSerializer +    authentication_classes = [authentication.BasicAuthentication] +    permission_classes = [permissions.DjangoModelPermissions] + +root_view = RootView.as_view() +instance_view = InstanceView.as_view() + + +def basic_auth_header(username, password): +    credentials = ('%s:%s' % (username, password)) +    base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING) +    return 'Basic %s' % base64_credentials + + +class ModelPermissionsIntegrationTests(TestCase): +    def setUp(self): +        User.objects.create_user('disallowed', 'disallowed@example.com', 'password') +        user = User.objects.create_user('permitted', 'permitted@example.com', 'password') +        user.user_permissions = [ +            Permission.objects.get(codename='add_basicmodel'), +            Permission.objects.get(codename='change_basicmodel'), +            Permission.objects.get(codename='delete_basicmodel') +        ] +        user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password') +        user.user_permissions = [ +            Permission.objects.get(codename='change_basicmodel'), +        ] + +        self.permitted_credentials = basic_auth_header('permitted', 'password') +        self.disallowed_credentials = basic_auth_header('disallowed', 'password') +        self.updateonly_credentials = basic_auth_header('updateonly', 'password') + +        BasicModel(text='foo').save() + +    def test_has_create_permissions(self): +        request = factory.post('/', {'text': 'foobar'}, format='json', +                               HTTP_AUTHORIZATION=self.permitted_credentials) +        response = root_view(request, pk=1) +        self.assertEqual(response.status_code, status.HTTP_201_CREATED) + +    def test_has_put_permissions(self): +        request = factory.put('/1', {'text': 'foobar'}, format='json', +                              HTTP_AUTHORIZATION=self.permitted_credentials) +        response = instance_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_200_OK) + +    def test_has_delete_permissions(self): +        request = factory.delete('/1', HTTP_AUTHORIZATION=self.permitted_credentials) +        response = instance_view(request, pk=1) +        self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + +    def test_does_not_have_create_permissions(self): +        request = factory.post('/', {'text': 'foobar'}, format='json', +                               HTTP_AUTHORIZATION=self.disallowed_credentials) +        response = root_view(request, pk=1) +        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + +    def test_does_not_have_put_permissions(self): +        request = factory.put('/1', {'text': 'foobar'}, format='json', +                              HTTP_AUTHORIZATION=self.disallowed_credentials) +        response = instance_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + +    def test_does_not_have_delete_permissions(self): +        request = factory.delete('/1', HTTP_AUTHORIZATION=self.disallowed_credentials) +        response = instance_view(request, pk=1) +        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + +    def test_options_permitted(self): +        request = factory.options( +            '/', +            HTTP_AUTHORIZATION=self.permitted_credentials +        ) +        response = root_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertIn('actions', response.data) +        self.assertEqual(list(response.data['actions'].keys()), ['POST']) + +        request = factory.options( +            '/1', +            HTTP_AUTHORIZATION=self.permitted_credentials +        ) +        response = instance_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertIn('actions', response.data) +        self.assertEqual(list(response.data['actions'].keys()), ['PUT']) + +    def test_options_disallowed(self): +        request = factory.options( +            '/', +            HTTP_AUTHORIZATION=self.disallowed_credentials +        ) +        response = root_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertNotIn('actions', response.data) + +        request = factory.options( +            '/1', +            HTTP_AUTHORIZATION=self.disallowed_credentials +        ) +        response = instance_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertNotIn('actions', response.data) + +    def test_options_updateonly(self): +        request = factory.options( +            '/', +            HTTP_AUTHORIZATION=self.updateonly_credentials +        ) +        response = root_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertNotIn('actions', response.data) + +        request = factory.options( +            '/1', +            HTTP_AUTHORIZATION=self.updateonly_credentials +        ) +        response = instance_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertIn('actions', response.data) +        self.assertEqual(list(response.data['actions'].keys()), ['PUT']) + + +class BasicPermModel(models.Model): +    text = models.CharField(max_length=100) + +    class Meta: +        app_label = 'tests' +        permissions = ( +            ('view_basicpermmodel', 'Can view basic perm model'), +            # add, change, delete built in to django +        ) + + +class BasicPermSerializer(serializers.ModelSerializer): +    class Meta: +        model = BasicPermModel + + +# Custom object-level permission, that includes 'view' permissions +class ViewObjectPermissions(permissions.DjangoObjectPermissions): +    perms_map = { +        'GET': ['%(app_label)s.view_%(model_name)s'], +        'OPTIONS': ['%(app_label)s.view_%(model_name)s'], +        'HEAD': ['%(app_label)s.view_%(model_name)s'], +        'POST': ['%(app_label)s.add_%(model_name)s'], +        'PUT': ['%(app_label)s.change_%(model_name)s'], +        'PATCH': ['%(app_label)s.change_%(model_name)s'], +        'DELETE': ['%(app_label)s.delete_%(model_name)s'], +    } + + +class ObjectPermissionInstanceView(generics.RetrieveUpdateDestroyAPIView): +    queryset = BasicPermModel.objects.all() +    serializer_class = BasicPermSerializer +    authentication_classes = [authentication.BasicAuthentication] +    permission_classes = [ViewObjectPermissions] + +object_permissions_view = ObjectPermissionInstanceView.as_view() + + +class ObjectPermissionListView(generics.ListAPIView): +    queryset = BasicPermModel.objects.all() +    serializer_class = BasicPermSerializer +    authentication_classes = [authentication.BasicAuthentication] +    permission_classes = [ViewObjectPermissions] + +object_permissions_list_view = ObjectPermissionListView.as_view() + + +@unittest.skipUnless(guardian, 'django-guardian not installed') +class ObjectPermissionsIntegrationTests(TestCase): +    """ +    Integration tests for the object level permissions API. +    """ +    def setUp(self): +        from guardian.shortcuts import assign_perm + +        # create users +        create = User.objects.create_user +        users = { +            'fullaccess': create('fullaccess', 'fullaccess@example.com', 'password'), +            'readonly': create('readonly', 'readonly@example.com', 'password'), +            'writeonly': create('writeonly', 'writeonly@example.com', 'password'), +            'deleteonly': create('deleteonly', 'deleteonly@example.com', 'password'), +        } + +        # give everyone model level permissions, as we are not testing those +        everyone = Group.objects.create(name='everyone') +        model_name = get_model_name(BasicPermModel) +        app_label = BasicPermModel._meta.app_label +        f = '{0}_{1}'.format +        perms = { +            'view': f('view', model_name), +            'change': f('change', model_name), +            'delete': f('delete', model_name) +        } +        for perm in perms.values(): +            perm = '{0}.{1}'.format(app_label, perm) +            assign_perm(perm, everyone) +        everyone.user_set.add(*users.values()) + +        # appropriate object level permissions +        readers = Group.objects.create(name='readers') +        writers = Group.objects.create(name='writers') +        deleters = Group.objects.create(name='deleters') + +        model = BasicPermModel.objects.create(text='foo') + +        assign_perm(perms['view'], readers, model) +        assign_perm(perms['change'], writers, model) +        assign_perm(perms['delete'], deleters, model) + +        readers.user_set.add(users['fullaccess'], users['readonly']) +        writers.user_set.add(users['fullaccess'], users['writeonly']) +        deleters.user_set.add(users['fullaccess'], users['deleteonly']) + +        self.credentials = {} +        for user in users.values(): +            self.credentials[user.username] = basic_auth_header(user.username, 'password') + +    # Delete +    def test_can_delete_permissions(self): +        request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['deleteonly']) +        response = object_permissions_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + +    def test_cannot_delete_permissions(self): +        request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['readonly']) +        response = object_permissions_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + +    # Update +    def test_can_update_permissions(self): +        request = factory.patch( +            '/1', {'text': 'foobar'}, format='json', +            HTTP_AUTHORIZATION=self.credentials['writeonly'] +        ) +        response = object_permissions_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data.get('text'), 'foobar') + +    def test_cannot_update_permissions(self): +        request = factory.patch( +            '/1', {'text': 'foobar'}, format='json', +            HTTP_AUTHORIZATION=self.credentials['deleteonly'] +        ) +        response = object_permissions_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + +    def test_cannot_update_permissions_non_existing(self): +        request = factory.patch( +            '/999', {'text': 'foobar'}, format='json', +            HTTP_AUTHORIZATION=self.credentials['deleteonly'] +        ) +        response = object_permissions_view(request, pk='999') +        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + +    # Read +    def test_can_read_permissions(self): +        request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly']) +        response = object_permissions_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_200_OK) + +    def test_cannot_read_permissions(self): +        request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['writeonly']) +        response = object_permissions_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + +    # Read list +    def test_can_read_list_permissions(self): +        request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['readonly']) +        object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,) +        response = object_permissions_list_view(request) +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data[0].get('id'), 1) + +    def test_cannot_read_list_permissions(self): +        request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['writeonly']) +        object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,) +        response = object_permissions_list_view(request) +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertListEqual(response.data, []) diff --git a/tests/test_relations.py b/tests/test_relations.py new file mode 100644 index 00000000..fbe176e2 --- /dev/null +++ b/tests/test_relations.py @@ -0,0 +1,169 @@ +from .utils import mock_reverse, fail_reverse, BadType, MockObject, MockQueryset +from django.core.exceptions import ImproperlyConfigured +from django.utils.datastructures import MultiValueDict +from rest_framework import serializers +from rest_framework.fields import empty +from rest_framework.test import APISimpleTestCase +import pytest + + +class TestStringRelatedField(APISimpleTestCase): +    def setUp(self): +        self.instance = MockObject(pk=1, name='foo') +        self.field = serializers.StringRelatedField() + +    def test_string_related_representation(self): +        representation = self.field.to_representation(self.instance) +        assert representation == '<MockObject name=foo, pk=1>' + + +class TestPrimaryKeyRelatedField(APISimpleTestCase): +    def setUp(self): +        self.queryset = MockQueryset([ +            MockObject(pk=1, name='foo'), +            MockObject(pk=2, name='bar'), +            MockObject(pk=3, name='baz') +        ]) +        self.instance = self.queryset.items[2] +        self.field = serializers.PrimaryKeyRelatedField(queryset=self.queryset) + +    def test_pk_related_lookup_exists(self): +        instance = self.field.to_internal_value(self.instance.pk) +        assert instance is self.instance + +    def test_pk_related_lookup_does_not_exist(self): +        with pytest.raises(serializers.ValidationError) as excinfo: +            self.field.to_internal_value(4) +        msg = excinfo.value.detail[0] +        assert msg == 'Invalid pk "4" - object does not exist.' + +    def test_pk_related_lookup_invalid_type(self): +        with pytest.raises(serializers.ValidationError) as excinfo: +            self.field.to_internal_value(BadType()) +        msg = excinfo.value.detail[0] +        assert msg == 'Incorrect type. Expected pk value, received BadType.' + +    def test_pk_representation(self): +        representation = self.field.to_representation(self.instance) +        assert representation == self.instance.pk + + +class TestHyperlinkedIdentityField(APISimpleTestCase): +    def setUp(self): +        self.instance = MockObject(pk=1, name='foo') +        self.field = serializers.HyperlinkedIdentityField(view_name='example') +        self.field.reverse = mock_reverse +        self.field._context = {'request': True} + +    def test_representation(self): +        representation = self.field.to_representation(self.instance) +        assert representation == 'http://example.org/example/1/' + +    def test_representation_unsaved_object(self): +        representation = self.field.to_representation(MockObject(pk=None)) +        assert representation is None + +    def test_representation_with_format(self): +        self.field._context['format'] = 'xml' +        representation = self.field.to_representation(self.instance) +        assert representation == 'http://example.org/example/1.xml/' + +    def test_improperly_configured(self): +        """ +        If a matching view cannot be reversed with the given instance, +        the the user has misconfigured something, as the URL conf and the +        hyperlinked field do not match. +        """ +        self.field.reverse = fail_reverse +        with pytest.raises(ImproperlyConfigured): +            self.field.to_representation(self.instance) + + +class TestHyperlinkedIdentityFieldWithFormat(APISimpleTestCase): +    """ +    Tests for a hyperlinked identity field that has a `format` set, +    which enforces that alternate formats are never linked too. + +    Eg. If your API includes some endpoints that accept both `.xml` and `.json`, +    but other endpoints that only accept `.json`, we allow for hyperlinked +    relationships that enforce only a single suffix type. +    """ + +    def setUp(self): +        self.instance = MockObject(pk=1, name='foo') +        self.field = serializers.HyperlinkedIdentityField(view_name='example', format='json') +        self.field.reverse = mock_reverse +        self.field._context = {'request': True} + +    def test_representation(self): +        representation = self.field.to_representation(self.instance) +        assert representation == 'http://example.org/example/1/' + +    def test_representation_with_format(self): +        self.field._context['format'] = 'xml' +        representation = self.field.to_representation(self.instance) +        assert representation == 'http://example.org/example/1.json/' + + +class TestSlugRelatedField(APISimpleTestCase): +    def setUp(self): +        self.queryset = MockQueryset([ +            MockObject(pk=1, name='foo'), +            MockObject(pk=2, name='bar'), +            MockObject(pk=3, name='baz') +        ]) +        self.instance = self.queryset.items[2] +        self.field = serializers.SlugRelatedField( +            slug_field='name', queryset=self.queryset +        ) + +    def test_slug_related_lookup_exists(self): +        instance = self.field.to_internal_value(self.instance.name) +        assert instance is self.instance + +    def test_slug_related_lookup_does_not_exist(self): +        with pytest.raises(serializers.ValidationError) as excinfo: +            self.field.to_internal_value('doesnotexist') +        msg = excinfo.value.detail[0] +        assert msg == 'Object with name=doesnotexist does not exist.' + +    def test_slug_related_lookup_invalid_type(self): +        with pytest.raises(serializers.ValidationError) as excinfo: +            self.field.to_internal_value(BadType()) +        msg = excinfo.value.detail[0] +        assert msg == 'Invalid value.' + +    def test_representation(self): +        representation = self.field.to_representation(self.instance) +        assert representation == self.instance.name + + +class TestManyRelatedField(APISimpleTestCase): +    def setUp(self): +        self.instance = MockObject(pk=1, name='foo') +        self.field = serializers.StringRelatedField(many=True) +        self.field.field_name = 'foo' + +    def test_get_value_regular_dictionary_full(self): +        assert 'bar' == self.field.get_value({'foo': 'bar'}) +        assert empty == self.field.get_value({'baz': 'bar'}) + +    def test_get_value_regular_dictionary_partial(self): +        setattr(self.field.root, 'partial', True) +        assert 'bar' == self.field.get_value({'foo': 'bar'}) +        assert empty == self.field.get_value({'baz': 'bar'}) + +    def test_get_value_multi_dictionary_full(self): +        mvd = MultiValueDict({'foo': ['bar1', 'bar2']}) +        assert ['bar1', 'bar2'] == self.field.get_value(mvd) + +        mvd = MultiValueDict({'baz': ['bar1', 'bar2']}) +        assert [] == self.field.get_value(mvd) + +    def test_get_value_multi_dictionary_partial(self): +        setattr(self.field.root, 'partial', True) +        mvd = MultiValueDict({'foo': ['bar1', 'bar2']}) +        assert ['bar1', 'bar2'] == self.field.get_value(mvd) + +        mvd = MultiValueDict({'baz': ['bar1', 'bar2']}) +        assert empty == self.field.get_value(mvd) diff --git a/tests/test_relations_generic.py b/tests/test_relations_generic.py new file mode 100644 index 00000000..b600b333 --- /dev/null +++ b/tests/test_relations_generic.py @@ -0,0 +1,104 @@ +from __future__ import unicode_literals +from django.contrib.contenttypes.models import ContentType +from django.contrib.contenttypes.generic import GenericRelation, GenericForeignKey +from django.db import models +from django.test import TestCase +from django.utils.encoding import python_2_unicode_compatible +from rest_framework import serializers + + +@python_2_unicode_compatible +class Tag(models.Model): +    """ +    Tags have a descriptive slug, and are attached to an arbitrary object. +    """ +    tag = models.SlugField() +    content_type = models.ForeignKey(ContentType) +    object_id = models.PositiveIntegerField() +    tagged_item = GenericForeignKey('content_type', 'object_id') + +    def __str__(self): +        return self.tag + + +@python_2_unicode_compatible +class Bookmark(models.Model): +    """ +    A URL bookmark that may have multiple tags attached. +    """ +    url = models.URLField() +    tags = GenericRelation(Tag) + +    def __str__(self): +        return 'Bookmark: %s' % self.url + + +@python_2_unicode_compatible +class Note(models.Model): +    """ +    A textual note that may have multiple tags attached. +    """ +    text = models.TextField() +    tags = GenericRelation(Tag) + +    def __str__(self): +        return 'Note: %s' % self.text + + +class TestGenericRelations(TestCase): +    def setUp(self): +        self.bookmark = Bookmark.objects.create(url='https://www.djangoproject.com/') +        Tag.objects.create(tagged_item=self.bookmark, tag='django') +        Tag.objects.create(tagged_item=self.bookmark, tag='python') +        self.note = Note.objects.create(text='Remember the milk') +        Tag.objects.create(tagged_item=self.note, tag='reminder') + +    def test_generic_relation(self): +        """ +        Test a relationship that spans a GenericRelation field. +        IE. A reverse generic relationship. +        """ + +        class BookmarkSerializer(serializers.ModelSerializer): +            tags = serializers.StringRelatedField(many=True) + +            class Meta: +                model = Bookmark +                fields = ('tags', 'url') + +        serializer = BookmarkSerializer(self.bookmark) +        expected = { +            'tags': ['django', 'python'], +            'url': 'https://www.djangoproject.com/' +        } +        self.assertEqual(serializer.data, expected) + +    def test_generic_fk(self): +        """ +        Test a relationship that spans a GenericForeignKey field. +        IE. A forward generic relationship. +        """ + +        class TagSerializer(serializers.ModelSerializer): +            tagged_item = serializers.StringRelatedField() + +            class Meta: +                model = Tag +                fields = ('tag', 'tagged_item') + +        serializer = TagSerializer(Tag.objects.all(), many=True) +        expected = [ +            { +                'tag': 'django', +                'tagged_item': 'Bookmark: https://www.djangoproject.com/' +            }, +            { +                'tag': 'python', +                'tagged_item': 'Bookmark: https://www.djangoproject.com/' +            }, +            { +                'tag': 'reminder', +                'tagged_item': 'Note: Remember the milk' +            } +        ] +        self.assertEqual(serializer.data, expected) diff --git a/tests/test_relations_hyperlink.py b/tests/test_relations_hyperlink.py new file mode 100644 index 00000000..33b09713 --- /dev/null +++ b/tests/test_relations_hyperlink.py @@ -0,0 +1,444 @@ +from __future__ import unicode_literals +from django.conf.urls import url +from django.test import TestCase +from rest_framework import serializers +from rest_framework.test import APIRequestFactory +from tests.models import ( +    ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, +    NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource +) + +factory = APIRequestFactory() +request = factory.get('/')  # Just to ensure we have a request in the serializer context + + +def dummy_view(request, pk): +    pass + + +urlpatterns = [ +    url(r'^dummyurl/(?P<pk>[0-9]+)/$', dummy_view, name='dummy-url'), +    url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'), +    url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'), +    url(r'^foreignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeysource-detail'), +    url(r'^foreignkeytarget/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeytarget-detail'), +    url(r'^nullableforeignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableforeignkeysource-detail'), +    url(r'^onetoonetarget/(?P<pk>[0-9]+)/$', dummy_view, name='onetoonetarget-detail'), +    url(r'^nullableonetoonesource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableonetoonesource-detail'), +] + + +# ManyToMany +class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer): +    class Meta: +        model = ManyToManyTarget +        fields = ('url', 'name', 'sources') + + +class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer): +    class Meta: +        model = ManyToManySource +        fields = ('url', 'name', 'targets') + + +# ForeignKey +class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer): +    class Meta: +        model = ForeignKeyTarget +        fields = ('url', 'name', 'sources') + + +class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): +    class Meta: +        model = ForeignKeySource +        fields = ('url', 'name', 'target') + + +# Nullable ForeignKey +class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): +    class Meta: +        model = NullableForeignKeySource +        fields = ('url', 'name', 'target') + + +# Nullable OneToOne +class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer): +    class Meta: +        model = OneToOneTarget +        fields = ('url', 'name', 'nullable_source') + + +# TODO: Add test that .data cannot be accessed prior to .is_valid + +class HyperlinkedManyToManyTests(TestCase): +    urls = 'tests.test_relations_hyperlink' + +    def setUp(self): +        for idx in range(1, 4): +            target = ManyToManyTarget(name='target-%d' % idx) +            target.save() +            source = ManyToManySource(name='source-%d' % idx) +            source.save() +            for target in ManyToManyTarget.objects.all(): +                source.targets.add(target) + +    def test_many_to_many_retrieve(self): +        queryset = ManyToManySource.objects.all() +        serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']}, +            {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, +            {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} +        ] +        with self.assertNumQueries(4): +            self.assertEqual(serializer.data, expected) + +    def test_many_to_many_retrieve_prefetch_related(self): +        queryset = ManyToManySource.objects.all().prefetch_related('targets') +        serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) +        with self.assertNumQueries(2): +            serializer.data + +    def test_reverse_many_to_many_retrieve(self): +        queryset = ManyToManyTarget.objects.all() +        serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, +            {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, +            {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']} +        ] +        with self.assertNumQueries(4): +            self.assertEqual(serializer.data, expected) + +    def test_many_to_many_update(self): +        data = {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} +        instance = ManyToManySource.objects.get(pk=1) +        serializer = ManyToManySourceSerializer(instance, data=data, context={'request': request}) +        self.assertTrue(serializer.is_valid()) +        serializer.save() +        self.assertEqual(serializer.data, data) + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = ManyToManySource.objects.all() +        serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}, +            {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, +            {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_reverse_many_to_many_update(self): +        data = {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']} +        instance = ManyToManyTarget.objects.get(pk=1) +        serializer = ManyToManyTargetSerializer(instance, data=data, context={'request': request}) +        self.assertTrue(serializer.is_valid()) +        serializer.save() +        self.assertEqual(serializer.data, data) + +        # Ensure target 1 is updated, and everything else is as expected +        queryset = ManyToManyTarget.objects.all() +        serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']}, +            {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, +            {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']} + +        ] +        self.assertEqual(serializer.data, expected) + +    def test_many_to_many_create(self): +        data = {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']} +        serializer = ManyToManySourceSerializer(data=data, context={'request': request}) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'source-4') + +        # Ensure source 4 is added, and everything else is as expected +        queryset = ManyToManySource.objects.all() +        serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']}, +            {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, +            {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}, +            {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_reverse_many_to_many_create(self): +        data = {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']} +        serializer = ManyToManyTargetSerializer(data=data, context={'request': request}) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'target-4') + +        # Ensure target 4 is added, and everything else is as expected +        queryset = ManyToManyTarget.objects.all() +        serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, +            {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, +            {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}, +            {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']} +        ] +        self.assertEqual(serializer.data, expected) + + +class HyperlinkedForeignKeyTests(TestCase): +    urls = 'tests.test_relations_hyperlink' + +    def setUp(self): +        target = ForeignKeyTarget(name='target-1') +        target.save() +        new_target = ForeignKeyTarget(name='target-2') +        new_target.save() +        for idx in range(1, 4): +            source = ForeignKeySource(name='source-%d' % idx, target=target) +            source.save() + +    def test_foreign_key_retrieve(self): +        queryset = ForeignKeySource.objects.all() +        serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, +            {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, +            {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'} +        ] +        with self.assertNumQueries(1): +            self.assertEqual(serializer.data, expected) + +    def test_reverse_foreign_key_retrieve(self): +        queryset = ForeignKeyTarget.objects.all() +        serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']}, +            {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, +        ] +        with self.assertNumQueries(3): +            self.assertEqual(serializer.data, expected) + +    def test_foreign_key_update(self): +        data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'} +        instance = ForeignKeySource.objects.get(pk=1) +        serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) +        self.assertTrue(serializer.is_valid()) +        serializer.save() +        self.assertEqual(serializer.data, data) + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = ForeignKeySource.objects.all() +        serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'}, +            {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, +            {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_update_incorrect_type(self): +        data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 2} +        instance = ForeignKeySource.objects.get(pk=1) +        serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) +        self.assertFalse(serializer.is_valid()) +        self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected URL string, received int.']}) + +    def test_reverse_foreign_key_update(self): +        data = {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']} +        instance = ForeignKeyTarget.objects.get(pk=2) +        serializer = ForeignKeyTargetSerializer(instance, data=data, context={'request': request}) +        self.assertTrue(serializer.is_valid()) +        # We shouldn't have saved anything to the db yet since save +        # hasn't been called. +        queryset = ForeignKeyTarget.objects.all() +        new_serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']}, +            {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, +        ] +        self.assertEqual(new_serializer.data, expected) + +        serializer.save() +        self.assertEqual(serializer.data, data) + +        # Ensure target 2 is update, and everything else is as expected +        queryset = ForeignKeyTarget.objects.all() +        serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']}, +            {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_create(self): +        data = {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'} +        serializer = ForeignKeySourceSerializer(data=data, context={'request': request}) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'source-4') + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = ForeignKeySource.objects.all() +        serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, +            {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, +            {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}, +            {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_reverse_foreign_key_create(self): +        data = {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']} +        serializer = ForeignKeyTargetSerializer(data=data, context={'request': request}) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'target-3') + +        # Ensure target 4 is added, and everything else is as expected +        queryset = ForeignKeyTarget.objects.all() +        serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']}, +            {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, +            {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_update_with_invalid_null(self): +        data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': None} +        instance = ForeignKeySource.objects.get(pk=1) +        serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) +        self.assertFalse(serializer.is_valid()) +        self.assertEqual(serializer.errors, {'target': ['This field may not be null.']}) + + +class HyperlinkedNullableForeignKeyTests(TestCase): +    urls = 'tests.test_relations_hyperlink' + +    def setUp(self): +        target = ForeignKeyTarget(name='target-1') +        target.save() +        for idx in range(1, 4): +            if idx == 3: +                target = None +            source = NullableForeignKeySource(name='source-%d' % idx, target=target) +            source.save() + +    def test_foreign_key_retrieve_with_null(self): +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, +            {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, +            {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_create_with_valid_null(self): +        data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} +        serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request}) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'source-4') + +        # Ensure source 4 is created, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, +            {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, +            {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, +            {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_create_with_valid_emptystring(self): +        """ +        The emptystring should be interpreted as null in the context +        of relationships. +        """ +        data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': ''} +        expected_data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} +        serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request}) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, expected_data) +        self.assertEqual(obj.name, 'source-4') + +        # Ensure source 4 is created, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, +            {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, +            {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, +            {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_update_with_valid_null(self): +        data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None} +        instance = NullableForeignKeySource.objects.get(pk=1) +        serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request}) +        self.assertTrue(serializer.is_valid()) +        serializer.save() +        self.assertEqual(serializer.data, data) + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}, +            {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, +            {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_update_with_valid_emptystring(self): +        """ +        The emptystring should be interpreted as null in the context +        of relationships. +        """ +        data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': ''} +        expected_data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None} +        instance = NullableForeignKeySource.objects.get(pk=1) +        serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request}) +        self.assertTrue(serializer.is_valid()) +        serializer.save() +        self.assertEqual(serializer.data, expected_data) + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}, +            {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, +            {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, +        ] +        self.assertEqual(serializer.data, expected) + + +class HyperlinkedNullableOneToOneTests(TestCase): +    urls = 'tests.test_relations_hyperlink' + +    def setUp(self): +        target = OneToOneTarget(name='target-1') +        target.save() +        new_target = OneToOneTarget(name='target-2') +        new_target.save() +        source = NullableOneToOneSource(name='source-1', target=target) +        source.save() + +    def test_reverse_foreign_key_retrieve_with_null(self): +        queryset = OneToOneTarget.objects.all() +        serializer = NullableOneToOneTargetSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/onetoonetarget/1/', 'name': 'target-1', 'nullable_source': 'http://testserver/nullableonetoonesource/1/'}, +            {'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None}, +        ] +        self.assertEqual(serializer.data, expected) diff --git a/tests/test_relations_pk.py b/tests/test_relations_pk.py new file mode 100644 index 00000000..ca43272b --- /dev/null +++ b/tests/test_relations_pk.py @@ -0,0 +1,450 @@ +from __future__ import unicode_literals +from django.test import TestCase +from django.utils import six +from rest_framework import serializers +from tests.models import ( +    ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, +    NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource, +) + + +# ManyToMany +class ManyToManyTargetSerializer(serializers.ModelSerializer): +    class Meta: +        model = ManyToManyTarget +        fields = ('id', 'name', 'sources') + + +class ManyToManySourceSerializer(serializers.ModelSerializer): +    class Meta: +        model = ManyToManySource +        fields = ('id', 'name', 'targets') + + +# ForeignKey +class ForeignKeyTargetSerializer(serializers.ModelSerializer): +    class Meta: +        model = ForeignKeyTarget +        fields = ('id', 'name', 'sources') + + +class ForeignKeySourceSerializer(serializers.ModelSerializer): +    class Meta: +        model = ForeignKeySource +        fields = ('id', 'name', 'target') + + +# Nullable ForeignKey +class NullableForeignKeySourceSerializer(serializers.ModelSerializer): +    class Meta: +        model = NullableForeignKeySource +        fields = ('id', 'name', 'target') + + +# Nullable OneToOne +class NullableOneToOneTargetSerializer(serializers.ModelSerializer): +    class Meta: +        model = OneToOneTarget +        fields = ('id', 'name', 'nullable_source') + + +# TODO: Add test that .data cannot be accessed prior to .is_valid + +class PKManyToManyTests(TestCase): +    def setUp(self): +        for idx in range(1, 4): +            target = ManyToManyTarget(name='target-%d' % idx) +            target.save() +            source = ManyToManySource(name='source-%d' % idx) +            source.save() +            for target in ManyToManyTarget.objects.all(): +                source.targets.add(target) + +    def test_many_to_many_retrieve(self): +        queryset = ManyToManySource.objects.all() +        serializer = ManyToManySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'targets': [1]}, +            {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, +            {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]} +        ] +        with self.assertNumQueries(4): +            self.assertEqual(serializer.data, expected) + +    def test_many_to_many_retrieve_prefetch_related(self): +        queryset = ManyToManySource.objects.all().prefetch_related('targets') +        serializer = ManyToManySourceSerializer(queryset, many=True) +        with self.assertNumQueries(2): +            serializer.data + +    def test_reverse_many_to_many_retrieve(self): +        queryset = ManyToManyTarget.objects.all() +        serializer = ManyToManyTargetSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, +            {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, +            {'id': 3, 'name': 'target-3', 'sources': [3]} +        ] +        with self.assertNumQueries(4): +            self.assertEqual(serializer.data, expected) + +    def test_many_to_many_update(self): +        data = {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]} +        instance = ManyToManySource.objects.get(pk=1) +        serializer = ManyToManySourceSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        serializer.save() +        self.assertEqual(serializer.data, data) + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = ManyToManySource.objects.all() +        serializer = ManyToManySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]}, +            {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, +            {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_reverse_many_to_many_update(self): +        data = {'id': 1, 'name': 'target-1', 'sources': [1]} +        instance = ManyToManyTarget.objects.get(pk=1) +        serializer = ManyToManyTargetSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        serializer.save() +        self.assertEqual(serializer.data, data) + +        # Ensure target 1 is updated, and everything else is as expected +        queryset = ManyToManyTarget.objects.all() +        serializer = ManyToManyTargetSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'sources': [1]}, +            {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, +            {'id': 3, 'name': 'target-3', 'sources': [3]} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_many_to_many_create(self): +        data = {'id': 4, 'name': 'source-4', 'targets': [1, 3]} +        serializer = ManyToManySourceSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'source-4') + +        # Ensure source 4 is added, and everything else is as expected +        queryset = ManyToManySource.objects.all() +        serializer = ManyToManySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'targets': [1]}, +            {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, +            {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}, +            {'id': 4, 'name': 'source-4', 'targets': [1, 3]}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_many_to_many_unsaved(self): +        source = ManyToManySource(name='source-unsaved') + +        serializer = ManyToManySourceSerializer(source) + +        expected = {'id': None, 'name': 'source-unsaved', 'targets': []} +        # no query if source hasn't been created yet +        with self.assertNumQueries(0): +            self.assertEqual(serializer.data, expected) + +    def test_reverse_many_to_many_create(self): +        data = {'id': 4, 'name': 'target-4', 'sources': [1, 3]} +        serializer = ManyToManyTargetSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'target-4') + +        # Ensure target 4 is added, and everything else is as expected +        queryset = ManyToManyTarget.objects.all() +        serializer = ManyToManyTargetSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, +            {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, +            {'id': 3, 'name': 'target-3', 'sources': [3]}, +            {'id': 4, 'name': 'target-4', 'sources': [1, 3]} +        ] +        self.assertEqual(serializer.data, expected) + + +class PKForeignKeyTests(TestCase): +    def setUp(self): +        target = ForeignKeyTarget(name='target-1') +        target.save() +        new_target = ForeignKeyTarget(name='target-2') +        new_target.save() +        for idx in range(1, 4): +            source = ForeignKeySource(name='source-%d' % idx, target=target) +            source.save() + +    def test_foreign_key_retrieve(self): +        queryset = ForeignKeySource.objects.all() +        serializer = ForeignKeySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': 1}, +            {'id': 2, 'name': 'source-2', 'target': 1}, +            {'id': 3, 'name': 'source-3', 'target': 1} +        ] +        with self.assertNumQueries(1): +            self.assertEqual(serializer.data, expected) + +    def test_reverse_foreign_key_retrieve(self): +        queryset = ForeignKeyTarget.objects.all() +        serializer = ForeignKeyTargetSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, +            {'id': 2, 'name': 'target-2', 'sources': []}, +        ] +        with self.assertNumQueries(3): +            self.assertEqual(serializer.data, expected) + +    def test_reverse_foreign_key_retrieve_prefetch_related(self): +        queryset = ForeignKeyTarget.objects.all().prefetch_related('sources') +        serializer = ForeignKeyTargetSerializer(queryset, many=True) +        with self.assertNumQueries(2): +            serializer.data + +    def test_foreign_key_update(self): +        data = {'id': 1, 'name': 'source-1', 'target': 2} +        instance = ForeignKeySource.objects.get(pk=1) +        serializer = ForeignKeySourceSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        serializer.save() +        self.assertEqual(serializer.data, data) + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = ForeignKeySource.objects.all() +        serializer = ForeignKeySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': 2}, +            {'id': 2, 'name': 'source-2', 'target': 1}, +            {'id': 3, 'name': 'source-3', 'target': 1} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_update_incorrect_type(self): +        data = {'id': 1, 'name': 'source-1', 'target': 'foo'} +        instance = ForeignKeySource.objects.get(pk=1) +        serializer = ForeignKeySourceSerializer(instance, data=data) +        self.assertFalse(serializer.is_valid()) +        self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected pk value, received %s.' % six.text_type.__name__]}) + +    def test_reverse_foreign_key_update(self): +        data = {'id': 2, 'name': 'target-2', 'sources': [1, 3]} +        instance = ForeignKeyTarget.objects.get(pk=2) +        serializer = ForeignKeyTargetSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        # We shouldn't have saved anything to the db yet since save +        # hasn't been called. +        queryset = ForeignKeyTarget.objects.all() +        new_serializer = ForeignKeyTargetSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, +            {'id': 2, 'name': 'target-2', 'sources': []}, +        ] +        self.assertEqual(new_serializer.data, expected) + +        serializer.save() +        self.assertEqual(serializer.data, data) + +        # Ensure target 2 is update, and everything else is as expected +        queryset = ForeignKeyTarget.objects.all() +        serializer = ForeignKeyTargetSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'sources': [2]}, +            {'id': 2, 'name': 'target-2', 'sources': [1, 3]}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_create(self): +        data = {'id': 4, 'name': 'source-4', 'target': 2} +        serializer = ForeignKeySourceSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'source-4') + +        # Ensure source 4 is added, and everything else is as expected +        queryset = ForeignKeySource.objects.all() +        serializer = ForeignKeySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': 1}, +            {'id': 2, 'name': 'source-2', 'target': 1}, +            {'id': 3, 'name': 'source-3', 'target': 1}, +            {'id': 4, 'name': 'source-4', 'target': 2}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_reverse_foreign_key_create(self): +        data = {'id': 3, 'name': 'target-3', 'sources': [1, 3]} +        serializer = ForeignKeyTargetSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'target-3') + +        # Ensure target 3 is added, and everything else is as expected +        queryset = ForeignKeyTarget.objects.all() +        serializer = ForeignKeyTargetSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'sources': [2]}, +            {'id': 2, 'name': 'target-2', 'sources': []}, +            {'id': 3, 'name': 'target-3', 'sources': [1, 3]}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_update_with_invalid_null(self): +        data = {'id': 1, 'name': 'source-1', 'target': None} +        instance = ForeignKeySource.objects.get(pk=1) +        serializer = ForeignKeySourceSerializer(instance, data=data) +        self.assertFalse(serializer.is_valid()) +        self.assertEqual(serializer.errors, {'target': ['This field may not be null.']}) + +    def test_foreign_key_with_unsaved(self): +        source = ForeignKeySource(name='source-unsaved') +        expected = {'id': None, 'name': 'source-unsaved', 'target': None} + +        serializer = ForeignKeySourceSerializer(source) + +        # no query if source hasn't been created yet +        with self.assertNumQueries(0): +            self.assertEqual(serializer.data, expected) + +    def test_foreign_key_with_empty(self): +        """ +        Regression test for #1072 + +        https://github.com/tomchristie/django-rest-framework/issues/1072 +        """ +        serializer = NullableForeignKeySourceSerializer() +        self.assertEqual(serializer.data['target'], None) + + +class PKNullableForeignKeyTests(TestCase): +    def setUp(self): +        target = ForeignKeyTarget(name='target-1') +        target.save() +        for idx in range(1, 4): +            if idx == 3: +                target = None +            source = NullableForeignKeySource(name='source-%d' % idx, target=target) +            source.save() + +    def test_foreign_key_retrieve_with_null(self): +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': 1}, +            {'id': 2, 'name': 'source-2', 'target': 1}, +            {'id': 3, 'name': 'source-3', 'target': None}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_create_with_valid_null(self): +        data = {'id': 4, 'name': 'source-4', 'target': None} +        serializer = NullableForeignKeySourceSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'source-4') + +        # Ensure source 4 is created, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': 1}, +            {'id': 2, 'name': 'source-2', 'target': 1}, +            {'id': 3, 'name': 'source-3', 'target': None}, +            {'id': 4, 'name': 'source-4', 'target': None} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_create_with_valid_emptystring(self): +        """ +        The emptystring should be interpreted as null in the context +        of relationships. +        """ +        data = {'id': 4, 'name': 'source-4', 'target': ''} +        expected_data = {'id': 4, 'name': 'source-4', 'target': None} +        serializer = NullableForeignKeySourceSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, expected_data) +        self.assertEqual(obj.name, 'source-4') + +        # Ensure source 4 is created, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': 1}, +            {'id': 2, 'name': 'source-2', 'target': 1}, +            {'id': 3, 'name': 'source-3', 'target': None}, +            {'id': 4, 'name': 'source-4', 'target': None} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_update_with_valid_null(self): +        data = {'id': 1, 'name': 'source-1', 'target': None} +        instance = NullableForeignKeySource.objects.get(pk=1) +        serializer = NullableForeignKeySourceSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        serializer.save() +        self.assertEqual(serializer.data, data) + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': None}, +            {'id': 2, 'name': 'source-2', 'target': 1}, +            {'id': 3, 'name': 'source-3', 'target': None} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_update_with_valid_emptystring(self): +        """ +        The emptystring should be interpreted as null in the context +        of relationships. +        """ +        data = {'id': 1, 'name': 'source-1', 'target': ''} +        expected_data = {'id': 1, 'name': 'source-1', 'target': None} +        instance = NullableForeignKeySource.objects.get(pk=1) +        serializer = NullableForeignKeySourceSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        serializer.save() +        self.assertEqual(serializer.data, expected_data) + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': None}, +            {'id': 2, 'name': 'source-2', 'target': 1}, +            {'id': 3, 'name': 'source-3', 'target': None} +        ] +        self.assertEqual(serializer.data, expected) + + +class PKNullableOneToOneTests(TestCase): +    def setUp(self): +        target = OneToOneTarget(name='target-1') +        target.save() +        new_target = OneToOneTarget(name='target-2') +        new_target.save() +        source = NullableOneToOneSource(name='source-1', target=new_target) +        source.save() + +    def test_reverse_foreign_key_retrieve_with_null(self): +        queryset = OneToOneTarget.objects.all() +        serializer = NullableOneToOneTargetSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'nullable_source': None}, +            {'id': 2, 'name': 'target-2', 'nullable_source': 1}, +        ] +        self.assertEqual(serializer.data, expected) diff --git a/tests/test_relations_slug.py b/tests/test_relations_slug.py new file mode 100644 index 00000000..cd2cb1ed --- /dev/null +++ b/tests/test_relations_slug.py @@ -0,0 +1,281 @@ +from django.test import TestCase +from rest_framework import serializers +from tests.models import NullableForeignKeySource, ForeignKeySource, ForeignKeyTarget + + +class ForeignKeyTargetSerializer(serializers.ModelSerializer): +    sources = serializers.SlugRelatedField( +        slug_field='name', +        queryset=ForeignKeySource.objects.all(), +        many=True +    ) + +    class Meta: +        model = ForeignKeyTarget + + +class ForeignKeySourceSerializer(serializers.ModelSerializer): +    target = serializers.SlugRelatedField( +        slug_field='name', +        queryset=ForeignKeyTarget.objects.all() +    ) + +    class Meta: +        model = ForeignKeySource + + +class NullableForeignKeySourceSerializer(serializers.ModelSerializer): +    target = serializers.SlugRelatedField( +        slug_field='name', +        queryset=ForeignKeyTarget.objects.all(), +        allow_null=True +    ) + +    class Meta: +        model = NullableForeignKeySource + + +# TODO: M2M Tests, FKTests (Non-nullable), One2One +class SlugForeignKeyTests(TestCase): +    def setUp(self): +        target = ForeignKeyTarget(name='target-1') +        target.save() +        new_target = ForeignKeyTarget(name='target-2') +        new_target.save() +        for idx in range(1, 4): +            source = ForeignKeySource(name='source-%d' % idx, target=target) +            source.save() + +    def test_foreign_key_retrieve(self): +        queryset = ForeignKeySource.objects.all() +        serializer = ForeignKeySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': 'target-1'}, +            {'id': 2, 'name': 'source-2', 'target': 'target-1'}, +            {'id': 3, 'name': 'source-3', 'target': 'target-1'} +        ] +        with self.assertNumQueries(4): +            self.assertEqual(serializer.data, expected) + +    def test_foreign_key_retrieve_select_related(self): +        queryset = ForeignKeySource.objects.all().select_related('target') +        serializer = ForeignKeySourceSerializer(queryset, many=True) +        with self.assertNumQueries(1): +            serializer.data + +    def test_reverse_foreign_key_retrieve(self): +        queryset = ForeignKeyTarget.objects.all() +        serializer = ForeignKeyTargetSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, +            {'id': 2, 'name': 'target-2', 'sources': []}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_reverse_foreign_key_retrieve_prefetch_related(self): +        queryset = ForeignKeyTarget.objects.all().prefetch_related('sources') +        serializer = ForeignKeyTargetSerializer(queryset, many=True) +        with self.assertNumQueries(2): +            serializer.data + +    def test_foreign_key_update(self): +        data = {'id': 1, 'name': 'source-1', 'target': 'target-2'} +        instance = ForeignKeySource.objects.get(pk=1) +        serializer = ForeignKeySourceSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        serializer.save() +        self.assertEqual(serializer.data, data) + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = ForeignKeySource.objects.all() +        serializer = ForeignKeySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': 'target-2'}, +            {'id': 2, 'name': 'source-2', 'target': 'target-1'}, +            {'id': 3, 'name': 'source-3', 'target': 'target-1'} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_update_incorrect_type(self): +        data = {'id': 1, 'name': 'source-1', 'target': 123} +        instance = ForeignKeySource.objects.get(pk=1) +        serializer = ForeignKeySourceSerializer(instance, data=data) +        self.assertFalse(serializer.is_valid()) +        self.assertEqual(serializer.errors, {'target': ['Object with name=123 does not exist.']}) + +    def test_reverse_foreign_key_update(self): +        data = {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']} +        instance = ForeignKeyTarget.objects.get(pk=2) +        serializer = ForeignKeyTargetSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        # We shouldn't have saved anything to the db yet since save +        # hasn't been called. +        queryset = ForeignKeyTarget.objects.all() +        new_serializer = ForeignKeyTargetSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, +            {'id': 2, 'name': 'target-2', 'sources': []}, +        ] +        self.assertEqual(new_serializer.data, expected) + +        serializer.save() +        self.assertEqual(serializer.data, data) + +        # Ensure target 2 is update, and everything else is as expected +        queryset = ForeignKeyTarget.objects.all() +        serializer = ForeignKeyTargetSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'sources': ['source-2']}, +            {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_create(self): +        data = {'id': 4, 'name': 'source-4', 'target': 'target-2'} +        serializer = ForeignKeySourceSerializer(data=data) +        serializer.is_valid() +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'source-4') + +        # Ensure source 4 is added, and everything else is as expected +        queryset = ForeignKeySource.objects.all() +        serializer = ForeignKeySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': 'target-1'}, +            {'id': 2, 'name': 'source-2', 'target': 'target-1'}, +            {'id': 3, 'name': 'source-3', 'target': 'target-1'}, +            {'id': 4, 'name': 'source-4', 'target': 'target-2'}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_reverse_foreign_key_create(self): +        data = {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']} +        serializer = ForeignKeyTargetSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'target-3') + +        # Ensure target 3 is added, and everything else is as expected +        queryset = ForeignKeyTarget.objects.all() +        serializer = ForeignKeyTargetSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'sources': ['source-2']}, +            {'id': 2, 'name': 'target-2', 'sources': []}, +            {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_update_with_invalid_null(self): +        data = {'id': 1, 'name': 'source-1', 'target': None} +        instance = ForeignKeySource.objects.get(pk=1) +        serializer = ForeignKeySourceSerializer(instance, data=data) +        self.assertFalse(serializer.is_valid()) +        self.assertEqual(serializer.errors, {'target': ['This field may not be null.']}) + + +class SlugNullableForeignKeyTests(TestCase): +    def setUp(self): +        target = ForeignKeyTarget(name='target-1') +        target.save() +        for idx in range(1, 4): +            if idx == 3: +                target = None +            source = NullableForeignKeySource(name='source-%d' % idx, target=target) +            source.save() + +    def test_foreign_key_retrieve_with_null(self): +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': 'target-1'}, +            {'id': 2, 'name': 'source-2', 'target': 'target-1'}, +            {'id': 3, 'name': 'source-3', 'target': None}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_create_with_valid_null(self): +        data = {'id': 4, 'name': 'source-4', 'target': None} +        serializer = NullableForeignKeySourceSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'source-4') + +        # Ensure source 4 is created, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': 'target-1'}, +            {'id': 2, 'name': 'source-2', 'target': 'target-1'}, +            {'id': 3, 'name': 'source-3', 'target': None}, +            {'id': 4, 'name': 'source-4', 'target': None} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_create_with_valid_emptystring(self): +        """ +        The emptystring should be interpreted as null in the context +        of relationships. +        """ +        data = {'id': 4, 'name': 'source-4', 'target': ''} +        expected_data = {'id': 4, 'name': 'source-4', 'target': None} +        serializer = NullableForeignKeySourceSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, expected_data) +        self.assertEqual(obj.name, 'source-4') + +        # Ensure source 4 is created, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': 'target-1'}, +            {'id': 2, 'name': 'source-2', 'target': 'target-1'}, +            {'id': 3, 'name': 'source-3', 'target': None}, +            {'id': 4, 'name': 'source-4', 'target': None} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_update_with_valid_null(self): +        data = {'id': 1, 'name': 'source-1', 'target': None} +        instance = NullableForeignKeySource.objects.get(pk=1) +        serializer = NullableForeignKeySourceSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        serializer.save() +        self.assertEqual(serializer.data, data) + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': None}, +            {'id': 2, 'name': 'source-2', 'target': 'target-1'}, +            {'id': 3, 'name': 'source-3', 'target': None} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_update_with_valid_emptystring(self): +        """ +        The emptystring should be interpreted as null in the context +        of relationships. +        """ +        data = {'id': 1, 'name': 'source-1', 'target': ''} +        expected_data = {'id': 1, 'name': 'source-1', 'target': None} +        instance = NullableForeignKeySource.objects.get(pk=1) +        serializer = NullableForeignKeySourceSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        serializer.save() +        self.assertEqual(serializer.data, expected_data) + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': None}, +            {'id': 2, 'name': 'source-2', 'target': 'target-1'}, +            {'id': 3, 'name': 'source-3', 'target': None} +        ] +        self.assertEqual(serializer.data, expected) diff --git a/tests/test_renderers.py b/tests/test_renderers.py new file mode 100644 index 00000000..cb76f683 --- /dev/null +++ b/tests/test_renderers.py @@ -0,0 +1,473 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals +from django.conf.urls import patterns, url, include +from django.core.cache import cache +from django.db import models +from django.test import TestCase +from django.utils import six +from django.utils.translation import ugettext_lazy as _ +from rest_framework import status, permissions +from rest_framework.compat import OrderedDict +from rest_framework.response import Response +from rest_framework.views import APIView +from rest_framework import serializers +from rest_framework.renderers import ( +    BaseRenderer, JSONRenderer, BrowsableAPIRenderer, HTMLFormRenderer +) +from rest_framework.settings import api_settings +from rest_framework.test import APIRequestFactory +from collections import MutableMapping +import json +import re + + +DUMMYSTATUS = status.HTTP_200_OK +DUMMYCONTENT = 'dummycontent' + + +def RENDERER_A_SERIALIZER(x): +    return ('Renderer A: %s' % x).encode('ascii') + + +def RENDERER_B_SERIALIZER(x): +    return ('Renderer B: %s' % x).encode('ascii') + + +expected_results = [ +    ((elem for elem in [1, 2, 3]), JSONRenderer, b'[1,2,3]')  # Generator +] + + +class DummyTestModel(models.Model): +    name = models.CharField(max_length=42, default='') + + +class BasicRendererTests(TestCase): +    def test_expected_results(self): +        for value, renderer_cls, expected in expected_results: +            output = renderer_cls().render(value) +            self.assertEqual(output, expected) + + +class RendererA(BaseRenderer): +    media_type = 'mock/renderera' +    format = "formata" + +    def render(self, data, media_type=None, renderer_context=None): +        return RENDERER_A_SERIALIZER(data) + + +class RendererB(BaseRenderer): +    media_type = 'mock/rendererb' +    format = "formatb" + +    def render(self, data, media_type=None, renderer_context=None): +        return RENDERER_B_SERIALIZER(data) + + +class MockView(APIView): +    renderer_classes = (RendererA, RendererB) + +    def get(self, request, **kwargs): +        response = Response(DUMMYCONTENT, status=DUMMYSTATUS) +        return response + + +class MockGETView(APIView): +    def get(self, request, **kwargs): +        return Response({'foo': ['bar', 'baz']}) + + +class MockPOSTView(APIView): +    def post(self, request, **kwargs): +        return Response({'foo': request.DATA}) + + +class EmptyGETView(APIView): +    renderer_classes = (JSONRenderer,) + +    def get(self, request, **kwargs): +        return Response(status=status.HTTP_204_NO_CONTENT) + + +class HTMLView(APIView): +    renderer_classes = (BrowsableAPIRenderer, ) + +    def get(self, request, **kwargs): +        return Response('text') + + +class HTMLView1(APIView): +    renderer_classes = (BrowsableAPIRenderer, JSONRenderer) + +    def get(self, request, **kwargs): +        return Response('text') + +urlpatterns = patterns( +    '', +    url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])), +    url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])), +    url(r'^cache$', MockGETView.as_view()), +    url(r'^parseerror$', MockPOSTView.as_view(renderer_classes=[JSONRenderer, BrowsableAPIRenderer])), +    url(r'^html$', HTMLView.as_view()), +    url(r'^html1$', HTMLView1.as_view()), +    url(r'^empty$', EmptyGETView.as_view()), +    url(r'^api', include('rest_framework.urls', namespace='rest_framework')) +) + + +class POSTDeniedPermission(permissions.BasePermission): +    def has_permission(self, request, view): +        return request.method != 'POST' + + +class POSTDeniedView(APIView): +    renderer_classes = (BrowsableAPIRenderer,) +    permission_classes = (POSTDeniedPermission,) + +    def get(self, request): +        return Response() + +    def post(self, request): +        return Response() + +    def put(self, request): +        return Response() + +    def patch(self, request): +        return Response() + + +class DocumentingRendererTests(TestCase): +    def test_only_permitted_forms_are_displayed(self): +        view = POSTDeniedView.as_view() +        request = APIRequestFactory().get('/') +        response = view(request).render() +        self.assertNotContains(response, '>POST<') +        self.assertContains(response, '>PUT<') +        self.assertContains(response, '>PATCH<') + + +class RendererEndToEndTests(TestCase): +    """ +    End-to-end testing of renderers using an RendererMixin on a generic view. +    """ + +    urls = 'tests.test_renderers' + +    def test_default_renderer_serializes_content(self): +        """If the Accept header is not set the default renderer should serialize the response.""" +        resp = self.client.get('/') +        self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) +        self.assertEqual(resp.status_code, DUMMYSTATUS) + +    def test_head_method_serializes_no_content(self): +        """No response must be included in HEAD requests.""" +        resp = self.client.head('/') +        self.assertEqual(resp.status_code, DUMMYSTATUS) +        self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, six.b('')) + +    def test_default_renderer_serializes_content_on_accept_any(self): +        """If the Accept header is set to */* the default renderer should serialize the response.""" +        resp = self.client.get('/', HTTP_ACCEPT='*/*') +        self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) +        self.assertEqual(resp.status_code, DUMMYSTATUS) + +    def test_specified_renderer_serializes_content_default_case(self): +        """If the Accept header is set the specified renderer should serialize the response. +        (In this case we check that works for the default renderer)""" +        resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type) +        self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) +        self.assertEqual(resp.status_code, DUMMYSTATUS) + +    def test_specified_renderer_serializes_content_non_default_case(self): +        """If the Accept header is set the specified renderer should serialize the response. +        (In this case we check that works for a non-default renderer)""" +        resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type) +        self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) +        self.assertEqual(resp.status_code, DUMMYSTATUS) + +    def test_specified_renderer_serializes_content_on_accept_query(self): +        """The '_accept' query string should behave in the same way as the Accept header.""" +        param = '?%s=%s' % ( +            api_settings.URL_ACCEPT_OVERRIDE, +            RendererB.media_type +        ) +        resp = self.client.get('/' + param) +        self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) +        self.assertEqual(resp.status_code, DUMMYSTATUS) + +    def test_unsatisfiable_accept_header_on_request_returns_406_status(self): +        """If the Accept header is unsatisfiable we should return a 406 Not Acceptable response.""" +        resp = self.client.get('/', HTTP_ACCEPT='foo/bar') +        self.assertEqual(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE) + +    def test_specified_renderer_serializes_content_on_format_query(self): +        """If a 'format' query is specified, the renderer with the matching +        format attribute should serialize the response.""" +        param = '?%s=%s' % ( +            api_settings.URL_FORMAT_OVERRIDE, +            RendererB.format +        ) +        resp = self.client.get('/' + param) +        self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) +        self.assertEqual(resp.status_code, DUMMYSTATUS) + +    def test_specified_renderer_serializes_content_on_format_kwargs(self): +        """If a 'format' keyword arg is specified, the renderer with the matching +        format attribute should serialize the response.""" +        resp = self.client.get('/something.formatb') +        self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) +        self.assertEqual(resp.status_code, DUMMYSTATUS) + +    def test_specified_renderer_is_used_on_format_query_with_matching_accept(self): +        """If both a 'format' query and a matching Accept header specified, +        the renderer with the matching format attribute should serialize the response.""" +        param = '?%s=%s' % ( +            api_settings.URL_FORMAT_OVERRIDE, +            RendererB.format +        ) +        resp = self.client.get('/' + param, +                               HTTP_ACCEPT=RendererB.media_type) +        self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) +        self.assertEqual(resp.status_code, DUMMYSTATUS) + +    def test_parse_error_renderers_browsable_api(self): +        """Invalid data should still render the browsable API correctly.""" +        resp = self.client.post('/parseerror', data='foobar', content_type='application/json', HTTP_ACCEPT='text/html') +        self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') +        self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST) + +    def test_204_no_content_responses_have_no_content_type_set(self): +        """ +        Regression test for #1196 + +        https://github.com/tomchristie/django-rest-framework/issues/1196 +        """ +        resp = self.client.get('/empty') +        self.assertEqual(resp.get('Content-Type', None), None) +        self.assertEqual(resp.status_code, status.HTTP_204_NO_CONTENT) + +    def test_contains_headers_of_api_response(self): +        """ +        Issue #1437 + +        Test we display the headers of the API response and not those from the +        HTML response +        """ +        resp = self.client.get('/html1') +        self.assertContains(resp, '>GET, HEAD, OPTIONS<') +        self.assertContains(resp, '>application/json<') +        self.assertNotContains(resp, '>text/html; charset=utf-8<') + + +_flat_repr = '{"foo":["bar","baz"]}' +_indented_repr = '{\n  "foo": [\n    "bar",\n    "baz"\n  ]\n}' + + +def strip_trailing_whitespace(content): +    """ +    Seems to be some inconsistencies re. trailing whitespace with +    different versions of the json lib. +    """ +    return re.sub(' +\n', '\n', content) + + +class JSONRendererTests(TestCase): +    """ +    Tests specific to the JSON Renderer +    """ + +    def test_render_lazy_strings(self): +        """ +        JSONRenderer should deal with lazy translated strings. +        """ +        ret = JSONRenderer().render(_('test')) +        self.assertEqual(ret, b'"test"') + +    def test_render_queryset_values(self): +        o = DummyTestModel.objects.create(name='dummy') +        qs = DummyTestModel.objects.values('id', 'name') +        ret = JSONRenderer().render(qs) +        data = json.loads(ret.decode('utf-8')) +        self.assertEquals(data, [{'id': o.id, 'name': o.name}]) + +    def test_render_queryset_values_list(self): +        o = DummyTestModel.objects.create(name='dummy') +        qs = DummyTestModel.objects.values_list('id', 'name') +        ret = JSONRenderer().render(qs) +        data = json.loads(ret.decode('utf-8')) +        self.assertEquals(data, [[o.id, o.name]]) + +    def test_render_dict_abc_obj(self): +        class Dict(MutableMapping): +            def __init__(self): +                self._dict = dict() + +            def __getitem__(self, key): +                return self._dict.__getitem__(key) + +            def __setitem__(self, key, value): +                return self._dict.__setitem__(key, value) + +            def __delitem__(self, key): +                return self._dict.__delitem__(key) + +            def __iter__(self): +                return self._dict.__iter__() + +            def __len__(self): +                return self._dict.__len__() + +            def keys(self): +                return self._dict.keys() + +        x = Dict() +        x['key'] = 'string value' +        x[2] = 3 +        ret = JSONRenderer().render(x) +        data = json.loads(ret.decode('utf-8')) +        self.assertEquals(data, {'key': 'string value', '2': 3}) + +    def test_render_obj_with_getitem(self): +        class DictLike(object): +            def __init__(self): +                self._dict = {} + +            def set(self, value): +                self._dict = dict(value) + +            def __getitem__(self, key): +                return self._dict[key] + +        x = DictLike() +        x.set({'a': 1, 'b': 'string'}) +        with self.assertRaises(TypeError): +            JSONRenderer().render(x) + +    def test_without_content_type_args(self): +        """ +        Test basic JSON rendering. +        """ +        obj = {'foo': ['bar', 'baz']} +        renderer = JSONRenderer() +        content = renderer.render(obj, 'application/json') +        # Fix failing test case which depends on version of JSON library. +        self.assertEqual(content.decode('utf-8'), _flat_repr) + +    def test_with_content_type_args(self): +        """ +        Test JSON rendering with additional content type arguments supplied. +        """ +        obj = {'foo': ['bar', 'baz']} +        renderer = JSONRenderer() +        content = renderer.render(obj, 'application/json; indent=2') +        self.assertEqual(strip_trailing_whitespace(content.decode('utf-8')), _indented_repr) + + +class UnicodeJSONRendererTests(TestCase): +    """ +    Tests specific for the Unicode JSON Renderer +    """ +    def test_proper_encoding(self): +        obj = {'countries': ['United Kingdom', 'France', 'España']} +        renderer = JSONRenderer() +        content = renderer.render(obj, 'application/json') +        self.assertEqual(content, '{"countries":["United Kingdom","France","España"]}'.encode('utf-8')) + +    def test_u2028_u2029(self): +        # The \u2028 and \u2029 characters should be escaped, +        # even when the non-escaping unicode representation is used. +        # Regression test for #2169 +        obj = {'should_escape': '\u2028\u2029'} +        renderer = JSONRenderer() +        content = renderer.render(obj, 'application/json') +        self.assertEqual(content, '{"should_escape":"\\u2028\\u2029"}'.encode('utf-8')) + + +class AsciiJSONRendererTests(TestCase): +    """ +    Tests specific for the Unicode JSON Renderer +    """ +    def test_proper_encoding(self): +        class AsciiJSONRenderer(JSONRenderer): +            ensure_ascii = True +        obj = {'countries': ['United Kingdom', 'France', 'España']} +        renderer = AsciiJSONRenderer() +        content = renderer.render(obj, 'application/json') +        self.assertEqual(content, '{"countries":["United Kingdom","France","Espa\\u00f1a"]}'.encode('utf-8')) + + +# Tests for caching issue, #346 +class CacheRenderTest(TestCase): +    """ +    Tests specific to caching responses +    """ + +    urls = 'tests.test_renderers' + +    def test_head_caching(self): +        """ +        Test caching of HEAD requests +        """ +        response = self.client.head('/cache') +        cache.set('key', response) +        cached_response = cache.get('key') +        assert isinstance(cached_response, Response) +        assert cached_response.content == response.content +        assert cached_response.status_code == response.status_code + +    def test_get_caching(self): +        """ +        Test caching of GET requests +        """ +        response = self.client.get('/cache') +        cache.set('key', response) +        cached_response = cache.get('key') +        assert isinstance(cached_response, Response) +        assert cached_response.content == response.content +        assert cached_response.status_code == response.status_code + + +class TestJSONIndentationStyles: +    def test_indented(self): +        renderer = JSONRenderer() +        data = OrderedDict([('a', 1), ('b', 2)]) +        assert renderer.render(data) == b'{"a":1,"b":2}' + +    def test_compact(self): +        renderer = JSONRenderer() +        data = OrderedDict([('a', 1), ('b', 2)]) +        context = {'indent': 4} +        assert ( +            renderer.render(data, renderer_context=context) == +            b'{\n    "a": 1,\n    "b": 2\n}' +        ) + +    def test_long_form(self): +        renderer = JSONRenderer() +        renderer.compact = False +        data = OrderedDict([('a', 1), ('b', 2)]) +        assert renderer.render(data) == b'{"a": 1, "b": 2}' + + +class TestHiddenFieldHTMLFormRenderer(TestCase): +    def test_hidden_field_rendering(self): +        class TestSerializer(serializers.Serializer): +            published = serializers.HiddenField(default=True) + +        serializer = TestSerializer(data={}) +        serializer.is_valid() +        renderer = HTMLFormRenderer() +        field = serializer['published'] +        rendered = renderer.render_field(field, {}) +        assert rendered == '' diff --git a/tests/test_request.py b/tests/test_request.py new file mode 100644 index 00000000..c274ab69 --- /dev/null +++ b/tests/test_request.py @@ -0,0 +1,278 @@ +""" +Tests for content parsing, and form-overloaded content parsing. +""" +from __future__ import unicode_literals +from django.conf.urls import patterns +from django.contrib.auth.models import User +from django.contrib.auth import authenticate, login, logout +from django.contrib.sessions.middleware import SessionMiddleware +from django.core.handlers.wsgi import WSGIRequest +from django.test import TestCase +from django.utils import six +from rest_framework import status +from rest_framework.authentication import SessionAuthentication +from rest_framework.parsers import ( +    BaseParser, +    FormParser, +    MultiPartParser, +    JSONParser +) +from rest_framework.request import Request, Empty +from rest_framework.response import Response +from rest_framework.settings import api_settings +from rest_framework.test import APIRequestFactory, APIClient +from rest_framework.views import APIView +from io import BytesIO +import json + + +factory = APIRequestFactory() + + +class PlainTextParser(BaseParser): +    media_type = 'text/plain' + +    def parse(self, stream, media_type=None, parser_context=None): +        """ +        Returns a 2-tuple of `(data, files)`. + +        `data` will simply be a string representing the body of the request. +        `files` will always be `None`. +        """ +        return stream.read() + + +class TestMethodOverloading(TestCase): +    def test_method(self): +        """ +        Request methods should be same as underlying request. +        """ +        request = Request(factory.get('/')) +        self.assertEqual(request.method, 'GET') +        request = Request(factory.post('/')) +        self.assertEqual(request.method, 'POST') + +    def test_overloaded_method(self): +        """ +        POST requests can be overloaded to another method by setting a +        reserved form field +        """ +        request = Request(factory.post('/', {api_settings.FORM_METHOD_OVERRIDE: 'DELETE'})) +        self.assertEqual(request.method, 'DELETE') + +    def test_x_http_method_override_header(self): +        """ +        POST requests can also be overloaded to another method by setting +        the X-HTTP-Method-Override header. +        """ +        request = Request(factory.post('/', {'foo': 'bar'}, HTTP_X_HTTP_METHOD_OVERRIDE='DELETE')) +        self.assertEqual(request.method, 'DELETE') + +        request = Request(factory.get('/', {'foo': 'bar'}, HTTP_X_HTTP_METHOD_OVERRIDE='DELETE')) +        self.assertEqual(request.method, 'DELETE') + + +class TestContentParsing(TestCase): +    def test_standard_behaviour_determines_no_content_GET(self): +        """ +        Ensure request.DATA returns empty QueryDict for GET request. +        """ +        request = Request(factory.get('/')) +        self.assertEqual(request.DATA, {}) + +    def test_standard_behaviour_determines_no_content_HEAD(self): +        """ +        Ensure request.DATA returns empty QueryDict for HEAD request. +        """ +        request = Request(factory.head('/')) +        self.assertEqual(request.DATA, {}) + +    def test_request_DATA_with_form_content(self): +        """ +        Ensure request.DATA returns content for POST request with form content. +        """ +        data = {'qwerty': 'uiop'} +        request = Request(factory.post('/', data)) +        request.parsers = (FormParser(), MultiPartParser()) +        self.assertEqual(list(request.DATA.items()), list(data.items())) + +    def test_request_DATA_with_text_content(self): +        """ +        Ensure request.DATA returns content for POST request with +        non-form content. +        """ +        content = six.b('qwerty') +        content_type = 'text/plain' +        request = Request(factory.post('/', content, content_type=content_type)) +        request.parsers = (PlainTextParser(),) +        self.assertEqual(request.DATA, content) + +    def test_request_POST_with_form_content(self): +        """ +        Ensure request.POST returns content for POST request with form content. +        """ +        data = {'qwerty': 'uiop'} +        request = Request(factory.post('/', data)) +        request.parsers = (FormParser(), MultiPartParser()) +        self.assertEqual(list(request.POST.items()), list(data.items())) + +    def test_standard_behaviour_determines_form_content_PUT(self): +        """ +        Ensure request.DATA returns content for PUT request with form content. +        """ +        data = {'qwerty': 'uiop'} +        request = Request(factory.put('/', data)) +        request.parsers = (FormParser(), MultiPartParser()) +        self.assertEqual(list(request.DATA.items()), list(data.items())) + +    def test_standard_behaviour_determines_non_form_content_PUT(self): +        """ +        Ensure request.DATA returns content for PUT request with +        non-form content. +        """ +        content = six.b('qwerty') +        content_type = 'text/plain' +        request = Request(factory.put('/', content, content_type=content_type)) +        request.parsers = (PlainTextParser(), ) +        self.assertEqual(request.DATA, content) + +    def test_overloaded_behaviour_allows_content_tunnelling(self): +        """ +        Ensure request.DATA returns content for overloaded POST request. +        """ +        json_data = {'foobar': 'qwerty'} +        content = json.dumps(json_data) +        content_type = 'application/json' +        form_data = { +            api_settings.FORM_CONTENT_OVERRIDE: content, +            api_settings.FORM_CONTENTTYPE_OVERRIDE: content_type +        } +        request = Request(factory.post('/', form_data)) +        request.parsers = (JSONParser(), ) +        self.assertEqual(request.DATA, json_data) + +    def test_form_POST_unicode(self): +        """ +        JSON POST via default web interface with unicode data +        """ +        # Note: environ and other variables here have simplified content compared to real Request +        CONTENT = b'_content_type=application%2Fjson&_content=%7B%22request%22%3A+4%2C+%22firm%22%3A+1%2C+%22text%22%3A+%22%D0%9F%D1%80%D0%B8%D0%B2%D0%B5%D1%82%21%22%7D' +        environ = { +            'REQUEST_METHOD': 'POST', +            'CONTENT_TYPE': 'application/x-www-form-urlencoded', +            'CONTENT_LENGTH': len(CONTENT), +            'wsgi.input': BytesIO(CONTENT), +        } +        wsgi_request = WSGIRequest(environ=environ) +        wsgi_request._load_post_and_files() +        parsers = (JSONParser(), FormParser(), MultiPartParser()) +        parser_context = { +            'encoding': 'utf-8', +            'kwargs': {}, +            'args': (), +        } +        request = Request(wsgi_request, parsers=parsers, parser_context=parser_context) +        method = request.method +        self.assertEqual(method, 'POST') +        self.assertEqual(request._content_type, 'application/json') +        self.assertEqual(request._stream.getvalue(), b'{"request": 4, "firm": 1, "text": "\xd0\x9f\xd1\x80\xd0\xb8\xd0\xb2\xd0\xb5\xd1\x82!"}') +        self.assertEqual(request._data, Empty) +        self.assertEqual(request._files, Empty) + + +class MockView(APIView): +    authentication_classes = (SessionAuthentication,) + +    def post(self, request): +        if request.POST.get('example') is not None: +            return Response(status=status.HTTP_200_OK) + +        return Response(status=status.HTTP_500_INTERNAL_SERVER_ERROR) + +urlpatterns = patterns( +    '', +    (r'^$', MockView.as_view()), +) + + +class TestContentParsingWithAuthentication(TestCase): +    urls = 'tests.test_request' + +    def setUp(self): +        self.csrf_client = APIClient(enforce_csrf_checks=True) +        self.username = 'john' +        self.email = 'lennon@thebeatles.com' +        self.password = 'password' +        self.user = User.objects.create_user(self.username, self.email, self.password) + +    def test_user_logged_in_authentication_has_POST_when_not_logged_in(self): +        """ +        Ensures request.POST exists after SessionAuthentication when user +        doesn't log in. +        """ +        content = {'example': 'example'} + +        response = self.client.post('/', content) +        self.assertEqual(status.HTTP_200_OK, response.status_code) + +        response = self.csrf_client.post('/', content) +        self.assertEqual(status.HTTP_200_OK, response.status_code) + + +class TestUserSetter(TestCase): + +    def setUp(self): +        # Pass request object through session middleware so session is +        # available to login and logout functions +        self.wrapped_request = factory.get('/') +        self.request = Request(self.wrapped_request) +        SessionMiddleware().process_request(self.request) + +        User.objects.create_user('ringo', 'starr@thebeatles.com', 'yellow') +        self.user = authenticate(username='ringo', password='yellow') + +    def test_user_can_be_set(self): +        self.request.user = self.user +        self.assertEqual(self.request.user, self.user) + +    def test_user_can_login(self): +        login(self.request, self.user) +        self.assertEqual(self.request.user, self.user) + +    def test_user_can_logout(self): +        self.request.user = self.user +        self.assertFalse(self.request.user.is_anonymous()) +        logout(self.request) +        self.assertTrue(self.request.user.is_anonymous()) + +    def test_logged_in_user_is_set_on_wrapped_request(self): +        login(self.request, self.user) +        self.assertEqual(self.wrapped_request.user, self.user) + +    def test_calling_user_fails_when_attribute_error_is_raised(self): +        """ +        This proves that when an AttributeError is raised inside of the request.user +        property, that we can handle this and report the true, underlying error. +        """ +        class AuthRaisesAttributeError(object): +            def authenticate(self, request): +                import rest_framework +                rest_framework.MISSPELLED_NAME_THAT_DOESNT_EXIST + +        self.request = Request(factory.get('/'), authenticators=(AuthRaisesAttributeError(),)) +        SessionMiddleware().process_request(self.request) + +        login(self.request, self.user) +        try: +            self.request.user +        except AttributeError as error: +            self.assertEqual(str(error), "'module' object has no attribute 'MISSPELLED_NAME_THAT_DOESNT_EXIST'") +        else: +            assert False, 'AttributeError not raised' + + +class TestAuthSetter(TestCase): +    def test_auth_can_be_set(self): +        request = Request(factory.get('/')) +        request.auth = 'DUMMY' +        self.assertEqual(request.auth, 'DUMMY') diff --git a/tests/test_response.py b/tests/test_response.py new file mode 100644 index 00000000..4a9deaa2 --- /dev/null +++ b/tests/test_response.py @@ -0,0 +1,292 @@ +from __future__ import unicode_literals +from django.conf.urls import patterns, url, include +from django.test import TestCase +from django.utils import six +from tests.models import BasicModel +from rest_framework.response import Response +from rest_framework.views import APIView +from rest_framework import generics +from rest_framework import routers +from rest_framework import serializers +from rest_framework import status +from rest_framework.renderers import ( +    BaseRenderer, +    JSONRenderer, +    BrowsableAPIRenderer +) +from rest_framework import viewsets +from rest_framework.settings import api_settings + + +# Serializer used to test BasicModel +class BasicModelSerializer(serializers.ModelSerializer): +    class Meta: +        model = BasicModel + + +class MockPickleRenderer(BaseRenderer): +    media_type = 'application/pickle' + + +class MockJsonRenderer(BaseRenderer): +    media_type = 'application/json' + + +class MockTextMediaRenderer(BaseRenderer): +    media_type = 'text/html' + +DUMMYSTATUS = status.HTTP_200_OK +DUMMYCONTENT = 'dummycontent' + + +def RENDERER_A_SERIALIZER(x): +    return ('Renderer A: %s' % x).encode('ascii') + + +def RENDERER_B_SERIALIZER(x): +    return ('Renderer B: %s' % x).encode('ascii') + + +class RendererA(BaseRenderer): +    media_type = 'mock/renderera' +    format = "formata" + +    def render(self, data, media_type=None, renderer_context=None): +        return RENDERER_A_SERIALIZER(data) + + +class RendererB(BaseRenderer): +    media_type = 'mock/rendererb' +    format = "formatb" + +    def render(self, data, media_type=None, renderer_context=None): +        return RENDERER_B_SERIALIZER(data) + + +class RendererC(RendererB): +    media_type = 'mock/rendererc' +    format = 'formatc' +    charset = "rendererc" + + +class MockView(APIView): +    renderer_classes = (RendererA, RendererB, RendererC) + +    def get(self, request, **kwargs): +        return Response(DUMMYCONTENT, status=DUMMYSTATUS) + + +class MockViewSettingContentType(APIView): +    renderer_classes = (RendererA, RendererB, RendererC) + +    def get(self, request, **kwargs): +        return Response(DUMMYCONTENT, status=DUMMYSTATUS, content_type='setbyview') + + +class HTMLView(APIView): +    renderer_classes = (BrowsableAPIRenderer, ) + +    def get(self, request, **kwargs): +        return Response('text') + + +class HTMLView1(APIView): +    renderer_classes = (BrowsableAPIRenderer, JSONRenderer) + +    def get(self, request, **kwargs): +        return Response('text') + + +class HTMLNewModelViewSet(viewsets.ModelViewSet): +    serializer_class = BasicModelSerializer +    queryset = BasicModel.objects.all() + + +class HTMLNewModelView(generics.ListCreateAPIView): +    renderer_classes = (BrowsableAPIRenderer,) +    permission_classes = [] +    serializer_class = BasicModelSerializer +    queryset = BasicModel.objects.all() + + +new_model_viewset_router = routers.DefaultRouter() +new_model_viewset_router.register(r'', HTMLNewModelViewSet) + + +urlpatterns = patterns( +    '', +    url(r'^setbyview$', MockViewSettingContentType.as_view(renderer_classes=[RendererA, RendererB, RendererC])), +    url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])), +    url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])), +    url(r'^html$', HTMLView.as_view()), +    url(r'^html1$', HTMLView1.as_view()), +    url(r'^html_new_model$', HTMLNewModelView.as_view()), +    url(r'^html_new_model_viewset', include(new_model_viewset_router.urls)), +    url(r'^restframework', include('rest_framework.urls', namespace='rest_framework')) +) + + +# TODO: Clean tests bellow - remove duplicates with above, better unit testing, ... +class RendererIntegrationTests(TestCase): +    """ +    End-to-end testing of renderers using an ResponseMixin on a generic view. +    """ + +    urls = 'tests.test_response' + +    def test_default_renderer_serializes_content(self): +        """If the Accept header is not set the default renderer should serialize the response.""" +        resp = self.client.get('/') +        self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) +        self.assertEqual(resp.status_code, DUMMYSTATUS) + +    def test_head_method_serializes_no_content(self): +        """No response must be included in HEAD requests.""" +        resp = self.client.head('/') +        self.assertEqual(resp.status_code, DUMMYSTATUS) +        self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, six.b('')) + +    def test_default_renderer_serializes_content_on_accept_any(self): +        """If the Accept header is set to */* the default renderer should serialize the response.""" +        resp = self.client.get('/', HTTP_ACCEPT='*/*') +        self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) +        self.assertEqual(resp.status_code, DUMMYSTATUS) + +    def test_specified_renderer_serializes_content_default_case(self): +        """If the Accept header is set the specified renderer should serialize the response. +        (In this case we check that works for the default renderer)""" +        resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type) +        self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) +        self.assertEqual(resp.status_code, DUMMYSTATUS) + +    def test_specified_renderer_serializes_content_non_default_case(self): +        """If the Accept header is set the specified renderer should serialize the response. +        (In this case we check that works for a non-default renderer)""" +        resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type) +        self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) +        self.assertEqual(resp.status_code, DUMMYSTATUS) + +    def test_specified_renderer_serializes_content_on_accept_query(self): +        """The '_accept' query string should behave in the same way as the Accept header.""" +        param = '?%s=%s' % ( +            api_settings.URL_ACCEPT_OVERRIDE, +            RendererB.media_type +        ) +        resp = self.client.get('/' + param) +        self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) +        self.assertEqual(resp.status_code, DUMMYSTATUS) + +    def test_specified_renderer_serializes_content_on_format_query(self): +        """If a 'format' query is specified, the renderer with the matching +        format attribute should serialize the response.""" +        resp = self.client.get('/?format=%s' % RendererB.format) +        self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) +        self.assertEqual(resp.status_code, DUMMYSTATUS) + +    def test_specified_renderer_serializes_content_on_format_kwargs(self): +        """If a 'format' keyword arg is specified, the renderer with the matching +        format attribute should serialize the response.""" +        resp = self.client.get('/something.formatb') +        self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) +        self.assertEqual(resp.status_code, DUMMYSTATUS) + +    def test_specified_renderer_is_used_on_format_query_with_matching_accept(self): +        """If both a 'format' query and a matching Accept header specified, +        the renderer with the matching format attribute should serialize the response.""" +        resp = self.client.get('/?format=%s' % RendererB.format, +                               HTTP_ACCEPT=RendererB.media_type) +        self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) +        self.assertEqual(resp.status_code, DUMMYSTATUS) + + +class Issue122Tests(TestCase): +    """ +    Tests that covers #122. +    """ +    urls = 'tests.test_response' + +    def test_only_html_renderer(self): +        """ +        Test if no infinite recursion occurs. +        """ +        self.client.get('/html') + +    def test_html_renderer_is_first(self): +        """ +        Test if no infinite recursion occurs. +        """ +        self.client.get('/html1') + + +class Issue467Tests(TestCase): +    """ +    Tests for #467 +    """ + +    urls = 'tests.test_response' + +    def test_form_has_label_and_help_text(self): +        resp = self.client.get('/html_new_model') +        self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') +        # self.assertContains(resp, 'Text comes here') +        # self.assertContains(resp, 'Text description.') + + +class Issue807Tests(TestCase): +    """ +    Covers #807 +    """ + +    urls = 'tests.test_response' + +    def test_does_not_append_charset_by_default(self): +        """ +        Renderers don't include a charset unless set explicitly. +        """ +        headers = {"HTTP_ACCEPT": RendererA.media_type} +        resp = self.client.get('/', **headers) +        expected = "{0}; charset={1}".format(RendererA.media_type, 'utf-8') +        self.assertEqual(expected, resp['Content-Type']) + +    def test_if_there_is_charset_specified_on_renderer_it_gets_appended(self): +        """ +        If renderer class has charset attribute declared, it gets appended +        to Response's Content-Type +        """ +        headers = {"HTTP_ACCEPT": RendererC.media_type} +        resp = self.client.get('/', **headers) +        expected = "{0}; charset={1}".format(RendererC.media_type, RendererC.charset) +        self.assertEqual(expected, resp['Content-Type']) + +    def test_content_type_set_explicitly_on_response(self): +        """ +        The content type may be set explicitly on the response. +        """ +        headers = {"HTTP_ACCEPT": RendererC.media_type} +        resp = self.client.get('/setbyview', **headers) +        self.assertEqual('setbyview', resp['Content-Type']) + +    def test_viewset_label_help_text(self): +        param = '?%s=%s' % ( +            api_settings.URL_ACCEPT_OVERRIDE, +            'text/html' +        ) +        resp = self.client.get('/html_new_model_viewset/' + param) +        self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') +        # self.assertContains(resp, 'Text comes here') +        # self.assertContains(resp, 'Text description.') + +    def test_form_has_label_and_help_text(self): +        resp = self.client.get('/html_new_model') +        self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') +        # self.assertContains(resp, 'Text comes here') +        # self.assertContains(resp, 'Text description.') diff --git a/tests/test_reverse.py b/tests/test_reverse.py new file mode 100644 index 00000000..675a9d5a --- /dev/null +++ b/tests/test_reverse.py @@ -0,0 +1,28 @@ +from __future__ import unicode_literals +from django.conf.urls import patterns, url +from django.test import TestCase +from rest_framework.reverse import reverse +from rest_framework.test import APIRequestFactory + +factory = APIRequestFactory() + + +def null_view(request): +    pass + +urlpatterns = patterns( +    '', +    url(r'^view$', null_view, name='view'), +) + + +class ReverseTests(TestCase): +    """ +    Tests for fully qualified URLs when using `reverse`. +    """ +    urls = 'tests.test_reverse' + +    def test_reversed_urls_are_fully_qualified(self): +        request = factory.get('/view') +        url = reverse('view', request=request) +        self.assertEqual(url, 'http://testserver/view') diff --git a/tests/test_routers.py b/tests/test_routers.py new file mode 100644 index 00000000..08c58ec7 --- /dev/null +++ b/tests/test_routers.py @@ -0,0 +1,348 @@ +from __future__ import unicode_literals +from django.conf.urls import url, include +from django.db import models +from django.test import TestCase +from django.core.exceptions import ImproperlyConfigured +from rest_framework import serializers, viewsets, permissions +from rest_framework.decorators import detail_route, list_route +from rest_framework.response import Response +from rest_framework.routers import SimpleRouter, DefaultRouter +from rest_framework.test import APIRequestFactory +from collections import namedtuple + +factory = APIRequestFactory() + + +class RouterTestModel(models.Model): +    uuid = models.CharField(max_length=20) +    text = models.CharField(max_length=200) + + +class NoteSerializer(serializers.HyperlinkedModelSerializer): +    url = serializers.HyperlinkedIdentityField(view_name='routertestmodel-detail', lookup_field='uuid') + +    class Meta: +        model = RouterTestModel +        fields = ('url', 'uuid', 'text') + + +class NoteViewSet(viewsets.ModelViewSet): +    queryset = RouterTestModel.objects.all() +    serializer_class = NoteSerializer +    lookup_field = 'uuid' + + +class MockViewSet(viewsets.ModelViewSet): +    queryset = None +    serializer_class = None + + +notes_router = SimpleRouter() +notes_router.register(r'notes', NoteViewSet) + +namespaced_router = DefaultRouter() +namespaced_router.register(r'example', MockViewSet, base_name='example') + +urlpatterns = [ +    url(r'^non-namespaced/', include(namespaced_router.urls)), +    url(r'^namespaced/', include(namespaced_router.urls, namespace='example')), +    url(r'^example/', include(notes_router.urls)), +] + + +class BasicViewSet(viewsets.ViewSet): +    def list(self, request, *args, **kwargs): +        return Response({'method': 'list'}) + +    @detail_route(methods=['post']) +    def action1(self, request, *args, **kwargs): +        return Response({'method': 'action1'}) + +    @detail_route(methods=['post']) +    def action2(self, request, *args, **kwargs): +        return Response({'method': 'action2'}) + +    @detail_route(methods=['post', 'delete']) +    def action3(self, request, *args, **kwargs): +        return Response({'method': 'action2'}) + +    @detail_route() +    def link1(self, request, *args, **kwargs): +        return Response({'method': 'link1'}) + +    @detail_route() +    def link2(self, request, *args, **kwargs): +        return Response({'method': 'link2'}) + + +class TestSimpleRouter(TestCase): +    def setUp(self): +        self.router = SimpleRouter() + +    def test_link_and_action_decorator(self): +        routes = self.router.get_routes(BasicViewSet) +        decorator_routes = routes[2:] +        # Make sure all these endpoints exist and none have been clobbered +        for i, endpoint in enumerate(['action1', 'action2', 'action3', 'link1', 'link2']): +            route = decorator_routes[i] +            # check url listing +            self.assertEqual(route.url, +                             '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(endpoint)) +            # check method to function mapping +            if endpoint == 'action3': +                methods_map = ['post', 'delete'] +            elif endpoint.startswith('action'): +                methods_map = ['post'] +            else: +                methods_map = ['get'] +            for method in methods_map: +                self.assertEqual(route.mapping[method], endpoint) + + +class TestRootView(TestCase): +    urls = 'tests.test_routers' + +    def test_retrieve_namespaced_root(self): +        response = self.client.get('/namespaced/') +        self.assertEqual( +            response.data, +            { +                "example": "http://testserver/namespaced/example/", +            } +        ) + +    def test_retrieve_non_namespaced_root(self): +        response = self.client.get('/non-namespaced/') +        self.assertEqual( +            response.data, +            { +                "example": "http://testserver/non-namespaced/example/", +            } +        ) + + +class TestCustomLookupFields(TestCase): +    """ +    Ensure that custom lookup fields are correctly routed. +    """ +    urls = 'tests.test_routers' + +    def setUp(self): +        RouterTestModel.objects.create(uuid='123', text='foo bar') + +    def test_custom_lookup_field_route(self): +        detail_route = notes_router.urls[-1] +        detail_url_pattern = detail_route.regex.pattern +        self.assertIn('<uuid>', detail_url_pattern) + +    def test_retrieve_lookup_field_list_view(self): +        response = self.client.get('/example/notes/') +        self.assertEqual( +            response.data, +            [{ +                "url": "http://testserver/example/notes/123/", +                "uuid": "123", "text": "foo bar" +            }] +        ) + +    def test_retrieve_lookup_field_detail_view(self): +        response = self.client.get('/example/notes/123/') +        self.assertEqual( +            response.data, +            { +                "url": "http://testserver/example/notes/123/", +                "uuid": "123", "text": "foo bar" +            } +        ) + + +class TestLookupValueRegex(TestCase): +    """ +    Ensure the router honors lookup_value_regex when applied +    to the viewset. +    """ +    def setUp(self): +        class NoteViewSet(viewsets.ModelViewSet): +            queryset = RouterTestModel.objects.all() +            lookup_field = 'uuid' +            lookup_value_regex = '[0-9a-f]{32}' + +        self.router = SimpleRouter() +        self.router.register(r'notes', NoteViewSet) +        self.urls = self.router.urls + +    def test_urls_limited_by_lookup_value_regex(self): +        expected = ['^notes/$', '^notes/(?P<uuid>[0-9a-f]{32})/$'] +        for idx in range(len(expected)): +            self.assertEqual(expected[idx], self.urls[idx].regex.pattern) + + +class TestTrailingSlashIncluded(TestCase): +    def setUp(self): +        class NoteViewSet(viewsets.ModelViewSet): +            queryset = RouterTestModel.objects.all() + +        self.router = SimpleRouter() +        self.router.register(r'notes', NoteViewSet) +        self.urls = self.router.urls + +    def test_urls_have_trailing_slash_by_default(self): +        expected = ['^notes/$', '^notes/(?P<pk>[^/.]+)/$'] +        for idx in range(len(expected)): +            self.assertEqual(expected[idx], self.urls[idx].regex.pattern) + + +class TestTrailingSlashRemoved(TestCase): +    def setUp(self): +        class NoteViewSet(viewsets.ModelViewSet): +            queryset = RouterTestModel.objects.all() + +        self.router = SimpleRouter(trailing_slash=False) +        self.router.register(r'notes', NoteViewSet) +        self.urls = self.router.urls + +    def test_urls_can_have_trailing_slash_removed(self): +        expected = ['^notes$', '^notes/(?P<pk>[^/.]+)$'] +        for idx in range(len(expected)): +            self.assertEqual(expected[idx], self.urls[idx].regex.pattern) + + +class TestNameableRoot(TestCase): +    def setUp(self): +        class NoteViewSet(viewsets.ModelViewSet): +            queryset = RouterTestModel.objects.all() + +        self.router = DefaultRouter() +        self.router.root_view_name = 'nameable-root' +        self.router.register(r'notes', NoteViewSet) +        self.urls = self.router.urls + +    def test_router_has_custom_name(self): +        expected = 'nameable-root' +        self.assertEqual(expected, self.urls[0].name) + + +class TestActionKeywordArgs(TestCase): +    """ +    Ensure keyword arguments passed in the `@action` decorator +    are properly handled.  Refs #940. +    """ + +    def setUp(self): +        class TestViewSet(viewsets.ModelViewSet): +            permission_classes = [] + +            @detail_route(methods=['post'], permission_classes=[permissions.AllowAny]) +            def custom(self, request, *args, **kwargs): +                return Response({ +                    'permission_classes': self.permission_classes +                }) + +        self.router = SimpleRouter() +        self.router.register(r'test', TestViewSet, base_name='test') +        self.view = self.router.urls[-1].callback + +    def test_action_kwargs(self): +        request = factory.post('/test/0/custom/') +        response = self.view(request) +        self.assertEqual( +            response.data, +            {'permission_classes': [permissions.AllowAny]} +        ) + + +class TestActionAppliedToExistingRoute(TestCase): +    """ +    Ensure `@detail_route` decorator raises an except when applied +    to an existing route +    """ + +    def test_exception_raised_when_action_applied_to_existing_route(self): +        class TestViewSet(viewsets.ModelViewSet): + +            @detail_route(methods=['post']) +            def retrieve(self, request, *args, **kwargs): +                return Response({ +                    'hello': 'world' +                }) + +        self.router = SimpleRouter() +        self.router.register(r'test', TestViewSet, base_name='test') + +        with self.assertRaises(ImproperlyConfigured): +            self.router.urls + + +class DynamicListAndDetailViewSet(viewsets.ViewSet): +    def list(self, request, *args, **kwargs): +        return Response({'method': 'list'}) + +    @list_route(methods=['post']) +    def list_route_post(self, request, *args, **kwargs): +        return Response({'method': 'action1'}) + +    @detail_route(methods=['post']) +    def detail_route_post(self, request, *args, **kwargs): +        return Response({'method': 'action2'}) + +    @list_route() +    def list_route_get(self, request, *args, **kwargs): +        return Response({'method': 'link1'}) + +    @detail_route() +    def detail_route_get(self, request, *args, **kwargs): +        return Response({'method': 'link2'}) + +    @list_route(url_path="list_custom-route") +    def list_custom_route_get(self, request, *args, **kwargs): +        return Response({'method': 'link1'}) + +    @detail_route(url_path="detail_custom-route") +    def detail_custom_route_get(self, request, *args, **kwargs): +        return Response({'method': 'link2'}) + + +class SubDynamicListAndDetailViewSet(DynamicListAndDetailViewSet): +    pass + + +class TestDynamicListAndDetailRouter(TestCase): +    def setUp(self): +        self.router = SimpleRouter() + +    def _test_list_and_detail_route_decorators(self, viewset): +        routes = self.router.get_routes(viewset) +        decorator_routes = [r for r in routes if not (r.name.endswith('-list') or r.name.endswith('-detail'))] + +        MethodNamesMap = namedtuple('MethodNamesMap', 'method_name url_path') +        # Make sure all these endpoints exist and none have been clobbered +        for i, endpoint in enumerate([MethodNamesMap('list_custom_route_get', 'list_custom-route'), +                                      MethodNamesMap('list_route_get', 'list_route_get'), +                                      MethodNamesMap('list_route_post', 'list_route_post'), +                                      MethodNamesMap('detail_custom_route_get', 'detail_custom-route'), +                                      MethodNamesMap('detail_route_get', 'detail_route_get'), +                                      MethodNamesMap('detail_route_post', 'detail_route_post') +                                      ]): +            route = decorator_routes[i] +            # check url listing +            method_name = endpoint.method_name +            url_path = endpoint.url_path + +            if method_name.startswith('list_'): +                self.assertEqual(route.url, +                                 '^{{prefix}}/{0}{{trailing_slash}}$'.format(url_path)) +            else: +                self.assertEqual(route.url, +                                 '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(url_path)) +            # check method to function mapping +            if method_name.endswith('_post'): +                method_map = 'post' +            else: +                method_map = 'get' +            self.assertEqual(route.mapping[method_map], method_name) + +    def test_list_and_detail_route_decorators(self): +        self._test_list_and_detail_route_decorators(DynamicListAndDetailViewSet) + +    def test_inherited_list_and_detail_route_decorators(self): +        self._test_list_and_detail_route_decorators(SubDynamicListAndDetailViewSet) diff --git a/tests/test_serializer.py b/tests/test_serializer.py new file mode 100644 index 00000000..b7a0484b --- /dev/null +++ b/tests/test_serializer.py @@ -0,0 +1,297 @@ +# coding: utf-8 +from __future__ import unicode_literals +from .utils import MockObject +from rest_framework import serializers +from rest_framework.compat import unicode_repr +import pickle +import pytest + + +# Tests for core functionality. +# ----------------------------- + +class TestSerializer: +    def setup(self): +        class ExampleSerializer(serializers.Serializer): +            char = serializers.CharField() +            integer = serializers.IntegerField() +        self.Serializer = ExampleSerializer + +    def test_valid_serializer(self): +        serializer = self.Serializer(data={'char': 'abc', 'integer': 123}) +        assert serializer.is_valid() +        assert serializer.validated_data == {'char': 'abc', 'integer': 123} +        assert serializer.errors == {} + +    def test_invalid_serializer(self): +        serializer = self.Serializer(data={'char': 'abc'}) +        assert not serializer.is_valid() +        assert serializer.validated_data == {} +        assert serializer.errors == {'integer': ['This field is required.']} + +    def test_partial_validation(self): +        serializer = self.Serializer(data={'char': 'abc'}, partial=True) +        assert serializer.is_valid() +        assert serializer.validated_data == {'char': 'abc'} +        assert serializer.errors == {} + +    def test_empty_serializer(self): +        serializer = self.Serializer() +        assert serializer.data == {'char': '', 'integer': None} + +    def test_missing_attribute_during_serialization(self): +        class MissingAttributes: +            pass +        instance = MissingAttributes() +        serializer = self.Serializer(instance) +        with pytest.raises(AttributeError): +            serializer.data + + +class TestValidateMethod: +    def test_non_field_error_validate_method(self): +        class ExampleSerializer(serializers.Serializer): +            char = serializers.CharField() +            integer = serializers.IntegerField() + +            def validate(self, attrs): +                raise serializers.ValidationError('Non field error') + +        serializer = ExampleSerializer(data={'char': 'abc', 'integer': 123}) +        assert not serializer.is_valid() +        assert serializer.errors == {'non_field_errors': ['Non field error']} + +    def test_field_error_validate_method(self): +        class ExampleSerializer(serializers.Serializer): +            char = serializers.CharField() +            integer = serializers.IntegerField() + +            def validate(self, attrs): +                raise serializers.ValidationError({'char': 'Field error'}) + +        serializer = ExampleSerializer(data={'char': 'abc', 'integer': 123}) +        assert not serializer.is_valid() +        assert serializer.errors == {'char': ['Field error']} + + +class TestBaseSerializer: +    def setup(self): +        class ExampleSerializer(serializers.BaseSerializer): +            def to_representation(self, obj): +                return { +                    'id': obj['id'], +                    'email': obj['name'] + '@' + obj['domain'] +                } + +            def to_internal_value(self, data): +                name, domain = str(data['email']).split('@') +                return { +                    'id': int(data['id']), +                    'name': name, +                    'domain': domain, +                } + +        self.Serializer = ExampleSerializer + +    def test_serialize_instance(self): +        instance = {'id': 1, 'name': 'tom', 'domain': 'example.com'} +        serializer = self.Serializer(instance) +        assert serializer.data == {'id': 1, 'email': 'tom@example.com'} + +    def test_serialize_list(self): +        instances = [ +            {'id': 1, 'name': 'tom', 'domain': 'example.com'}, +            {'id': 2, 'name': 'ann', 'domain': 'example.com'}, +        ] +        serializer = self.Serializer(instances, many=True) +        assert serializer.data == [ +            {'id': 1, 'email': 'tom@example.com'}, +            {'id': 2, 'email': 'ann@example.com'} +        ] + +    def test_validate_data(self): +        data = {'id': 1, 'email': 'tom@example.com'} +        serializer = self.Serializer(data=data) +        assert serializer.is_valid() +        assert serializer.validated_data == { +            'id': 1, +            'name': 'tom', +            'domain': 'example.com' +        } + +    def test_validate_list(self): +        data = [ +            {'id': 1, 'email': 'tom@example.com'}, +            {'id': 2, 'email': 'ann@example.com'}, +        ] +        serializer = self.Serializer(data=data, many=True) +        assert serializer.is_valid() +        assert serializer.validated_data == [ +            {'id': 1, 'name': 'tom', 'domain': 'example.com'}, +            {'id': 2, 'name': 'ann', 'domain': 'example.com'} +        ] + + +class TestStarredSource: +    """ +    Tests for `source='*'` argument, which is used for nested representations. + +    For example: + +        nested_field = NestedField(source='*') +    """ +    data = { +        'nested1': {'a': 1, 'b': 2}, +        'nested2': {'c': 3, 'd': 4} +    } + +    def setup(self): +        class NestedSerializer1(serializers.Serializer): +            a = serializers.IntegerField() +            b = serializers.IntegerField() + +        class NestedSerializer2(serializers.Serializer): +            c = serializers.IntegerField() +            d = serializers.IntegerField() + +        class TestSerializer(serializers.Serializer): +            nested1 = NestedSerializer1(source='*') +            nested2 = NestedSerializer2(source='*') + +        self.Serializer = TestSerializer + +    def test_nested_validate(self): +        """ +        A nested representation is validated into a flat internal object. +        """ +        serializer = self.Serializer(data=self.data) +        assert serializer.is_valid() +        assert serializer.validated_data == { +            'a': 1, +            'b': 2, +            'c': 3, +            'd': 4 +        } + +    def test_nested_serialize(self): +        """ +        An object can be serialized into a nested representation. +        """ +        instance = {'a': 1, 'b': 2, 'c': 3, 'd': 4} +        serializer = self.Serializer(instance) +        assert serializer.data == self.data + + +class TestIncorrectlyConfigured: +    def test_incorrect_field_name(self): +        class ExampleSerializer(serializers.Serializer): +            incorrect_name = serializers.IntegerField() + +        class ExampleObject: +            def __init__(self): +                self.correct_name = 123 + +        instance = ExampleObject() +        serializer = ExampleSerializer(instance) +        with pytest.raises(AttributeError) as exc_info: +            serializer.data +        msg = str(exc_info.value) +        assert msg.startswith( +            "Got AttributeError when attempting to get a value for field `incorrect_name` on serializer `ExampleSerializer`.\n" +            "The serializer field might be named incorrectly and not match any attribute or key on the `ExampleObject` instance.\n" +            "Original exception text was:" +        ) + + +class TestUnicodeRepr: +    def test_unicode_repr(self): +        class ExampleSerializer(serializers.Serializer): +            example = serializers.CharField() + +        class ExampleObject: +            def __init__(self): +                self.example = '한국' + +            def __repr__(self): +                return unicode_repr(self.example) + +        instance = ExampleObject() +        serializer = ExampleSerializer(instance) +        repr(serializer)  # Should not error. + + +class TestNotRequiredOutput: +    def test_not_required_output_for_dict(self): +        """ +        'required=False' should allow a dictionary key to be missing in output. +        """ +        class ExampleSerializer(serializers.Serializer): +            omitted = serializers.CharField(required=False) +            included = serializers.CharField() + +        serializer = ExampleSerializer(data={'included': 'abc'}) +        serializer.is_valid() +        assert serializer.data == {'included': 'abc'} + +    def test_not_required_output_for_object(self): +        """ +        'required=False' should allow an object attribute to be missing in output. +        """ +        class ExampleSerializer(serializers.Serializer): +            omitted = serializers.CharField(required=False) +            included = serializers.CharField() + +            def create(self, validated_data): +                return MockObject(**validated_data) + +        serializer = ExampleSerializer(data={'included': 'abc'}) +        serializer.is_valid() +        serializer.save() +        assert serializer.data == {'included': 'abc'} + +    def test_default_required_output_for_dict(self): +        """ +        'default="something"' should require dictionary key. + +        We need to handle this as the field will have an implicit +        'required=False', but it should still have a value. +        """ +        class ExampleSerializer(serializers.Serializer): +            omitted = serializers.CharField(default='abc') +            included = serializers.CharField() + +        serializer = ExampleSerializer({'included': 'abc'}) +        with pytest.raises(KeyError): +            serializer.data + +    def test_default_required_output_for_object(self): +        """ +        'default="something"' should require object attribute. + +        We need to handle this as the field will have an implicit +        'required=False', but it should still have a value. +        """ +        class ExampleSerializer(serializers.Serializer): +            omitted = serializers.CharField(default='abc') +            included = serializers.CharField() + +        instance = MockObject(included='abc') +        serializer = ExampleSerializer(instance) +        with pytest.raises(AttributeError): +            serializer.data + + +class TestCacheSerializerData: +    def test_cache_serializer_data(self): +        """ +        Caching serializer data with pickle will drop the serializer info, +        but does preserve the data itself. +        """ +        class ExampleSerializer(serializers.Serializer): +            field1 = serializers.CharField() +            field2 = serializers.CharField() + +        serializer = ExampleSerializer({'field1': 'a', 'field2': 'b'}) +        pickled = pickle.dumps(serializer.data) +        data = pickle.loads(pickled) +        assert data == {'field1': 'a', 'field2': 'b'} diff --git a/tests/test_serializer_bulk_update.py b/tests/test_serializer_bulk_update.py new file mode 100644 index 00000000..bc955b2e --- /dev/null +++ b/tests/test_serializer_bulk_update.py @@ -0,0 +1,123 @@ +""" +Tests to cover bulk create and update using serializers. +""" +from __future__ import unicode_literals +from django.test import TestCase +from django.utils import six +from rest_framework import serializers + + +class BulkCreateSerializerTests(TestCase): +    """ +    Creating multiple instances using serializers. +    """ + +    def setUp(self): +        class BookSerializer(serializers.Serializer): +            id = serializers.IntegerField() +            title = serializers.CharField(max_length=100) +            author = serializers.CharField(max_length=100) + +        self.BookSerializer = BookSerializer + +    def test_bulk_create_success(self): +        """ +        Correct bulk update serialization should return the input data. +        """ + +        data = [ +            { +                'id': 0, +                'title': 'The electric kool-aid acid test', +                'author': 'Tom Wolfe' +            }, { +                'id': 1, +                'title': 'If this is a man', +                'author': 'Primo Levi' +            }, { +                'id': 2, +                'title': 'The wind-up bird chronicle', +                'author': 'Haruki Murakami' +            } +        ] + +        serializer = self.BookSerializer(data=data, many=True) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(serializer.validated_data, data) + +    def test_bulk_create_errors(self): +        """ +        Incorrect bulk create serialization should return errors. +        """ + +        data = [ +            { +                'id': 0, +                'title': 'The electric kool-aid acid test', +                'author': 'Tom Wolfe' +            }, { +                'id': 1, +                'title': 'If this is a man', +                'author': 'Primo Levi' +            }, { +                'id': 'foo', +                'title': 'The wind-up bird chronicle', +                'author': 'Haruki Murakami' +            } +        ] +        expected_errors = [ +            {}, +            {}, +            {'id': ['A valid integer is required.']} +        ] + +        serializer = self.BookSerializer(data=data, many=True) +        self.assertEqual(serializer.is_valid(), False) +        self.assertEqual(serializer.errors, expected_errors) + +    def test_invalid_list_datatype(self): +        """ +        Data containing list of incorrect data type should return errors. +        """ +        data = ['foo', 'bar', 'baz'] +        serializer = self.BookSerializer(data=data, many=True) +        self.assertEqual(serializer.is_valid(), False) + +        text_type_string = six.text_type.__name__ +        message = 'Invalid data. Expected a dictionary, but got %s.' % text_type_string +        expected_errors = [ +            {'non_field_errors': [message]}, +            {'non_field_errors': [message]}, +            {'non_field_errors': [message]} +        ] + +        self.assertEqual(serializer.errors, expected_errors) + +    def test_invalid_single_datatype(self): +        """ +        Data containing a single incorrect data type should return errors. +        """ +        data = 123 +        serializer = self.BookSerializer(data=data, many=True) +        self.assertEqual(serializer.is_valid(), False) + +        expected_errors = {'non_field_errors': ['Expected a list of items but got type "int".']} + +        self.assertEqual(serializer.errors, expected_errors) + +    def test_invalid_single_object(self): +        """ +        Data containing only a single object, instead of a list of objects +        should return errors. +        """ +        data = { +            'id': 0, +            'title': 'The electric kool-aid acid test', +            'author': 'Tom Wolfe' +        } +        serializer = self.BookSerializer(data=data, many=True) +        self.assertEqual(serializer.is_valid(), False) + +        expected_errors = {'non_field_errors': ['Expected a list of items but got type "dict".']} + +        self.assertEqual(serializer.errors, expected_errors) diff --git a/tests/test_serializer_lists.py b/tests/test_serializer_lists.py new file mode 100644 index 00000000..35b68ae7 --- /dev/null +++ b/tests/test_serializer_lists.py @@ -0,0 +1,290 @@ +from rest_framework import serializers +from django.utils.datastructures import MultiValueDict + + +class BasicObject: +    """ +    A mock object for testing serializer save behavior. +    """ +    def __init__(self, **kwargs): +        self._data = kwargs +        for key, value in kwargs.items(): +            setattr(self, key, value) + +    def __eq__(self, other): +        if self._data.keys() != other._data.keys(): +            return False +        for key in self._data.keys(): +            if self._data[key] != other._data[key]: +                return False +        return True + + +class TestListSerializer: +    """ +    Tests for using a ListSerializer as a top-level serializer. +    Note that this is in contrast to using ListSerializer as a field. +    """ + +    def setup(self): +        class IntegerListSerializer(serializers.ListSerializer): +            child = serializers.IntegerField() +        self.Serializer = IntegerListSerializer + +    def test_validate(self): +        """ +        Validating a list of items should return a list of validated items. +        """ +        input_data = ["123", "456"] +        expected_output = [123, 456] +        serializer = self.Serializer(data=input_data) +        assert serializer.is_valid() +        assert serializer.validated_data == expected_output + +    def test_validate_html_input(self): +        """ +        HTML input should be able to mock list structures using [x] style ids. +        """ +        input_data = MultiValueDict({"[0]": ["123"], "[1]": ["456"]}) +        expected_output = [123, 456] +        serializer = self.Serializer(data=input_data) +        assert serializer.is_valid() +        assert serializer.validated_data == expected_output + + +class TestListSerializerContainingNestedSerializer: +    """ +    Tests for using a ListSerializer containing another serializer. +    """ + +    def setup(self): +        class TestSerializer(serializers.Serializer): +            integer = serializers.IntegerField() +            boolean = serializers.BooleanField() + +            def create(self, validated_data): +                return BasicObject(**validated_data) + +        class ObjectListSerializer(serializers.ListSerializer): +            child = TestSerializer() + +        self.Serializer = ObjectListSerializer + +    def test_validate(self): +        """ +        Validating a list of dictionaries should return a list of +        validated dictionaries. +        """ +        input_data = [ +            {"integer": "123", "boolean": "true"}, +            {"integer": "456", "boolean": "false"} +        ] +        expected_output = [ +            {"integer": 123, "boolean": True}, +            {"integer": 456, "boolean": False} +        ] +        serializer = self.Serializer(data=input_data) +        assert serializer.is_valid() +        assert serializer.validated_data == expected_output + +    def test_create(self): +        """ +        Creating from a list of dictionaries should return a list of objects. +        """ +        input_data = [ +            {"integer": "123", "boolean": "true"}, +            {"integer": "456", "boolean": "false"} +        ] +        expected_output = [ +            BasicObject(integer=123, boolean=True), +            BasicObject(integer=456, boolean=False), +        ] +        serializer = self.Serializer(data=input_data) +        assert serializer.is_valid() +        assert serializer.save() == expected_output + +    def test_serialize(self): +        """ +        Serialization of a list of objects should return a list of dictionaries. +        """ +        input_objects = [ +            BasicObject(integer=123, boolean=True), +            BasicObject(integer=456, boolean=False) +        ] +        expected_output = [ +            {"integer": 123, "boolean": True}, +            {"integer": 456, "boolean": False} +        ] +        serializer = self.Serializer(input_objects) +        assert serializer.data == expected_output + +    def test_validate_html_input(self): +        """ +        HTML input should be able to mock list structures using [x] +        style prefixes. +        """ +        input_data = MultiValueDict({ +            "[0]integer": ["123"], +            "[0]boolean": ["true"], +            "[1]integer": ["456"], +            "[1]boolean": ["false"] +        }) +        expected_output = [ +            {"integer": 123, "boolean": True}, +            {"integer": 456, "boolean": False} +        ] +        serializer = self.Serializer(data=input_data) +        assert serializer.is_valid() +        assert serializer.validated_data == expected_output + + +class TestNestedListSerializer: +    """ +    Tests for using a ListSerializer as a field. +    """ + +    def setup(self): +        class TestSerializer(serializers.Serializer): +            integers = serializers.ListSerializer(child=serializers.IntegerField()) +            booleans = serializers.ListSerializer(child=serializers.BooleanField()) + +            def create(self, validated_data): +                return BasicObject(**validated_data) + +        self.Serializer = TestSerializer + +    def test_validate(self): +        """ +        Validating a list of items should return a list of validated items. +        """ +        input_data = { +            "integers": ["123", "456"], +            "booleans": ["true", "false"] +        } +        expected_output = { +            "integers": [123, 456], +            "booleans": [True, False] +        } +        serializer = self.Serializer(data=input_data) +        assert serializer.is_valid() +        assert serializer.validated_data == expected_output + +    def test_create(self): +        """ +        Creation with a list of items return an object with an attribute that +        is a list of items. +        """ +        input_data = { +            "integers": ["123", "456"], +            "booleans": ["true", "false"] +        } +        expected_output = BasicObject( +            integers=[123, 456], +            booleans=[True, False] +        ) +        serializer = self.Serializer(data=input_data) +        assert serializer.is_valid() +        assert serializer.save() == expected_output + +    def test_serialize(self): +        """ +        Serialization of a list of items should return a list of items. +        """ +        input_object = BasicObject( +            integers=[123, 456], +            booleans=[True, False] +        ) +        expected_output = { +            "integers": [123, 456], +            "booleans": [True, False] +        } +        serializer = self.Serializer(input_object) +        assert serializer.data == expected_output + +    def test_validate_html_input(self): +        """ +        HTML input should be able to mock list structures using [x] +        style prefixes. +        """ +        input_data = MultiValueDict({ +            "integers[0]": ["123"], +            "integers[1]": ["456"], +            "booleans[0]": ["true"], +            "booleans[1]": ["false"] +        }) +        expected_output = { +            "integers": [123, 456], +            "booleans": [True, False] +        } +        serializer = self.Serializer(data=input_data) +        assert serializer.is_valid() +        assert serializer.validated_data == expected_output + + +class TestNestedListOfListsSerializer: +    def setup(self): +        class TestSerializer(serializers.Serializer): +            integers = serializers.ListSerializer( +                child=serializers.ListSerializer( +                    child=serializers.IntegerField() +                ) +            ) +            booleans = serializers.ListSerializer( +                child=serializers.ListSerializer( +                    child=serializers.BooleanField() +                ) +            ) + +        self.Serializer = TestSerializer + +    def test_validate(self): +        input_data = { +            'integers': [['123', '456'], ['789', '0']], +            'booleans': [['true', 'true'], ['false', 'true']] +        } +        expected_output = { +            "integers": [[123, 456], [789, 0]], +            "booleans": [[True, True], [False, True]] +        } +        serializer = self.Serializer(data=input_data) +        assert serializer.is_valid() +        assert serializer.validated_data == expected_output + +    def test_validate_html_input(self): +        """ +        HTML input should be able to mock lists of lists using [x][y] +        style prefixes. +        """ +        input_data = MultiValueDict({ +            "integers[0][0]": ["123"], +            "integers[0][1]": ["456"], +            "integers[1][0]": ["789"], +            "integers[1][1]": ["000"], +            "booleans[0][0]": ["true"], +            "booleans[0][1]": ["true"], +            "booleans[1][0]": ["false"], +            "booleans[1][1]": ["true"] +        }) +        expected_output = { +            "integers": [[123, 456], [789, 0]], +            "booleans": [[True, True], [False, True]] +        } +        serializer = self.Serializer(data=input_data) +        assert serializer.is_valid() +        assert serializer.validated_data == expected_output + + +class TestListSerializerClass: +    """Tests for a custom list_serializer_class.""" +    def test_list_serializer_class_validate(self): +        class CustomListSerializer(serializers.ListSerializer): +            def validate(self, attrs): +                raise serializers.ValidationError('Non field error') + +        class TestSerializer(serializers.Serializer): +            class Meta: +                list_serializer_class = CustomListSerializer + +        serializer = TestSerializer(data=[], many=True) +        assert not serializer.is_valid() +        assert serializer.errors == {'non_field_errors': ['Non field error']} diff --git a/tests/test_serializer_nested.py b/tests/test_serializer_nested.py new file mode 100644 index 00000000..f5e4b26a --- /dev/null +++ b/tests/test_serializer_nested.py @@ -0,0 +1,40 @@ +from rest_framework import serializers + + +class TestNestedSerializer: +    def setup(self): +        class NestedSerializer(serializers.Serializer): +            one = serializers.IntegerField(max_value=10) +            two = serializers.IntegerField(max_value=10) + +        class TestSerializer(serializers.Serializer): +            nested = NestedSerializer() + +        self.Serializer = TestSerializer + +    def test_nested_validate(self): +        input_data = { +            'nested': { +                'one': '1', +                'two': '2', +            } +        } +        expected_data = { +            'nested': { +                'one': 1, +                'two': 2, +            } +        } +        serializer = self.Serializer(data=input_data) +        assert serializer.is_valid() +        assert serializer.validated_data == expected_data + +    def test_nested_serialize_empty(self): +        expected_data = { +            'nested': { +                'one': None, +                'two': None +            } +        } +        serializer = self.Serializer() +        assert serializer.data == expected_data diff --git a/tests/test_settings.py b/tests/test_settings.py new file mode 100644 index 00000000..f2ff4ca1 --- /dev/null +++ b/tests/test_settings.py @@ -0,0 +1,17 @@ +from __future__ import unicode_literals +from django.test import TestCase +from rest_framework.settings import APISettings + + +class TestSettings(TestCase): +    def test_import_error_message_maintained(self): +        """ +        Make sure import errors are captured and raised sensibly. +        """ +        settings = APISettings({ +            'DEFAULT_RENDERER_CLASSES': [ +                'tests.invalid_module.InvalidClassName' +            ] +        }) +        with self.assertRaises(ImportError): +            settings.DEFAULT_RENDERER_CLASSES diff --git a/tests/test_status.py b/tests/test_status.py new file mode 100644 index 00000000..721a6e30 --- /dev/null +++ b/tests/test_status.py @@ -0,0 +1,33 @@ +from __future__ import unicode_literals +from django.test import TestCase +from rest_framework.status import ( +    is_informational, is_success, is_redirect, is_client_error, is_server_error +) + + +class TestStatus(TestCase): +    def test_status_categories(self): +        self.assertFalse(is_informational(99)) +        self.assertTrue(is_informational(100)) +        self.assertTrue(is_informational(199)) +        self.assertFalse(is_informational(200)) + +        self.assertFalse(is_success(199)) +        self.assertTrue(is_success(200)) +        self.assertTrue(is_success(299)) +        self.assertFalse(is_success(300)) + +        self.assertFalse(is_redirect(299)) +        self.assertTrue(is_redirect(300)) +        self.assertTrue(is_redirect(399)) +        self.assertFalse(is_redirect(400)) + +        self.assertFalse(is_client_error(399)) +        self.assertTrue(is_client_error(400)) +        self.assertTrue(is_client_error(499)) +        self.assertFalse(is_client_error(500)) + +        self.assertFalse(is_server_error(499)) +        self.assertTrue(is_server_error(500)) +        self.assertTrue(is_server_error(599)) +        self.assertFalse(is_server_error(600)) diff --git a/tests/test_templatetags.py b/tests/test_templatetags.py new file mode 100644 index 00000000..0cee91f1 --- /dev/null +++ b/tests/test_templatetags.py @@ -0,0 +1,75 @@ +# encoding: utf-8 +from __future__ import unicode_literals +from django.test import TestCase +from rest_framework.test import APIRequestFactory +from rest_framework.templatetags.rest_framework import add_query_param, urlize_quoted_links + + +factory = APIRequestFactory() + + +class TemplateTagTests(TestCase): + +    def test_add_query_param_with_non_latin_charactor(self): +        # Ensure we don't double-escape non-latin characters +        # that are present in the querystring. +        # See #1314. +        request = factory.get("/", {'q': '查询'}) +        json_url = add_query_param(request, "format", "json") +        self.assertIn("q=%E6%9F%A5%E8%AF%A2", json_url) +        self.assertIn("format=json", json_url) + + +class Issue1386Tests(TestCase): +    """ +    Covers #1386 +    """ + +    def test_issue_1386(self): +        """ +        Test function urlize_quoted_links with different args +        """ +        correct_urls = [ +            "asdf.com", +            "asdf.net", +            "www.as_df.org", +            "as.d8f.ghj8.gov", +        ] +        for i in correct_urls: +            res = urlize_quoted_links(i) +            self.assertNotEqual(res, i) +            self.assertIn(i, res) + +        incorrect_urls = [ +            "mailto://asdf@fdf.com", +            "asdf.netnet", +        ] +        for i in incorrect_urls: +            res = urlize_quoted_links(i) +            self.assertEqual(i, res) + +        # example from issue #1386, this shouldn't raise an exception +        urlize_quoted_links("asdf:[/p]zxcv.com") + + +class URLizerTests(TestCase): +    """ +    Test if JSON URLs are transformed into links well +    """ +    def _urlize_dict_check(self, data): +        """ +        For all items in dict test assert that the value is urlized key +        """ +        for original, urlized in data.items(): +            assert urlize_quoted_links(original, nofollow=False) == urlized + +    def test_json_with_url(self): +        """ +        Test if JSON URLs are transformed into links well +        """ +        data = {} +        data['"url": "http://api/users/1/", '] = \ +            '"url": "<a href="http://api/users/1/">http://api/users/1/</a>", ' +        data['"foo_set": [\n    "http://api/foos/1/"\n], '] = \ +            '"foo_set": [\n    "<a href="http://api/foos/1/">http://api/foos/1/</a>"\n], ' +        self._urlize_dict_check(data) diff --git a/tests/test_testing.py b/tests/test_testing.py new file mode 100644 index 00000000..87d2b61f --- /dev/null +++ b/tests/test_testing.py @@ -0,0 +1,234 @@ +# encoding: utf-8 +from __future__ import unicode_literals +from django.conf.urls import patterns, url +from django.contrib.auth.models import User +from django.shortcuts import redirect +from django.test import TestCase +from rest_framework.decorators import api_view +from rest_framework.response import Response +from rest_framework.test import APIClient, APIRequestFactory, force_authenticate +from io import BytesIO + + +@api_view(['GET', 'POST']) +def view(request): +    return Response({ +        'auth': request.META.get('HTTP_AUTHORIZATION', b''), +        'user': request.user.username +    }) + + +@api_view(['GET', 'POST']) +def session_view(request): +    active_session = request.session.get('active_session', False) +    request.session['active_session'] = True +    return Response({ +        'active_session': active_session +    }) + + +@api_view(['GET', 'POST', 'PUT', 'PATCH', 'DELETE', 'OPTIONS']) +def redirect_view(request): +    return redirect('/view/') + + +urlpatterns = patterns( +    '', +    url(r'^view/$', view), +    url(r'^session-view/$', session_view), +    url(r'^redirect-view/$', redirect_view), +) + + +class TestAPITestClient(TestCase): +    urls = 'tests.test_testing' + +    def setUp(self): +        self.client = APIClient() + +    def test_credentials(self): +        """ +        Setting `.credentials()` adds the required headers to each request. +        """ +        self.client.credentials(HTTP_AUTHORIZATION='example') +        for _ in range(0, 3): +            response = self.client.get('/view/') +            self.assertEqual(response.data['auth'], 'example') + +    def test_force_authenticate(self): +        """ +        Setting `.force_authenticate()` forcibly authenticates each request. +        """ +        user = User.objects.create_user('example', 'example@example.com') +        self.client.force_authenticate(user) +        response = self.client.get('/view/') +        self.assertEqual(response.data['user'], 'example') + +    def test_force_authenticate_with_sessions(self): +        """ +        Setting `.force_authenticate()` forcibly authenticates each request. +        """ +        user = User.objects.create_user('example', 'example@example.com') +        self.client.force_authenticate(user) + +        # First request does not yet have an active session +        response = self.client.get('/session-view/') +        self.assertEqual(response.data['active_session'], False) + +        # Subsequant requests have an active session +        response = self.client.get('/session-view/') +        self.assertEqual(response.data['active_session'], True) + +        # Force authenticating as `None` should also logout the user session. +        self.client.force_authenticate(None) +        response = self.client.get('/session-view/') +        self.assertEqual(response.data['active_session'], False) + +    def test_csrf_exempt_by_default(self): +        """ +        By default, the test client is CSRF exempt. +        """ +        User.objects.create_user('example', 'example@example.com', 'password') +        self.client.login(username='example', password='password') +        response = self.client.post('/view/') +        self.assertEqual(response.status_code, 200) + +    def test_explicitly_enforce_csrf_checks(self): +        """ +        The test client can enforce CSRF checks. +        """ +        client = APIClient(enforce_csrf_checks=True) +        User.objects.create_user('example', 'example@example.com', 'password') +        client.login(username='example', password='password') +        response = client.post('/view/') +        expected = {'detail': 'CSRF Failed: CSRF cookie not set.'} +        self.assertEqual(response.status_code, 403) +        self.assertEqual(response.data, expected) + +    def test_can_logout(self): +        """ +        `logout()` resets stored credentials +        """ +        self.client.credentials(HTTP_AUTHORIZATION='example') +        response = self.client.get('/view/') +        self.assertEqual(response.data['auth'], 'example') +        self.client.logout() +        response = self.client.get('/view/') +        self.assertEqual(response.data['auth'], b'') + +    def test_logout_resets_force_authenticate(self): +        """ +        `logout()` resets any `force_authenticate` +        """ +        user = User.objects.create_user('example', 'example@example.com', 'password') +        self.client.force_authenticate(user) +        response = self.client.get('/view/') +        self.assertEqual(response.data['user'], 'example') +        self.client.logout() +        response = self.client.get('/view/') +        self.assertEqual(response.data['user'], '') + +    def test_follow_redirect(self): +        """ +        Follow redirect by setting follow argument. +        """ +        response = self.client.get('/redirect-view/') +        self.assertEqual(response.status_code, 302) +        response = self.client.get('/redirect-view/', follow=True) +        self.assertIsNotNone(response.redirect_chain) +        self.assertEqual(response.status_code, 200) + +        response = self.client.post('/redirect-view/') +        self.assertEqual(response.status_code, 302) +        response = self.client.post('/redirect-view/', follow=True) +        self.assertIsNotNone(response.redirect_chain) +        self.assertEqual(response.status_code, 200) + +        response = self.client.put('/redirect-view/') +        self.assertEqual(response.status_code, 302) +        response = self.client.put('/redirect-view/', follow=True) +        self.assertIsNotNone(response.redirect_chain) +        self.assertEqual(response.status_code, 200) + +        response = self.client.patch('/redirect-view/') +        self.assertEqual(response.status_code, 302) +        response = self.client.patch('/redirect-view/', follow=True) +        self.assertIsNotNone(response.redirect_chain) +        self.assertEqual(response.status_code, 200) + +        response = self.client.delete('/redirect-view/') +        self.assertEqual(response.status_code, 302) +        response = self.client.delete('/redirect-view/', follow=True) +        self.assertIsNotNone(response.redirect_chain) +        self.assertEqual(response.status_code, 200) + +        response = self.client.options('/redirect-view/') +        self.assertEqual(response.status_code, 302) +        response = self.client.options('/redirect-view/', follow=True) +        self.assertIsNotNone(response.redirect_chain) +        self.assertEqual(response.status_code, 200) + + +class TestAPIRequestFactory(TestCase): +    def test_csrf_exempt_by_default(self): +        """ +        By default, the test client is CSRF exempt. +        """ +        user = User.objects.create_user('example', 'example@example.com', 'password') +        factory = APIRequestFactory() +        request = factory.post('/view/') +        request.user = user +        response = view(request) +        self.assertEqual(response.status_code, 200) + +    def test_explicitly_enforce_csrf_checks(self): +        """ +        The test client can enforce CSRF checks. +        """ +        user = User.objects.create_user('example', 'example@example.com', 'password') +        factory = APIRequestFactory(enforce_csrf_checks=True) +        request = factory.post('/view/') +        request.user = user +        response = view(request) +        expected = {'detail': 'CSRF Failed: CSRF cookie not set.'} +        self.assertEqual(response.status_code, 403) +        self.assertEqual(response.data, expected) + +    def test_invalid_format(self): +        """ +        Attempting to use a format that is not configured will raise an +        assertion error. +        """ +        factory = APIRequestFactory() +        self.assertRaises( +            AssertionError, factory.post, +            path='/view/', data={'example': 1}, format='xml' +        ) + +    def test_force_authenticate(self): +        """ +        Setting `force_authenticate()` forcibly authenticates the request. +        """ +        user = User.objects.create_user('example', 'example@example.com') +        factory = APIRequestFactory() +        request = factory.get('/view') +        force_authenticate(request, user=user) +        response = view(request) +        self.assertEqual(response.data['user'], 'example') + +    def test_upload_file(self): +        # This is a 1x1 black png +        simple_png = BytesIO(b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x06\x00\x00\x00\x1f\x15\xc4\x89\x00\x00\x00\rIDATx\x9cc````\x00\x00\x00\x05\x00\x01\xa5\xf6E@\x00\x00\x00\x00IEND\xaeB`\x82') +        simple_png.name = 'test.png' +        factory = APIRequestFactory() +        factory.post('/', data={'image': simple_png}) + +    def test_request_factory_url_arguments(self): +        """ +        This is a non regression test against #1461 +        """ +        factory = APIRequestFactory() +        request = factory.get('/view/?demo=test') +        self.assertEqual(dict(request.GET), {'demo': ['test']}) +        request = factory.get('/view/', {'demo': 'test'}) +        self.assertEqual(dict(request.GET), {'demo': ['test']}) diff --git a/tests/test_throttling.py b/tests/test_throttling.py new file mode 100644 index 00000000..50a53b3e --- /dev/null +++ b/tests/test_throttling.py @@ -0,0 +1,353 @@ +""" +Tests for the throttling implementations in the permissions module. +""" +from __future__ import unicode_literals +from django.test import TestCase +from django.contrib.auth.models import User +from django.core.cache import cache +from rest_framework.settings import api_settings +from rest_framework.test import APIRequestFactory +from rest_framework.views import APIView +from rest_framework.throttling import BaseThrottle, UserRateThrottle, ScopedRateThrottle +from rest_framework.response import Response + + +class User3SecRateThrottle(UserRateThrottle): +    rate = '3/sec' +    scope = 'seconds' + + +class User3MinRateThrottle(UserRateThrottle): +    rate = '3/min' +    scope = 'minutes' + + +class NonTimeThrottle(BaseThrottle): +    def allow_request(self, request, view): +        if not hasattr(self.__class__, 'called'): +            self.__class__.called = True +            return True +        return False + + +class MockView(APIView): +    throttle_classes = (User3SecRateThrottle,) + +    def get(self, request): +        return Response('foo') + + +class MockView_MinuteThrottling(APIView): +    throttle_classes = (User3MinRateThrottle,) + +    def get(self, request): +        return Response('foo') + + +class MockView_NonTimeThrottling(APIView): +    throttle_classes = (NonTimeThrottle,) + +    def get(self, request): +        return Response('foo') + + +class ThrottlingTests(TestCase): +    def setUp(self): +        """ +        Reset the cache so that no throttles will be active +        """ +        cache.clear() +        self.factory = APIRequestFactory() + +    def test_requests_are_throttled(self): +        """ +        Ensure request rate is limited +        """ +        request = self.factory.get('/') +        for dummy in range(4): +            response = MockView.as_view()(request) +        self.assertEqual(429, response.status_code) + +    def set_throttle_timer(self, view, value): +        """ +        Explicitly set the timer, overriding time.time() +        """ +        view.throttle_classes[0].timer = lambda self: value + +    def test_request_throttling_expires(self): +        """ +        Ensure request rate is limited for a limited duration only +        """ +        self.set_throttle_timer(MockView, 0) + +        request = self.factory.get('/') +        for dummy in range(4): +            response = MockView.as_view()(request) +        self.assertEqual(429, response.status_code) + +        # Advance the timer by one second +        self.set_throttle_timer(MockView, 1) + +        response = MockView.as_view()(request) +        self.assertEqual(200, response.status_code) + +    def ensure_is_throttled(self, view, expect): +        request = self.factory.get('/') +        request.user = User.objects.create(username='a') +        for dummy in range(3): +            view.as_view()(request) +        request.user = User.objects.create(username='b') +        response = view.as_view()(request) +        self.assertEqual(expect, response.status_code) + +    def test_request_throttling_is_per_user(self): +        """ +        Ensure request rate is only limited per user, not globally for +        PerUserThrottles +        """ +        self.ensure_is_throttled(MockView, 200) + +    def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers): +        """ +        Ensure the response returns an Retry-After field with status and next attributes +        set properly. +        """ +        request = self.factory.get('/') +        for timer, expect in expected_headers: +            self.set_throttle_timer(view, timer) +            response = view.as_view()(request) +            if expect is not None: +                self.assertEqual(response['Retry-After'], expect) +            else: +                self.assertFalse('Retry-After' in response) + +    def test_seconds_fields(self): +        """ +        Ensure for second based throttles. +        """ +        self.ensure_response_header_contains_proper_throttle_field( +            MockView, ( +                (0, None), +                (0, None), +                (0, None), +                (0, '1') +            ) +        ) + +    def test_minutes_fields(self): +        """ +        Ensure for minute based throttles. +        """ +        self.ensure_response_header_contains_proper_throttle_field( +            MockView_MinuteThrottling, ( +                (0, None), +                (0, None), +                (0, None), +                (0, '60') +            ) +        ) + +    def test_next_rate_remains_constant_if_followed(self): +        """ +        If a client follows the recommended next request rate, +        the throttling rate should stay constant. +        """ +        self.ensure_response_header_contains_proper_throttle_field( +            MockView_MinuteThrottling, ( +                (0, None), +                (20, None), +                (40, None), +                (60, None), +                (80, None) +            ) +        ) + +    def test_non_time_throttle(self): +        """ +        Ensure for second based throttles. +        """ +        request = self.factory.get('/') + +        self.assertFalse(hasattr(MockView_NonTimeThrottling.throttle_classes[0], 'called')) + +        response = MockView_NonTimeThrottling.as_view()(request) +        self.assertFalse('Retry-After' in response) + +        self.assertTrue(MockView_NonTimeThrottling.throttle_classes[0].called) + +        response = MockView_NonTimeThrottling.as_view()(request) +        self.assertFalse('Retry-After' in response) + + +class ScopedRateThrottleTests(TestCase): +    """ +    Tests for ScopedRateThrottle. +    """ + +    def setUp(self): +        class XYScopedRateThrottle(ScopedRateThrottle): +            TIMER_SECONDS = 0 +            THROTTLE_RATES = {'x': '3/min', 'y': '1/min'} + +            def timer(self): +                return self.TIMER_SECONDS + +        class XView(APIView): +            throttle_classes = (XYScopedRateThrottle,) +            throttle_scope = 'x' + +            def get(self, request): +                return Response('x') + +        class YView(APIView): +            throttle_classes = (XYScopedRateThrottle,) +            throttle_scope = 'y' + +            def get(self, request): +                return Response('y') + +        class UnscopedView(APIView): +            throttle_classes = (XYScopedRateThrottle,) + +            def get(self, request): +                return Response('y') + +        self.throttle_class = XYScopedRateThrottle +        self.factory = APIRequestFactory() +        self.x_view = XView.as_view() +        self.y_view = YView.as_view() +        self.unscoped_view = UnscopedView.as_view() + +    def increment_timer(self, seconds=1): +        self.throttle_class.TIMER_SECONDS += seconds + +    def test_scoped_rate_throttle(self): +        request = self.factory.get('/') + +        # Should be able to hit x view 3 times per minute. +        response = self.x_view(request) +        self.assertEqual(200, response.status_code) + +        self.increment_timer() +        response = self.x_view(request) +        self.assertEqual(200, response.status_code) + +        self.increment_timer() +        response = self.x_view(request) +        self.assertEqual(200, response.status_code) + +        self.increment_timer() +        response = self.x_view(request) +        self.assertEqual(429, response.status_code) + +        # Should be able to hit y view 1 time per minute. +        self.increment_timer() +        response = self.y_view(request) +        self.assertEqual(200, response.status_code) + +        self.increment_timer() +        response = self.y_view(request) +        self.assertEqual(429, response.status_code) + +        # Ensure throttles properly reset by advancing the rest of the minute +        self.increment_timer(55) + +        # Should still be able to hit x view 3 times per minute. +        response = self.x_view(request) +        self.assertEqual(200, response.status_code) + +        self.increment_timer() +        response = self.x_view(request) +        self.assertEqual(200, response.status_code) + +        self.increment_timer() +        response = self.x_view(request) +        self.assertEqual(200, response.status_code) + +        self.increment_timer() +        response = self.x_view(request) +        self.assertEqual(429, response.status_code) + +        # Should still be able to hit y view 1 time per minute. +        self.increment_timer() +        response = self.y_view(request) +        self.assertEqual(200, response.status_code) + +        self.increment_timer() +        response = self.y_view(request) +        self.assertEqual(429, response.status_code) + +    def test_unscoped_view_not_throttled(self): +        request = self.factory.get('/') + +        for idx in range(10): +            self.increment_timer() +            response = self.unscoped_view(request) +            self.assertEqual(200, response.status_code) + + +class XffTestingBase(TestCase): +    def setUp(self): + +        class Throttle(ScopedRateThrottle): +            THROTTLE_RATES = {'test_limit': '1/day'} +            TIMER_SECONDS = 0 + +            def timer(self): +                return self.TIMER_SECONDS + +        class View(APIView): +            throttle_classes = (Throttle,) +            throttle_scope = 'test_limit' + +            def get(self, request): +                return Response('test_limit') + +        cache.clear() +        self.throttle = Throttle() +        self.view = View.as_view() +        self.request = APIRequestFactory().get('/some_uri') +        self.request.META['REMOTE_ADDR'] = '3.3.3.3' +        self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 1.1.1.1, 2.2.2.2' + +    def config_proxy(self, num_proxies): +        setattr(api_settings, 'NUM_PROXIES', num_proxies) + + +class IdWithXffBasicTests(XffTestingBase): +    def test_accepts_request_under_limit(self): +        self.config_proxy(0) +        self.assertEqual(200, self.view(self.request).status_code) + +    def test_denies_request_over_limit(self): +        self.config_proxy(0) +        self.view(self.request) +        self.assertEqual(429, self.view(self.request).status_code) + + +class XffSpoofingTests(XffTestingBase): +    def test_xff_spoofing_doesnt_change_machine_id_with_one_app_proxy(self): +        self.config_proxy(1) +        self.view(self.request) +        self.request.META['HTTP_X_FORWARDED_FOR'] = '4.4.4.4, 5.5.5.5, 2.2.2.2' +        self.assertEqual(429, self.view(self.request).status_code) + +    def test_xff_spoofing_doesnt_change_machine_id_with_two_app_proxies(self): +        self.config_proxy(2) +        self.view(self.request) +        self.request.META['HTTP_X_FORWARDED_FOR'] = '4.4.4.4, 1.1.1.1, 2.2.2.2' +        self.assertEqual(429, self.view(self.request).status_code) + + +class XffUniqueMachinesTest(XffTestingBase): +    def test_unique_clients_are_counted_independently_with_one_proxy(self): +        self.config_proxy(1) +        self.view(self.request) +        self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 1.1.1.1, 7.7.7.7' +        self.assertEqual(200, self.view(self.request).status_code) + +    def test_unique_clients_are_counted_independently_with_two_proxies(self): +        self.config_proxy(2) +        self.view(self.request) +        self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 7.7.7.7, 2.2.2.2' +        self.assertEqual(200, self.view(self.request).status_code) diff --git a/tests/test_urlpatterns.py b/tests/test_urlpatterns.py new file mode 100644 index 00000000..e0060e69 --- /dev/null +++ b/tests/test_urlpatterns.py @@ -0,0 +1,76 @@ +from __future__ import unicode_literals +from collections import namedtuple +from django.conf.urls import patterns, url, include +from django.core import urlresolvers +from django.test import TestCase +from rest_framework.test import APIRequestFactory +from rest_framework.urlpatterns import format_suffix_patterns + + +# A container class for test paths for the test case +URLTestPath = namedtuple('URLTestPath', ['path', 'args', 'kwargs']) + + +def dummy_view(request, *args, **kwargs): +    pass + + +class FormatSuffixTests(TestCase): +    """ +    Tests `format_suffix_patterns` against different URLPatterns to ensure the URLs still resolve properly, including any captured parameters. +    """ +    def _resolve_urlpatterns(self, urlpatterns, test_paths): +        factory = APIRequestFactory() +        try: +            urlpatterns = format_suffix_patterns(urlpatterns) +        except Exception: +            self.fail("Failed to apply `format_suffix_patterns` on  the supplied urlpatterns") +        resolver = urlresolvers.RegexURLResolver(r'^/', urlpatterns) +        for test_path in test_paths: +            request = factory.get(test_path.path) +            try: +                callback, callback_args, callback_kwargs = resolver.resolve(request.path_info) +            except Exception: +                self.fail("Failed to resolve URL: %s" % request.path_info) +            self.assertEqual(callback_args, test_path.args) +            self.assertEqual(callback_kwargs, test_path.kwargs) + +    def test_format_suffix(self): +        urlpatterns = patterns( +            '', +            url(r'^test$', dummy_view), +        ) +        test_paths = [ +            URLTestPath('/test', (), {}), +            URLTestPath('/test.api', (), {'format': 'api'}), +            URLTestPath('/test.asdf', (), {'format': 'asdf'}), +        ] +        self._resolve_urlpatterns(urlpatterns, test_paths) + +    def test_default_args(self): +        urlpatterns = patterns( +            '', +            url(r'^test$', dummy_view, {'foo': 'bar'}), +        ) +        test_paths = [ +            URLTestPath('/test', (), {'foo': 'bar', }), +            URLTestPath('/test.api', (), {'foo': 'bar', 'format': 'api'}), +            URLTestPath('/test.asdf', (), {'foo': 'bar', 'format': 'asdf'}), +        ] +        self._resolve_urlpatterns(urlpatterns, test_paths) + +    def test_included_urls(self): +        nested_patterns = patterns( +            '', +            url(r'^path$', dummy_view) +        ) +        urlpatterns = patterns( +            '', +            url(r'^test/', include(nested_patterns), {'foo': 'bar'}), +        ) +        test_paths = [ +            URLTestPath('/test/path', (), {'foo': 'bar', }), +            URLTestPath('/test/path.api', (), {'foo': 'bar', 'format': 'api'}), +            URLTestPath('/test/path.asdf', (), {'foo': 'bar', 'format': 'asdf'}), +        ] +        self._resolve_urlpatterns(urlpatterns, test_paths) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..8c286ea4 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,166 @@ +from __future__ import unicode_literals +from django.core.exceptions import ImproperlyConfigured +from django.conf.urls import patterns, url +from django.test import TestCase +from django.utils import six +from rest_framework.utils.model_meta import _resolve_model +from rest_framework.utils.breadcrumbs import get_breadcrumbs +from rest_framework.views import APIView +from tests.models import BasicModel + +import rest_framework.utils.model_meta + + +class Root(APIView): +    pass + + +class ResourceRoot(APIView): +    pass + + +class ResourceInstance(APIView): +    pass + + +class NestedResourceRoot(APIView): +    pass + + +class NestedResourceInstance(APIView): +    pass + + +urlpatterns = patterns( +    '', +    url(r'^$', Root.as_view()), +    url(r'^resource/$', ResourceRoot.as_view()), +    url(r'^resource/(?P<key>[0-9]+)$', ResourceInstance.as_view()), +    url(r'^resource/(?P<key>[0-9]+)/$', NestedResourceRoot.as_view()), +    url(r'^resource/(?P<key>[0-9]+)/(?P<other>[A-Za-z]+)$', NestedResourceInstance.as_view()), +) + + +class BreadcrumbTests(TestCase): +    """ +    Tests the breadcrumb functionality used by the HTML renderer. +    """ +    urls = 'tests.test_utils' + +    def test_root_breadcrumbs(self): +        url = '/' +        self.assertEqual( +            get_breadcrumbs(url), +            [('Root', '/')] +        ) + +    def test_resource_root_breadcrumbs(self): +        url = '/resource/' +        self.assertEqual( +            get_breadcrumbs(url), +            [ +                ('Root', '/'), +                ('Resource Root', '/resource/') +            ] +        ) + +    def test_resource_instance_breadcrumbs(self): +        url = '/resource/123' +        self.assertEqual( +            get_breadcrumbs(url), +            [ +                ('Root', '/'), +                ('Resource Root', '/resource/'), +                ('Resource Instance', '/resource/123') +            ] +        ) + +    def test_nested_resource_breadcrumbs(self): +        url = '/resource/123/' +        self.assertEqual( +            get_breadcrumbs(url), +            [ +                ('Root', '/'), +                ('Resource Root', '/resource/'), +                ('Resource Instance', '/resource/123'), +                ('Nested Resource Root', '/resource/123/') +            ] +        ) + +    def test_nested_resource_instance_breadcrumbs(self): +        url = '/resource/123/abc' +        self.assertEqual( +            get_breadcrumbs(url), +            [ +                ('Root', '/'), +                ('Resource Root', '/resource/'), +                ('Resource Instance', '/resource/123'), +                ('Nested Resource Root', '/resource/123/'), +                ('Nested Resource Instance', '/resource/123/abc') +            ] +        ) + +    def test_broken_url_breadcrumbs_handled_gracefully(self): +        url = '/foobar' +        self.assertEqual( +            get_breadcrumbs(url), +            [('Root', '/')] +        ) + + +class ResolveModelTests(TestCase): +    """ +    `_resolve_model` should return a Django model class given the +    provided argument is a Django model class itself, or a properly +    formatted string representation of one. +    """ +    def test_resolve_django_model(self): +        resolved_model = _resolve_model(BasicModel) +        self.assertEqual(resolved_model, BasicModel) + +    def test_resolve_string_representation(self): +        resolved_model = _resolve_model('tests.BasicModel') +        self.assertEqual(resolved_model, BasicModel) + +    def test_resolve_unicode_representation(self): +        resolved_model = _resolve_model(six.text_type('tests.BasicModel')) +        self.assertEqual(resolved_model, BasicModel) + +    def test_resolve_non_django_model(self): +        with self.assertRaises(ValueError): +            _resolve_model(TestCase) + +    def test_resolve_improper_string_representation(self): +        with self.assertRaises(ValueError): +            _resolve_model('BasicModel') + + +class ResolveModelWithPatchedDjangoTests(TestCase): +    """ +    Test coverage for when Django's `get_model` returns `None`. + +    Under certain circumstances Django may return `None` with `get_model`: +    http://git.io/get-model-source + +    It usually happens with circular imports so it is important that DRF +    excepts early, otherwise fault happens downstream and is much more +    difficult to debug. + +    """ + +    def setUp(self): +        """Monkeypatch get_model.""" +        self.get_model = rest_framework.utils.model_meta.models.get_model + +        def get_model(app_label, model_name): +            return None + +        rest_framework.utils.model_meta.models.get_model = get_model + +    def tearDown(self): +        """Revert monkeypatching.""" +        rest_framework.utils.model_meta.models.get_model = self.get_model + +    def test_blows_up_if_model_does_not_resolve(self): +        with self.assertRaises(ImproperlyConfigured): +            _resolve_model('tests.BasicModel') diff --git a/tests/test_validation.py b/tests/test_validation.py new file mode 100644 index 00000000..4234efd3 --- /dev/null +++ b/tests/test_validation.py @@ -0,0 +1,183 @@ +from __future__ import unicode_literals +from django.core.validators import RegexValidator, MaxValueValidator +from django.db import models +from django.test import TestCase +from rest_framework import generics, serializers, status +from rest_framework.test import APIRequestFactory +import re + +factory = APIRequestFactory() + + +# Regression for #666 + +class ValidationModel(models.Model): +    blank_validated_field = models.CharField(max_length=255) + + +class ValidationModelSerializer(serializers.ModelSerializer): +    class Meta: +        model = ValidationModel +        fields = ('blank_validated_field',) +        read_only_fields = ('blank_validated_field',) + + +class UpdateValidationModel(generics.RetrieveUpdateDestroyAPIView): +    queryset = ValidationModel.objects.all() +    serializer_class = ValidationModelSerializer + + +# Regression for #653 + +class ShouldValidateModel(models.Model): +    should_validate_field = models.CharField(max_length=255) + + +class ShouldValidateModelSerializer(serializers.ModelSerializer): +    renamed = serializers.CharField(source='should_validate_field', required=False) + +    def validate_renamed(self, value): +        if len(value) < 3: +            raise serializers.ValidationError('Minimum 3 characters.') +        return value + +    class Meta: +        model = ShouldValidateModel +        fields = ('renamed',) + + +class TestPreSaveValidationExclusionsSerializer(TestCase): +    def test_renamed_fields_are_model_validated(self): +        """ +        Ensure fields with 'source' applied do get still get model validation. +        """ +        # We've set `required=False` on the serializer, but the model +        # does not have `blank=True`, so this serializer should not validate. +        serializer = ShouldValidateModelSerializer(data={'renamed': ''}) +        self.assertEqual(serializer.is_valid(), False) +        self.assertIn('renamed', serializer.errors) +        self.assertNotIn('should_validate_field', serializer.errors) + + +class TestCustomValidationMethods(TestCase): +    def test_custom_validation_method_is_executed(self): +        serializer = ShouldValidateModelSerializer(data={'renamed': 'fo'}) +        self.assertFalse(serializer.is_valid()) +        self.assertIn('renamed', serializer.errors) + +    def test_custom_validation_method_passing(self): +        serializer = ShouldValidateModelSerializer(data={'renamed': 'foo'}) +        self.assertTrue(serializer.is_valid()) + + +class ValidationSerializer(serializers.Serializer): +    foo = serializers.CharField() + +    def validate_foo(self, attrs, source): +        raise serializers.ValidationError("foo invalid") + +    def validate(self, attrs): +        raise serializers.ValidationError("serializer invalid") + + +class TestAvoidValidation(TestCase): +    """ +    If serializer was initialized with invalid data (None or non dict-like), it +    should avoid validation layer (validate_<field> and validate methods) +    """ +    def test_serializer_errors_has_only_invalid_data_error(self): +        serializer = ValidationSerializer(data='invalid data') +        self.assertFalse(serializer.is_valid()) +        self.assertDictEqual(serializer.errors, { +            'non_field_errors': [ +                'Invalid data. Expected a dictionary, but got %s.' % type('').__name__ +            ] +        }) + + +# regression tests for issue: 1493 + +class ValidationMaxValueValidatorModel(models.Model): +    number_value = models.PositiveIntegerField(validators=[MaxValueValidator(100)]) + + +class ValidationMaxValueValidatorModelSerializer(serializers.ModelSerializer): +    class Meta: +        model = ValidationMaxValueValidatorModel + + +class UpdateMaxValueValidationModel(generics.RetrieveUpdateDestroyAPIView): +    queryset = ValidationMaxValueValidatorModel.objects.all() +    serializer_class = ValidationMaxValueValidatorModelSerializer + + +class TestMaxValueValidatorValidation(TestCase): + +    def test_max_value_validation_serializer_success(self): +        serializer = ValidationMaxValueValidatorModelSerializer(data={'number_value': 99}) +        self.assertTrue(serializer.is_valid()) + +    def test_max_value_validation_serializer_fails(self): +        serializer = ValidationMaxValueValidatorModelSerializer(data={'number_value': 101}) +        self.assertFalse(serializer.is_valid()) +        self.assertDictEqual({'number_value': ['Ensure this value is less than or equal to 100.']}, serializer.errors) + +    def test_max_value_validation_success(self): +        obj = ValidationMaxValueValidatorModel.objects.create(number_value=100) +        request = factory.patch('/{0}'.format(obj.pk), {'number_value': 98}, format='json') +        view = UpdateMaxValueValidationModel().as_view() +        response = view(request, pk=obj.pk).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) + +    def test_max_value_validation_fail(self): +        obj = ValidationMaxValueValidatorModel.objects.create(number_value=100) +        request = factory.patch('/{0}'.format(obj.pk), {'number_value': 101}, format='json') +        view = UpdateMaxValueValidationModel().as_view() +        response = view(request, pk=obj.pk).render() +        self.assertEqual(response.content, b'{"number_value":["Ensure this value is less than or equal to 100."]}') +        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + +class TestChoiceFieldChoicesValidate(TestCase): +    CHOICES = [ +        (0, 'Small'), +        (1, 'Medium'), +        (2, 'Large'), +    ] + +    CHOICES_NESTED = [ +        ('Category', ( +            (1, 'First'), +            (2, 'Second'), +            (3, 'Third'), +        )), +        (4, 'Fourth'), +    ] + +    def test_choices(self): +        """ +        Make sure a value for choices works as expected. +        """ +        f = serializers.ChoiceField(choices=self.CHOICES) +        value = self.CHOICES[0][0] +        try: +            f.to_internal_value(value) +        except serializers.ValidationError: +            self.fail("Value %s does not validate" % str(value)) + + +class RegexSerializer(serializers.Serializer): +    pin = serializers.CharField( +        validators=[RegexValidator(regex=re.compile('^[0-9]{4,6}$'), +                                   message='A PIN is 4-6 digits')]) + +expected_repr = """ +RegexSerializer(): +    pin = CharField(validators=[<django.core.validators.RegexValidator object>]) +""".strip() + + +class TestRegexSerializer(TestCase): +    def test_regex_repr(self): +        serializer_repr = repr(RegexSerializer()) +        assert serializer_repr == expected_repr diff --git a/tests/test_validators.py b/tests/test_validators.py new file mode 100644 index 00000000..127ec6f8 --- /dev/null +++ b/tests/test_validators.py @@ -0,0 +1,347 @@ +from django.db import models +from django.test import TestCase +from rest_framework import serializers +import datetime + + +def dedent(blocktext): +    return '\n'.join([line[12:] for line in blocktext.splitlines()[1:-1]]) + + +# Tests for `UniqueValidator` +# --------------------------- + +class UniquenessModel(models.Model): +    username = models.CharField(unique=True, max_length=100) + + +class UniquenessSerializer(serializers.ModelSerializer): +    class Meta: +        model = UniquenessModel + + +class AnotherUniquenessModel(models.Model): +    code = models.IntegerField(unique=True) + + +class AnotherUniquenessSerializer(serializers.ModelSerializer): +    class Meta: +        model = AnotherUniquenessModel + + +class TestUniquenessValidation(TestCase): +    def setUp(self): +        self.instance = UniquenessModel.objects.create(username='existing') + +    def test_repr(self): +        serializer = UniquenessSerializer() +        expected = dedent(""" +            UniquenessSerializer(): +                id = IntegerField(label='ID', read_only=True) +                username = CharField(max_length=100, validators=[<UniqueValidator(queryset=UniquenessModel.objects.all())>]) +        """) +        assert repr(serializer) == expected + +    def test_is_not_unique(self): +        data = {'username': 'existing'} +        serializer = UniquenessSerializer(data=data) +        assert not serializer.is_valid() +        assert serializer.errors == {'username': ['This field must be unique.']} + +    def test_is_unique(self): +        data = {'username': 'other'} +        serializer = UniquenessSerializer(data=data) +        assert serializer.is_valid() +        assert serializer.validated_data == {'username': 'other'} + +    def test_updated_instance_excluded(self): +        data = {'username': 'existing'} +        serializer = UniquenessSerializer(self.instance, data=data) +        assert serializer.is_valid() +        assert serializer.validated_data == {'username': 'existing'} + +    def test_doesnt_pollute_model(self): +        instance = AnotherUniquenessModel.objects.create(code='100') +        serializer = AnotherUniquenessSerializer(instance) +        self.assertEqual( +            AnotherUniquenessModel._meta.get_field('code').validators, []) + +        # Accessing data shouldn't effect validators on the model +        serializer.data +        self.assertEqual( +            AnotherUniquenessModel._meta.get_field('code').validators, []) + + +# Tests for `UniqueTogetherValidator` +# ----------------------------------- + +class UniquenessTogetherModel(models.Model): +    race_name = models.CharField(max_length=100) +    position = models.IntegerField() + +    class Meta: +        unique_together = ('race_name', 'position') + + +class NullUniquenessTogetherModel(models.Model): +    """ +    Used to ensure that null values are not included when checking +    unique_together constraints. + +    Ignoring items which have a null in any of the validated fields is the same +    behavior that database backends will use when they have the +    unique_together constraint added. + +    Example case: a null position could indicate a non-finisher in the race, +    there could be many non-finishers in a race, but all non-NULL +    values *should* be unique against the given `race_name`. +    """ +    date_of_birth = models.DateField(null=True)  # Not part of the uniqueness constraint +    race_name = models.CharField(max_length=100) +    position = models.IntegerField(null=True) + +    class Meta: +        unique_together = ('race_name', 'position') + + +class UniquenessTogetherSerializer(serializers.ModelSerializer): +    class Meta: +        model = UniquenessTogetherModel + + +class NullUniquenessTogetherSerializer(serializers.ModelSerializer): +    class Meta: +        model = NullUniquenessTogetherModel + + +class TestUniquenessTogetherValidation(TestCase): +    def setUp(self): +        self.instance = UniquenessTogetherModel.objects.create( +            race_name='example', +            position=1 +        ) +        UniquenessTogetherModel.objects.create( +            race_name='example', +            position=2 +        ) +        UniquenessTogetherModel.objects.create( +            race_name='other', +            position=1 +        ) + +    def test_repr(self): +        serializer = UniquenessTogetherSerializer() +        expected = dedent(""" +            UniquenessTogetherSerializer(): +                id = IntegerField(label='ID', read_only=True) +                race_name = CharField(max_length=100, required=True) +                position = IntegerField(required=True) +                class Meta: +                    validators = [<UniqueTogetherValidator(queryset=UniquenessTogetherModel.objects.all(), fields=('race_name', 'position'))>] +        """) +        assert repr(serializer) == expected + +    def test_is_not_unique_together(self): +        """ +        Failing unique together validation should result in non field errors. +        """ +        data = {'race_name': 'example', 'position': 2} +        serializer = UniquenessTogetherSerializer(data=data) +        assert not serializer.is_valid() +        assert serializer.errors == { +            'non_field_errors': [ +                'The fields race_name, position must make a unique set.' +            ] +        } + +    def test_is_unique_together(self): +        """ +        In a unique together validation, one field may be non-unique +        so long as the set as a whole is unique. +        """ +        data = {'race_name': 'other', 'position': 2} +        serializer = UniquenessTogetherSerializer(data=data) +        assert serializer.is_valid() +        assert serializer.validated_data == { +            'race_name': 'other', +            'position': 2 +        } + +    def test_updated_instance_excluded_from_unique_together(self): +        """ +        When performing an update, the existing instance does not count +        as a match against uniqueness. +        """ +        data = {'race_name': 'example', 'position': 1} +        serializer = UniquenessTogetherSerializer(self.instance, data=data) +        assert serializer.is_valid() +        assert serializer.validated_data == { +            'race_name': 'example', +            'position': 1 +        } + +    def test_unique_together_is_required(self): +        """ +        In a unique together validation, all fields are required. +        """ +        data = {'position': 2} +        serializer = UniquenessTogetherSerializer(data=data, partial=True) +        assert not serializer.is_valid() +        assert serializer.errors == { +            'race_name': ['This field is required.'] +        } + +    def test_ignore_excluded_fields(self): +        """ +        When model fields are not included in a serializer, then uniqueness +        validators should not be added for that field. +        """ +        class ExcludedFieldSerializer(serializers.ModelSerializer): +            class Meta: +                model = UniquenessTogetherModel +                fields = ('id', 'race_name',) +        serializer = ExcludedFieldSerializer() +        expected = dedent(""" +            ExcludedFieldSerializer(): +                id = IntegerField(label='ID', read_only=True) +                race_name = CharField(max_length=100) +        """) +        assert repr(serializer) == expected + +    def test_ignore_validation_for_null_fields(self): +        # None values that are on fields which are part of the uniqueness +        # constraint cause the instance to ignore uniqueness validation. +        NullUniquenessTogetherModel.objects.create( +            date_of_birth=datetime.date(2000, 1, 1), +            race_name='Paris Marathon', +            position=None +        ) +        data = { +            'date': datetime.date(2000, 1, 1), +            'race_name': 'Paris Marathon', +            'position': None +        } +        serializer = NullUniquenessTogetherSerializer(data=data) +        assert serializer.is_valid() + +    def test_do_not_ignore_validation_for_null_fields(self): +        # None values that are not on fields part of the uniqueness constraint +        # do not cause the instance to skip validation. +        NullUniquenessTogetherModel.objects.create( +            date_of_birth=datetime.date(2000, 1, 1), +            race_name='Paris Marathon', +            position=1 +        ) +        data = {'date': None, 'race_name': 'Paris Marathon', 'position': 1} +        serializer = NullUniquenessTogetherSerializer(data=data) +        assert not serializer.is_valid() + + +# Tests for `UniqueForDateValidator` +# ---------------------------------- + +class UniqueForDateModel(models.Model): +    slug = models.CharField(max_length=100, unique_for_date='published') +    published = models.DateField() + + +class UniqueForDateSerializer(serializers.ModelSerializer): +    class Meta: +        model = UniqueForDateModel + + +class TestUniquenessForDateValidation(TestCase): +    def setUp(self): +        self.instance = UniqueForDateModel.objects.create( +            slug='existing', +            published='2000-01-01' +        ) + +    def test_repr(self): +        serializer = UniqueForDateSerializer() +        expected = dedent(""" +            UniqueForDateSerializer(): +                id = IntegerField(label='ID', read_only=True) +                slug = CharField(max_length=100) +                published = DateField(required=True) +                class Meta: +                    validators = [<UniqueForDateValidator(queryset=UniqueForDateModel.objects.all(), field='slug', date_field='published')>] +        """) +        assert repr(serializer) == expected + +    def test_is_not_unique_for_date(self): +        """ +        Failing unique for date validation should result in field error. +        """ +        data = {'slug': 'existing', 'published': '2000-01-01'} +        serializer = UniqueForDateSerializer(data=data) +        assert not serializer.is_valid() +        assert serializer.errors == { +            'slug': ['This field must be unique for the "published" date.'] +        } + +    def test_is_unique_for_date(self): +        """ +        Passing unique for date validation. +        """ +        data = {'slug': 'existing', 'published': '2000-01-02'} +        serializer = UniqueForDateSerializer(data=data) +        assert serializer.is_valid() +        assert serializer.validated_data == { +            'slug': 'existing', +            'published': datetime.date(2000, 1, 2) +        } + +    def test_updated_instance_excluded_from_unique_for_date(self): +        """ +        When performing an update, the existing instance does not count +        as a match against unique_for_date. +        """ +        data = {'slug': 'existing', 'published': '2000-01-01'} +        serializer = UniqueForDateSerializer(instance=self.instance, data=data) +        assert serializer.is_valid() +        assert serializer.validated_data == { +            'slug': 'existing', +            'published': datetime.date(2000, 1, 1) +        } + + +class HiddenFieldUniqueForDateModel(models.Model): +    slug = models.CharField(max_length=100, unique_for_date='published') +    published = models.DateTimeField(auto_now_add=True) + + +class TestHiddenFieldUniquenessForDateValidation(TestCase): +    def test_repr_date_field_not_included(self): +        class TestSerializer(serializers.ModelSerializer): +            class Meta: +                model = HiddenFieldUniqueForDateModel +                fields = ('id', 'slug') + +        serializer = TestSerializer() +        expected = dedent(""" +            TestSerializer(): +                id = IntegerField(label='ID', read_only=True) +                slug = CharField(max_length=100) +                published = HiddenField(default=CreateOnlyDefault(<function now>)) +                class Meta: +                    validators = [<UniqueForDateValidator(queryset=HiddenFieldUniqueForDateModel.objects.all(), field='slug', date_field='published')>] +        """) +        assert repr(serializer) == expected + +    def test_repr_date_field_included(self): +        class TestSerializer(serializers.ModelSerializer): +            class Meta: +                model = HiddenFieldUniqueForDateModel +                fields = ('id', 'slug', 'published') + +        serializer = TestSerializer() +        expected = dedent(""" +            TestSerializer(): +                id = IntegerField(label='ID', read_only=True) +                slug = CharField(max_length=100) +                published = DateTimeField(default=CreateOnlyDefault(<function now>), read_only=True) +                class Meta: +                    validators = [<UniqueForDateValidator(queryset=HiddenFieldUniqueForDateModel.objects.all(), field='slug', date_field='published')>] +        """) +        assert repr(serializer) == expected diff --git a/tests/test_versioning.py b/tests/test_versioning.py new file mode 100644 index 00000000..90ad8afd --- /dev/null +++ b/tests/test_versioning.py @@ -0,0 +1,264 @@ +from .utils import UsingURLPatterns +from django.conf.urls import include, url +from rest_framework import serializers +from rest_framework import status, versioning +from rest_framework.decorators import APIView +from rest_framework.response import Response +from rest_framework.reverse import reverse +from rest_framework.test import APIRequestFactory, APITestCase +from rest_framework.versioning import NamespaceVersioning +import pytest + + +class RequestVersionView(APIView): +    def get(self, request, *args, **kwargs): +        return Response({'version': request.version}) + + +class ReverseView(APIView): +    def get(self, request, *args, **kwargs): +        return Response({'url': reverse('another', request=request)}) + + +class RequestInvalidVersionView(APIView): +    def determine_version(self, request, *args, **kwargs): +        scheme = self.versioning_class() +        scheme.allowed_versions = ('v1', 'v2') +        return (scheme.determine_version(request, *args, **kwargs), scheme) + +    def get(self, request, *args, **kwargs): +        return Response({'version': request.version}) + + +factory = APIRequestFactory() + + +def dummy_view(request): +    pass + + +def dummy_pk_view(request, pk): +    pass + + +class TestRequestVersion: +    def test_unversioned(self): +        view = RequestVersionView.as_view() + +        request = factory.get('/endpoint/') +        response = view(request) +        assert response.data == {'version': None} + +    def test_query_param_versioning(self): +        scheme = versioning.QueryParameterVersioning +        view = RequestVersionView.as_view(versioning_class=scheme) + +        request = factory.get('/endpoint/?version=1.2.3') +        response = view(request) +        assert response.data == {'version': '1.2.3'} + +        request = factory.get('/endpoint/') +        response = view(request) +        assert response.data == {'version': None} + +    def test_host_name_versioning(self): +        scheme = versioning.HostNameVersioning +        view = RequestVersionView.as_view(versioning_class=scheme) + +        request = factory.get('/endpoint/', HTTP_HOST='v1.example.org') +        response = view(request) +        assert response.data == {'version': 'v1'} + +        request = factory.get('/endpoint/') +        response = view(request) +        assert response.data == {'version': None} + +    def test_accept_header_versioning(self): +        scheme = versioning.AcceptHeaderVersioning +        view = RequestVersionView.as_view(versioning_class=scheme) + +        request = factory.get('/endpoint/', HTTP_ACCEPT='application/json; version=1.2.3') +        response = view(request) +        assert response.data == {'version': '1.2.3'} + +        request = factory.get('/endpoint/', HTTP_ACCEPT='application/json') +        response = view(request) +        assert response.data == {'version': None} + +    def test_url_path_versioning(self): +        scheme = versioning.URLPathVersioning +        view = RequestVersionView.as_view(versioning_class=scheme) + +        request = factory.get('/1.2.3/endpoint/') +        response = view(request, version='1.2.3') +        assert response.data == {'version': '1.2.3'} + +        request = factory.get('/endpoint/') +        response = view(request) +        assert response.data == {'version': None} + +    def test_namespace_versioning(self): +        class FakeResolverMatch: +            namespace = 'v1' + +        scheme = versioning.NamespaceVersioning +        view = RequestVersionView.as_view(versioning_class=scheme) + +        request = factory.get('/v1/endpoint/') +        request.resolver_match = FakeResolverMatch +        response = view(request, version='v1') +        assert response.data == {'version': 'v1'} + +        request = factory.get('/endpoint/') +        response = view(request) +        assert response.data == {'version': None} + + +class TestURLReversing(UsingURLPatterns, APITestCase): +    included = [ +        url(r'^namespaced/$', dummy_view, name='another'), +        url(r'^example/(?P<pk>\d+)/$', dummy_pk_view, name='example-detail') +    ] + +    urlpatterns = [ +        url(r'^v1/', include(included, namespace='v1')), +        url(r'^another/$', dummy_view, name='another'), +        url(r'^(?P<version>[^/]+)/another/$', dummy_view, name='another'), +    ] + +    def test_reverse_unversioned(self): +        view = ReverseView.as_view() + +        request = factory.get('/endpoint/') +        response = view(request) +        assert response.data == {'url': 'http://testserver/another/'} + +    def test_reverse_query_param_versioning(self): +        scheme = versioning.QueryParameterVersioning +        view = ReverseView.as_view(versioning_class=scheme) + +        request = factory.get('/endpoint/?version=v1') +        response = view(request) +        assert response.data == {'url': 'http://testserver/another/?version=v1'} + +        request = factory.get('/endpoint/') +        response = view(request) +        assert response.data == {'url': 'http://testserver/another/'} + +    def test_reverse_host_name_versioning(self): +        scheme = versioning.HostNameVersioning +        view = ReverseView.as_view(versioning_class=scheme) + +        request = factory.get('/endpoint/', HTTP_HOST='v1.example.org') +        response = view(request) +        assert response.data == {'url': 'http://v1.example.org/another/'} + +        request = factory.get('/endpoint/') +        response = view(request) +        assert response.data == {'url': 'http://testserver/another/'} + +    def test_reverse_url_path_versioning(self): +        scheme = versioning.URLPathVersioning +        view = ReverseView.as_view(versioning_class=scheme) + +        request = factory.get('/v1/endpoint/') +        response = view(request, version='v1') +        assert response.data == {'url': 'http://testserver/v1/another/'} + +        request = factory.get('/endpoint/') +        response = view(request) +        assert response.data == {'url': 'http://testserver/another/'} + +    def test_reverse_namespace_versioning(self): +        class FakeResolverMatch: +            namespace = 'v1' + +        scheme = versioning.NamespaceVersioning +        view = ReverseView.as_view(versioning_class=scheme) + +        request = factory.get('/v1/endpoint/') +        request.resolver_match = FakeResolverMatch +        response = view(request, version='v1') +        assert response.data == {'url': 'http://testserver/v1/namespaced/'} + +        request = factory.get('/endpoint/') +        response = view(request) +        assert response.data == {'url': 'http://testserver/another/'} + + +class TestInvalidVersion: +    def test_invalid_query_param_versioning(self): +        scheme = versioning.QueryParameterVersioning +        view = RequestInvalidVersionView.as_view(versioning_class=scheme) + +        request = factory.get('/endpoint/?version=v3') +        response = view(request) +        assert response.status_code == status.HTTP_404_NOT_FOUND + +    def test_invalid_host_name_versioning(self): +        scheme = versioning.HostNameVersioning +        view = RequestInvalidVersionView.as_view(versioning_class=scheme) + +        request = factory.get('/endpoint/', HTTP_HOST='v3.example.org') +        response = view(request) +        assert response.status_code == status.HTTP_404_NOT_FOUND + +    def test_invalid_accept_header_versioning(self): +        scheme = versioning.AcceptHeaderVersioning +        view = RequestInvalidVersionView.as_view(versioning_class=scheme) + +        request = factory.get('/endpoint/', HTTP_ACCEPT='application/json; version=v3') +        response = view(request) +        assert response.status_code == status.HTTP_406_NOT_ACCEPTABLE + +    def test_invalid_url_path_versioning(self): +        scheme = versioning.URLPathVersioning +        view = RequestInvalidVersionView.as_view(versioning_class=scheme) + +        request = factory.get('/v3/endpoint/') +        response = view(request, version='v3') +        assert response.status_code == status.HTTP_404_NOT_FOUND + +    def test_invalid_namespace_versioning(self): +        class FakeResolverMatch: +            namespace = 'v3' + +        scheme = versioning.NamespaceVersioning +        view = RequestInvalidVersionView.as_view(versioning_class=scheme) + +        request = factory.get('/v3/endpoint/') +        request.resolver_match = FakeResolverMatch +        response = view(request, version='v3') +        assert response.status_code == status.HTTP_404_NOT_FOUND + + +class TestHyperlinkedRelatedField(UsingURLPatterns, APITestCase): +    included = [ +        url(r'^namespaced/(?P<pk>\d+)/$', dummy_view, name='namespaced'), +    ] + +    urlpatterns = [ +        url(r'^v1/', include(included, namespace='v1')), +        url(r'^v2/', include(included, namespace='v2')) +    ] + +    def setUp(self): +        super(TestHyperlinkedRelatedField, self).setUp() + +        class MockQueryset(object): +            def get(self, pk): +                return 'object %s' % pk + +        self.field = serializers.HyperlinkedRelatedField( +            view_name='namespaced', +            queryset=MockQueryset() +        ) +        request = factory.get('/') +        request.versioning_scheme = NamespaceVersioning() +        request.version = 'v1' +        self.field._context = {'request': request} + +    def test_bug_2489(self): +        assert self.field.to_internal_value('/v1/namespaced/3/') == 'object 3' +        with pytest.raises(serializers.ValidationError): +            self.field.to_internal_value('/v2/namespaced/3/') diff --git a/tests/test_views.py b/tests/test_views.py new file mode 100644 index 00000000..77b113ee --- /dev/null +++ b/tests/test_views.py @@ -0,0 +1,148 @@ +from __future__ import unicode_literals + +import sys +import copy +from django.test import TestCase +from rest_framework import status +from rest_framework.decorators import api_view +from rest_framework.response import Response +from rest_framework.settings import api_settings +from rest_framework.test import APIRequestFactory +from rest_framework.views import APIView + +factory = APIRequestFactory() + +if sys.version_info[:2] >= (3, 4): +    JSON_ERROR = 'JSON parse error - Expecting value:' +else: +    JSON_ERROR = 'JSON parse error - No JSON object could be decoded' + + +class BasicView(APIView): +    def get(self, request, *args, **kwargs): +        return Response({'method': 'GET'}) + +    def post(self, request, *args, **kwargs): +        return Response({'method': 'POST', 'data': request.DATA}) + + +@api_view(['GET', 'POST', 'PUT', 'PATCH']) +def basic_view(request): +    if request.method == 'GET': +        return {'method': 'GET'} +    elif request.method == 'POST': +        return {'method': 'POST', 'data': request.DATA} +    elif request.method == 'PUT': +        return {'method': 'PUT', 'data': request.DATA} +    elif request.method == 'PATCH': +        return {'method': 'PATCH', 'data': request.DATA} + + +class ErrorView(APIView): +    def get(self, request, *args, **kwargs): +        raise Exception + + +@api_view(['GET']) +def error_view(request): +    raise Exception + + +def sanitise_json_error(error_dict): +    """ +    Exact contents of JSON error messages depend on the installed version +    of json. +    """ +    ret = copy.copy(error_dict) +    chop = len(JSON_ERROR) +    ret['detail'] = ret['detail'][:chop] +    return ret + + +class ClassBasedViewIntegrationTests(TestCase): +    def setUp(self): +        self.view = BasicView.as_view() + +    def test_400_parse_error(self): +        request = factory.post('/', 'f00bar', content_type='application/json') +        response = self.view(request) +        expected = { +            'detail': JSON_ERROR +        } +        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) +        self.assertEqual(sanitise_json_error(response.data), expected) + +    def test_400_parse_error_tunneled_content(self): +        content = 'f00bar' +        content_type = 'application/json' +        form_data = { +            api_settings.FORM_CONTENT_OVERRIDE: content, +            api_settings.FORM_CONTENTTYPE_OVERRIDE: content_type +        } +        request = factory.post('/', form_data) +        response = self.view(request) +        expected = { +            'detail': JSON_ERROR +        } +        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) +        self.assertEqual(sanitise_json_error(response.data), expected) + + +class FunctionBasedViewIntegrationTests(TestCase): +    def setUp(self): +        self.view = basic_view + +    def test_400_parse_error(self): +        request = factory.post('/', 'f00bar', content_type='application/json') +        response = self.view(request) +        expected = { +            'detail': JSON_ERROR +        } +        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) +        self.assertEqual(sanitise_json_error(response.data), expected) + +    def test_400_parse_error_tunneled_content(self): +        content = 'f00bar' +        content_type = 'application/json' +        form_data = { +            api_settings.FORM_CONTENT_OVERRIDE: content, +            api_settings.FORM_CONTENTTYPE_OVERRIDE: content_type +        } +        request = factory.post('/', form_data) +        response = self.view(request) +        expected = { +            'detail': JSON_ERROR +        } +        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) +        self.assertEqual(sanitise_json_error(response.data), expected) + + +class TestCustomExceptionHandler(TestCase): +    def setUp(self): +        self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER + +        def exception_handler(exc): +            return Response('Error!', status=status.HTTP_400_BAD_REQUEST) + +        api_settings.EXCEPTION_HANDLER = exception_handler + +    def tearDown(self): +        api_settings.EXCEPTION_HANDLER = self.DEFAULT_HANDLER + +    def test_class_based_view_exception_handler(self): +        view = ErrorView.as_view() + +        request = factory.get('/', content_type='application/json') +        response = view(request) +        expected = 'Error!' +        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) +        self.assertEqual(response.data, expected) + +    def test_function_based_view_exception_handler(self): +        view = error_view + +        request = factory.get('/', content_type='application/json') +        response = view(request) +        expected = 'Error!' +        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) +        self.assertEqual(response.data, expected) diff --git a/tests/test_viewsets.py b/tests/test_viewsets.py new file mode 100644 index 00000000..4d18a955 --- /dev/null +++ b/tests/test_viewsets.py @@ -0,0 +1,35 @@ +from django.test import TestCase +from rest_framework import status +from rest_framework.response import Response +from rest_framework.test import APIRequestFactory +from rest_framework.viewsets import GenericViewSet + + +factory = APIRequestFactory() + + +class BasicViewSet(GenericViewSet): +    def list(self, request, *args, **kwargs): +        return Response({'ACTION': 'LIST'}) + + +class InitializeViewSetsTestCase(TestCase): +    def test_initialize_view_set_with_actions(self): +        request = factory.get('/', '', content_type='application/json') +        my_view = BasicViewSet.as_view(actions={ +            'get': 'list', +        }) + +        response = my_view(request) +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, {'ACTION': 'LIST'}) + +    def test_initialize_view_set_with_empty_actions(self): +        try: +            BasicViewSet.as_view() +        except TypeError as e: +            self.assertEqual(str(e), "The `actions` argument must be provided " +                                     "when calling `.as_view()` on a ViewSet. " +                                     "For example `.as_view({'get': 'list'})`") +        else: +            self.fail("actions must not be empty.") diff --git a/tests/test_write_only_fields.py b/tests/test_write_only_fields.py new file mode 100644 index 00000000..dd3bbd6e --- /dev/null +++ b/tests/test_write_only_fields.py @@ -0,0 +1,31 @@ +from django.test import TestCase +from rest_framework import serializers + + +class WriteOnlyFieldTests(TestCase): +    def setUp(self): +        class ExampleSerializer(serializers.Serializer): +            email = serializers.EmailField() +            password = serializers.CharField(write_only=True) + +            def create(self, attrs): +                return attrs + +        self.Serializer = ExampleSerializer + +    def write_only_fields_are_present_on_input(self): +        data = { +            'email': 'foo@example.com', +            'password': '123' +        } +        serializer = self.Serializer(data=data) +        self.assertTrue(serializer.is_valid()) +        self.assertEquals(serializer.validated_data, data) + +    def write_only_fields_are_not_present_on_output(self): +        instance = { +            'email': 'foo@example.com', +            'password': '123' +        } +        serializer = self.Serializer(instance) +        self.assertEquals(serializer.data, {'email': 'foo@example.com'}) diff --git a/tests/urls.py b/tests/urls.py new file mode 100644 index 00000000..41f527df --- /dev/null +++ b/tests/urls.py @@ -0,0 +1,6 @@ +""" +Blank URLConf just to keep the test suite happy +""" +from django.conf.urls import patterns + +urlpatterns = patterns('') diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 00000000..b9034996 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,77 @@ +from django.core.exceptions import ObjectDoesNotExist +from django.core.urlresolvers import NoReverseMatch + + +class UsingURLPatterns(object): +    """ +    Isolates URL patterns used during testing on the test class itself. +    For example: + +    class MyTestCase(UsingURLPatterns, TestCase): +        urlpatterns = [ +            ... +        ] + +        def test_something(self): +            ... +    """ +    urls = __name__ + +    def setUp(self): +        global urlpatterns +        urlpatterns = self.urlpatterns + +    def tearDown(self): +        global urlpatterns +        urlpatterns = [] + + +class MockObject(object): +    def __init__(self, **kwargs): +        self._kwargs = kwargs +        for key, val in kwargs.items(): +            setattr(self, key, val) + +    def __str__(self): +        kwargs_str = ', '.join([ +            '%s=%s' % (key, value) +            for key, value in sorted(self._kwargs.items()) +        ]) +        return '<MockObject %s>' % kwargs_str + + +class MockQueryset(object): +    def __init__(self, iterable): +        self.items = iterable + +    def get(self, **lookup): +        for item in self.items: +            if all([ +                getattr(item, key, None) == value +                for key, value in lookup.items() +            ]): +                return item +        raise ObjectDoesNotExist() + + +class BadType(object): +    """ +    When used as a lookup with a `MockQueryset`, these objects +    will raise a `TypeError`, as occurs in Django when making +    queryset lookups with an incorrect type for the lookup value. +    """ +    def __eq__(self): +        raise TypeError() + + +def mock_reverse(view_name, args=None, kwargs=None, request=None, format=None): +    args = args or [] +    kwargs = kwargs or {} +    value = (args + list(kwargs.values()) + ['-'])[0] +    prefix = 'http://example.org' if request else '' +    suffix = ('.' + format) if (format is not None) else '' +    return '%s/%s/%s%s/' % (prefix, view_name, value, suffix) + + +def fail_reverse(view_name, args=None, kwargs=None, request=None, format=None): +    raise NoReverseMatch() | 
