diff options
Diffstat (limited to 'tests')
61 files changed, 12522 insertions, 0 deletions
| diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/tests/__init__.py diff --git a/tests/accounts/__init__.py b/tests/accounts/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/tests/accounts/__init__.py diff --git a/tests/accounts/models.py b/tests/accounts/models.py new file mode 100644 index 00000000..3bf4a0c3 --- /dev/null +++ b/tests/accounts/models.py @@ -0,0 +1,8 @@ +from django.db import models + +from tests.users.models import User + + +class Account(models.Model): +    owner = models.ForeignKey(User, related_name='accounts_owned') +    admins = models.ManyToManyField(User, blank=True, null=True, related_name='accounts_administered') diff --git a/tests/accounts/serializers.py b/tests/accounts/serializers.py new file mode 100644 index 00000000..57a91b92 --- /dev/null +++ b/tests/accounts/serializers.py @@ -0,0 +1,11 @@ +from rest_framework import serializers + +from tests.accounts.models import Account +from tests.users.serializers import UserSerializer + + +class AccountSerializer(serializers.ModelSerializer): +    admins = UserSerializer(many=True) + +    class Meta: +        model = Account diff --git a/tests/description.py b/tests/description.py new file mode 100644 index 00000000..b46d7f54 --- /dev/null +++ b/tests/description.py @@ -0,0 +1,26 @@ +# -- coding: utf-8 -- + +# Apparently there is a python 2.6 issue where docstrings of imported view classes +# do not retain their encoding information even if a module has a proper +# encoding declaration at the top of its source file. Therefore for tests +# to catch unicode related errors, a mock view has to be declared in a separate +# module. + +from rest_framework.views import APIView + + +# test strings snatched from http://www.columbia.edu/~fdc/utf8/, +# http://winrus.com/utf8-jap.htm and memory +UTF8_TEST_DOCSTRING = ( +    'zażółć gęślą jaźń' +    'Sîne klâwen durh die wolken sint geslagen' +    'Τη γλώσσα μου έδωσαν ελληνική' +    'யாமறிந்த மொழிகளிலே தமிழ்மொழி' +    'На берегу пустынных волн' +    'てすと' +    'アイウエオカキクケコサシスセソタチツテ' +) + + +class ViewWithNonASCIICharactersInDocstring(APIView): +    __doc__ = UTF8_TEST_DOCSTRING diff --git a/tests/extras/__init__.py b/tests/extras/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/tests/extras/__init__.py diff --git a/tests/extras/bad_import.py b/tests/extras/bad_import.py new file mode 100644 index 00000000..68263d94 --- /dev/null +++ b/tests/extras/bad_import.py @@ -0,0 +1 @@ +raise ValueError diff --git a/tests/models.py b/tests/models.py new file mode 100644 index 00000000..0256697a --- /dev/null +++ b/tests/models.py @@ -0,0 +1,178 @@ +from __future__ import unicode_literals +from django.db import models +from django.utils.translation import ugettext_lazy as _ +from rest_framework import serializers + + +def foobar(): +    return 'foobar' + + +class CustomField(models.CharField): + +    def __init__(self, *args, **kwargs): +        kwargs['max_length'] = 12 +        super(CustomField, self).__init__(*args, **kwargs) + + +class RESTFrameworkModel(models.Model): +    """ +    Base for test models that sets app_label, so they play nicely. +    """ +    class Meta: +        app_label = 'tests' +        abstract = True + + +class HasPositiveIntegerAsChoice(RESTFrameworkModel): +    some_choices = ((1, 'A'), (2, 'B'), (3, 'C')) +    some_integer = models.PositiveIntegerField(choices=some_choices) + + +class Anchor(RESTFrameworkModel): +    text = models.CharField(max_length=100, default='anchor') + + +class BasicModel(RESTFrameworkModel): +    text = models.CharField(max_length=100, verbose_name=_("Text comes here"), help_text=_("Text description.")) + + +class SlugBasedModel(RESTFrameworkModel): +    text = models.CharField(max_length=100) +    slug = models.SlugField(max_length=32) + + +class DefaultValueModel(RESTFrameworkModel): +    text = models.CharField(default='foobar', max_length=100) +    extra = models.CharField(blank=True, null=True, max_length=100) + + +class CallableDefaultValueModel(RESTFrameworkModel): +    text = models.CharField(default=foobar, max_length=100) + + +class ManyToManyModel(RESTFrameworkModel): +    rel = models.ManyToManyField(Anchor, help_text='Some help text.') + + +class ReadOnlyManyToManyModel(RESTFrameworkModel): +    text = models.CharField(max_length=100, default='anchor') +    rel = models.ManyToManyField(Anchor) + + +# Model for regression test for #285 + +class Comment(RESTFrameworkModel): +    email = models.EmailField() +    content = models.CharField(max_length=200) +    created = models.DateTimeField(auto_now_add=True) + + +class ActionItem(RESTFrameworkModel): +    title = models.CharField(max_length=200) +    started = models.NullBooleanField(default=False) +    done = models.BooleanField(default=False) +    info = CustomField(default='---', max_length=12) + + +# Models for reverse relations +class Person(RESTFrameworkModel): +    name = models.CharField(max_length=10) +    age = models.IntegerField(null=True, blank=True) + +    @property +    def info(self): +        return { +            'name': self.name, +            'age': self.age, +        } + + +class BlogPost(RESTFrameworkModel): +    title = models.CharField(max_length=100) +    writer = models.ForeignKey(Person, null=True, blank=True) + +    def get_first_comment(self): +        return self.blogpostcomment_set.all()[0] + + +class BlogPostComment(RESTFrameworkModel): +    text = models.TextField() +    blog_post = models.ForeignKey(BlogPost) + + +class Album(RESTFrameworkModel): +    title = models.CharField(max_length=100, unique=True) +    ref = models.CharField(max_length=10, unique=True, null=True, blank=True) + +class Photo(RESTFrameworkModel): +    description = models.TextField() +    album = models.ForeignKey(Album) + + +# Model for issue #324 +class BlankFieldModel(RESTFrameworkModel): +    title = models.CharField(max_length=100, blank=True, null=False) + + +# Model for issue #380 +class OptionalRelationModel(RESTFrameworkModel): +    other = models.ForeignKey('OptionalRelationModel', blank=True, null=True) + + +# Model for RegexField +class Book(RESTFrameworkModel): +    isbn = models.CharField(max_length=13) + + +# Models for relations tests +# ManyToMany +class ManyToManyTarget(RESTFrameworkModel): +    name = models.CharField(max_length=100) + + +class ManyToManySource(RESTFrameworkModel): +    name = models.CharField(max_length=100) +    targets = models.ManyToManyField(ManyToManyTarget, related_name='sources') + + +# ForeignKey +class ForeignKeyTarget(RESTFrameworkModel): +    name = models.CharField(max_length=100) + + +class ForeignKeySource(RESTFrameworkModel): +    name = models.CharField(max_length=100) +    target = models.ForeignKey(ForeignKeyTarget, related_name='sources', +                               help_text='Target', verbose_name='Target') + + +# Nullable ForeignKey +class NullableForeignKeySource(RESTFrameworkModel): +    name = models.CharField(max_length=100) +    target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True, +                               related_name='nullable_sources') + + +# OneToOne +class OneToOneTarget(RESTFrameworkModel): +    name = models.CharField(max_length=100) + + +class NullableOneToOneSource(RESTFrameworkModel): +    name = models.CharField(max_length=100) +    target = models.OneToOneField(OneToOneTarget, null=True, blank=True, +                                  related_name='nullable_source') + + +# Serializer used to test BasicModel +class BasicModelSerializer(serializers.ModelSerializer): +    class Meta: +        model = BasicModel + + +# Models to test filters +class FilterableItem(models.Model): +    text = models.CharField(max_length=100) +    decimal = models.DecimalField(max_digits=4, decimal_places=2) +    date = models.DateField() diff --git a/tests/records/__init__.py b/tests/records/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/tests/records/__init__.py diff --git a/tests/records/models.py b/tests/records/models.py new file mode 100644 index 00000000..76954807 --- /dev/null +++ b/tests/records/models.py @@ -0,0 +1,6 @@ +from django.db import models + + +class Record(models.Model): +    account = models.ForeignKey('accounts.Account', blank=True, null=True) +    owner = models.ForeignKey('users.User', blank=True, null=True) diff --git a/tests/serializers.py b/tests/serializers.py new file mode 100644 index 00000000..f2f85b6e --- /dev/null +++ b/tests/serializers.py @@ -0,0 +1,8 @@ +from rest_framework import serializers + +from tests.models import NullableForeignKeySource + + +class NullableFKSourceSerializer(serializers.ModelSerializer): +    class Meta: +        model = NullableForeignKeySource diff --git a/tests/settings.py b/tests/settings.py new file mode 100644 index 00000000..75f7c54b --- /dev/null +++ b/tests/settings.py @@ -0,0 +1,169 @@ +# Django settings for testproject project. + +DEBUG = True +TEMPLATE_DEBUG = DEBUG +DEBUG_PROPAGATE_EXCEPTIONS = True + +ALLOWED_HOSTS = ['*'] + +ADMINS = ( +    # ('Your Name', 'your_email@domain.com'), +) + +MANAGERS = ADMINS + +DATABASES = { +    'default': { +        'ENGINE': 'django.db.backends.sqlite3',  # Add 'postgresql_psycopg2', 'postgresql', 'mysql', 'sqlite3' or 'oracle'. +        'NAME': 'sqlite.db',                     # Or path to database file if using sqlite3. +        'USER': '',                      # Not used with sqlite3. +        'PASSWORD': '',                  # Not used with sqlite3. +        'HOST': '',                      # Set to empty string for localhost. Not used with sqlite3. +        'PORT': '',                      # Set to empty string for default. Not used with sqlite3. +    } +} + +CACHES = { +    'default': { +        'BACKEND': 'django.core.cache.backends.locmem.LocMemCache', +    } +} + +# Local time zone for this installation. Choices can be found here: +# http://en.wikipedia.org/wiki/List_of_tz_zones_by_name +# although not all choices may be available on all operating systems. +# On Unix systems, a value of None will cause Django to use the same +# timezone as the operating system. +# If running in a Windows environment this must be set to the same as your +# system time zone. +TIME_ZONE = 'Europe/London' + +# Language code for this installation. All choices can be found here: +# http://www.i18nguy.com/unicode/language-identifiers.html +LANGUAGE_CODE = 'en-uk' + +SITE_ID = 1 + +# If you set this to False, Django will make some optimizations so as not +# to load the internationalization machinery. +USE_I18N = True + +# If you set this to False, Django will not format dates, numbers and +# calendars according to the current locale +USE_L10N = True + +# Absolute filesystem path to the directory that will hold user-uploaded files. +# Example: "/home/media/media.lawrence.com/" +MEDIA_ROOT = '' + +# URL that handles the media served from MEDIA_ROOT. Make sure to use a +# trailing slash if there is a path component (optional in other cases). +# Examples: "http://media.lawrence.com", "http://example.com/media/" +MEDIA_URL = '' + +# Make this unique, and don't share it with anybody. +SECRET_KEY = 'u@x-aj9(hoh#rb-^ymf#g2jx_hp0vj7u5#b@ag1n^seu9e!%cy' + +# List of callables that know how to import templates from various sources. +TEMPLATE_LOADERS = ( +    'django.template.loaders.filesystem.Loader', +    'django.template.loaders.app_directories.Loader', +#     'django.template.loaders.eggs.Loader', +) + +MIDDLEWARE_CLASSES = ( +    'django.middleware.common.CommonMiddleware', +    'django.contrib.sessions.middleware.SessionMiddleware', +    'django.middleware.csrf.CsrfViewMiddleware', +    'django.contrib.auth.middleware.AuthenticationMiddleware', +    'django.contrib.messages.middleware.MessageMiddleware', +) + +ROOT_URLCONF = 'tests.urls' + +TEMPLATE_DIRS = ( +    # Put strings here, like "/home/html/django_templates" or "C:/www/django/templates". +    # Always use forward slashes, even on Windows. +    # Don't forget to use absolute paths, not relative paths. +) + +INSTALLED_APPS = ( +    'django.contrib.auth', +    'django.contrib.contenttypes', +    'django.contrib.sessions', +    'django.contrib.sites', +    'django.contrib.messages', +    # Uncomment the next line to enable the admin: +    # 'django.contrib.admin', +    # Uncomment the next line to enable admin documentation: +    # 'django.contrib.admindocs', +    'rest_framework', +    'rest_framework.authtoken', +    'tests', +    'tests.accounts', +    'tests.records', +    'tests.users', +) + +# OAuth is optional and won't work if there is no oauth_provider & oauth2 +try: +    import oauth_provider +    import oauth2 +except ImportError: +    pass +else: +    INSTALLED_APPS += ( +        'oauth_provider', +    ) + +try: +    import provider +except ImportError: +    pass +else: +    INSTALLED_APPS += ( +        'provider', +        'provider.oauth2', +    ) + +# guardian is optional +try: +    import guardian +except ImportError: +    pass +else: +    ANONYMOUS_USER_ID = -1 +    AUTHENTICATION_BACKENDS = ( +        'django.contrib.auth.backends.ModelBackend', # default +        'guardian.backends.ObjectPermissionBackend', +    ) +    INSTALLED_APPS += ( +        'guardian', +    ) + +STATIC_URL = '/static/' + +PASSWORD_HASHERS = ( +    'django.contrib.auth.hashers.SHA1PasswordHasher', +    'django.contrib.auth.hashers.PBKDF2PasswordHasher', +    'django.contrib.auth.hashers.PBKDF2SHA1PasswordHasher', +    'django.contrib.auth.hashers.BCryptPasswordHasher', +    'django.contrib.auth.hashers.MD5PasswordHasher', +    'django.contrib.auth.hashers.CryptPasswordHasher', +) + +AUTH_USER_MODEL = 'auth.User' + +import django + +if django.VERSION < (1, 3): +    INSTALLED_APPS += ('staticfiles',) + + +# If we're running on the Jenkins server we want to archive the coverage reports as XML. +import os +if os.environ.get('HUDSON_URL', None): +    TEST_RUNNER = 'xmlrunner.extra.djangotestrunner.XMLTestRunner' +    TEST_OUTPUT_VERBOSE = True +    TEST_OUTPUT_DESCRIPTIONS = True +    TEST_OUTPUT_DIR = 'xmlrunner' diff --git a/tests/test_authentication.py b/tests/test_authentication.py new file mode 100644 index 00000000..d0290eac --- /dev/null +++ b/tests/test_authentication.py @@ -0,0 +1,669 @@ +from __future__ import unicode_literals +from django.contrib.auth.models import User +from django.http import HttpResponse +from django.test import TestCase +from django.utils import unittest +from django.utils.http import urlencode +from rest_framework import HTTP_HEADER_ENCODING +from rest_framework import exceptions +from rest_framework import permissions +from rest_framework import renderers +from rest_framework.response import Response +from rest_framework import status +from rest_framework.authentication import ( +    BaseAuthentication, +    TokenAuthentication, +    BasicAuthentication, +    SessionAuthentication, +    OAuthAuthentication, +    OAuth2Authentication +) +from rest_framework.authtoken.models import Token +from rest_framework.compat import patterns, url, include, six +from rest_framework.compat import oauth2_provider, oauth2_provider_scope +from rest_framework.compat import oauth, oauth_provider +from rest_framework.test import APIRequestFactory, APIClient +from rest_framework.views import APIView +import base64 +import time +import datetime + +factory = APIRequestFactory() + + +class MockView(APIView): +    permission_classes = (permissions.IsAuthenticated,) + +    def get(self, request): +        return HttpResponse({'a': 1, 'b': 2, 'c': 3}) + +    def post(self, request): +        return HttpResponse({'a': 1, 'b': 2, 'c': 3}) + +    def put(self, request): +        return HttpResponse({'a': 1, 'b': 2, 'c': 3}) + + +urlpatterns = patterns('', +    (r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])), +    (r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])), +    (r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])), +    (r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'), +    (r'^oauth/$', MockView.as_view(authentication_classes=[OAuthAuthentication])), +    (r'^oauth-with-scope/$', MockView.as_view(authentication_classes=[OAuthAuthentication], +        permission_classes=[permissions.TokenHasReadWriteScope])) +) + +class OAuth2AuthenticationDebug(OAuth2Authentication): +    allow_query_params_token = True + +if oauth2_provider is not None: +    urlpatterns += patterns('', +        url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')), +        url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])), +        url(r'^oauth2-test-debug/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationDebug])), +        url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication], +            permission_classes=[permissions.TokenHasReadWriteScope])), +    ) + + +class BasicAuthTests(TestCase): +    """Basic authentication""" +    urls = 'tests.test_authentication' + +    def setUp(self): +        self.csrf_client = APIClient(enforce_csrf_checks=True) +        self.username = 'john' +        self.email = 'lennon@thebeatles.com' +        self.password = 'password' +        self.user = User.objects.create_user(self.username, self.email, self.password) + +    def test_post_form_passing_basic_auth(self): +        """Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF""" +        credentials = ('%s:%s' % (self.username, self.password)) +        base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING) +        auth = 'Basic %s' % base64_credentials +        response = self.csrf_client.post('/basic/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) +        self.assertEqual(response.status_code, status.HTTP_200_OK) + +    def test_post_json_passing_basic_auth(self): +        """Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF""" +        credentials = ('%s:%s' % (self.username, self.password)) +        base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING) +        auth = 'Basic %s' % base64_credentials +        response = self.csrf_client.post('/basic/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth) +        self.assertEqual(response.status_code, status.HTTP_200_OK) + +    def test_post_form_failing_basic_auth(self): +        """Ensure POSTing form over basic auth without correct credentials fails""" +        response = self.csrf_client.post('/basic/', {'example': 'example'}) +        self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + +    def test_post_json_failing_basic_auth(self): +        """Ensure POSTing json over basic auth without correct credentials fails""" +        response = self.csrf_client.post('/basic/', {'example': 'example'}, format='json') +        self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) +        self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"') + + +class SessionAuthTests(TestCase): +    """User session authentication""" +    urls = 'tests.test_authentication' + +    def setUp(self): +        self.csrf_client = APIClient(enforce_csrf_checks=True) +        self.non_csrf_client = APIClient(enforce_csrf_checks=False) +        self.username = 'john' +        self.email = 'lennon@thebeatles.com' +        self.password = 'password' +        self.user = User.objects.create_user(self.username, self.email, self.password) + +    def tearDown(self): +        self.csrf_client.logout() + +    def test_post_form_session_auth_failing_csrf(self): +        """ +        Ensure POSTing form over session authentication without CSRF token fails. +        """ +        self.csrf_client.login(username=self.username, password=self.password) +        response = self.csrf_client.post('/session/', {'example': 'example'}) +        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + +    def test_post_form_session_auth_passing(self): +        """ +        Ensure POSTing form over session authentication with logged in user and CSRF token passes. +        """ +        self.non_csrf_client.login(username=self.username, password=self.password) +        response = self.non_csrf_client.post('/session/', {'example': 'example'}) +        self.assertEqual(response.status_code, status.HTTP_200_OK) + +    def test_put_form_session_auth_passing(self): +        """ +        Ensure PUTting form over session authentication with logged in user and CSRF token passes. +        """ +        self.non_csrf_client.login(username=self.username, password=self.password) +        response = self.non_csrf_client.put('/session/', {'example': 'example'}) +        self.assertEqual(response.status_code, status.HTTP_200_OK) + +    def test_post_form_session_auth_failing(self): +        """ +        Ensure POSTing form over session authentication without logged in user fails. +        """ +        response = self.csrf_client.post('/session/', {'example': 'example'}) +        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + +class TokenAuthTests(TestCase): +    """Token authentication""" +    urls = 'tests.test_authentication' + +    def setUp(self): +        self.csrf_client = APIClient(enforce_csrf_checks=True) +        self.username = 'john' +        self.email = 'lennon@thebeatles.com' +        self.password = 'password' +        self.user = User.objects.create_user(self.username, self.email, self.password) + +        self.key = 'abcd1234' +        self.token = Token.objects.create(key=self.key, user=self.user) + +    def test_post_form_passing_token_auth(self): +        """Ensure POSTing json over token auth with correct credentials passes and does not require CSRF""" +        auth = 'Token ' + self.key +        response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) +        self.assertEqual(response.status_code, status.HTTP_200_OK) + +    def test_post_json_passing_token_auth(self): +        """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF""" +        auth = "Token " + self.key +        response = self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth) +        self.assertEqual(response.status_code, status.HTTP_200_OK) + +    def test_post_form_failing_token_auth(self): +        """Ensure POSTing form over token auth without correct credentials fails""" +        response = self.csrf_client.post('/token/', {'example': 'example'}) +        self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + +    def test_post_json_failing_token_auth(self): +        """Ensure POSTing json over token auth without correct credentials fails""" +        response = self.csrf_client.post('/token/', {'example': 'example'}, format='json') +        self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + +    def test_token_has_auto_assigned_key_if_none_provided(self): +        """Ensure creating a token with no key will auto-assign a key""" +        self.token.delete() +        token = Token.objects.create(user=self.user) +        self.assertTrue(bool(token.key)) + +    def test_generate_key_returns_string(self): +        """Ensure generate_key returns a string""" +        token = Token() +        key = token.generate_key() +        self.assertTrue(isinstance(key, six.string_types)) + +    def test_token_login_json(self): +        """Ensure token login view using JSON POST works.""" +        client = APIClient(enforce_csrf_checks=True) +        response = client.post('/auth-token/', +                               {'username': self.username, 'password': self.password}, format='json') +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data['token'], self.key) + +    def test_token_login_json_bad_creds(self): +        """Ensure token login view using JSON POST fails if bad credentials are used.""" +        client = APIClient(enforce_csrf_checks=True) +        response = client.post('/auth-token/', +                               {'username': self.username, 'password': "badpass"}, format='json') +        self.assertEqual(response.status_code, 400) + +    def test_token_login_json_missing_fields(self): +        """Ensure token login view using JSON POST fails if missing fields.""" +        client = APIClient(enforce_csrf_checks=True) +        response = client.post('/auth-token/', +                               {'username': self.username}, format='json') +        self.assertEqual(response.status_code, 400) + +    def test_token_login_form(self): +        """Ensure token login view using form POST works.""" +        client = APIClient(enforce_csrf_checks=True) +        response = client.post('/auth-token/', +                               {'username': self.username, 'password': self.password}) +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data['token'], self.key) + + +class IncorrectCredentialsTests(TestCase): +    def test_incorrect_credentials(self): +        """ +        If a request contains bad authentication credentials, then +        authentication should run and error, even if no permissions +        are set on the view. +        """ +        class IncorrectCredentialsAuth(BaseAuthentication): +            def authenticate(self, request): +                raise exceptions.AuthenticationFailed('Bad credentials') + +        request = factory.get('/') +        view = MockView.as_view( +            authentication_classes=(IncorrectCredentialsAuth,), +            permission_classes=() +        ) +        response = view(request) +        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) +        self.assertEqual(response.data, {'detail': 'Bad credentials'}) + + +class OAuthTests(TestCase): +    """OAuth 1.0a authentication""" +    urls = 'tests.test_authentication' + +    def setUp(self): +        # these imports are here because oauth is optional and hiding them in try..except block or compat +        # could obscure problems if something breaks +        from oauth_provider.models import Consumer, Scope +        from oauth_provider.models import Token as OAuthToken +        from oauth_provider import consts + +        self.consts = consts + +        self.csrf_client = APIClient(enforce_csrf_checks=True) +        self.username = 'john' +        self.email = 'lennon@thebeatles.com' +        self.password = 'password' +        self.user = User.objects.create_user(self.username, self.email, self.password) + +        self.CONSUMER_KEY = 'consumer_key' +        self.CONSUMER_SECRET = 'consumer_secret' +        self.TOKEN_KEY = "token_key" +        self.TOKEN_SECRET = "token_secret" + +        self.consumer = Consumer.objects.create(key=self.CONSUMER_KEY, secret=self.CONSUMER_SECRET, +            name='example', user=self.user, status=self.consts.ACCEPTED) + +        self.scope = Scope.objects.create(name="resource name", url="api/") +        self.token = OAuthToken.objects.create(user=self.user, consumer=self.consumer, scope=self.scope, +            token_type=OAuthToken.ACCESS, key=self.TOKEN_KEY, secret=self.TOKEN_SECRET, is_approved=True +        ) + +    def _create_authorization_header(self): +        params = { +            'oauth_version': "1.0", +            'oauth_nonce': oauth.generate_nonce(), +            'oauth_timestamp': int(time.time()), +            'oauth_token': self.token.key, +            'oauth_consumer_key': self.consumer.key +        } + +        req = oauth.Request(method="GET", url="http://example.com", parameters=params) + +        signature_method = oauth.SignatureMethod_PLAINTEXT() +        req.sign_request(signature_method, self.consumer, self.token) + +        return req.to_header()["Authorization"] + +    def _create_authorization_url_parameters(self): +        params = { +            'oauth_version': "1.0", +            'oauth_nonce': oauth.generate_nonce(), +            'oauth_timestamp': int(time.time()), +            'oauth_token': self.token.key, +            'oauth_consumer_key': self.consumer.key +        } + +        req = oauth.Request(method="GET", url="http://example.com", parameters=params) + +        signature_method = oauth.SignatureMethod_PLAINTEXT() +        req.sign_request(signature_method, self.consumer, self.token) +        return dict(req) + +    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') +    @unittest.skipUnless(oauth, 'oauth2 not installed') +    def test_post_form_passing_oauth(self): +        """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF""" +        auth = self._create_authorization_header() +        response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) +        self.assertEqual(response.status_code, 200) + +    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') +    @unittest.skipUnless(oauth, 'oauth2 not installed') +    def test_post_form_repeated_nonce_failing_oauth(self): +        """Ensure POSTing form over OAuth with repeated auth (same nonces and timestamp) credentials fails""" +        auth = self._create_authorization_header() +        response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) +        self.assertEqual(response.status_code, 200) + +        # simulate reply attack auth header containes already used (nonce, timestamp) pair +        response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) +        self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + +    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') +    @unittest.skipUnless(oauth, 'oauth2 not installed') +    def test_post_form_token_removed_failing_oauth(self): +        """Ensure POSTing when there is no OAuth access token in db fails""" +        self.token.delete() +        auth = self._create_authorization_header() +        response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) +        self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + +    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') +    @unittest.skipUnless(oauth, 'oauth2 not installed') +    def test_post_form_consumer_status_not_accepted_failing_oauth(self): +        """Ensure POSTing when consumer status is anything other than ACCEPTED fails""" +        for consumer_status in (self.consts.CANCELED, self.consts.PENDING, self.consts.REJECTED): +            self.consumer.status = consumer_status +            self.consumer.save() + +            auth = self._create_authorization_header() +            response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) +            self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + +    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') +    @unittest.skipUnless(oauth, 'oauth2 not installed') +    def test_post_form_with_request_token_failing_oauth(self): +        """Ensure POSTing with unauthorized request token instead of access token fails""" +        self.token.token_type = self.token.REQUEST +        self.token.save() + +        auth = self._create_authorization_header() +        response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) +        self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + +    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') +    @unittest.skipUnless(oauth, 'oauth2 not installed') +    def test_post_form_with_urlencoded_parameters(self): +        """Ensure POSTing with x-www-form-urlencoded auth parameters passes""" +        params = self._create_authorization_url_parameters() +        auth = self._create_authorization_header() +        response = self.csrf_client.post('/oauth/', params, HTTP_AUTHORIZATION=auth) +        self.assertEqual(response.status_code, 200) + +    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') +    @unittest.skipUnless(oauth, 'oauth2 not installed') +    def test_get_form_with_url_parameters(self): +        """Ensure GETing with auth in url parameters passes""" +        params = self._create_authorization_url_parameters() +        response = self.csrf_client.get('/oauth/', params) +        self.assertEqual(response.status_code, 200) + +    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') +    @unittest.skipUnless(oauth, 'oauth2 not installed') +    def test_post_hmac_sha1_signature_passes(self): +        """Ensure POSTing using HMAC_SHA1 signature method passes""" +        params = { +            'oauth_version': "1.0", +            'oauth_nonce': oauth.generate_nonce(), +            'oauth_timestamp': int(time.time()), +            'oauth_token': self.token.key, +            'oauth_consumer_key': self.consumer.key +        } + +        req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params) + +        signature_method = oauth.SignatureMethod_HMAC_SHA1() +        req.sign_request(signature_method, self.consumer, self.token) +        auth = req.to_header()["Authorization"] + +        response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth) +        self.assertEqual(response.status_code, 200) + +    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') +    @unittest.skipUnless(oauth, 'oauth2 not installed') +    def test_get_form_with_readonly_resource_passing_auth(self): +        """Ensure POSTing with a readonly scope instead of a write scope fails""" +        read_only_access_token = self.token +        read_only_access_token.scope.is_readonly = True +        read_only_access_token.scope.save() +        params = self._create_authorization_url_parameters() +        response = self.csrf_client.get('/oauth-with-scope/', params) +        self.assertEqual(response.status_code, 200) + +    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') +    @unittest.skipUnless(oauth, 'oauth2 not installed') +    def test_post_form_with_readonly_resource_failing_auth(self): +        """Ensure POSTing with a readonly resource instead of a write scope fails""" +        read_only_access_token = self.token +        read_only_access_token.scope.is_readonly = True +        read_only_access_token.scope.save() +        params = self._create_authorization_url_parameters() +        response = self.csrf_client.post('/oauth-with-scope/', params) +        self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + +    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') +    @unittest.skipUnless(oauth, 'oauth2 not installed') +    def test_post_form_with_write_resource_passing_auth(self): +        """Ensure POSTing with a write resource succeed""" +        read_write_access_token = self.token +        read_write_access_token.scope.is_readonly = False +        read_write_access_token.scope.save() +        params = self._create_authorization_url_parameters() +        auth = self._create_authorization_header() +        response = self.csrf_client.post('/oauth-with-scope/', params, HTTP_AUTHORIZATION=auth) +        self.assertEqual(response.status_code, 200) + +    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') +    @unittest.skipUnless(oauth, 'oauth2 not installed') +    def test_bad_consumer_key(self): +        """Ensure POSTing using HMAC_SHA1 signature method passes""" +        params = { +            'oauth_version': "1.0", +            'oauth_nonce': oauth.generate_nonce(), +            'oauth_timestamp': int(time.time()), +            'oauth_token': self.token.key, +            'oauth_consumer_key': 'badconsumerkey' +        } + +        req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params) + +        signature_method = oauth.SignatureMethod_HMAC_SHA1() +        req.sign_request(signature_method, self.consumer, self.token) +        auth = req.to_header()["Authorization"] + +        response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth) +        self.assertEqual(response.status_code, 401) + +    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') +    @unittest.skipUnless(oauth, 'oauth2 not installed') +    def test_bad_token_key(self): +        """Ensure POSTing using HMAC_SHA1 signature method passes""" +        params = { +            'oauth_version': "1.0", +            'oauth_nonce': oauth.generate_nonce(), +            'oauth_timestamp': int(time.time()), +            'oauth_token': 'badtokenkey', +            'oauth_consumer_key': self.consumer.key +        } + +        req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params) + +        signature_method = oauth.SignatureMethod_HMAC_SHA1() +        req.sign_request(signature_method, self.consumer, self.token) +        auth = req.to_header()["Authorization"] + +        response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth) +        self.assertEqual(response.status_code, 401) + + +class OAuth2Tests(TestCase): +    """OAuth 2.0 authentication""" +    urls = 'tests.test_authentication' + +    def setUp(self): +        self.csrf_client = APIClient(enforce_csrf_checks=True) +        self.username = 'john' +        self.email = 'lennon@thebeatles.com' +        self.password = 'password' +        self.user = User.objects.create_user(self.username, self.email, self.password) + +        self.CLIENT_ID = 'client_key' +        self.CLIENT_SECRET = 'client_secret' +        self.ACCESS_TOKEN = "access_token" +        self.REFRESH_TOKEN = "refresh_token" + +        self.oauth2_client = oauth2_provider.oauth2.models.Client.objects.create( +                client_id=self.CLIENT_ID, +                client_secret=self.CLIENT_SECRET, +                redirect_uri='', +                client_type=0, +                name='example', +                user=None, +            ) + +        self.access_token = oauth2_provider.oauth2.models.AccessToken.objects.create( +                token=self.ACCESS_TOKEN, +                client=self.oauth2_client, +                user=self.user, +            ) +        self.refresh_token = oauth2_provider.oauth2.models.RefreshToken.objects.create( +                user=self.user, +                access_token=self.access_token, +                client=self.oauth2_client +            ) + +    def _create_authorization_header(self, token=None): +        return "Bearer {0}".format(token or self.access_token.token) + +    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') +    def test_get_form_with_wrong_authorization_header_token_type_failing(self): +        """Ensure that a wrong token type lead to the correct HTTP error status code""" +        auth = "Wrong token-type-obsviously" +        response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) +        self.assertEqual(response.status_code, 401) +        response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) +        self.assertEqual(response.status_code, 401) + +    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') +    def test_get_form_with_wrong_authorization_header_token_format_failing(self): +        """Ensure that a wrong token format lead to the correct HTTP error status code""" +        auth = "Bearer wrong token format" +        response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) +        self.assertEqual(response.status_code, 401) +        response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) +        self.assertEqual(response.status_code, 401) + +    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') +    def test_get_form_with_wrong_authorization_header_token_failing(self): +        """Ensure that a wrong token lead to the correct HTTP error status code""" +        auth = "Bearer wrong-token" +        response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) +        self.assertEqual(response.status_code, 401) +        response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) +        self.assertEqual(response.status_code, 401) + +    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') +    def test_get_form_passing_auth(self): +        """Ensure GETing form over OAuth with correct client credentials succeed""" +        auth = self._create_authorization_header() +        response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) +        self.assertEqual(response.status_code, 200) + +    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') +    def test_post_form_passing_auth_url_transport(self): +        """Ensure GETing form over OAuth with correct client credentials in form data succeed""" +        response = self.csrf_client.post('/oauth2-test/', +                data={'access_token': self.access_token.token}) +        self.assertEqual(response.status_code, 200) + +    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') +    def test_get_form_passing_auth_url_transport(self): +        """Ensure GETing form over OAuth with correct client credentials in query succeed when DEBUG is True""" +        query = urlencode({'access_token': self.access_token.token}) +        response = self.csrf_client.get('/oauth2-test-debug/?%s' % query) +        self.assertEqual(response.status_code, 200) + +    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') +    def test_get_form_failing_auth_url_transport(self): +        """Ensure GETing form over OAuth with correct client credentials in query fails when DEBUG is False""" +        query = urlencode({'access_token': self.access_token.token}) +        response = self.csrf_client.get('/oauth2-test/?%s' % query) +        self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + +    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') +    def test_post_form_passing_auth(self): +        """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF""" +        auth = self._create_authorization_header() +        response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) +        self.assertEqual(response.status_code, 200) + +    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') +    def test_post_form_token_removed_failing_auth(self): +        """Ensure POSTing when there is no OAuth access token in db fails""" +        self.access_token.delete() +        auth = self._create_authorization_header() +        response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) +        self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + +    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') +    def test_post_form_with_refresh_token_failing_auth(self): +        """Ensure POSTing with refresh token instead of access token fails""" +        auth = self._create_authorization_header(token=self.refresh_token.token) +        response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) +        self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + +    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') +    def test_post_form_with_expired_access_token_failing_auth(self): +        """Ensure POSTing with expired access token fails with an 'Invalid token' error""" +        self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10)  # 10 seconds late +        self.access_token.save() +        auth = self._create_authorization_header() +        response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) +        self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) +        self.assertIn('Invalid token', response.content) + +    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') +    def test_post_form_with_invalid_scope_failing_auth(self): +        """Ensure POSTing with a readonly scope instead of a write scope fails""" +        read_only_access_token = self.access_token +        read_only_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['read'] +        read_only_access_token.save() +        auth = self._create_authorization_header(token=read_only_access_token.token) +        response = self.csrf_client.get('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) +        self.assertEqual(response.status_code, 200) +        response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) +        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + +    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') +    def test_post_form_with_valid_scope_passing_auth(self): +        """Ensure POSTing with a write scope succeed""" +        read_write_access_token = self.access_token +        read_write_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['write'] +        read_write_access_token.save() +        auth = self._create_authorization_header(token=read_write_access_token.token) +        response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) +        self.assertEqual(response.status_code, 200) + + +class FailingAuthAccessedInRenderer(TestCase): +    def setUp(self): +        class AuthAccessingRenderer(renderers.BaseRenderer): +            media_type = 'text/plain' +            format = 'txt' + +            def render(self, data, media_type=None, renderer_context=None): +                request = renderer_context['request'] +                if request.user.is_authenticated(): +                    return b'authenticated' +                return b'not authenticated' + +        class FailingAuth(BaseAuthentication): +            def authenticate(self, request): +                raise exceptions.AuthenticationFailed('authentication failed') + +        class ExampleView(APIView): +            authentication_classes = (FailingAuth,) +            renderer_classes = (AuthAccessingRenderer,) + +            def get(self, request): +                return Response({'foo': 'bar'}) + +        self.view = ExampleView.as_view() + +    def test_failing_auth_accessed_in_renderer(self): +        """ +        When authentication fails the renderer should still be able to access +        `request.user` without raising an exception. Particularly relevant +        to HTML responses that might reasonably access `request.user`. +        """ +        request = factory.get('/') +        response = self.view(request) +        content = response.render().content +        self.assertEqual(content, b'not authenticated') diff --git a/tests/test_breadcrumbs.py b/tests/test_breadcrumbs.py new file mode 100644 index 00000000..78edc603 --- /dev/null +++ b/tests/test_breadcrumbs.py @@ -0,0 +1,73 @@ +from __future__ import unicode_literals +from django.test import TestCase +from rest_framework.compat import patterns, url +from rest_framework.utils.breadcrumbs import get_breadcrumbs +from rest_framework.views import APIView + + +class Root(APIView): +    pass + + +class ResourceRoot(APIView): +    pass + + +class ResourceInstance(APIView): +    pass + + +class NestedResourceRoot(APIView): +    pass + + +class NestedResourceInstance(APIView): +    pass + +urlpatterns = patterns('', +    url(r'^$', Root.as_view()), +    url(r'^resource/$', ResourceRoot.as_view()), +    url(r'^resource/(?P<key>[0-9]+)$', ResourceInstance.as_view()), +    url(r'^resource/(?P<key>[0-9]+)/$', NestedResourceRoot.as_view()), +    url(r'^resource/(?P<key>[0-9]+)/(?P<other>[A-Za-z]+)$', NestedResourceInstance.as_view()), +) + + +class BreadcrumbTests(TestCase): +    """Tests the breadcrumb functionality used by the HTML renderer.""" + +    urls = 'tests.test_breadcrumbs' + +    def test_root_breadcrumbs(self): +        url = '/' +        self.assertEqual(get_breadcrumbs(url), [('Root', '/')]) + +    def test_resource_root_breadcrumbs(self): +        url = '/resource/' +        self.assertEqual(get_breadcrumbs(url), [('Root', '/'), +                                            ('Resource Root', '/resource/')]) + +    def test_resource_instance_breadcrumbs(self): +        url = '/resource/123' +        self.assertEqual(get_breadcrumbs(url), [('Root', '/'), +                                            ('Resource Root', '/resource/'), +                                            ('Resource Instance', '/resource/123')]) + +    def test_nested_resource_breadcrumbs(self): +        url = '/resource/123/' +        self.assertEqual(get_breadcrumbs(url), [('Root', '/'), +                                            ('Resource Root', '/resource/'), +                                            ('Resource Instance', '/resource/123'), +                                            ('Nested Resource Root', '/resource/123/')]) + +    def test_nested_resource_instance_breadcrumbs(self): +        url = '/resource/123/abc' +        self.assertEqual(get_breadcrumbs(url), [('Root', '/'), +                                            ('Resource Root', '/resource/'), +                                            ('Resource Instance', '/resource/123'), +                                            ('Nested Resource Root', '/resource/123/'), +                                            ('Nested Resource Instance', '/resource/123/abc')]) + +    def test_broken_url_breadcrumbs_handled_gracefully(self): +        url = '/foobar' +        self.assertEqual(get_breadcrumbs(url), [('Root', '/')]) diff --git a/tests/test_decorators.py b/tests/test_decorators.py new file mode 100644 index 00000000..195f0ba3 --- /dev/null +++ b/tests/test_decorators.py @@ -0,0 +1,157 @@ +from __future__ import unicode_literals +from django.test import TestCase +from rest_framework import status +from rest_framework.authentication import BasicAuthentication +from rest_framework.parsers import JSONParser +from rest_framework.permissions import IsAuthenticated +from rest_framework.response import Response +from rest_framework.renderers import JSONRenderer +from rest_framework.test import APIRequestFactory +from rest_framework.throttling import UserRateThrottle +from rest_framework.views import APIView +from rest_framework.decorators import ( +    api_view, +    renderer_classes, +    parser_classes, +    authentication_classes, +    throttle_classes, +    permission_classes, +) + + +class DecoratorTestCase(TestCase): + +    def setUp(self): +        self.factory = APIRequestFactory() + +    def _finalize_response(self, request, response, *args, **kwargs): +        response.request = request +        return APIView.finalize_response(self, request, response, *args, **kwargs) + +    def test_api_view_incorrect(self): +        """ +        If @api_view is not applied correct, we should raise an assertion. +        """ + +        @api_view +        def view(request): +            return Response() + +        request = self.factory.get('/') +        self.assertRaises(AssertionError, view, request) + +    def test_api_view_incorrect_arguments(self): +        """ +        If @api_view is missing arguments, we should raise an assertion. +        """ + +        with self.assertRaises(AssertionError): +            @api_view('GET') +            def view(request): +                return Response() + +    def test_calling_method(self): + +        @api_view(['GET']) +        def view(request): +            return Response({}) + +        request = self.factory.get('/') +        response = view(request) +        self.assertEqual(response.status_code, status.HTTP_200_OK) + +        request = self.factory.post('/') +        response = view(request) +        self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) + +    def test_calling_put_method(self): + +        @api_view(['GET', 'PUT']) +        def view(request): +            return Response({}) + +        request = self.factory.put('/') +        response = view(request) +        self.assertEqual(response.status_code, status.HTTP_200_OK) + +        request = self.factory.post('/') +        response = view(request) +        self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) + +    def test_calling_patch_method(self): + +        @api_view(['GET', 'PATCH']) +        def view(request): +            return Response({}) + +        request = self.factory.patch('/') +        response = view(request) +        self.assertEqual(response.status_code, status.HTTP_200_OK) + +        request = self.factory.post('/') +        response = view(request) +        self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) + +    def test_renderer_classes(self): + +        @api_view(['GET']) +        @renderer_classes([JSONRenderer]) +        def view(request): +            return Response({}) + +        request = self.factory.get('/') +        response = view(request) +        self.assertTrue(isinstance(response.accepted_renderer, JSONRenderer)) + +    def test_parser_classes(self): + +        @api_view(['GET']) +        @parser_classes([JSONParser]) +        def view(request): +            self.assertEqual(len(request.parsers), 1) +            self.assertTrue(isinstance(request.parsers[0], +                                       JSONParser)) +            return Response({}) + +        request = self.factory.get('/') +        view(request) + +    def test_authentication_classes(self): + +        @api_view(['GET']) +        @authentication_classes([BasicAuthentication]) +        def view(request): +            self.assertEqual(len(request.authenticators), 1) +            self.assertTrue(isinstance(request.authenticators[0], +                                       BasicAuthentication)) +            return Response({}) + +        request = self.factory.get('/') +        view(request) + +    def test_permission_classes(self): + +        @api_view(['GET']) +        @permission_classes([IsAuthenticated]) +        def view(request): +            return Response({}) + +        request = self.factory.get('/') +        response = view(request) +        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + +    def test_throttle_classes(self): +        class OncePerDayUserThrottle(UserRateThrottle): +            rate = '1/day' + +        @api_view(['GET']) +        @throttle_classes([OncePerDayUserThrottle]) +        def view(request): +            return Response({}) + +        request = self.factory.get('/') +        response = view(request) +        self.assertEqual(response.status_code, status.HTTP_200_OK) + +        response = view(request) +        self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) diff --git a/tests/test_description.py b/tests/test_description.py new file mode 100644 index 00000000..1e481f06 --- /dev/null +++ b/tests/test_description.py @@ -0,0 +1,108 @@ +# -- coding: utf-8 -- + +from __future__ import unicode_literals +from django.test import TestCase +from rest_framework.compat import apply_markdown, smart_text +from rest_framework.views import APIView +from .description import ViewWithNonASCIICharactersInDocstring +from .description import UTF8_TEST_DOCSTRING + +# We check that docstrings get nicely un-indented. +DESCRIPTION = """an example docstring +==================== + +* list +* list + +another header +-------------- + +    code block + +indented + +# hash style header #""" + +# If markdown is installed we also test it's working +# (and that our wrapped forces '=' to h2 and '-' to h3) + +# We support markdown < 2.1 and markdown >= 2.1 +MARKED_DOWN_lt_21 = """<h2>an example docstring</h2> +<ul> +<li>list</li> +<li>list</li> +</ul> +<h3>another header</h3> +<pre><code>code block +</code></pre> +<p>indented</p> +<h2 id="hash_style_header">hash style header</h2>""" + +MARKED_DOWN_gte_21 = """<h2 id="an-example-docstring">an example docstring</h2> +<ul> +<li>list</li> +<li>list</li> +</ul> +<h3 id="another-header">another header</h3> +<pre><code>code block +</code></pre> +<p>indented</p> +<h2 id="hash-style-header">hash style header</h2>""" + + +class TestViewNamesAndDescriptions(TestCase): +    def test_view_name_uses_class_name(self): +        """ +        Ensure view names are based on the class name. +        """ +        class MockView(APIView): +            pass +        self.assertEqual(MockView().get_view_name(), 'Mock') + +    def test_view_description_uses_docstring(self): +        """Ensure view descriptions are based on the docstring.""" +        class MockView(APIView): +            """an example docstring +            ==================== + +            * list +            * list + +            another header +            -------------- + +                code block + +            indented + +            # hash style header #""" + +        self.assertEqual(MockView().get_view_description(), DESCRIPTION) + +    def test_view_description_supports_unicode(self): +        """ +        Unicode in docstrings should be respected. +        """ + +        self.assertEqual( +            ViewWithNonASCIICharactersInDocstring().get_view_description(), +            smart_text(UTF8_TEST_DOCSTRING) +        ) + +    def test_view_description_can_be_empty(self): +        """ +        Ensure that if a view has no docstring, +        then it's description is the empty string. +        """ +        class MockView(APIView): +            pass +        self.assertEqual(MockView().get_view_description(), '') + +    def test_markdown(self): +        """ +        Ensure markdown to HTML works as expected. +        """ +        if apply_markdown: +            gte_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_gte_21 +            lt_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_lt_21 +            self.assertTrue(gte_21_match or lt_21_match) diff --git a/tests/test_fields.py b/tests/test_fields.py new file mode 100644 index 00000000..e65a2fb3 --- /dev/null +++ b/tests/test_fields.py @@ -0,0 +1,984 @@ +""" +General serializer field tests. +""" +from __future__ import unicode_literals + +import datetime +from decimal import Decimal +from uuid import uuid4 +from django.core import validators +from django.db import models +from django.test import TestCase +from django.utils.datastructures import SortedDict +from rest_framework import serializers +from tests.models import RESTFrameworkModel + + +class TimestampedModel(models.Model): +    added = models.DateTimeField(auto_now_add=True) +    updated = models.DateTimeField(auto_now=True) + + +class CharPrimaryKeyModel(models.Model): +    id = models.CharField(max_length=20, primary_key=True) + + +class TimestampedModelSerializer(serializers.ModelSerializer): +    class Meta: +        model = TimestampedModel + + +class CharPrimaryKeyModelSerializer(serializers.ModelSerializer): +    class Meta: +        model = CharPrimaryKeyModel + + +class TimeFieldModel(models.Model): +    clock = models.TimeField() + + +class TimeFieldModelSerializer(serializers.ModelSerializer): +    class Meta: +        model = TimeFieldModel + + +SAMPLE_CHOICES = [ +    ('red', 'Red'), +    ('green', 'Green'), +    ('blue', 'Blue'), +] + + +class ChoiceFieldModel(models.Model): +    choice = models.CharField(choices=SAMPLE_CHOICES, blank=True, max_length=255) + + +class ChoiceFieldModelSerializer(serializers.ModelSerializer): +    class Meta: +        model = ChoiceFieldModel + + +class ChoiceFieldModelWithNull(models.Model): +    choice = models.CharField(choices=SAMPLE_CHOICES, blank=True, null=True, max_length=255) + + +class ChoiceFieldModelWithNullSerializer(serializers.ModelSerializer): +    class Meta: +        model = ChoiceFieldModelWithNull + + +class BasicFieldTests(TestCase): +    def test_auto_now_fields_read_only(self): +        """ +        auto_now and auto_now_add fields should be read_only by default. +        """ +        serializer = TimestampedModelSerializer() +        self.assertEqual(serializer.fields['added'].read_only, True) + +    def test_auto_pk_fields_read_only(self): +        """ +        AutoField fields should be read_only by default. +        """ +        serializer = TimestampedModelSerializer() +        self.assertEqual(serializer.fields['id'].read_only, True) + +    def test_non_auto_pk_fields_not_read_only(self): +        """ +        PK fields other than AutoField fields should not be read_only by default. +        """ +        serializer = CharPrimaryKeyModelSerializer() +        self.assertEqual(serializer.fields['id'].read_only, False) + +    def test_dict_field_ordering(self): +        """ +        Field should preserve dictionary ordering, if it exists. +        See: https://github.com/tomchristie/django-rest-framework/issues/832 +        """ +        ret = SortedDict() +        ret['c'] = 1 +        ret['b'] = 1 +        ret['a'] = 1 +        ret['z'] = 1 +        field = serializers.Field() +        keys = list(field.to_native(ret).keys()) +        self.assertEqual(keys, ['c', 'b', 'a', 'z']) + + +class DateFieldTest(TestCase): +    """ +    Tests for the DateFieldTest from_native() and to_native() behavior +    """ + +    def test_from_native_string(self): +        """ +        Make sure from_native() accepts default iso input formats. +        """ +        f = serializers.DateField() +        result_1 = f.from_native('1984-07-31') + +        self.assertEqual(datetime.date(1984, 7, 31), result_1) + +    def test_from_native_datetime_date(self): +        """ +        Make sure from_native() accepts a datetime.date instance. +        """ +        f = serializers.DateField() +        result_1 = f.from_native(datetime.date(1984, 7, 31)) + +        self.assertEqual(result_1, datetime.date(1984, 7, 31)) + +    def test_from_native_custom_format(self): +        """ +        Make sure from_native() accepts custom input formats. +        """ +        f = serializers.DateField(input_formats=['%Y -- %d']) +        result = f.from_native('1984 -- 31') + +        self.assertEqual(datetime.date(1984, 1, 31), result) + +    def test_from_native_invalid_default_on_custom_format(self): +        """ +        Make sure from_native() don't accept default formats if custom format is preset +        """ +        f = serializers.DateField(input_formats=['%Y -- %d']) + +        try: +            f.from_native('1984-07-31') +        except validators.ValidationError as e: +            self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY -- DD"]) +        else: +            self.fail("ValidationError was not properly raised") + +    def test_from_native_empty(self): +        """ +        Make sure from_native() returns None on empty param. +        """ +        f = serializers.DateField() +        result = f.from_native('') + +        self.assertEqual(result, None) + +    def test_from_native_none(self): +        """ +        Make sure from_native() returns None on None param. +        """ +        f = serializers.DateField() +        result = f.from_native(None) + +        self.assertEqual(result, None) + +    def test_from_native_invalid_date(self): +        """ +        Make sure from_native() raises a ValidationError on passing an invalid date. +        """ +        f = serializers.DateField() + +        try: +            f.from_native('1984-13-31') +        except validators.ValidationError as e: +            self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]"]) +        else: +            self.fail("ValidationError was not properly raised") + +    def test_from_native_invalid_format(self): +        """ +        Make sure from_native() raises a ValidationError on passing an invalid format. +        """ +        f = serializers.DateField() + +        try: +            f.from_native('1984 -- 31') +        except validators.ValidationError as e: +            self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]"]) +        else: +            self.fail("ValidationError was not properly raised") + +    def test_to_native(self): +        """ +        Make sure to_native() returns datetime as default. +        """ +        f = serializers.DateField() + +        result_1 = f.to_native(datetime.date(1984, 7, 31)) + +        self.assertEqual(datetime.date(1984, 7, 31), result_1) + +    def test_to_native_iso(self): +        """ +        Make sure to_native() with 'iso-8601' returns iso formated date. +        """ +        f = serializers.DateField(format='iso-8601') + +        result_1 = f.to_native(datetime.date(1984, 7, 31)) + +        self.assertEqual('1984-07-31', result_1) + +    def test_to_native_custom_format(self): +        """ +        Make sure to_native() returns correct custom format. +        """ +        f = serializers.DateField(format="%Y - %m.%d") + +        result_1 = f.to_native(datetime.date(1984, 7, 31)) + +        self.assertEqual('1984 - 07.31', result_1) + +    def test_to_native_none(self): +        """ +        Make sure from_native() returns None on None param. +        """ +        f = serializers.DateField(required=False) +        self.assertEqual(None, f.to_native(None)) + + +class DateTimeFieldTest(TestCase): +    """ +    Tests for the DateTimeField from_native() and to_native() behavior +    """ + +    def test_from_native_string(self): +        """ +        Make sure from_native() accepts default iso input formats. +        """ +        f = serializers.DateTimeField() +        result_1 = f.from_native('1984-07-31 04:31') +        result_2 = f.from_native('1984-07-31 04:31:59') +        result_3 = f.from_native('1984-07-31 04:31:59.000200') + +        self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_1) +        self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_2) +        self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_3) + +    def test_from_native_datetime_datetime(self): +        """ +        Make sure from_native() accepts a datetime.datetime instance. +        """ +        f = serializers.DateTimeField() +        result_1 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31)) +        result_2 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31, 59)) +        result_3 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) + +        self.assertEqual(result_1, datetime.datetime(1984, 7, 31, 4, 31)) +        self.assertEqual(result_2, datetime.datetime(1984, 7, 31, 4, 31, 59)) +        self.assertEqual(result_3, datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) + +    def test_from_native_custom_format(self): +        """ +        Make sure from_native() accepts custom input formats. +        """ +        f = serializers.DateTimeField(input_formats=['%Y -- %H:%M']) +        result = f.from_native('1984 -- 04:59') + +        self.assertEqual(datetime.datetime(1984, 1, 1, 4, 59), result) + +    def test_from_native_invalid_default_on_custom_format(self): +        """ +        Make sure from_native() don't accept default formats if custom format is preset +        """ +        f = serializers.DateTimeField(input_formats=['%Y -- %H:%M']) + +        try: +            f.from_native('1984-07-31 04:31:59') +        except validators.ValidationError as e: +            self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: YYYY -- hh:mm"]) +        else: +            self.fail("ValidationError was not properly raised") + +    def test_from_native_empty(self): +        """ +        Make sure from_native() returns None on empty param. +        """ +        f = serializers.DateTimeField() +        result = f.from_native('') + +        self.assertEqual(result, None) + +    def test_from_native_none(self): +        """ +        Make sure from_native() returns None on None param. +        """ +        f = serializers.DateTimeField() +        result = f.from_native(None) + +        self.assertEqual(result, None) + +    def test_from_native_invalid_datetime(self): +        """ +        Make sure from_native() raises a ValidationError on passing an invalid datetime. +        """ +        f = serializers.DateTimeField() + +        try: +            f.from_native('04:61:59') +        except validators.ValidationError as e: +            self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: " +                                          "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]"]) +        else: +            self.fail("ValidationError was not properly raised") + +    def test_from_native_invalid_format(self): +        """ +        Make sure from_native() raises a ValidationError on passing an invalid format. +        """ +        f = serializers.DateTimeField() + +        try: +            f.from_native('04 -- 31') +        except validators.ValidationError as e: +            self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: " +                                          "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]"]) +        else: +            self.fail("ValidationError was not properly raised") + +    def test_to_native(self): +        """ +        Make sure to_native() returns isoformat as default. +        """ +        f = serializers.DateTimeField() + +        result_1 = f.to_native(datetime.datetime(1984, 7, 31)) +        result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31)) +        result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59)) +        result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) + +        self.assertEqual(datetime.datetime(1984, 7, 31), result_1) +        self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_2) +        self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_3) +        self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_4) + +    def test_to_native_iso(self): +        """ +        Make sure to_native() with format=iso-8601 returns iso formatted datetime. +        """ +        f = serializers.DateTimeField(format='iso-8601') + +        result_1 = f.to_native(datetime.datetime(1984, 7, 31)) +        result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31)) +        result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59)) +        result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) + +        self.assertEqual('1984-07-31T00:00:00', result_1) +        self.assertEqual('1984-07-31T04:31:00', result_2) +        self.assertEqual('1984-07-31T04:31:59', result_3) +        self.assertEqual('1984-07-31T04:31:59.000200', result_4) + +    def test_to_native_custom_format(self): +        """ +        Make sure to_native() returns correct custom format. +        """ +        f = serializers.DateTimeField(format="%Y - %H:%M") + +        result_1 = f.to_native(datetime.datetime(1984, 7, 31)) +        result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31)) +        result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59)) +        result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) + +        self.assertEqual('1984 - 00:00', result_1) +        self.assertEqual('1984 - 04:31', result_2) +        self.assertEqual('1984 - 04:31', result_3) +        self.assertEqual('1984 - 04:31', result_4) + +    def test_to_native_none(self): +        """ +        Make sure from_native() returns None on None param. +        """ +        f = serializers.DateTimeField(required=False) +        self.assertEqual(None, f.to_native(None)) + + +class TimeFieldTest(TestCase): +    """ +    Tests for the TimeField from_native() and to_native() behavior +    """ + +    def test_from_native_string(self): +        """ +        Make sure from_native() accepts default iso input formats. +        """ +        f = serializers.TimeField() +        result_1 = f.from_native('04:31') +        result_2 = f.from_native('04:31:59') +        result_3 = f.from_native('04:31:59.000200') + +        self.assertEqual(datetime.time(4, 31), result_1) +        self.assertEqual(datetime.time(4, 31, 59), result_2) +        self.assertEqual(datetime.time(4, 31, 59, 200), result_3) + +    def test_from_native_datetime_time(self): +        """ +        Make sure from_native() accepts a datetime.time instance. +        """ +        f = serializers.TimeField() +        result_1 = f.from_native(datetime.time(4, 31)) +        result_2 = f.from_native(datetime.time(4, 31, 59)) +        result_3 = f.from_native(datetime.time(4, 31, 59, 200)) + +        self.assertEqual(result_1, datetime.time(4, 31)) +        self.assertEqual(result_2, datetime.time(4, 31, 59)) +        self.assertEqual(result_3, datetime.time(4, 31, 59, 200)) + +    def test_from_native_custom_format(self): +        """ +        Make sure from_native() accepts custom input formats. +        """ +        f = serializers.TimeField(input_formats=['%H -- %M']) +        result = f.from_native('04 -- 31') + +        self.assertEqual(datetime.time(4, 31), result) + +    def test_from_native_invalid_default_on_custom_format(self): +        """ +        Make sure from_native() don't accept default formats if custom format is preset +        """ +        f = serializers.TimeField(input_formats=['%H -- %M']) + +        try: +            f.from_native('04:31:59') +        except validators.ValidationError as e: +            self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: hh -- mm"]) +        else: +            self.fail("ValidationError was not properly raised") + +    def test_from_native_empty(self): +        """ +        Make sure from_native() returns None on empty param. +        """ +        f = serializers.TimeField() +        result = f.from_native('') + +        self.assertEqual(result, None) + +    def test_from_native_none(self): +        """ +        Make sure from_native() returns None on None param. +        """ +        f = serializers.TimeField() +        result = f.from_native(None) + +        self.assertEqual(result, None) + +    def test_from_native_invalid_time(self): +        """ +        Make sure from_native() raises a ValidationError on passing an invalid time. +        """ +        f = serializers.TimeField() + +        try: +            f.from_native('04:61:59') +        except validators.ValidationError as e: +            self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: " +                                          "hh:mm[:ss[.uuuuuu]]"]) +        else: +            self.fail("ValidationError was not properly raised") + +    def test_from_native_invalid_format(self): +        """ +        Make sure from_native() raises a ValidationError on passing an invalid format. +        """ +        f = serializers.TimeField() + +        try: +            f.from_native('04 -- 31') +        except validators.ValidationError as e: +            self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: " +                                          "hh:mm[:ss[.uuuuuu]]"]) +        else: +            self.fail("ValidationError was not properly raised") + +    def test_to_native(self): +        """ +        Make sure to_native() returns time object as default. +        """ +        f = serializers.TimeField() +        result_1 = f.to_native(datetime.time(4, 31)) +        result_2 = f.to_native(datetime.time(4, 31, 59)) +        result_3 = f.to_native(datetime.time(4, 31, 59, 200)) + +        self.assertEqual(datetime.time(4, 31), result_1) +        self.assertEqual(datetime.time(4, 31, 59), result_2) +        self.assertEqual(datetime.time(4, 31, 59, 200), result_3) + +    def test_to_native_iso(self): +        """ +        Make sure to_native() with format='iso-8601' returns iso formatted time. +        """ +        f = serializers.TimeField(format='iso-8601') +        result_1 = f.to_native(datetime.time(4, 31)) +        result_2 = f.to_native(datetime.time(4, 31, 59)) +        result_3 = f.to_native(datetime.time(4, 31, 59, 200)) + +        self.assertEqual('04:31:00', result_1) +        self.assertEqual('04:31:59', result_2) +        self.assertEqual('04:31:59.000200', result_3) + +    def test_to_native_custom_format(self): +        """ +        Make sure to_native() returns correct custom format. +        """ +        f = serializers.TimeField(format="%H - %S [%f]") +        result_1 = f.to_native(datetime.time(4, 31)) +        result_2 = f.to_native(datetime.time(4, 31, 59)) +        result_3 = f.to_native(datetime.time(4, 31, 59, 200)) + +        self.assertEqual('04 - 00 [000000]', result_1) +        self.assertEqual('04 - 59 [000000]', result_2) +        self.assertEqual('04 - 59 [000200]', result_3) + + +class DecimalFieldTest(TestCase): +    """ +    Tests for the DecimalField from_native() and to_native() behavior +    """ + +    def test_from_native_string(self): +        """ +        Make sure from_native() accepts string values +        """ +        f = serializers.DecimalField() +        result_1 = f.from_native('9000') +        result_2 = f.from_native('1.00000001') + +        self.assertEqual(Decimal('9000'), result_1) +        self.assertEqual(Decimal('1.00000001'), result_2) + +    def test_from_native_invalid_string(self): +        """ +        Make sure from_native() raises ValidationError on passing invalid string +        """ +        f = serializers.DecimalField() + +        try: +            f.from_native('123.45.6') +        except validators.ValidationError as e: +            self.assertEqual(e.messages, ["Enter a number."]) +        else: +            self.fail("ValidationError was not properly raised") + +    def test_from_native_integer(self): +        """ +        Make sure from_native() accepts integer values +        """ +        f = serializers.DecimalField() +        result = f.from_native(9000) + +        self.assertEqual(Decimal('9000'), result) + +    def test_from_native_float(self): +        """ +        Make sure from_native() accepts float values +        """ +        f = serializers.DecimalField() +        result = f.from_native(1.00000001) + +        self.assertEqual(Decimal('1.00000001'), result) + +    def test_from_native_empty(self): +        """ +        Make sure from_native() returns None on empty param. +        """ +        f = serializers.DecimalField() +        result = f.from_native('') + +        self.assertEqual(result, None) + +    def test_from_native_none(self): +        """ +        Make sure from_native() returns None on None param. +        """ +        f = serializers.DecimalField() +        result = f.from_native(None) + +        self.assertEqual(result, None) + +    def test_to_native(self): +        """ +        Make sure to_native() returns Decimal as string. +        """ +        f = serializers.DecimalField() + +        result_1 = f.to_native(Decimal('9000')) +        result_2 = f.to_native(Decimal('1.00000001')) + +        self.assertEqual(Decimal('9000'), result_1) +        self.assertEqual(Decimal('1.00000001'), result_2) + +    def test_to_native_none(self): +        """ +        Make sure from_native() returns None on None param. +        """ +        f = serializers.DecimalField(required=False) +        self.assertEqual(None, f.to_native(None)) + +    def test_valid_serialization(self): +        """ +        Make sure the serializer works correctly +        """ +        class DecimalSerializer(serializers.Serializer): +            decimal_field = serializers.DecimalField(max_value=9010, +                                                     min_value=9000, +                                                     max_digits=6, +                                                     decimal_places=2) + +        self.assertTrue(DecimalSerializer(data={'decimal_field': '9001'}).is_valid()) +        self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.2'}).is_valid()) +        self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.23'}).is_valid()) + +        self.assertFalse(DecimalSerializer(data={'decimal_field': '8000'}).is_valid()) +        self.assertFalse(DecimalSerializer(data={'decimal_field': '9900'}).is_valid()) +        self.assertFalse(DecimalSerializer(data={'decimal_field': '9001.234'}).is_valid()) + +    def test_raise_max_value(self): +        """ +        Make sure max_value violations raises ValidationError +        """ +        class DecimalSerializer(serializers.Serializer): +            decimal_field = serializers.DecimalField(max_value=100) + +        s = DecimalSerializer(data={'decimal_field': '123'}) + +        self.assertFalse(s.is_valid()) +        self.assertEqual(s.errors,  {'decimal_field': ['Ensure this value is less than or equal to 100.']}) + +    def test_raise_min_value(self): +        """ +        Make sure min_value violations raises ValidationError +        """ +        class DecimalSerializer(serializers.Serializer): +            decimal_field = serializers.DecimalField(min_value=100) + +        s = DecimalSerializer(data={'decimal_field': '99'}) + +        self.assertFalse(s.is_valid()) +        self.assertEqual(s.errors,  {'decimal_field': ['Ensure this value is greater than or equal to 100.']}) + +    def test_raise_max_digits(self): +        """ +        Make sure max_digits violations raises ValidationError +        """ +        class DecimalSerializer(serializers.Serializer): +            decimal_field = serializers.DecimalField(max_digits=5) + +        s = DecimalSerializer(data={'decimal_field': '123.456'}) + +        self.assertFalse(s.is_valid()) +        self.assertEqual(s.errors,  {'decimal_field': ['Ensure that there are no more than 5 digits in total.']}) + +    def test_raise_max_decimal_places(self): +        """ +        Make sure max_decimal_places violations raises ValidationError +        """ +        class DecimalSerializer(serializers.Serializer): +            decimal_field = serializers.DecimalField(decimal_places=3) + +        s = DecimalSerializer(data={'decimal_field': '123.4567'}) + +        self.assertFalse(s.is_valid()) +        self.assertEqual(s.errors,  {'decimal_field': ['Ensure that there are no more than 3 decimal places.']}) + +    def test_raise_max_whole_digits(self): +        """ +        Make sure max_whole_digits violations raises ValidationError +        """ +        class DecimalSerializer(serializers.Serializer): +            decimal_field = serializers.DecimalField(max_digits=4, decimal_places=3) + +        s = DecimalSerializer(data={'decimal_field': '12345.6'}) + +        self.assertFalse(s.is_valid()) +        self.assertEqual(s.errors,  {'decimal_field': ['Ensure that there are no more than 4 digits in total.']}) + + +class ChoiceFieldTests(TestCase): +    """ +    Tests for the ChoiceField options generator +    """ +    def test_choices_required(self): +        """ +        Make sure proper choices are rendered if field is required +        """ +        f = serializers.ChoiceField(required=True, choices=SAMPLE_CHOICES) +        self.assertEqual(f.choices, SAMPLE_CHOICES) + +    def test_choices_not_required(self): +        """ +        Make sure proper choices (plus blank) are rendered if the field isn't required +        """ +        f = serializers.ChoiceField(required=False, choices=SAMPLE_CHOICES) +        self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + SAMPLE_CHOICES) + +    def test_invalid_choice_model(self): +        s = ChoiceFieldModelSerializer(data={'choice': 'wrong_value'}) +        self.assertFalse(s.is_valid()) +        self.assertEqual(s.errors,  {'choice': ['Select a valid choice. wrong_value is not one of the available choices.']}) +        self.assertEqual(s.data['choice'], '') + +    def test_empty_choice_model(self): +        """ +        Test that the 'empty' value is correctly passed and used depending on +        the 'null' property on the model field. +        """ +        s = ChoiceFieldModelSerializer(data={'choice': ''}) +        self.assertTrue(s.is_valid()) +        self.assertEqual(s.data['choice'], '') + +        s = ChoiceFieldModelWithNullSerializer(data={'choice': ''}) +        self.assertTrue(s.is_valid()) +        self.assertEqual(s.data['choice'], None) + +    def test_from_native_empty(self): +        """ +        Make sure from_native() returns an empty string on empty param by default. +        """ +        f = serializers.ChoiceField(choices=SAMPLE_CHOICES) +        self.assertEqual(f.from_native(''), '') +        self.assertEqual(f.from_native(None), '') + +    def test_from_native_empty_override(self): +        """ +        Make sure you can override from_native() behavior regarding empty values. +        """ +        f = serializers.ChoiceField(choices=SAMPLE_CHOICES, empty=None) +        self.assertEqual(f.from_native(''), None) +        self.assertEqual(f.from_native(None), None) + +    def test_metadata_choices(self): +        """ +        Make sure proper choices are included in the field's metadata. +        """ +        choices = [{'value': v, 'display_name': n} for v, n in SAMPLE_CHOICES] +        f = serializers.ChoiceField(choices=SAMPLE_CHOICES) +        self.assertEqual(f.metadata()['choices'], choices) + +    def test_metadata_choices_not_required(self): +        """ +        Make sure proper choices are included in the field's metadata. +        """ +        choices = [{'value': v, 'display_name': n} +                   for v, n in models.fields.BLANK_CHOICE_DASH + SAMPLE_CHOICES] +        f = serializers.ChoiceField(required=False, choices=SAMPLE_CHOICES) +        self.assertEqual(f.metadata()['choices'], choices) + + +class EmailFieldTests(TestCase): +    """ +    Tests for EmailField attribute values +    """ + +    class EmailFieldModel(RESTFrameworkModel): +        email_field = models.EmailField(blank=True) + +    class EmailFieldWithGivenMaxLengthModel(RESTFrameworkModel): +        email_field = models.EmailField(max_length=150, blank=True) + +    def test_default_model_value(self): +        class EmailFieldSerializer(serializers.ModelSerializer): +            class Meta: +                model = self.EmailFieldModel + +        serializer = EmailFieldSerializer(data={}) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 75) + +    def test_given_model_value(self): +        class EmailFieldSerializer(serializers.ModelSerializer): +            class Meta: +                model = self.EmailFieldWithGivenMaxLengthModel + +        serializer = EmailFieldSerializer(data={}) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 150) + +    def test_given_serializer_value(self): +        class EmailFieldSerializer(serializers.ModelSerializer): +            email_field = serializers.EmailField(source='email_field', max_length=20, required=False) + +            class Meta: +                model = self.EmailFieldModel + +        serializer = EmailFieldSerializer(data={}) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 20) + + +class SlugFieldTests(TestCase): +    """ +    Tests for SlugField attribute values +    """ + +    class SlugFieldModel(RESTFrameworkModel): +        slug_field = models.SlugField(blank=True) + +    class SlugFieldWithGivenMaxLengthModel(RESTFrameworkModel): +        slug_field = models.SlugField(max_length=84, blank=True) + +    def test_default_model_value(self): +        class SlugFieldSerializer(serializers.ModelSerializer): +            class Meta: +                model = self.SlugFieldModel + +        serializer = SlugFieldSerializer(data={}) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(getattr(serializer.fields['slug_field'], 'max_length'), 50) + +    def test_given_model_value(self): +        class SlugFieldSerializer(serializers.ModelSerializer): +            class Meta: +                model = self.SlugFieldWithGivenMaxLengthModel + +        serializer = SlugFieldSerializer(data={}) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(getattr(serializer.fields['slug_field'], 'max_length'), 84) + +    def test_given_serializer_value(self): +        class SlugFieldSerializer(serializers.ModelSerializer): +            slug_field = serializers.SlugField(source='slug_field', +                                               max_length=20, required=False) + +            class Meta: +                model = self.SlugFieldModel + +        serializer = SlugFieldSerializer(data={}) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(getattr(serializer.fields['slug_field'], +                                 'max_length'), 20) + +    def test_invalid_slug(self): +        """ +        Make sure an invalid slug raises ValidationError +        """ +        class SlugFieldSerializer(serializers.ModelSerializer): +            slug_field = serializers.SlugField(source='slug_field', max_length=20, required=True) + +            class Meta: +                model = self.SlugFieldModel + +        s = SlugFieldSerializer(data={'slug_field': 'a b'}) + +        self.assertEqual(s.is_valid(), False) +        self.assertEqual(s.errors,  {'slug_field': ["Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens."]}) + + +class URLFieldTests(TestCase): +    """ +    Tests for URLField attribute values. + +    (Includes test for #1210, checking that validators can be overridden.) +    """ + +    class URLFieldModel(RESTFrameworkModel): +        url_field = models.URLField(blank=True) + +    class URLFieldWithGivenMaxLengthModel(RESTFrameworkModel): +        url_field = models.URLField(max_length=128, blank=True) + +    def test_default_model_value(self): +        class URLFieldSerializer(serializers.ModelSerializer): +            class Meta: +                model = self.URLFieldModel + +        serializer = URLFieldSerializer(data={}) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(getattr(serializer.fields['url_field'], +                                 'max_length'), 200) + +    def test_given_model_value(self): +        class URLFieldSerializer(serializers.ModelSerializer): +            class Meta: +                model = self.URLFieldWithGivenMaxLengthModel + +        serializer = URLFieldSerializer(data={}) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(getattr(serializer.fields['url_field'], +                                 'max_length'), 128) + +    def test_given_serializer_value(self): +        class URLFieldSerializer(serializers.ModelSerializer): +            url_field = serializers.URLField(source='url_field', +                                             max_length=20, required=False) + +            class Meta: +                model = self.URLFieldWithGivenMaxLengthModel + +        serializer = URLFieldSerializer(data={}) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(getattr(serializer.fields['url_field'], +                         'max_length'), 20) + +    def test_validators_can_be_overridden(self): +        url_field = serializers.URLField(validators=[]) +        validators = url_field.validators +        self.assertEqual([], validators, 'Passing `validators` kwarg should have overridden default validators') + + +class FieldMetadata(TestCase): +    def setUp(self): +        self.required_field = serializers.Field() +        self.required_field.label = uuid4().hex +        self.required_field.required = True + +        self.optional_field = serializers.Field() +        self.optional_field.label = uuid4().hex +        self.optional_field.required = False + +    def test_required(self): +        self.assertEqual(self.required_field.metadata()['required'], True) + +    def test_optional(self): +        self.assertEqual(self.optional_field.metadata()['required'], False) + +    def test_label(self): +        for field in (self.required_field, self.optional_field): +            self.assertEqual(field.metadata()['label'], field.label) + + +class FieldCallableDefault(TestCase): +    def setUp(self): +        self.simple_callable = lambda: 'foo bar' + +    def test_default_can_be_simple_callable(self): +        """ +        Ensure that the 'default' argument can also be a simple callable. +        """ +        field = serializers.WritableField(default=self.simple_callable) +        into = {} +        field.field_from_native({}, {}, 'field', into) +        self.assertEqual(into, {'field': 'foo bar'}) + + +class CustomIntegerField(TestCase): +    """ +        Test that custom fields apply min_value and max_value constraints +    """ +    def test_custom_fields_can_be_validated_for_value(self): + +        class MoneyField(models.PositiveIntegerField): +            pass + +        class EntryModel(models.Model): +            bank = MoneyField(validators=[validators.MaxValueValidator(100)]) + +        class EntrySerializer(serializers.ModelSerializer): +            class Meta: +                model = EntryModel + +        entry = EntryModel(bank=1) + +        serializer = EntrySerializer(entry, data={"bank": 11}) +        self.assertTrue(serializer.is_valid()) + +        serializer = EntrySerializer(entry, data={"bank": -1}) +        self.assertFalse(serializer.is_valid()) + +        serializer = EntrySerializer(entry, data={"bank": 101}) +        self.assertFalse(serializer.is_valid()) + + +class BooleanField(TestCase): +    """ +        Tests for BooleanField +    """ +    def test_boolean_required(self): +        class BooleanRequiredSerializer(serializers.Serializer): +            bool_field = serializers.BooleanField(required=True) + +        self.assertFalse(BooleanRequiredSerializer(data={}).is_valid()) diff --git a/tests/test_files.py b/tests/test_files.py new file mode 100644 index 00000000..78f4cf42 --- /dev/null +++ b/tests/test_files.py @@ -0,0 +1,95 @@ +from __future__ import unicode_literals +from django.test import TestCase +from rest_framework import serializers +from rest_framework.compat import BytesIO +from rest_framework.compat import six +import datetime + + +class UploadedFile(object): +    def __init__(self, file=None, created=None): +        self.file = file +        self.created = created or datetime.datetime.now() + + +class UploadedFileSerializer(serializers.Serializer): +    file = serializers.FileField(required=False) +    created = serializers.DateTimeField() + +    def restore_object(self, attrs, instance=None): +        if instance: +            instance.file = attrs['file'] +            instance.created = attrs['created'] +            return instance +        return UploadedFile(**attrs) + + +class FileSerializerTests(TestCase): +    def test_create(self): +        now = datetime.datetime.now() +        file = BytesIO(six.b('stuff')) +        file.name = 'stuff.txt' +        file.size = len(file.getvalue()) +        serializer = UploadedFileSerializer(data={'created': now}, files={'file': file}) +        uploaded_file = UploadedFile(file=file, created=now) +        self.assertTrue(serializer.is_valid()) +        self.assertEqual(serializer.object.created, uploaded_file.created) +        self.assertEqual(serializer.object.file, uploaded_file.file) +        self.assertFalse(serializer.object is uploaded_file) + +    def test_creation_failure(self): +        """ +        Passing files=None should result in an ValidationError + +        Regression test for: +        https://github.com/tomchristie/django-rest-framework/issues/542 +        """ +        now = datetime.datetime.now() + +        serializer = UploadedFileSerializer(data={'created': now}) +        self.assertTrue(serializer.is_valid()) +        self.assertEqual(serializer.object.created, now) +        self.assertIsNone(serializer.object.file) + +    def test_remove_with_empty_string(self): +        """ +        Passing empty string as data should cause file to be removed + +        Test for: +        https://github.com/tomchristie/django-rest-framework/issues/937 +        """ +        now = datetime.datetime.now() +        file = BytesIO(six.b('stuff')) +        file.name = 'stuff.txt' +        file.size = len(file.getvalue()) + +        uploaded_file = UploadedFile(file=file, created=now) + +        serializer = UploadedFileSerializer(instance=uploaded_file, data={'created': now, 'file': ''}) +        self.assertTrue(serializer.is_valid()) +        self.assertEqual(serializer.object.created, uploaded_file.created) +        self.assertIsNone(serializer.object.file) + +    def test_validation_error_with_non_file(self): +        """ +        Passing non-files should raise a validation error. +        """ +        now = datetime.datetime.now() +        errmsg = 'No file was submitted. Check the encoding type on the form.' + +        serializer = UploadedFileSerializer(data={'created': now, 'file': 'abc'}) +        self.assertFalse(serializer.is_valid()) +        self.assertEqual(serializer.errors, {'file': [errmsg]}) + +    def test_validation_with_no_data(self): +        """ +        Validation should still function when no data dictionary is provided. +        """ +        now = datetime.datetime.now() +        file = BytesIO(six.b('stuff')) +        file.name = 'stuff.txt' +        file.size = len(file.getvalue()) +        uploaded_file = UploadedFile(file=file, created=now) + +        serializer = UploadedFileSerializer(files={'file': file}) +        self.assertFalse(serializer.is_valid()) diff --git a/tests/test_filters.py b/tests/test_filters.py new file mode 100644 index 00000000..3c6e8857 --- /dev/null +++ b/tests/test_filters.py @@ -0,0 +1,660 @@ +from __future__ import unicode_literals +import datetime +from decimal import Decimal +from django.db import models +from django.core.urlresolvers import reverse +from django.test import TestCase +from django.utils import unittest +from rest_framework import generics, serializers, status, filters +from rest_framework.compat import django_filters, patterns, url +from rest_framework.settings import api_settings +from rest_framework.test import APIRequestFactory +from .models import FilterableItem, BasicModel +from .utils import temporary_setting + +factory = APIRequestFactory() + + +if django_filters: +    # Basic filter on a list view. +    class FilterFieldsRootView(generics.ListCreateAPIView): +        model = FilterableItem +        filter_fields = ['decimal', 'date'] +        filter_backends = (filters.DjangoFilterBackend,) + +    # These class are used to test a filter class. +    class SeveralFieldsFilter(django_filters.FilterSet): +        text = django_filters.CharFilter(lookup_type='icontains') +        decimal = django_filters.NumberFilter(lookup_type='lt') +        date = django_filters.DateFilter(lookup_type='gt') + +        class Meta: +            model = FilterableItem +            fields = ['text', 'decimal', 'date'] + +    class FilterClassRootView(generics.ListCreateAPIView): +        model = FilterableItem +        filter_class = SeveralFieldsFilter +        filter_backends = (filters.DjangoFilterBackend,) + +    # These classes are used to test a misconfigured filter class. +    class MisconfiguredFilter(django_filters.FilterSet): +        text = django_filters.CharFilter(lookup_type='icontains') + +        class Meta: +            model = BasicModel +            fields = ['text'] + +    class IncorrectlyConfiguredRootView(generics.ListCreateAPIView): +        model = FilterableItem +        filter_class = MisconfiguredFilter +        filter_backends = (filters.DjangoFilterBackend,) + +    class FilterClassDetailView(generics.RetrieveAPIView): +        model = FilterableItem +        filter_class = SeveralFieldsFilter +        filter_backends = (filters.DjangoFilterBackend,) + +    # Regression test for #814 +    class FilterableItemSerializer(serializers.ModelSerializer): +        class Meta: +            model = FilterableItem + +    class FilterFieldsQuerysetView(generics.ListCreateAPIView): +        queryset = FilterableItem.objects.all() +        serializer_class = FilterableItemSerializer +        filter_fields = ['decimal', 'date'] +        filter_backends = (filters.DjangoFilterBackend,) + +    class GetQuerysetView(generics.ListCreateAPIView): +        serializer_class = FilterableItemSerializer +        filter_class = SeveralFieldsFilter +        filter_backends = (filters.DjangoFilterBackend,) + +        def get_queryset(self): +            return FilterableItem.objects.all() + +    urlpatterns = patterns('', +        url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'), +        url(r'^$', FilterClassRootView.as_view(), name='root-view'), +        url(r'^get-queryset/$', GetQuerysetView.as_view(), +            name='get-queryset-view'), +    ) + + +class CommonFilteringTestCase(TestCase): +    def _serialize_object(self, obj): +        return {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} + +    def setUp(self): +        """ +        Create 10 FilterableItem instances. +        """ +        base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8)) +        for i in range(10): +            text = chr(i + ord(base_data[0])) * 3  # Produces string 'aaa', 'bbb', etc. +            decimal = base_data[1] + i +            date = base_data[2] - datetime.timedelta(days=i * 2) +            FilterableItem(text=text, decimal=decimal, date=date).save() + +        self.objects = FilterableItem.objects +        self.data = [ +            self._serialize_object(obj) +            for obj in self.objects.all() +        ] + + +class IntegrationTestFiltering(CommonFilteringTestCase): +    """ +    Integration tests for filtered list views. +    """ + +    @unittest.skipUnless(django_filters, 'django-filter not installed') +    def test_get_filtered_fields_root_view(self): +        """ +        GET requests to paginated ListCreateAPIView should return paginated results. +        """ +        view = FilterFieldsRootView.as_view() + +        # Basic test with no filter. +        request = factory.get('/') +        response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, self.data) + +        # Tests that the decimal filter works. +        search_decimal = Decimal('2.25') +        request = factory.get('/', {'decimal': '%s' % search_decimal}) +        response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        expected_data = [f for f in self.data if f['decimal'] == search_decimal] +        self.assertEqual(response.data, expected_data) + +        # Tests that the date filter works. +        search_date = datetime.date(2012, 9, 22) +        request = factory.get('/', {'date': '%s' % search_date})  # search_date str: '2012-09-22' +        response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        expected_data = [f for f in self.data if f['date'] == search_date] +        self.assertEqual(response.data, expected_data) + +    @unittest.skipUnless(django_filters, 'django-filter not installed') +    def test_filter_with_queryset(self): +        """ +        Regression test for #814. +        """ +        view = FilterFieldsQuerysetView.as_view() + +        # Tests that the decimal filter works. +        search_decimal = Decimal('2.25') +        request = factory.get('/', {'decimal': '%s' % search_decimal}) +        response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        expected_data = [f for f in self.data if f['decimal'] == search_decimal] +        self.assertEqual(response.data, expected_data) + +    @unittest.skipUnless(django_filters, 'django-filter not installed') +    def test_filter_with_get_queryset_only(self): +        """ +        Regression test for #834. +        """ +        view = GetQuerysetView.as_view() +        request = factory.get('/get-queryset/') +        view(request).render() +        # Used to raise "issubclass() arg 2 must be a class or tuple of classes" +        # here when neither `model' nor `queryset' was specified. + +    @unittest.skipUnless(django_filters, 'django-filter not installed') +    def test_get_filtered_class_root_view(self): +        """ +        GET requests to filtered ListCreateAPIView that have a filter_class set +        should return filtered results. +        """ +        view = FilterClassRootView.as_view() + +        # Basic test with no filter. +        request = factory.get('/') +        response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, self.data) + +        # Tests that the decimal filter set with 'lt' in the filter class works. +        search_decimal = Decimal('4.25') +        request = factory.get('/', {'decimal': '%s' % search_decimal}) +        response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        expected_data = [f for f in self.data if f['decimal'] < search_decimal] +        self.assertEqual(response.data, expected_data) + +        # Tests that the date filter set with 'gt' in the filter class works. +        search_date = datetime.date(2012, 10, 2) +        request = factory.get('/', {'date': '%s' % search_date})  # search_date str: '2012-10-02' +        response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        expected_data = [f for f in self.data if f['date'] > search_date] +        self.assertEqual(response.data, expected_data) + +        # Tests that the text filter set with 'icontains' in the filter class works. +        search_text = 'ff' +        request = factory.get('/', {'text': '%s' % search_text}) +        response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        expected_data = [f for f in self.data if search_text in f['text'].lower()] +        self.assertEqual(response.data, expected_data) + +        # Tests that multiple filters works. +        search_decimal = Decimal('5.25') +        search_date = datetime.date(2012, 10, 2) +        request = factory.get('/', { +            'decimal': '%s' % (search_decimal,), +            'date': '%s' % (search_date,) +        }) +        response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        expected_data = [f for f in self.data if f['date'] > search_date and +                         f['decimal'] < search_decimal] +        self.assertEqual(response.data, expected_data) + +    @unittest.skipUnless(django_filters, 'django-filter not installed') +    def test_incorrectly_configured_filter(self): +        """ +        An error should be displayed when the filter class is misconfigured. +        """ +        view = IncorrectlyConfiguredRootView.as_view() + +        request = factory.get('/') +        self.assertRaises(AssertionError, view, request) + +    @unittest.skipUnless(django_filters, 'django-filter not installed') +    def test_unknown_filter(self): +        """ +        GET requests with filters that aren't configured should return 200. +        """ +        view = FilterFieldsRootView.as_view() + +        search_integer = 10 +        request = factory.get('/', {'integer': '%s' % search_integer}) +        response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) + + +class IntegrationTestDetailFiltering(CommonFilteringTestCase): +    """ +    Integration tests for filtered detail views. +    """ +    urls = 'tests.test_filters' + +    def _get_url(self, item): +        return reverse('detail-view', kwargs=dict(pk=item.pk)) + +    @unittest.skipUnless(django_filters, 'django-filter not installed') +    def test_get_filtered_detail_view(self): +        """ +        GET requests to filtered RetrieveAPIView that have a filter_class set +        should return filtered results. +        """ +        item = self.objects.all()[0] +        data = self._serialize_object(item) + +        # Basic test with no filter. +        response = self.client.get(self._get_url(item)) +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, data) + +        # Tests that the decimal filter set that should fail. +        search_decimal = Decimal('4.25') +        high_item = self.objects.filter(decimal__gt=search_decimal)[0] +        response = self.client.get( +            '{url}'.format(url=self._get_url(high_item)), +            {'decimal': '{param}'.format(param=search_decimal)}) +        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + +        # Tests that the decimal filter set that should succeed. +        search_decimal = Decimal('4.25') +        low_item = self.objects.filter(decimal__lt=search_decimal)[0] +        low_item_data = self._serialize_object(low_item) +        response = self.client.get( +            '{url}'.format(url=self._get_url(low_item)), +            {'decimal': '{param}'.format(param=search_decimal)}) +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, low_item_data) + +        # Tests that multiple filters works. +        search_decimal = Decimal('5.25') +        search_date = datetime.date(2012, 10, 2) +        valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0] +        valid_item_data = self._serialize_object(valid_item) +        response = self.client.get( +            '{url}'.format(url=self._get_url(valid_item)), { +                'decimal': '{decimal}'.format(decimal=search_decimal), +                'date': '{date}'.format(date=search_date) +            }) +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, valid_item_data) + + +class SearchFilterModel(models.Model): +    title = models.CharField(max_length=20) +    text = models.CharField(max_length=100) + + +class SearchFilterTests(TestCase): +    def setUp(self): +        # Sequence of title/text is: +        # +        # z   abc +        # zz  bcd +        # zzz cde +        # ... +        for idx in range(10): +            title = 'z' * (idx + 1) +            text = ( +                chr(idx + ord('a')) + +                chr(idx + ord('b')) + +                chr(idx + ord('c')) +            ) +            SearchFilterModel(title=title, text=text).save() + +    def test_search(self): +        class SearchListView(generics.ListAPIView): +            model = SearchFilterModel +            filter_backends = (filters.SearchFilter,) +            search_fields = ('title', 'text') + +        view = SearchListView.as_view() +        request = factory.get('/', {'search': 'b'}) +        response = view(request) +        self.assertEqual( +            response.data, +            [ +                {'id': 1, 'title': 'z', 'text': 'abc'}, +                {'id': 2, 'title': 'zz', 'text': 'bcd'} +            ] +        ) + +    def test_exact_search(self): +        class SearchListView(generics.ListAPIView): +            model = SearchFilterModel +            filter_backends = (filters.SearchFilter,) +            search_fields = ('=title', 'text') + +        view = SearchListView.as_view() +        request = factory.get('/', {'search': 'zzz'}) +        response = view(request) +        self.assertEqual( +            response.data, +            [ +                {'id': 3, 'title': 'zzz', 'text': 'cde'} +            ] +        ) + +    def test_startswith_search(self): +        class SearchListView(generics.ListAPIView): +            model = SearchFilterModel +            filter_backends = (filters.SearchFilter,) +            search_fields = ('title', '^text') + +        view = SearchListView.as_view() +        request = factory.get('/', {'search': 'b'}) +        response = view(request) +        self.assertEqual( +            response.data, +            [ +                {'id': 2, 'title': 'zz', 'text': 'bcd'} +            ] +        ) + +    def test_search_with_nonstandard_search_param(self): +        with temporary_setting('SEARCH_PARAM', 'query', module=filters): +            class SearchListView(generics.ListAPIView): +                model = SearchFilterModel +                filter_backends = (filters.SearchFilter,) +                search_fields = ('title', 'text') + +            view = SearchListView.as_view() +            request = factory.get('/', {'query': 'b'}) +            response = view(request) +            self.assertEqual( +                response.data, +                [ +                    {'id': 1, 'title': 'z', 'text': 'abc'}, +                    {'id': 2, 'title': 'zz', 'text': 'bcd'} +                ] +            ) + + +class OrdringFilterModel(models.Model): +    title = models.CharField(max_length=20) +    text = models.CharField(max_length=100) + + +class OrderingFilterRelatedModel(models.Model): +    related_object = models.ForeignKey(OrdringFilterModel, +                                       related_name="relateds") + + +class OrderingFilterTests(TestCase): +    def setUp(self): +        # Sequence of title/text is: +        # +        # zyx abc +        # yxw bcd +        # xwv cde +        for idx in range(3): +            title = ( +                chr(ord('z') - idx) + +                chr(ord('y') - idx) + +                chr(ord('x') - idx) +            ) +            text = ( +                chr(idx + ord('a')) + +                chr(idx + ord('b')) + +                chr(idx + ord('c')) +            ) +            OrdringFilterModel(title=title, text=text).save() + +    def test_ordering(self): +        class OrderingListView(generics.ListAPIView): +            model = OrdringFilterModel +            filter_backends = (filters.OrderingFilter,) +            ordering = ('title',) +            ordering_fields = ('text',) + +        view = OrderingListView.as_view() +        request = factory.get('/', {'ordering': 'text'}) +        response = view(request) +        self.assertEqual( +            response.data, +            [ +                {'id': 1, 'title': 'zyx', 'text': 'abc'}, +                {'id': 2, 'title': 'yxw', 'text': 'bcd'}, +                {'id': 3, 'title': 'xwv', 'text': 'cde'}, +            ] +        ) + +    def test_reverse_ordering(self): +        class OrderingListView(generics.ListAPIView): +            model = OrdringFilterModel +            filter_backends = (filters.OrderingFilter,) +            ordering = ('title',) +            ordering_fields = ('text',) + +        view = OrderingListView.as_view() +        request = factory.get('/', {'ordering': '-text'}) +        response = view(request) +        self.assertEqual( +            response.data, +            [ +                {'id': 3, 'title': 'xwv', 'text': 'cde'}, +                {'id': 2, 'title': 'yxw', 'text': 'bcd'}, +                {'id': 1, 'title': 'zyx', 'text': 'abc'}, +            ] +        ) + +    def test_incorrectfield_ordering(self): +        class OrderingListView(generics.ListAPIView): +            model = OrdringFilterModel +            filter_backends = (filters.OrderingFilter,) +            ordering = ('title',) +            ordering_fields = ('text',) + +        view = OrderingListView.as_view() +        request = factory.get('/', {'ordering': 'foobar'}) +        response = view(request) +        self.assertEqual( +            response.data, +            [ +                {'id': 3, 'title': 'xwv', 'text': 'cde'}, +                {'id': 2, 'title': 'yxw', 'text': 'bcd'}, +                {'id': 1, 'title': 'zyx', 'text': 'abc'}, +            ] +        ) + +    def test_default_ordering(self): +        class OrderingListView(generics.ListAPIView): +            model = OrdringFilterModel +            filter_backends = (filters.OrderingFilter,) +            ordering = ('title',) +            oredering_fields = ('text',) + +        view = OrderingListView.as_view() +        request = factory.get('') +        response = view(request) +        self.assertEqual( +            response.data, +            [ +                {'id': 3, 'title': 'xwv', 'text': 'cde'}, +                {'id': 2, 'title': 'yxw', 'text': 'bcd'}, +                {'id': 1, 'title': 'zyx', 'text': 'abc'}, +            ] +        ) + +    def test_default_ordering_using_string(self): +        class OrderingListView(generics.ListAPIView): +            model = OrdringFilterModel +            filter_backends = (filters.OrderingFilter,) +            ordering = 'title' +            ordering_fields = ('text',) + +        view = OrderingListView.as_view() +        request = factory.get('') +        response = view(request) +        self.assertEqual( +            response.data, +            [ +                {'id': 3, 'title': 'xwv', 'text': 'cde'}, +                {'id': 2, 'title': 'yxw', 'text': 'bcd'}, +                {'id': 1, 'title': 'zyx', 'text': 'abc'}, +            ] +        ) + +    def test_ordering_by_aggregate_field(self): +        # create some related models to aggregate order by +        num_objs = [2, 5, 3] +        for obj, num_relateds in zip(OrdringFilterModel.objects.all(), +                                     num_objs): +            for _ in range(num_relateds): +                new_related = OrderingFilterRelatedModel( +                    related_object=obj +                ) +                new_related.save() + +        class OrderingListView(generics.ListAPIView): +            model = OrdringFilterModel +            filter_backends = (filters.OrderingFilter,) +            ordering = 'title' +            ordering_fields = '__all__' +            queryset = OrdringFilterModel.objects.all().annotate( +                models.Count("relateds")) + +        view = OrderingListView.as_view() +        request = factory.get('/', {'ordering': 'relateds__count'}) +        response = view(request) +        self.assertEqual( +            response.data, +            [ +                {'id': 1, 'title': 'zyx', 'text': 'abc'}, +                {'id': 3, 'title': 'xwv', 'text': 'cde'}, +                {'id': 2, 'title': 'yxw', 'text': 'bcd'}, +            ] +        ) + +    def test_ordering_with_nonstandard_ordering_param(self): +        with temporary_setting('ORDERING_PARAM', 'order', filters): +            class OrderingListView(generics.ListAPIView): +                model = OrdringFilterModel +                filter_backends = (filters.OrderingFilter,) +                ordering = ('title',) +                ordering_fields = ('text',) + +            view = OrderingListView.as_view() +            request = factory.get('/', {'order': 'text'}) +            response = view(request) +            self.assertEqual( +                response.data, +                [ +                    {'id': 1, 'title': 'zyx', 'text': 'abc'}, +                    {'id': 2, 'title': 'yxw', 'text': 'bcd'}, +                    {'id': 3, 'title': 'xwv', 'text': 'cde'}, +                ] +            ) + + +class SensitiveOrderingFilterModel(models.Model): +    username = models.CharField(max_length=20) +    password = models.CharField(max_length=100) + + +# Three different styles of serializer. +# All should allow ordering by username, but not by password. +class SensitiveDataSerializer1(serializers.ModelSerializer): +    username = serializers.CharField() + +    class Meta: +        model = SensitiveOrderingFilterModel +        fields = ('id', 'username') + + +class SensitiveDataSerializer2(serializers.ModelSerializer): +    username = serializers.CharField() +    password = serializers.CharField(write_only=True) + +    class Meta: +        model = SensitiveOrderingFilterModel +        fields = ('id', 'username', 'password') + + +class SensitiveDataSerializer3(serializers.ModelSerializer): +    user = serializers.CharField(source='username') + +    class Meta: +        model = SensitiveOrderingFilterModel +        fields = ('id', 'user') + + +class SensitiveOrderingFilterTests(TestCase): +    def setUp(self): +        for idx in range(3): +            username = {0: 'userA', 1: 'userB', 2: 'userC'}[idx] +            password = {0: 'passA', 1: 'passC', 2: 'passB'}[idx] +            SensitiveOrderingFilterModel(username=username, password=password).save() + +    def test_order_by_serializer_fields(self): +        for serializer_cls in [ +            SensitiveDataSerializer1, +            SensitiveDataSerializer2, +            SensitiveDataSerializer3 +        ]: +            class OrderingListView(generics.ListAPIView): +                queryset = SensitiveOrderingFilterModel.objects.all().order_by('username') +                filter_backends = (filters.OrderingFilter,) +                serializer_class = serializer_cls + +            view = OrderingListView.as_view() +            request = factory.get('/', {'ordering': '-username'}) +            response = view(request) + +            if serializer_cls == SensitiveDataSerializer3: +                username_field = 'user' +            else: +                username_field = 'username' + +            # Note: Inverse username ordering correctly applied. +            self.assertEqual( +                response.data, +                [ +                    {'id': 3, username_field: 'userC'}, +                    {'id': 2, username_field: 'userB'}, +                    {'id': 1, username_field: 'userA'}, +                ] +            ) + +    def test_cannot_order_by_non_serializer_fields(self): +        for serializer_cls in [ +            SensitiveDataSerializer1, +            SensitiveDataSerializer2, +            SensitiveDataSerializer3 +        ]: +            class OrderingListView(generics.ListAPIView): +                queryset = SensitiveOrderingFilterModel.objects.all().order_by('username') +                filter_backends = (filters.OrderingFilter,) +                serializer_class = serializer_cls + +            view = OrderingListView.as_view() +            request = factory.get('/', {'ordering': 'password'}) +            response = view(request) + +            if serializer_cls == SensitiveDataSerializer3: +                username_field = 'user' +            else: +                username_field = 'username' + +            # Note: The passwords are not in order.  Default ordering is used. +            self.assertEqual( +                response.data, +                [ +                    {'id': 1, username_field: 'userA'}, # PassB +                    {'id': 2, username_field: 'userB'}, # PassC +                    {'id': 3, username_field: 'userC'}, # PassA +                ] +            ) diff --git a/tests/test_genericrelations.py b/tests/test_genericrelations.py new file mode 100644 index 00000000..46a2d863 --- /dev/null +++ b/tests/test_genericrelations.py @@ -0,0 +1,151 @@ +from __future__ import unicode_literals +from django.contrib.contenttypes.models import ContentType +from django.contrib.contenttypes.generic import GenericRelation, GenericForeignKey +from django.db import models +from django.test import TestCase +from rest_framework import serializers +from rest_framework.compat import python_2_unicode_compatible + + +@python_2_unicode_compatible +class Tag(models.Model): +    """ +    Tags have a descriptive slug, and are attached to an arbitrary object. +    """ +    tag = models.SlugField() +    content_type = models.ForeignKey(ContentType) +    object_id = models.PositiveIntegerField() +    tagged_item = GenericForeignKey('content_type', 'object_id') + +    def __str__(self): +        return self.tag + + +@python_2_unicode_compatible +class Bookmark(models.Model): +    """ +    A URL bookmark that may have multiple tags attached. +    """ +    url = models.URLField() +    tags = GenericRelation(Tag) + +    def __str__(self): +        return 'Bookmark: %s' % self.url + + +@python_2_unicode_compatible +class Note(models.Model): +    """ +    A textual note that may have multiple tags attached. +    """ +    text = models.TextField() +    tags = GenericRelation(Tag) + +    def __str__(self): +        return 'Note: %s' % self.text + + +class TestGenericRelations(TestCase): +    def setUp(self): +        self.bookmark = Bookmark.objects.create(url='https://www.djangoproject.com/') +        Tag.objects.create(tagged_item=self.bookmark, tag='django') +        Tag.objects.create(tagged_item=self.bookmark, tag='python') +        self.note = Note.objects.create(text='Remember the milk') +        Tag.objects.create(tagged_item=self.note, tag='reminder') + +    def test_generic_relation(self): +        """ +        Test a relationship that spans a GenericRelation field. +        IE. A reverse generic relationship. +        """ + +        class BookmarkSerializer(serializers.ModelSerializer): +            tags = serializers.RelatedField(many=True) + +            class Meta: +                model = Bookmark +                exclude = ('id',) + +        serializer = BookmarkSerializer(self.bookmark) +        expected = { +            'tags': ['django', 'python'], +            'url': 'https://www.djangoproject.com/' +        } +        self.assertEqual(serializer.data, expected) + +    def test_generic_nested_relation(self): +        """ +        Test saving a GenericRelation field via a nested serializer. +        """ + +        class TagSerializer(serializers.ModelSerializer): +            class Meta: +                model = Tag +                exclude = ('content_type', 'object_id') + +        class BookmarkSerializer(serializers.ModelSerializer): +            tags = TagSerializer() + +            class Meta: +                model = Bookmark +                exclude = ('id',) + +        data = { +            'url': 'https://docs.djangoproject.com/', +            'tags': [ +                {'tag': 'contenttypes'}, +                {'tag': 'genericrelations'}, +            ] +        } +        serializer = BookmarkSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        serializer.save() +        self.assertEqual(serializer.object.tags.count(), 2) + +    def test_generic_fk(self): +        """ +        Test a relationship that spans a GenericForeignKey field. +        IE. A forward generic relationship. +        """ + +        class TagSerializer(serializers.ModelSerializer): +            tagged_item = serializers.RelatedField() + +            class Meta: +                model = Tag +                exclude = ('id', 'content_type', 'object_id') + +        serializer = TagSerializer(Tag.objects.all(), many=True) +        expected = [ +        { +            'tag': 'django', +            'tagged_item': 'Bookmark: https://www.djangoproject.com/' +        }, +        { +            'tag': 'python', +            'tagged_item': 'Bookmark: https://www.djangoproject.com/' +        }, +        { +            'tag': 'reminder', +            'tagged_item': 'Note: Remember the milk' +        } +        ] +        self.assertEqual(serializer.data, expected) + +    def test_restore_object_generic_fk(self): +        """ +        Ensure an object with a generic foreign key can be restored. +        """ + +        class TagSerializer(serializers.ModelSerializer): +            class Meta: +                model = Tag +                exclude = ('content_type', 'object_id') + +        serializer = TagSerializer() + +        bookmark = Bookmark(url='http://example.com') +        attrs = {'tagged_item': bookmark, 'tag': 'example'} + +        tag = serializer.restore_object(attrs) +        self.assertEqual(tag.tagged_item, bookmark) diff --git a/tests/test_generics.py b/tests/test_generics.py new file mode 100644 index 00000000..4389994a --- /dev/null +++ b/tests/test_generics.py @@ -0,0 +1,609 @@ +from __future__ import unicode_literals +from django.db import models +from django.shortcuts import get_object_or_404 +from django.test import TestCase +from rest_framework import generics, renderers, serializers, status +from rest_framework.test import APIRequestFactory +from tests.models import BasicModel, Comment, SlugBasedModel +from rest_framework.compat import six + +factory = APIRequestFactory() + + +class RootView(generics.ListCreateAPIView): +    """ +    Example description for OPTIONS. +    """ +    model = BasicModel + + +class InstanceView(generics.RetrieveUpdateDestroyAPIView): +    """ +    Example description for OPTIONS. +    """ +    model = BasicModel + +    def get_queryset(self): +        queryset = super(InstanceView, self).get_queryset() +        return queryset.exclude(text='filtered out') + + +class SlugSerializer(serializers.ModelSerializer): +    slug = serializers.Field()  # read only + +    class Meta: +        model = SlugBasedModel +        exclude = ('id',) + + +class SlugBasedInstanceView(InstanceView): +    """ +    A model with a slug-field. +    """ +    model = SlugBasedModel +    serializer_class = SlugSerializer +    lookup_field = 'slug' + + +class TestRootView(TestCase): +    def setUp(self): +        """ +        Create 3 BasicModel instances. +        """ +        items = ['foo', 'bar', 'baz'] +        for item in items: +            BasicModel(text=item).save() +        self.objects = BasicModel.objects +        self.data = [ +            {'id': obj.id, 'text': obj.text} +            for obj in self.objects.all() +        ] +        self.view = RootView.as_view() + +    def test_get_root_view(self): +        """ +        GET requests to ListCreateAPIView should return list of objects. +        """ +        request = factory.get('/') +        with self.assertNumQueries(1): +            response = self.view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, self.data) + +    def test_post_root_view(self): +        """ +        POST requests to ListCreateAPIView should create a new object. +        """ +        data = {'text': 'foobar'} +        request = factory.post('/', data, format='json') +        with self.assertNumQueries(1): +            response = self.view(request).render() +        self.assertEqual(response.status_code, status.HTTP_201_CREATED) +        self.assertEqual(response.data, {'id': 4, 'text': 'foobar'}) +        created = self.objects.get(id=4) +        self.assertEqual(created.text, 'foobar') + +    def test_put_root_view(self): +        """ +        PUT requests to ListCreateAPIView should not be allowed +        """ +        data = {'text': 'foobar'} +        request = factory.put('/', data, format='json') +        with self.assertNumQueries(0): +            response = self.view(request).render() +        self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) +        self.assertEqual(response.data, {"detail": "Method 'PUT' not allowed."}) + +    def test_delete_root_view(self): +        """ +        DELETE requests to ListCreateAPIView should not be allowed +        """ +        request = factory.delete('/') +        with self.assertNumQueries(0): +            response = self.view(request).render() +        self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) +        self.assertEqual(response.data, {"detail": "Method 'DELETE' not allowed."}) + +    def test_options_root_view(self): +        """ +        OPTIONS requests to ListCreateAPIView should return metadata +        """ +        request = factory.options('/') +        with self.assertNumQueries(0): +            response = self.view(request).render() +        expected = { +            'parses': [ +                'application/json', +                'application/x-www-form-urlencoded', +                'multipart/form-data' +            ], +            'renders': [ +                'application/json', +                'text/html' +            ], +            'name': 'Root', +            'description': 'Example description for OPTIONS.', +            'actions': { +                'POST': { +                    'text': { +                        'max_length': 100, +                        'read_only': False, +                        'required': True, +                        'type': 'string', +                        "label": "Text comes here", +                        "help_text": "Text description." +                    }, +                    'id': { +                        'read_only': True, +                        'required': False, +                        'type': 'integer', +                        'label': 'ID', +                    }, +                } +            } +        } +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, expected) + +    def test_post_cannot_set_id(self): +        """ +        POST requests to create a new object should not be able to set the id. +        """ +        data = {'id': 999, 'text': 'foobar'} +        request = factory.post('/', data, format='json') +        with self.assertNumQueries(1): +            response = self.view(request).render() +        self.assertEqual(response.status_code, status.HTTP_201_CREATED) +        self.assertEqual(response.data, {'id': 4, 'text': 'foobar'}) +        created = self.objects.get(id=4) +        self.assertEqual(created.text, 'foobar') + + +class TestInstanceView(TestCase): +    def setUp(self): +        """ +        Create 3 BasicModel intances. +        """ +        items = ['foo', 'bar', 'baz', 'filtered out'] +        for item in items: +            BasicModel(text=item).save() +        self.objects = BasicModel.objects.exclude(text='filtered out') +        self.data = [ +            {'id': obj.id, 'text': obj.text} +            for obj in self.objects.all() +        ] +        self.view = InstanceView.as_view() +        self.slug_based_view = SlugBasedInstanceView.as_view() + +    def test_get_instance_view(self): +        """ +        GET requests to RetrieveUpdateDestroyAPIView should return a single object. +        """ +        request = factory.get('/1') +        with self.assertNumQueries(1): +            response = self.view(request, pk=1).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, self.data[0]) + +    def test_post_instance_view(self): +        """ +        POST requests to RetrieveUpdateDestroyAPIView should not be allowed +        """ +        data = {'text': 'foobar'} +        request = factory.post('/', data, format='json') +        with self.assertNumQueries(0): +            response = self.view(request).render() +        self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) +        self.assertEqual(response.data, {"detail": "Method 'POST' not allowed."}) + +    def test_put_instance_view(self): +        """ +        PUT requests to RetrieveUpdateDestroyAPIView should update an object. +        """ +        data = {'text': 'foobar'} +        request = factory.put('/1', data, format='json') +        with self.assertNumQueries(2): +            response = self.view(request, pk='1').render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) +        updated = self.objects.get(id=1) +        self.assertEqual(updated.text, 'foobar') + +    def test_patch_instance_view(self): +        """ +        PATCH requests to RetrieveUpdateDestroyAPIView should update an object. +        """ +        data = {'text': 'foobar'} +        request = factory.patch('/1', data, format='json') + +        with self.assertNumQueries(2): +            response = self.view(request, pk=1).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) +        updated = self.objects.get(id=1) +        self.assertEqual(updated.text, 'foobar') + +    def test_delete_instance_view(self): +        """ +        DELETE requests to RetrieveUpdateDestroyAPIView should delete an object. +        """ +        request = factory.delete('/1') +        with self.assertNumQueries(2): +            response = self.view(request, pk=1).render() +        self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) +        self.assertEqual(response.content, six.b('')) +        ids = [obj.id for obj in self.objects.all()] +        self.assertEqual(ids, [2, 3]) + +    def test_options_instance_view(self): +        """ +        OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata +        """ +        request = factory.options('/1') +        with self.assertNumQueries(1): +            response = self.view(request, pk=1).render() +        expected = { +            'parses': [ +                'application/json', +                'application/x-www-form-urlencoded', +                'multipart/form-data' +            ], +            'renders': [ +                'application/json', +                'text/html' +            ], +            'name': 'Instance', +            'description': 'Example description for OPTIONS.', +            'actions': { +                'PUT': { +                    'text': { +                        'max_length': 100, +                        'read_only': False, +                        'required': True, +                        'type': 'string', +                        'label': 'Text comes here', +                        'help_text': 'Text description.' +                    }, +                    'id': { +                        'read_only': True, +                        'required': False, +                        'type': 'integer', +                        'label': 'ID', +                    }, +                } +            } +        } +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, expected) + +    def test_options_before_instance_create(self): +        """ +        OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata +        before the instance has been created +        """ +        request = factory.options('/999') +        with self.assertNumQueries(1): +            response = self.view(request, pk=999).render() +        expected = { +            'parses': [ +                'application/json', +                'application/x-www-form-urlencoded', +                'multipart/form-data' +            ], +            'renders': [ +                'application/json', +                'text/html' +            ], +            'name': 'Instance', +            'description': 'Example description for OPTIONS.', +            'actions': { +                'PUT': { +                    'text': { +                        'max_length': 100, +                        'read_only': False, +                        'required': True, +                        'type': 'string', +                        'label': 'Text comes here', +                        'help_text': 'Text description.' +                    }, +                    'id': { +                        'read_only': True, +                        'required': False, +                        'type': 'integer', +                        'label': 'ID', +                    }, +                } +            } +        } +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, expected) + +    def test_get_instance_view_incorrect_arg(self): +        """ +        GET requests with an incorrect pk type, should raise 404, not 500. +        Regression test for #890. +        """ +        request = factory.get('/a') +        with self.assertNumQueries(0): +            response = self.view(request, pk='a').render() +        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + +    def test_put_cannot_set_id(self): +        """ +        PUT requests to create a new object should not be able to set the id. +        """ +        data = {'id': 999, 'text': 'foobar'} +        request = factory.put('/1', data, format='json') +        with self.assertNumQueries(2): +            response = self.view(request, pk=1).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) +        updated = self.objects.get(id=1) +        self.assertEqual(updated.text, 'foobar') + +    def test_put_to_deleted_instance(self): +        """ +        PUT requests to RetrieveUpdateDestroyAPIView should create an object +        if it does not currently exist. +        """ +        self.objects.get(id=1).delete() +        data = {'text': 'foobar'} +        request = factory.put('/1', data, format='json') +        with self.assertNumQueries(3): +            response = self.view(request, pk=1).render() +        self.assertEqual(response.status_code, status.HTTP_201_CREATED) +        self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) +        updated = self.objects.get(id=1) +        self.assertEqual(updated.text, 'foobar') + +    def test_put_to_filtered_out_instance(self): +        """ +        PUT requests to an URL of instance which is filtered out should not be +        able to create new objects. +        """ +        data = {'text': 'foo'} +        filtered_out_pk = BasicModel.objects.filter(text='filtered out')[0].pk +        request = factory.put('/{0}'.format(filtered_out_pk), data, format='json') +        response = self.view(request, pk=filtered_out_pk).render() +        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + +    def test_put_as_create_on_id_based_url(self): +        """ +        PUT requests to RetrieveUpdateDestroyAPIView should create an object +        at the requested url if it doesn't exist. +        """ +        data = {'text': 'foobar'} +        # pk fields can not be created on demand, only the database can set the pk for a new object +        request = factory.put('/5', data, format='json') +        with self.assertNumQueries(3): +            response = self.view(request, pk=5).render() +        self.assertEqual(response.status_code, status.HTTP_201_CREATED) +        new_obj = self.objects.get(pk=5) +        self.assertEqual(new_obj.text, 'foobar') + +    def test_put_as_create_on_slug_based_url(self): +        """ +        PUT requests to RetrieveUpdateDestroyAPIView should create an object +        at the requested url if possible, else return HTTP_403_FORBIDDEN error-response. +        """ +        data = {'text': 'foobar'} +        request = factory.put('/test_slug', data, format='json') +        with self.assertNumQueries(2): +            response = self.slug_based_view(request, slug='test_slug').render() +        self.assertEqual(response.status_code, status.HTTP_201_CREATED) +        self.assertEqual(response.data, {'slug': 'test_slug', 'text': 'foobar'}) +        new_obj = SlugBasedModel.objects.get(slug='test_slug') +        self.assertEqual(new_obj.text, 'foobar') + +    def test_patch_cannot_create_an_object(self): +        """ +        PATCH requests should not be able to create objects. +        """ +        data = {'text': 'foobar'} +        request = factory.patch('/999', data, format='json') +        with self.assertNumQueries(1): +            response = self.view(request, pk=999).render() +        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) +        self.assertFalse(self.objects.filter(id=999).exists()) + + +class TestOverriddenGetObject(TestCase): +    """ +    Test cases for a RetrieveUpdateDestroyAPIView that does NOT use the +    queryset/model mechanism but instead overrides get_object() +    """ +    def setUp(self): +        """ +        Create 3 BasicModel intances. +        """ +        items = ['foo', 'bar', 'baz'] +        for item in items: +            BasicModel(text=item).save() +        self.objects = BasicModel.objects +        self.data = [ +            {'id': obj.id, 'text': obj.text} +            for obj in self.objects.all() +        ] + +        class OverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView): +            """ +            Example detail view for override of get_object(). +            """ +            model = BasicModel + +            def get_object(self): +                pk = int(self.kwargs['pk']) +                return get_object_or_404(BasicModel.objects.all(), id=pk) + +        self.view = OverriddenGetObjectView.as_view() + +    def test_overridden_get_object_view(self): +        """ +        GET requests to RetrieveUpdateDestroyAPIView should return a single object. +        """ +        request = factory.get('/1') +        with self.assertNumQueries(1): +            response = self.view(request, pk=1).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, self.data[0]) + + +# Regression test for #285 + +class CommentSerializer(serializers.ModelSerializer): +    class Meta: +        model = Comment +        exclude = ('created',) + + +class CommentView(generics.ListCreateAPIView): +    serializer_class = CommentSerializer +    model = Comment + + +class TestCreateModelWithAutoNowAddField(TestCase): +    def setUp(self): +        self.objects = Comment.objects +        self.view = CommentView.as_view() + +    def test_create_model_with_auto_now_add_field(self): +        """ +        Regression test for #285 + +        https://github.com/tomchristie/django-rest-framework/issues/285 +        """ +        data = {'email': 'foobar@example.com', 'content': 'foobar'} +        request = factory.post('/', data, format='json') +        response = self.view(request).render() +        self.assertEqual(response.status_code, status.HTTP_201_CREATED) +        created = self.objects.get(id=1) +        self.assertEqual(created.content, 'foobar') + + +# Test for particularly ugly regression with m2m in browsable API +class ClassB(models.Model): +    name = models.CharField(max_length=255) + + +class ClassA(models.Model): +    name = models.CharField(max_length=255) +    childs = models.ManyToManyField(ClassB, blank=True, null=True) + + +class ClassASerializer(serializers.ModelSerializer): +    childs = serializers.PrimaryKeyRelatedField(many=True, source='childs') + +    class Meta: +        model = ClassA + + +class ExampleView(generics.ListCreateAPIView): +    serializer_class = ClassASerializer +    model = ClassA + + +class TestM2MBrowseableAPI(TestCase): +    def test_m2m_in_browseable_api(self): +        """ +        Test for particularly ugly regression with m2m in browsable API +        """ +        request = factory.get('/', HTTP_ACCEPT='text/html') +        view = ExampleView().as_view() +        response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) + + +class InclusiveFilterBackend(object): +    def filter_queryset(self, request, queryset, view): +        return queryset.filter(text='foo') + + +class ExclusiveFilterBackend(object): +    def filter_queryset(self, request, queryset, view): +        return queryset.filter(text='other') + + +class TwoFieldModel(models.Model): +    field_a = models.CharField(max_length=100) +    field_b = models.CharField(max_length=100) + + +class DynamicSerializerView(generics.ListCreateAPIView): +    model = TwoFieldModel +    renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer) + +    def get_serializer_class(self): +        if self.request.method == 'POST': +            class DynamicSerializer(serializers.ModelSerializer): +                class Meta: +                    model = TwoFieldModel +                    fields = ('field_b',) +            return DynamicSerializer +        return super(DynamicSerializerView, self).get_serializer_class() + + +class TestFilterBackendAppliedToViews(TestCase): + +    def setUp(self): +        """ +        Create 3 BasicModel instances to filter on. +        """ +        items = ['foo', 'bar', 'baz'] +        for item in items: +            BasicModel(text=item).save() +        self.objects = BasicModel.objects +        self.data = [ +            {'id': obj.id, 'text': obj.text} +            for obj in self.objects.all() +        ] + +    def test_get_root_view_filters_by_name_with_filter_backend(self): +        """ +        GET requests to ListCreateAPIView should return filtered list. +        """ +        root_view = RootView.as_view(filter_backends=(InclusiveFilterBackend,)) +        request = factory.get('/') +        response = root_view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(len(response.data), 1) +        self.assertEqual(response.data, [{'id': 1, 'text': 'foo'}]) + +    def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(self): +        """ +        GET requests to ListCreateAPIView should return empty list when all models are filtered out. +        """ +        root_view = RootView.as_view(filter_backends=(ExclusiveFilterBackend,)) +        request = factory.get('/') +        response = root_view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, []) + +    def test_get_instance_view_filters_out_name_with_filter_backend(self): +        """ +        GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out. +        """ +        instance_view = InstanceView.as_view(filter_backends=(ExclusiveFilterBackend,)) +        request = factory.get('/1') +        response = instance_view(request, pk=1).render() +        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) +        self.assertEqual(response.data, {'detail': 'Not found'}) + +    def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self): +        """ +        GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded +        """ +        instance_view = InstanceView.as_view(filter_backends=(InclusiveFilterBackend,)) +        request = factory.get('/1') +        response = instance_view(request, pk=1).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, {'id': 1, 'text': 'foo'}) + +    def test_dynamic_serializer_form_in_browsable_api(self): +        """ +        GET requests to ListCreateAPIView should return filtered list. +        """ +        view = DynamicSerializerView.as_view() +        request = factory.get('/') +        response = view(request).render() +        self.assertContains(response, 'field_b') +        self.assertNotContains(response, 'field_a') diff --git a/tests/test_htmlrenderer.py b/tests/test_htmlrenderer.py new file mode 100644 index 00000000..8af5bb50 --- /dev/null +++ b/tests/test_htmlrenderer.py @@ -0,0 +1,120 @@ +from __future__ import unicode_literals +from django.core.exceptions import PermissionDenied +from django.http import Http404 +from django.test import TestCase +from django.template import TemplateDoesNotExist, Template +import django.template.loader +from rest_framework import status +from rest_framework.compat import patterns, url +from rest_framework.decorators import api_view, renderer_classes +from rest_framework.renderers import TemplateHTMLRenderer +from rest_framework.response import Response +from rest_framework.compat import six + + +@api_view(('GET',)) +@renderer_classes((TemplateHTMLRenderer,)) +def example(request): +    """ +    A view that can returns an HTML representation. +    """ +    data = {'object': 'foobar'} +    return Response(data, template_name='example.html') + + +@api_view(('GET',)) +@renderer_classes((TemplateHTMLRenderer,)) +def permission_denied(request): +    raise PermissionDenied() + + +@api_view(('GET',)) +@renderer_classes((TemplateHTMLRenderer,)) +def not_found(request): +    raise Http404() + + +urlpatterns = patterns('', +    url(r'^$', example), +    url(r'^permission_denied$', permission_denied), +    url(r'^not_found$', not_found), +) + + +class TemplateHTMLRendererTests(TestCase): +    urls = 'tests.test_htmlrenderer' + +    def setUp(self): +        """ +        Monkeypatch get_template +        """ +        self.get_template = django.template.loader.get_template + +        def get_template(template_name, dirs=None): +            if template_name == 'example.html': +                return Template("example: {{ object }}") +            raise TemplateDoesNotExist(template_name) + +        django.template.loader.get_template = get_template + +    def tearDown(self): +        """ +        Revert monkeypatching +        """ +        django.template.loader.get_template = self.get_template + +    def test_simple_html_view(self): +        response = self.client.get('/') +        self.assertContains(response, "example: foobar") +        self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') + +    def test_not_found_html_view(self): +        response = self.client.get('/not_found') +        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) +        self.assertEqual(response.content, six.b("404 Not Found")) +        self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') + +    def test_permission_denied_html_view(self): +        response = self.client.get('/permission_denied') +        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) +        self.assertEqual(response.content, six.b("403 Forbidden")) +        self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') + + +class TemplateHTMLRendererExceptionTests(TestCase): +    urls = 'tests.test_htmlrenderer' + +    def setUp(self): +        """ +        Monkeypatch get_template +        """ +        self.get_template = django.template.loader.get_template + +        def get_template(template_name): +            if template_name == '404.html': +                return Template("404: {{ detail }}") +            if template_name == '403.html': +                return Template("403: {{ detail }}") +            raise TemplateDoesNotExist(template_name) + +        django.template.loader.get_template = get_template + +    def tearDown(self): +        """ +        Revert monkeypatching +        """ +        django.template.loader.get_template = self.get_template + +    def test_not_found_html_view_with_template(self): +        response = self.client.get('/not_found') +        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) +        self.assertTrue(response.content in ( +            six.b("404: Not found"), six.b("404 Not Found"))) +        self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') + +    def test_permission_denied_html_view_with_template(self): +        response = self.client.get('/permission_denied') +        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) +        self.assertTrue(response.content in ( +            six.b("403: Permission denied"), six.b("403 Forbidden"))) +        self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') diff --git a/tests/test_hyperlinkedserializers.py b/tests/test_hyperlinkedserializers.py new file mode 100644 index 00000000..eee179ca --- /dev/null +++ b/tests/test_hyperlinkedserializers.py @@ -0,0 +1,379 @@ +from __future__ import unicode_literals +import json +from django.test import TestCase +from rest_framework import generics, status, serializers +from rest_framework.compat import patterns, url +from rest_framework.settings import api_settings +from rest_framework.test import APIRequestFactory +from tests.models import ( +    Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, +    Album, Photo, OptionalRelationModel +) + +factory = APIRequestFactory() + + +class BlogPostCommentSerializer(serializers.ModelSerializer): +    url = serializers.HyperlinkedIdentityField(view_name='blogpostcomment-detail') +    text = serializers.CharField() +    blog_post_url = serializers.HyperlinkedRelatedField(source='blog_post', view_name='blogpost-detail') + +    class Meta: +        model = BlogPostComment +        fields = ('text', 'blog_post_url', 'url') + + +class PhotoSerializer(serializers.Serializer): +    description = serializers.CharField() +    album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', queryset=Album.objects.all(), lookup_field='title', slug_url_kwarg='title') + +    def restore_object(self, attrs, instance=None): +        return Photo(**attrs) + + +class AlbumSerializer(serializers.ModelSerializer): +    url = serializers.HyperlinkedIdentityField(view_name='album-detail', lookup_field='title') + +    class Meta: +        model = Album +        fields = ('title', 'url') + + +class BasicList(generics.ListCreateAPIView): +    model = BasicModel +    model_serializer_class = serializers.HyperlinkedModelSerializer + + +class BasicDetail(generics.RetrieveUpdateDestroyAPIView): +    model = BasicModel +    model_serializer_class = serializers.HyperlinkedModelSerializer + + +class AnchorDetail(generics.RetrieveAPIView): +    model = Anchor +    model_serializer_class = serializers.HyperlinkedModelSerializer + + +class ManyToManyList(generics.ListAPIView): +    model = ManyToManyModel +    model_serializer_class = serializers.HyperlinkedModelSerializer + + +class ManyToManyDetail(generics.RetrieveAPIView): +    model = ManyToManyModel +    model_serializer_class = serializers.HyperlinkedModelSerializer + + +class BlogPostCommentListCreate(generics.ListCreateAPIView): +    model = BlogPostComment +    serializer_class = BlogPostCommentSerializer + + +class BlogPostCommentDetail(generics.RetrieveAPIView): +    model = BlogPostComment +    serializer_class = BlogPostCommentSerializer + + +class BlogPostDetail(generics.RetrieveAPIView): +    model = BlogPost + + +class PhotoListCreate(generics.ListCreateAPIView): +    model = Photo +    model_serializer_class = PhotoSerializer + + +class AlbumDetail(generics.RetrieveAPIView): +    model = Album +    serializer_class = AlbumSerializer +    lookup_field = 'title' + + +class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView): +    model = OptionalRelationModel +    model_serializer_class = serializers.HyperlinkedModelSerializer + + +urlpatterns = patterns('', +    url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'), +    url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'), +    url(r'^anchor/(?P<pk>\d+)/$', AnchorDetail.as_view(), name='anchor-detail'), +    url(r'^manytomany/$', ManyToManyList.as_view(), name='manytomanymodel-list'), +    url(r'^manytomany/(?P<pk>\d+)/$', ManyToManyDetail.as_view(), name='manytomanymodel-detail'), +    url(r'^posts/(?P<pk>\d+)/$', BlogPostDetail.as_view(), name='blogpost-detail'), +    url(r'^comments/$', BlogPostCommentListCreate.as_view(), name='blogpostcomment-list'), +    url(r'^comments/(?P<pk>\d+)/$', BlogPostCommentDetail.as_view(), name='blogpostcomment-detail'), +    url(r'^albums/(?P<title>\w[\w-]*)/$', AlbumDetail.as_view(), name='album-detail'), +    url(r'^photos/$', PhotoListCreate.as_view(), name='photo-list'), +    url(r'^optionalrelation/(?P<pk>\d+)/$', OptionalRelationDetail.as_view(), name='optionalrelationmodel-detail'), +) + + +class TestBasicHyperlinkedView(TestCase): +    urls = 'tests.test_hyperlinkedserializers' + +    def setUp(self): +        """ +        Create 3 BasicModel instances. +        """ +        items = ['foo', 'bar', 'baz'] +        for item in items: +            BasicModel(text=item).save() +        self.objects = BasicModel.objects +        self.data = [ +            {'url': 'http://testserver/basic/%d/' % obj.id, 'text': obj.text} +            for obj in self.objects.all() +        ] +        self.list_view = BasicList.as_view() +        self.detail_view = BasicDetail.as_view() + +    def test_get_list_view(self): +        """ +        GET requests to ListCreateAPIView should return list of objects. +        """ +        request = factory.get('/basic/') +        response = self.list_view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, self.data) + +    def test_get_detail_view(self): +        """ +        GET requests to ListCreateAPIView should return list of objects. +        """ +        request = factory.get('/basic/1') +        response = self.detail_view(request, pk=1).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, self.data[0]) + + +class TestManyToManyHyperlinkedView(TestCase): +    urls = 'tests.test_hyperlinkedserializers' + +    def setUp(self): +        """ +        Create 3 BasicModel instances. +        """ +        items = ['foo', 'bar', 'baz'] +        anchors = [] +        for item in items: +            anchor = Anchor(text=item) +            anchor.save() +            anchors.append(anchor) + +        manytomany = ManyToManyModel() +        manytomany.save() +        manytomany.rel.add(*anchors) + +        self.data = [{ +            'url': 'http://testserver/manytomany/1/', +            'rel': [ +                'http://testserver/anchor/1/', +                'http://testserver/anchor/2/', +                'http://testserver/anchor/3/', +            ] +        }] +        self.list_view = ManyToManyList.as_view() +        self.detail_view = ManyToManyDetail.as_view() + +    def test_get_list_view(self): +        """ +        GET requests to ListCreateAPIView should return list of objects. +        """ +        request = factory.get('/manytomany/') +        response = self.list_view(request) +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, self.data) + +    def test_get_detail_view(self): +        """ +        GET requests to ListCreateAPIView should return list of objects. +        """ +        request = factory.get('/manytomany/1/') +        response = self.detail_view(request, pk=1) +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, self.data[0]) + + +class TestHyperlinkedIdentityFieldLookup(TestCase): +    urls = 'tests.test_hyperlinkedserializers' + +    def setUp(self): +        """ +        Create 3 Album instances. +        """ +        titles = ['foo', 'bar', 'baz'] +        for title in titles: +            album = Album(title=title) +            album.save() +        self.detail_view = AlbumDetail.as_view() +        self.data = { +            'foo': {'title': 'foo', 'url': 'http://testserver/albums/foo/'}, +            'bar': {'title': 'bar', 'url': 'http://testserver/albums/bar/'}, +            'baz': {'title': 'baz', 'url': 'http://testserver/albums/baz/'} +        } + +    def test_lookup_field(self): +        """ +        GET requests to AlbumDetail view should return serialized Albums +        with a url field keyed by `title`. +        """ +        for album in Album.objects.all(): +            request = factory.get('/albums/{0}/'.format(album.title)) +            response = self.detail_view(request, title=album.title) +            self.assertEqual(response.status_code, status.HTTP_200_OK) +            self.assertEqual(response.data, self.data[album.title]) + + +class TestCreateWithForeignKeys(TestCase): +    urls = 'tests.test_hyperlinkedserializers' + +    def setUp(self): +        """ +        Create a blog post +        """ +        self.post = BlogPost.objects.create(title="Test post") +        self.create_view = BlogPostCommentListCreate.as_view() + +    def test_create_comment(self): + +        data = { +            'text': 'A test comment', +            'blog_post_url': 'http://testserver/posts/1/' +        } + +        request = factory.post('/comments/', data=data) +        response = self.create_view(request) +        self.assertEqual(response.status_code, status.HTTP_201_CREATED) +        self.assertEqual(response['Location'], 'http://testserver/comments/1/') +        self.assertEqual(self.post.blogpostcomment_set.count(), 1) +        self.assertEqual(self.post.blogpostcomment_set.all()[0].text, 'A test comment') + + +class TestCreateWithForeignKeysAndCustomSlug(TestCase): +    urls = 'tests.test_hyperlinkedserializers' + +    def setUp(self): +        """ +        Create an Album +        """ +        self.post = Album.objects.create(title='test-album') +        self.list_create_view = PhotoListCreate.as_view() + +    def test_create_photo(self): + +        data = { +            'description': 'A test photo', +            'album_url': 'http://testserver/albums/test-album/' +        } + +        request = factory.post('/photos/', data=data) +        response = self.list_create_view(request) +        self.assertEqual(response.status_code, status.HTTP_201_CREATED) +        self.assertNotIn('Location', response, msg='Location should only be included if there is a "url" field on the serializer') +        self.assertEqual(self.post.photo_set.count(), 1) +        self.assertEqual(self.post.photo_set.all()[0].description, 'A test photo') + + +class TestOptionalRelationHyperlinkedView(TestCase): +    urls = 'tests.test_hyperlinkedserializers' + +    def setUp(self): +        """ +        Create 1 OptionalRelationModel instances. +        """ +        OptionalRelationModel().save() +        self.objects = OptionalRelationModel.objects +        self.detail_view = OptionalRelationDetail.as_view() +        self.data = {"url": "http://testserver/optionalrelation/1/", "other": None} + +    def test_get_detail_view(self): +        """ +        GET requests to RetrieveAPIView with optional relations should return None +        for non existing relations. +        """ +        request = factory.get('/optionalrelationmodel-detail/1') +        response = self.detail_view(request, pk=1) +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, self.data) + +    def test_put_detail_view(self): +        """ +        PUT requests to RetrieveUpdateDestroyAPIView with optional relations +        should accept None for non existing relations. +        """ +        response = self.client.put('/optionalrelation/1/', +                                   data=json.dumps(self.data), +                                   content_type='application/json') +        self.assertEqual(response.status_code, status.HTTP_200_OK) + + +class TestOverriddenURLField(TestCase): +    def setUp(self): +        class OverriddenURLSerializer(serializers.HyperlinkedModelSerializer): +            url = serializers.SerializerMethodField('get_url') + +            class Meta: +                model = BlogPost +                fields = ('title', 'url') + +            def get_url(self, obj): +                return 'foo bar' + +        self.Serializer = OverriddenURLSerializer +        self.obj = BlogPost.objects.create(title='New blog post') + +    def test_overridden_url_field(self): +        """ +        The 'url' field should respect overriding. +        Regression test for #936. +        """ +        serializer = self.Serializer(self.obj) +        self.assertEqual( +            serializer.data, +            {'title': 'New blog post', 'url': 'foo bar'} +        ) + + +class TestURLFieldNameBySettings(TestCase): +    urls = 'tests.test_hyperlinkedserializers' + +    def setUp(self): +        self.saved_url_field_name = api_settings.URL_FIELD_NAME +        api_settings.URL_FIELD_NAME = 'global_url_field' + +        class Serializer(serializers.HyperlinkedModelSerializer): + +            class Meta: +                model = BlogPost +                fields = ('title', api_settings.URL_FIELD_NAME) + +        self.Serializer = Serializer +        self.obj = BlogPost.objects.create(title="New blog post") + +    def tearDown(self): +        api_settings.URL_FIELD_NAME = self.saved_url_field_name + +    def test_overridden_url_field_name(self): +        request = factory.get('/posts/') +        serializer = self.Serializer(self.obj, context={'request': request}) +        self.assertIn(api_settings.URL_FIELD_NAME, serializer.data) + + +class TestURLFieldNameByOptions(TestCase): +    urls = 'tests.test_hyperlinkedserializers' + +    def setUp(self): +        class Serializer(serializers.HyperlinkedModelSerializer): + +            class Meta: +                model = BlogPost +                fields = ('title', 'serializer_url_field') +                url_field_name = 'serializer_url_field' + +        self.Serializer = Serializer +        self.obj = BlogPost.objects.create(title="New blog post") + +    def test_overridden_url_field_name(self): +        request = factory.get('/posts/') +        serializer = self.Serializer(self.obj, context={'request': request}) +        self.assertIn(self.Serializer.Meta.url_field_name, serializer.data) diff --git a/tests/test_multitable_inheritance.py b/tests/test_multitable_inheritance.py new file mode 100644 index 00000000..ce1bf3ea --- /dev/null +++ b/tests/test_multitable_inheritance.py @@ -0,0 +1,67 @@ +from __future__ import unicode_literals +from django.db import models +from django.test import TestCase +from rest_framework import serializers +from tests.models import RESTFrameworkModel + + +# Models +class ParentModel(RESTFrameworkModel): +    name1 = models.CharField(max_length=100) + + +class ChildModel(ParentModel): +    name2 = models.CharField(max_length=100) + + +class AssociatedModel(RESTFrameworkModel): +    ref = models.OneToOneField(ParentModel, primary_key=True) +    name = models.CharField(max_length=100) + + +# Serializers +class DerivedModelSerializer(serializers.ModelSerializer): +    class Meta: +        model = ChildModel + + +class AssociatedModelSerializer(serializers.ModelSerializer): +    class Meta: +        model = AssociatedModel + + +# Tests +class IneritedModelSerializationTests(TestCase): + +    def test_multitable_inherited_model_fields_as_expected(self): +        """ +        Assert that the parent pointer field is not included in the fields +        serialized fields +        """ +        child = ChildModel(name1='parent name', name2='child name') +        serializer = DerivedModelSerializer(child) +        self.assertEqual(set(serializer.data.keys()), +                         set(['name1', 'name2', 'id'])) + +    def test_onetoone_primary_key_model_fields_as_expected(self): +        """ +        Assert that a model with a onetoone field that is the primary key is +        not treated like a derived model +        """ +        parent = ParentModel(name1='parent name') +        associate = AssociatedModel(name='hello', ref=parent) +        serializer = AssociatedModelSerializer(associate) +        self.assertEqual(set(serializer.data.keys()), +                         set(['name', 'ref'])) + +    def test_data_is_valid_without_parent_ptr(self): +        """ +        Assert that the pointer to the parent table is not a required field +        for input data +        """ +        data = { +            'name1': 'parent name', +            'name2': 'child name', +        } +        serializer = DerivedModelSerializer(data=data) +        self.assertEqual(serializer.is_valid(), True) diff --git a/tests/test_negotiation.py b/tests/test_negotiation.py new file mode 100644 index 00000000..04b89eb6 --- /dev/null +++ b/tests/test_negotiation.py @@ -0,0 +1,45 @@ +from __future__ import unicode_literals +from django.test import TestCase +from rest_framework.negotiation import DefaultContentNegotiation +from rest_framework.request import Request +from rest_framework.renderers import BaseRenderer +from rest_framework.test import APIRequestFactory + + +factory = APIRequestFactory() + + +class MockJSONRenderer(BaseRenderer): +    media_type = 'application/json' + + +class MockHTMLRenderer(BaseRenderer): +    media_type = 'text/html' + + +class NoCharsetSpecifiedRenderer(BaseRenderer): +    media_type = 'my/media' + + +class TestAcceptedMediaType(TestCase): +    def setUp(self): +        self.renderers = [MockJSONRenderer(), MockHTMLRenderer()] +        self.negotiator = DefaultContentNegotiation() + +    def select_renderer(self, request): +        return self.negotiator.select_renderer(request, self.renderers) + +    def test_client_without_accept_use_renderer(self): +        request = Request(factory.get('/')) +        accepted_renderer, accepted_media_type = self.select_renderer(request) +        self.assertEqual(accepted_media_type, 'application/json') + +    def test_client_underspecifies_accept_use_renderer(self): +        request = Request(factory.get('/', HTTP_ACCEPT='*/*')) +        accepted_renderer, accepted_media_type = self.select_renderer(request) +        self.assertEqual(accepted_media_type, 'application/json') + +    def test_client_overspecifies_accept_use_client(self): +        request = Request(factory.get('/', HTTP_ACCEPT='application/json; indent=8')) +        accepted_renderer, accepted_media_type = self.select_renderer(request) +        self.assertEqual(accepted_media_type, 'application/json; indent=8') diff --git a/tests/test_nullable_fields.py b/tests/test_nullable_fields.py new file mode 100644 index 00000000..33a9685f --- /dev/null +++ b/tests/test_nullable_fields.py @@ -0,0 +1,30 @@ +from django.core.urlresolvers import reverse + +from rest_framework.compat import patterns, url +from rest_framework.test import APITestCase +from tests.models import NullableForeignKeySource +from tests.serializers import NullableFKSourceSerializer +from tests.views import NullableFKSourceDetail + + +urlpatterns = patterns( +    '', +    url(r'^objects/(?P<pk>\d+)/$', NullableFKSourceDetail.as_view(), name='object-detail'), +) + + +class NullableForeignKeyTests(APITestCase): +    """ +    DRF should be able to handle nullable foreign keys when a test +    Client POST/PUT request is made with its own serialized object. +    """ +    urls = 'tests.test_nullable_fields' + +    def test_updating_object_with_null_fk(self): +        obj = NullableForeignKeySource(name='example', target=None) +        obj.save() +        serialized_data = NullableFKSourceSerializer(obj).data + +        response = self.client.put(reverse('object-detail', args=[obj.pk]), serialized_data) + +        self.assertEqual(response.data, serialized_data) diff --git a/tests/test_pagination.py b/tests/test_pagination.py new file mode 100644 index 00000000..293146c0 --- /dev/null +++ b/tests/test_pagination.py @@ -0,0 +1,520 @@ +from __future__ import unicode_literals +import datetime +from decimal import Decimal +from django.db import models +from django.core.paginator import Paginator +from django.test import TestCase +from django.utils import unittest +from rest_framework import generics, status, pagination, filters, serializers +from rest_framework.compat import django_filters +from rest_framework.test import APIRequestFactory +from .models import BasicModel, FilterableItem + +factory = APIRequestFactory() + +# Helper function to split arguments out of an url +def split_arguments_from_url(url): +    if '?' not in url: +        return url + +    path, args = url.split('?') +    args = dict(r.split('=') for r in args.split('&')) +    return path, args + + +class RootView(generics.ListCreateAPIView): +    """ +    Example description for OPTIONS. +    """ +    model = BasicModel +    paginate_by = 10 + + +class DefaultPageSizeKwargView(generics.ListAPIView): +    """ +    View for testing default paginate_by_param usage +    """ +    model = BasicModel + + +class PaginateByParamView(generics.ListAPIView): +    """ +    View for testing custom paginate_by_param usage +    """ +    model = BasicModel +    paginate_by_param = 'page_size' + + +class MaxPaginateByView(generics.ListAPIView): +    """ +    View for testing custom max_paginate_by usage +    """ +    model = BasicModel +    paginate_by = 3 +    max_paginate_by = 5 +    paginate_by_param = 'page_size' + + +class IntegrationTestPagination(TestCase): +    """ +    Integration tests for paginated list views. +    """ + +    def setUp(self): +        """ +        Create 26 BasicModel instances. +        """ +        for char in 'abcdefghijklmnopqrstuvwxyz': +            BasicModel(text=char * 3).save() +        self.objects = BasicModel.objects +        self.data = [ +            {'id': obj.id, 'text': obj.text} +            for obj in self.objects.all() +        ] +        self.view = RootView.as_view() + +    def test_get_paginated_root_view(self): +        """ +        GET requests to paginated ListCreateAPIView should return paginated results. +        """ +        request = factory.get('/') +        # Note: Database queries are a `SELECT COUNT`, and `SELECT <fields>` +        with self.assertNumQueries(2): +            response = self.view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data['count'], 26) +        self.assertEqual(response.data['results'], self.data[:10]) +        self.assertNotEqual(response.data['next'], None) +        self.assertEqual(response.data['previous'], None) + +        request = factory.get(*split_arguments_from_url(response.data['next'])) +        with self.assertNumQueries(2): +            response = self.view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data['count'], 26) +        self.assertEqual(response.data['results'], self.data[10:20]) +        self.assertNotEqual(response.data['next'], None) +        self.assertNotEqual(response.data['previous'], None) + +        request = factory.get(*split_arguments_from_url(response.data['next'])) +        with self.assertNumQueries(2): +            response = self.view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data['count'], 26) +        self.assertEqual(response.data['results'], self.data[20:]) +        self.assertEqual(response.data['next'], None) +        self.assertNotEqual(response.data['previous'], None) + + +class IntegrationTestPaginationAndFiltering(TestCase): + +    def setUp(self): +        """ +        Create 50 FilterableItem instances. +        """ +        base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8)) +        for i in range(26): +            text = chr(i + ord(base_data[0])) * 3  # Produces string 'aaa', 'bbb', etc. +            decimal = base_data[1] + i +            date = base_data[2] - datetime.timedelta(days=i * 2) +            FilterableItem(text=text, decimal=decimal, date=date).save() + +        self.objects = FilterableItem.objects +        self.data = [ +            {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} +            for obj in self.objects.all() +        ] + +    @unittest.skipUnless(django_filters, 'django-filter not installed') +    def test_get_django_filter_paginated_filtered_root_view(self): +        """ +        GET requests to paginated filtered ListCreateAPIView should return +        paginated results. The next and previous links should preserve the +        filtered parameters. +        """ +        class DecimalFilter(django_filters.FilterSet): +            decimal = django_filters.NumberFilter(lookup_type='lt') + +            class Meta: +                model = FilterableItem +                fields = ['text', 'decimal', 'date'] + +        class FilterFieldsRootView(generics.ListCreateAPIView): +            model = FilterableItem +            paginate_by = 10 +            filter_class = DecimalFilter +            filter_backends = (filters.DjangoFilterBackend,) + +        view = FilterFieldsRootView.as_view() + +        EXPECTED_NUM_QUERIES = 2 + +        request = factory.get('/', {'decimal': '15.20'}) +        with self.assertNumQueries(EXPECTED_NUM_QUERIES): +            response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data['count'], 15) +        self.assertEqual(response.data['results'], self.data[:10]) +        self.assertNotEqual(response.data['next'], None) +        self.assertEqual(response.data['previous'], None) + +        request = factory.get(*split_arguments_from_url(response.data['next'])) +        with self.assertNumQueries(EXPECTED_NUM_QUERIES): +            response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data['count'], 15) +        self.assertEqual(response.data['results'], self.data[10:15]) +        self.assertEqual(response.data['next'], None) +        self.assertNotEqual(response.data['previous'], None) + +        request = factory.get(*split_arguments_from_url(response.data['previous'])) +        with self.assertNumQueries(EXPECTED_NUM_QUERIES): +            response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data['count'], 15) +        self.assertEqual(response.data['results'], self.data[:10]) +        self.assertNotEqual(response.data['next'], None) +        self.assertEqual(response.data['previous'], None) + +    def test_get_basic_paginated_filtered_root_view(self): +        """ +        Same as `test_get_django_filter_paginated_filtered_root_view`, +        except using a custom filter backend instead of the django-filter +        backend, +        """ + +        class DecimalFilterBackend(filters.BaseFilterBackend): +            def filter_queryset(self, request, queryset, view): +                return queryset.filter(decimal__lt=Decimal(request.GET['decimal'])) + +        class BasicFilterFieldsRootView(generics.ListCreateAPIView): +            model = FilterableItem +            paginate_by = 10 +            filter_backends = (DecimalFilterBackend,) + +        view = BasicFilterFieldsRootView.as_view() + +        request = factory.get('/', {'decimal': '15.20'}) +        with self.assertNumQueries(2): +            response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data['count'], 15) +        self.assertEqual(response.data['results'], self.data[:10]) +        self.assertNotEqual(response.data['next'], None) +        self.assertEqual(response.data['previous'], None) + +        request = factory.get(*split_arguments_from_url(response.data['next'])) +        with self.assertNumQueries(2): +            response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data['count'], 15) +        self.assertEqual(response.data['results'], self.data[10:15]) +        self.assertEqual(response.data['next'], None) +        self.assertNotEqual(response.data['previous'], None) + +        request = factory.get(*split_arguments_from_url(response.data['previous'])) +        with self.assertNumQueries(2): +            response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data['count'], 15) +        self.assertEqual(response.data['results'], self.data[:10]) +        self.assertNotEqual(response.data['next'], None) +        self.assertEqual(response.data['previous'], None) + + +class PassOnContextPaginationSerializer(pagination.PaginationSerializer): +    class Meta: +        object_serializer_class = serializers.Serializer + + +class UnitTestPagination(TestCase): +    """ +    Unit tests for pagination of primitive objects. +    """ + +    def setUp(self): +        self.objects = [char * 3 for char in 'abcdefghijklmnopqrstuvwxyz'] +        paginator = Paginator(self.objects, 10) +        self.first_page = paginator.page(1) +        self.last_page = paginator.page(3) + +    def test_native_pagination(self): +        serializer = pagination.PaginationSerializer(self.first_page) +        self.assertEqual(serializer.data['count'], 26) +        self.assertEqual(serializer.data['next'], '?page=2') +        self.assertEqual(serializer.data['previous'], None) +        self.assertEqual(serializer.data['results'], self.objects[:10]) + +        serializer = pagination.PaginationSerializer(self.last_page) +        self.assertEqual(serializer.data['count'], 26) +        self.assertEqual(serializer.data['next'], None) +        self.assertEqual(serializer.data['previous'], '?page=2') +        self.assertEqual(serializer.data['results'], self.objects[20:]) + +    def test_context_available_in_result(self): +        """ +        Ensure context gets passed through to the object serializer. +        """ +        serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'}) +        serializer.data +        results = serializer.fields[serializer.results_field] +        self.assertEqual(serializer.context, results.context) + + +class TestUnpaginated(TestCase): +    """ +    Tests for list views without pagination. +    """ + +    def setUp(self): +        """ +        Create 13 BasicModel instances. +        """ +        for i in range(13): +            BasicModel(text=i).save() +        self.objects = BasicModel.objects +        self.data = [ +        {'id': obj.id, 'text': obj.text} +        for obj in self.objects.all() +        ] +        self.view = DefaultPageSizeKwargView.as_view() + +    def test_unpaginated(self): +        """ +        Tests the default page size for this view. +        no page size --> no limit --> no meta data +        """ +        request = factory.get('/') +        response = self.view(request) +        self.assertEqual(response.data, self.data) + + +class TestCustomPaginateByParam(TestCase): +    """ +    Tests for list views with default page size kwarg +    """ + +    def setUp(self): +        """ +        Create 13 BasicModel instances. +        """ +        for i in range(13): +            BasicModel(text=i).save() +        self.objects = BasicModel.objects +        self.data = [ +        {'id': obj.id, 'text': obj.text} +        for obj in self.objects.all() +        ] +        self.view = PaginateByParamView.as_view() + +    def test_default_page_size(self): +        """ +        Tests the default page size for this view. +        no page size --> no limit --> no meta data +        """ +        request = factory.get('/') +        response = self.view(request).render() +        self.assertEqual(response.data, self.data) + +    def test_paginate_by_param(self): +        """ +        If paginate_by_param is set, the new kwarg should limit per view requests. +        """ +        request = factory.get('/', {'page_size': 5}) +        response = self.view(request).render() +        self.assertEqual(response.data['count'], 13) +        self.assertEqual(response.data['results'], self.data[:5]) + + +class TestMaxPaginateByParam(TestCase): +    """ +    Tests for list views with max_paginate_by kwarg +    """ + +    def setUp(self): +        """ +        Create 13 BasicModel instances. +        """ +        for i in range(13): +            BasicModel(text=i).save() +        self.objects = BasicModel.objects +        self.data = [ +            {'id': obj.id, 'text': obj.text} +            for obj in self.objects.all() +        ] +        self.view = MaxPaginateByView.as_view() + +    def test_max_paginate_by(self): +        """ +        If max_paginate_by is set, it should limit page size for the view. +        """ +        request = factory.get('/', data={'page_size': 10}) +        response = self.view(request).render() +        self.assertEqual(response.data['count'], 13) +        self.assertEqual(response.data['results'], self.data[:5]) + +    def test_max_paginate_by_without_page_size_param(self): +        """ +        If max_paginate_by is set, but client does not specifiy page_size, +        standard `paginate_by` behavior should be used. +        """ +        request = factory.get('/') +        response = self.view(request).render() +        self.assertEqual(response.data['results'], self.data[:3]) + + +### Tests for context in pagination serializers + +class CustomField(serializers.Field): +    def to_native(self, value): +        if not 'view' in self.context: +            raise RuntimeError("context isn't getting passed into custom field") +        return "value" + + +class BasicModelSerializer(serializers.Serializer): +    text = CustomField() + +    def __init__(self, *args, **kwargs): +        super(BasicModelSerializer, self).__init__(*args, **kwargs) +        if not 'view' in self.context: +            raise RuntimeError("context isn't getting passed into serializer init") + + +class TestContextPassedToCustomField(TestCase): +    def setUp(self): +        BasicModel.objects.create(text='ala ma kota') + +    def test_with_pagination(self): +        class ListView(generics.ListCreateAPIView): +            model = BasicModel +            serializer_class = BasicModelSerializer +            paginate_by = 1 + +        self.view = ListView.as_view() +        request = factory.get('/') +        response = self.view(request).render() + +        self.assertEqual(response.status_code, status.HTTP_200_OK) + + +### Tests for custom pagination serializers + +class LinksSerializer(serializers.Serializer): +    next = pagination.NextPageField(source='*') +    prev = pagination.PreviousPageField(source='*') + + +class CustomPaginationSerializer(pagination.BasePaginationSerializer): +    links = LinksSerializer(source='*')  # Takes the page object as the source +    total_results = serializers.Field(source='paginator.count') + +    results_field = 'objects' + + +class TestCustomPaginationSerializer(TestCase): +    def setUp(self): +        objects = ['john', 'paul', 'george', 'ringo'] +        paginator = Paginator(objects, 2) +        self.page = paginator.page(1) + +    def test_custom_pagination_serializer(self): +        request = APIRequestFactory().get('/foobar') +        serializer = CustomPaginationSerializer( +            instance=self.page, +            context={'request': request} +        ) +        expected = { +            'links': { +                'next': 'http://testserver/foobar?page=2', +                'prev': None +            }, +            'total_results': 4, +            'objects': ['john', 'paul'] +        } +        self.assertEqual(serializer.data, expected) + + +class NonIntegerPage(object): + +    def __init__(self, paginator, object_list, prev_token, token, next_token): +        self.paginator = paginator +        self.object_list = object_list +        self.prev_token = prev_token +        self.token = token +        self.next_token = next_token + +    def has_next(self): +        return not not self.next_token + +    def next_page_number(self): +        return self.next_token + +    def has_previous(self): +        return not not self.prev_token + +    def previous_page_number(self): +        return self.prev_token + + +class NonIntegerPaginator(object): + +    def __init__(self, object_list, per_page): +        self.object_list = object_list +        self.per_page = per_page + +    def count(self): +        # pretend like we don't know how many pages we have +        return None + +    def page(self, token=None): +        if token: +            try: +                first = self.object_list.index(token) +            except ValueError: +                first = 0 +        else: +            first = 0 +        n = len(self.object_list) +        last = min(first + self.per_page, n) +        prev_token = self.object_list[last - (2 * self.per_page)] if first else None +        next_token = self.object_list[last] if last < n else None +        return NonIntegerPage(self, self.object_list[first:last], prev_token, token, next_token) + + +class TestNonIntegerPagination(TestCase): + + +    def test_custom_pagination_serializer(self): +        objects = ['john', 'paul', 'george', 'ringo'] +        paginator = NonIntegerPaginator(objects, 2) + +        request = APIRequestFactory().get('/foobar') +        serializer = CustomPaginationSerializer( +            instance=paginator.page(), +            context={'request': request} +        ) +        expected = { +            'links': { +                'next': 'http://testserver/foobar?page={0}'.format(objects[2]), +                'prev': None +            }, +            'total_results': None, +            'objects': objects[:2] +        } +        self.assertEqual(serializer.data, expected) + +        request = APIRequestFactory().get('/foobar') +        serializer = CustomPaginationSerializer( +            instance=paginator.page('george'), +            context={'request': request} +        ) +        expected = { +            'links': { +                'next': None, +                'prev': 'http://testserver/foobar?page={0}'.format(objects[0]), +            }, +            'total_results': None, +            'objects': objects[2:] +        } +        self.assertEqual(serializer.data, expected) diff --git a/tests/test_parsers.py b/tests/test_parsers.py new file mode 100644 index 00000000..8af90677 --- /dev/null +++ b/tests/test_parsers.py @@ -0,0 +1,115 @@ +from __future__ import unicode_literals +from rest_framework.compat import StringIO +from django import forms +from django.core.files.uploadhandler import MemoryFileUploadHandler +from django.test import TestCase +from django.utils import unittest +from rest_framework.compat import etree +from rest_framework.parsers import FormParser, FileUploadParser +from rest_framework.parsers import XMLParser +import datetime + + +class Form(forms.Form): +    field1 = forms.CharField(max_length=3) +    field2 = forms.CharField() + + +class TestFormParser(TestCase): +    def setUp(self): +        self.string = "field1=abc&field2=defghijk" + +    def test_parse(self): +        """ Make sure the `QueryDict` works OK """ +        parser = FormParser() + +        stream = StringIO(self.string) +        data = parser.parse(stream) + +        self.assertEqual(Form(data).is_valid(), True) + + +class TestXMLParser(TestCase): +    def setUp(self): +        self._input = StringIO( +            '<?xml version="1.0" encoding="utf-8"?>' +            '<root>' +            '<field_a>121.0</field_a>' +            '<field_b>dasd</field_b>' +            '<field_c></field_c>' +            '<field_d>2011-12-25 12:45:00</field_d>' +            '</root>' +        ) +        self._data = { +            'field_a': 121, +            'field_b': 'dasd', +            'field_c': None, +            'field_d': datetime.datetime(2011, 12, 25, 12, 45, 00) +        } +        self._complex_data_input = StringIO( +            '<?xml version="1.0" encoding="utf-8"?>' +            '<root>' +            '<creation_date>2011-12-25 12:45:00</creation_date>' +            '<sub_data_list>' +            '<list-item><sub_id>1</sub_id><sub_name>first</sub_name></list-item>' +            '<list-item><sub_id>2</sub_id><sub_name>second</sub_name></list-item>' +            '</sub_data_list>' +            '<name>name</name>' +            '</root>' +        ) +        self._complex_data = { +            "creation_date": datetime.datetime(2011, 12, 25, 12, 45, 00), +            "name": "name", +            "sub_data_list": [ +                { +                    "sub_id": 1, +                    "sub_name": "first" +                }, +                { +                    "sub_id": 2, +                    "sub_name": "second" +                } +            ] +        } + +    @unittest.skipUnless(etree, 'defusedxml not installed') +    def test_parse(self): +        parser = XMLParser() +        data = parser.parse(self._input) +        self.assertEqual(data, self._data) + +    @unittest.skipUnless(etree, 'defusedxml not installed') +    def test_complex_data_parse(self): +        parser = XMLParser() +        data = parser.parse(self._complex_data_input) +        self.assertEqual(data, self._complex_data) + + +class TestFileUploadParser(TestCase): +    def setUp(self): +        class MockRequest(object): +            pass +        from io import BytesIO +        self.stream = BytesIO( +            "Test text file".encode('utf-8') +        ) +        request = MockRequest() +        request.upload_handlers = (MemoryFileUploadHandler(),) +        request.META = { +            'HTTP_CONTENT_DISPOSITION': 'Content-Disposition: inline; filename=file.txt', +            'HTTP_CONTENT_LENGTH': 14, +        } +        self.parser_context = {'request': request, 'kwargs': {}} + +    def test_parse(self): +        """ Make sure the `QueryDict` works OK """ +        parser = FileUploadParser() +        self.stream.seek(0) +        data_and_files = parser.parse(self.stream, None, self.parser_context) +        file_obj = data_and_files.files['file'] +        self.assertEqual(file_obj._size, 14) + +    def test_get_filename(self): +        parser = FileUploadParser() +        filename = parser.get_filename(self.stream, None, self.parser_context) +        self.assertEqual(filename, 'file.txt') diff --git a/tests/test_permissions.py b/tests/test_permissions.py new file mode 100644 index 00000000..a2cb0c36 --- /dev/null +++ b/tests/test_permissions.py @@ -0,0 +1,291 @@ +from __future__ import unicode_literals +from django.contrib.auth.models import User, Permission, Group +from django.db import models +from django.test import TestCase +from django.utils import unittest +from rest_framework import generics, status, permissions, authentication, HTTP_HEADER_ENCODING +from rest_framework.compat import guardian, get_model_name +from rest_framework.filters import DjangoObjectPermissionsFilter +from rest_framework.test import APIRequestFactory +from tests.models import BasicModel +import base64 + +factory = APIRequestFactory() + +class RootView(generics.ListCreateAPIView): +    model = BasicModel +    authentication_classes = [authentication.BasicAuthentication] +    permission_classes = [permissions.DjangoModelPermissions] + + +class InstanceView(generics.RetrieveUpdateDestroyAPIView): +    model = BasicModel +    authentication_classes = [authentication.BasicAuthentication] +    permission_classes = [permissions.DjangoModelPermissions] + +root_view = RootView.as_view() +instance_view = InstanceView.as_view() + + +def basic_auth_header(username, password): +    credentials = ('%s:%s' % (username, password)) +    base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING) +    return 'Basic %s' % base64_credentials + + +class ModelPermissionsIntegrationTests(TestCase): +    def setUp(self): +        User.objects.create_user('disallowed', 'disallowed@example.com', 'password') +        user = User.objects.create_user('permitted', 'permitted@example.com', 'password') +        user.user_permissions = [ +            Permission.objects.get(codename='add_basicmodel'), +            Permission.objects.get(codename='change_basicmodel'), +            Permission.objects.get(codename='delete_basicmodel') +        ] +        user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password') +        user.user_permissions = [ +            Permission.objects.get(codename='change_basicmodel'), +        ] + +        self.permitted_credentials = basic_auth_header('permitted', 'password') +        self.disallowed_credentials = basic_auth_header('disallowed', 'password') +        self.updateonly_credentials = basic_auth_header('updateonly', 'password') + +        BasicModel(text='foo').save() + +    def test_has_create_permissions(self): +        request = factory.post('/', {'text': 'foobar'}, format='json', +                               HTTP_AUTHORIZATION=self.permitted_credentials) +        response = root_view(request, pk=1) +        self.assertEqual(response.status_code, status.HTTP_201_CREATED) + +    def test_has_put_permissions(self): +        request = factory.put('/1', {'text': 'foobar'}, format='json', +                              HTTP_AUTHORIZATION=self.permitted_credentials) +        response = instance_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_200_OK) + +    def test_has_delete_permissions(self): +        request = factory.delete('/1', HTTP_AUTHORIZATION=self.permitted_credentials) +        response = instance_view(request, pk=1) +        self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + +    def test_does_not_have_create_permissions(self): +        request = factory.post('/', {'text': 'foobar'}, format='json', +                               HTTP_AUTHORIZATION=self.disallowed_credentials) +        response = root_view(request, pk=1) +        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + +    def test_does_not_have_put_permissions(self): +        request = factory.put('/1', {'text': 'foobar'}, format='json', +                              HTTP_AUTHORIZATION=self.disallowed_credentials) +        response = instance_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + +    def test_does_not_have_delete_permissions(self): +        request = factory.delete('/1', HTTP_AUTHORIZATION=self.disallowed_credentials) +        response = instance_view(request, pk=1) +        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + +    def test_has_put_as_create_permissions(self): +        # User only has update permissions - should be able to update an entity. +        request = factory.put('/1', {'text': 'foobar'}, format='json', +                              HTTP_AUTHORIZATION=self.updateonly_credentials) +        response = instance_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_200_OK) + +        # But if PUTing to a new entity, permission should be denied. +        request = factory.put('/2', {'text': 'foobar'}, format='json', +                              HTTP_AUTHORIZATION=self.updateonly_credentials) +        response = instance_view(request, pk='2') +        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + +    def test_options_permitted(self): +        request = factory.options('/', +                               HTTP_AUTHORIZATION=self.permitted_credentials) +        response = root_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertIn('actions', response.data) +        self.assertEqual(list(response.data['actions'].keys()), ['POST']) + +        request = factory.options('/1', +                               HTTP_AUTHORIZATION=self.permitted_credentials) +        response = instance_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertIn('actions', response.data) +        self.assertEqual(list(response.data['actions'].keys()), ['PUT']) + +    def test_options_disallowed(self): +        request = factory.options('/', +                               HTTP_AUTHORIZATION=self.disallowed_credentials) +        response = root_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertNotIn('actions', response.data) + +        request = factory.options('/1', +                               HTTP_AUTHORIZATION=self.disallowed_credentials) +        response = instance_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertNotIn('actions', response.data) + +    def test_options_updateonly(self): +        request = factory.options('/', +                               HTTP_AUTHORIZATION=self.updateonly_credentials) +        response = root_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertNotIn('actions', response.data) + +        request = factory.options('/1', +                               HTTP_AUTHORIZATION=self.updateonly_credentials) +        response = instance_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertIn('actions', response.data) +        self.assertEqual(list(response.data['actions'].keys()), ['PUT']) + + +class BasicPermModel(models.Model): +    text = models.CharField(max_length=100) + +    class Meta: +        app_label = 'tests' +        permissions = ( +            ('view_basicpermmodel', 'Can view basic perm model'), +            # add, change, delete built in to django +        ) + +# Custom object-level permission, that includes 'view' permissions +class ViewObjectPermissions(permissions.DjangoObjectPermissions): +    perms_map = { +        'GET': ['%(app_label)s.view_%(model_name)s'], +        'OPTIONS': ['%(app_label)s.view_%(model_name)s'], +        'HEAD': ['%(app_label)s.view_%(model_name)s'], +        'POST': ['%(app_label)s.add_%(model_name)s'], +        'PUT': ['%(app_label)s.change_%(model_name)s'], +        'PATCH': ['%(app_label)s.change_%(model_name)s'], +        'DELETE': ['%(app_label)s.delete_%(model_name)s'], +    } + + +class ObjectPermissionInstanceView(generics.RetrieveUpdateDestroyAPIView): +    model = BasicPermModel +    authentication_classes = [authentication.BasicAuthentication] +    permission_classes = [ViewObjectPermissions] + +object_permissions_view = ObjectPermissionInstanceView.as_view() + + +class ObjectPermissionListView(generics.ListAPIView): +    model = BasicPermModel +    authentication_classes = [authentication.BasicAuthentication] +    permission_classes = [ViewObjectPermissions] + +object_permissions_list_view = ObjectPermissionListView.as_view() + + +@unittest.skipUnless(guardian, 'django-guardian not installed') +class ObjectPermissionsIntegrationTests(TestCase): +    """ +    Integration tests for the object level permissions API. +    """ +    def setUp(self): +        from guardian.shortcuts import assign_perm + +        # create users +        create = User.objects.create_user +        users = { +            'fullaccess': create('fullaccess', 'fullaccess@example.com', 'password'), +            'readonly': create('readonly', 'readonly@example.com', 'password'), +            'writeonly': create('writeonly', 'writeonly@example.com', 'password'), +            'deleteonly': create('deleteonly', 'deleteonly@example.com', 'password'), +        } + +        # give everyone model level permissions, as we are not testing those +        everyone = Group.objects.create(name='everyone') +        model_name = get_model_name(BasicPermModel) +        app_label = BasicPermModel._meta.app_label +        f = '{0}_{1}'.format +        perms = { +            'view':   f('view', model_name), +            'change': f('change', model_name), +            'delete': f('delete', model_name) +        } +        for perm in perms.values(): +            perm = '{0}.{1}'.format(app_label, perm) +            assign_perm(perm, everyone) +        everyone.user_set.add(*users.values()) + +        # appropriate object level permissions +        readers = Group.objects.create(name='readers') +        writers = Group.objects.create(name='writers') +        deleters = Group.objects.create(name='deleters') + +        model = BasicPermModel.objects.create(text='foo') + +        assign_perm(perms['view'], readers, model) +        assign_perm(perms['change'], writers, model) +        assign_perm(perms['delete'], deleters, model) + +        readers.user_set.add(users['fullaccess'], users['readonly']) +        writers.user_set.add(users['fullaccess'], users['writeonly']) +        deleters.user_set.add(users['fullaccess'], users['deleteonly']) + +        self.credentials = {} +        for user in users.values(): +            self.credentials[user.username] = basic_auth_header(user.username, 'password') + +    # Delete +    def test_can_delete_permissions(self): +        request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['deleteonly']) +        response = object_permissions_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) + +    def test_cannot_delete_permissions(self): +        request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['readonly']) +        response = object_permissions_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + +    # Update +    def test_can_update_permissions(self): +        request = factory.patch('/1', {'text': 'foobar'}, format='json', +            HTTP_AUTHORIZATION=self.credentials['writeonly']) +        response = object_permissions_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data.get('text'), 'foobar') + +    def test_cannot_update_permissions(self): +        request = factory.patch('/1', {'text': 'foobar'}, format='json', +            HTTP_AUTHORIZATION=self.credentials['deleteonly']) +        response = object_permissions_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + +    def test_cannot_update_permissions_non_existing(self): +        request = factory.patch('/999', {'text': 'foobar'}, format='json', +            HTTP_AUTHORIZATION=self.credentials['deleteonly']) +        response = object_permissions_view(request, pk='999') +        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + +    # Read +    def test_can_read_permissions(self): +        request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly']) +        response = object_permissions_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_200_OK) + +    def test_cannot_read_permissions(self): +        request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['writeonly']) +        response = object_permissions_view(request, pk='1') +        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + +    # Read list +    def test_can_read_list_permissions(self): +        request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['readonly']) +        object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,) +        response = object_permissions_list_view(request) +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data[0].get('id'), 1) + +    def test_cannot_read_list_permissions(self): +        request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['writeonly']) +        object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,) +        response = object_permissions_list_view(request) +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertListEqual(response.data, []) diff --git a/tests/test_relations.py b/tests/test_relations.py new file mode 100644 index 00000000..cd276d30 --- /dev/null +++ b/tests/test_relations.py @@ -0,0 +1,144 @@ +""" +General tests for relational fields. +""" +from __future__ import unicode_literals +from django import get_version +from django.db import models +from django.test import TestCase +from django.utils import unittest +from rest_framework import serializers +from tests.models import BlogPost + + +class NullModel(models.Model): +    pass + + +class FieldTests(TestCase): +    def test_pk_related_field_with_empty_string(self): +        """ +        Regression test for #446 + +        https://github.com/tomchristie/django-rest-framework/issues/446 +        """ +        field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all()) +        self.assertRaises(serializers.ValidationError, field.from_native, '') +        self.assertRaises(serializers.ValidationError, field.from_native, []) + +    def test_hyperlinked_related_field_with_empty_string(self): +        field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='') +        self.assertRaises(serializers.ValidationError, field.from_native, '') +        self.assertRaises(serializers.ValidationError, field.from_native, []) + +    def test_slug_related_field_with_empty_string(self): +        field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk') +        self.assertRaises(serializers.ValidationError, field.from_native, '') +        self.assertRaises(serializers.ValidationError, field.from_native, []) + + +class TestManyRelatedMixin(TestCase): +    def test_missing_many_to_many_related_field(self): +        ''' +        Regression test for #632 + +        https://github.com/tomchristie/django-rest-framework/pull/632 +        ''' +        field = serializers.RelatedField(many=True, read_only=False) + +        into = {} +        field.field_from_native({}, None, 'field_name', into) +        self.assertEqual(into['field_name'], []) + + +# Regression tests for #694 (`source` attribute on related fields) + +class RelatedFieldSourceTests(TestCase): +    def test_related_manager_source(self): +        """ +        Relational fields should be able to use manager-returning methods as their source. +        """ +        BlogPost.objects.create(title='blah') +        field = serializers.RelatedField(many=True, source='get_blogposts_manager') + +        class ClassWithManagerMethod(object): +            def get_blogposts_manager(self): +                return BlogPost.objects + +        obj = ClassWithManagerMethod() +        value = field.field_to_native(obj, 'field_name') +        self.assertEqual(value, ['BlogPost object']) + +    def test_related_queryset_source(self): +        """ +        Relational fields should be able to use queryset-returning methods as their source. +        """ +        BlogPost.objects.create(title='blah') +        field = serializers.RelatedField(many=True, source='get_blogposts_queryset') + +        class ClassWithQuerysetMethod(object): +            def get_blogposts_queryset(self): +                return BlogPost.objects.all() + +        obj = ClassWithQuerysetMethod() +        value = field.field_to_native(obj, 'field_name') +        self.assertEqual(value, ['BlogPost object']) + +    def test_dotted_source(self): +        """ +        Source argument should support dotted.source notation. +        """ +        BlogPost.objects.create(title='blah') +        field = serializers.RelatedField(many=True, source='a.b.c') + +        class ClassWithQuerysetMethod(object): +            a = { +                'b': { +                    'c': BlogPost.objects.all() +                } +            } + +        obj = ClassWithQuerysetMethod() +        value = field.field_to_native(obj, 'field_name') +        self.assertEqual(value, ['BlogPost object']) + +    # Regression for #1129 +    def test_exception_for_incorect_fk(self): +        """ +        Check that the exception message are correct if the source field +        doesn't exist. +        """ +        from tests.models import ManyToManySource +        class Meta: +            model = ManyToManySource +        attrs = { +            'name': serializers.SlugRelatedField( +                slug_field='name', source='banzai'), +            'Meta': Meta, +        } + +        TestSerializer = type(str('TestSerializer'), +            (serializers.ModelSerializer,), attrs) +        with self.assertRaises(AttributeError): +            TestSerializer(data={'name': 'foo'}) + +@unittest.skipIf(get_version() < '1.6.0', 'Upstream behaviour changed in v1.6') +class RelatedFieldChoicesTests(TestCase): +    """ +    Tests for #1408 "Web browseable API doesn't have blank option on drop down list box" +    https://github.com/tomchristie/django-rest-framework/issues/1408 +    """ +    def test_blank_option_is_added_to_choice_if_required_equals_false(self): +        """ + +        """ +        post = BlogPost(title="Checking blank option is added") +        post.save() + +        queryset = BlogPost.objects.all() +        field = serializers.RelatedField(required=False, queryset=queryset) + +        choice_count = BlogPost.objects.count() +        widget_count = len(field.widget.choices) + +        self.assertEqual(widget_count, choice_count + 1, 'BLANK_CHOICE_DASH option should have been added') + diff --git a/tests/test_relations_hyperlink.py b/tests/test_relations_hyperlink.py new file mode 100644 index 00000000..98f68d29 --- /dev/null +++ b/tests/test_relations_hyperlink.py @@ -0,0 +1,524 @@ +from __future__ import unicode_literals +from django.test import TestCase +from rest_framework import serializers +from rest_framework.compat import patterns, url +from rest_framework.test import APIRequestFactory +from tests.models import ( +    BlogPost, +    ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, +    NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource +) + +factory = APIRequestFactory() +request = factory.get('/')  # Just to ensure we have a request in the serializer context + + +def dummy_view(request, pk): +    pass + +urlpatterns = patterns('', +    url(r'^dummyurl/(?P<pk>[0-9]+)/$', dummy_view, name='dummy-url'), +    url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'), +    url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'), +    url(r'^foreignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeysource-detail'), +    url(r'^foreignkeytarget/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeytarget-detail'), +    url(r'^nullableforeignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableforeignkeysource-detail'), +    url(r'^onetoonetarget/(?P<pk>[0-9]+)/$', dummy_view, name='onetoonetarget-detail'), +    url(r'^nullableonetoonesource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableonetoonesource-detail'), +) + + +# ManyToMany +class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer): +    class Meta: +        model = ManyToManyTarget +        fields = ('url', 'name', 'sources') + + +class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer): +    class Meta: +        model = ManyToManySource +        fields = ('url', 'name', 'targets') + + +# ForeignKey +class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer): +    class Meta: +        model = ForeignKeyTarget +        fields = ('url', 'name', 'sources') + + +class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): +    class Meta: +        model = ForeignKeySource +        fields = ('url', 'name', 'target') + + +# Nullable ForeignKey +class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): +    class Meta: +        model = NullableForeignKeySource +        fields = ('url', 'name', 'target') + + +# Nullable OneToOne +class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer): +    class Meta: +        model = OneToOneTarget +        fields = ('url', 'name', 'nullable_source') + + +# TODO: Add test that .data cannot be accessed prior to .is_valid + +class HyperlinkedManyToManyTests(TestCase): +    urls = 'tests.test_relations_hyperlink' + +    def setUp(self): +        for idx in range(1, 4): +            target = ManyToManyTarget(name='target-%d' % idx) +            target.save() +            source = ManyToManySource(name='source-%d' % idx) +            source.save() +            for target in ManyToManyTarget.objects.all(): +                source.targets.add(target) + +    def test_many_to_many_retrieve(self): +        queryset = ManyToManySource.objects.all() +        serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) +        expected = [ +                {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']}, +                {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, +                {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_reverse_many_to_many_retrieve(self): +        queryset = ManyToManyTarget.objects.all() +        serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, +            {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, +            {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_many_to_many_update(self): +        data = {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} +        instance = ManyToManySource.objects.get(pk=1) +        serializer = ManyToManySourceSerializer(instance, data=data, context={'request': request}) +        self.assertTrue(serializer.is_valid()) +        serializer.save() +        self.assertEqual(serializer.data, data) + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = ManyToManySource.objects.all() +        serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) +        expected = [ +                {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}, +                {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, +                {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_reverse_many_to_many_update(self): +        data = {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']} +        instance = ManyToManyTarget.objects.get(pk=1) +        serializer = ManyToManyTargetSerializer(instance, data=data, context={'request': request}) +        self.assertTrue(serializer.is_valid()) +        serializer.save() +        self.assertEqual(serializer.data, data) + +        # Ensure target 1 is updated, and everything else is as expected +        queryset = ManyToManyTarget.objects.all() +        serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']}, +            {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, +            {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']} + +        ] +        self.assertEqual(serializer.data, expected) + +    def test_many_to_many_create(self): +        data = {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']} +        serializer = ManyToManySourceSerializer(data=data, context={'request': request}) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'source-4') + +        # Ensure source 4 is added, and everything else is as expected +        queryset = ManyToManySource.objects.all() +        serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']}, +            {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, +            {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}, +            {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_reverse_many_to_many_create(self): +        data = {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']} +        serializer = ManyToManyTargetSerializer(data=data, context={'request': request}) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'target-4') + +        # Ensure target 4 is added, and everything else is as expected +        queryset = ManyToManyTarget.objects.all() +        serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, +            {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, +            {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}, +            {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']} +        ] +        self.assertEqual(serializer.data, expected) + + +class HyperlinkedForeignKeyTests(TestCase): +    urls = 'tests.test_relations_hyperlink' + +    def setUp(self): +        target = ForeignKeyTarget(name='target-1') +        target.save() +        new_target = ForeignKeyTarget(name='target-2') +        new_target.save() +        for idx in range(1, 4): +            source = ForeignKeySource(name='source-%d' % idx, target=target) +            source.save() + +    def test_foreign_key_retrieve(self): +        queryset = ForeignKeySource.objects.all() +        serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, +            {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, +            {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_reverse_foreign_key_retrieve(self): +        queryset = ForeignKeyTarget.objects.all() +        serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']}, +            {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_update(self): +        data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'} +        instance = ForeignKeySource.objects.get(pk=1) +        serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) +        self.assertTrue(serializer.is_valid()) +        self.assertEqual(serializer.data, data) +        serializer.save() + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = ForeignKeySource.objects.all() +        serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'}, +            {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, +            {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_update_incorrect_type(self): +        data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 2} +        instance = ForeignKeySource.objects.get(pk=1) +        serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) +        self.assertFalse(serializer.is_valid()) +        self.assertEqual(serializer.errors, {'target': ['Incorrect type.  Expected url string, received int.']}) + +    def test_reverse_foreign_key_update(self): +        data = {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']} +        instance = ForeignKeyTarget.objects.get(pk=2) +        serializer = ForeignKeyTargetSerializer(instance, data=data, context={'request': request}) +        self.assertTrue(serializer.is_valid()) +        # We shouldn't have saved anything to the db yet since save +        # hasn't been called. +        queryset = ForeignKeyTarget.objects.all() +        new_serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']}, +            {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, +        ] +        self.assertEqual(new_serializer.data, expected) + +        serializer.save() +        self.assertEqual(serializer.data, data) + +        # Ensure target 2 is update, and everything else is as expected +        queryset = ForeignKeyTarget.objects.all() +        serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']}, +            {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_create(self): +        data = {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'} +        serializer = ForeignKeySourceSerializer(data=data, context={'request': request}) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'source-4') + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = ForeignKeySource.objects.all() +        serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, +            {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, +            {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}, +            {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_reverse_foreign_key_create(self): +        data = {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']} +        serializer = ForeignKeyTargetSerializer(data=data, context={'request': request}) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'target-3') + +        # Ensure target 4 is added, and everything else is as expected +        queryset = ForeignKeyTarget.objects.all() +        serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']}, +            {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, +            {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_update_with_invalid_null(self): +        data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': None} +        instance = ForeignKeySource.objects.get(pk=1) +        serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) +        self.assertFalse(serializer.is_valid()) +        self.assertEqual(serializer.errors, {'target': ['This field is required.']}) + + +class HyperlinkedNullableForeignKeyTests(TestCase): +    urls = 'tests.test_relations_hyperlink' + +    def setUp(self): +        target = ForeignKeyTarget(name='target-1') +        target.save() +        for idx in range(1, 4): +            if idx == 3: +                target = None +            source = NullableForeignKeySource(name='source-%d' % idx, target=target) +            source.save() + +    def test_foreign_key_retrieve_with_null(self): +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, +            {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, +            {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_create_with_valid_null(self): +        data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} +        serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request}) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'source-4') + +        # Ensure source 4 is created, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, +            {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, +            {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, +            {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_create_with_valid_emptystring(self): +        """ +        The emptystring should be interpreted as null in the context +        of relationships. +        """ +        data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': ''} +        expected_data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} +        serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request}) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, expected_data) +        self.assertEqual(obj.name, 'source-4') + +        # Ensure source 4 is created, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, +            {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, +            {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, +            {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_update_with_valid_null(self): +        data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None} +        instance = NullableForeignKeySource.objects.get(pk=1) +        serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request}) +        self.assertTrue(serializer.is_valid()) +        self.assertEqual(serializer.data, data) +        serializer.save() + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}, +            {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, +            {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_update_with_valid_emptystring(self): +        """ +        The emptystring should be interpreted as null in the context +        of relationships. +        """ +        data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': ''} +        expected_data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None} +        instance = NullableForeignKeySource.objects.get(pk=1) +        serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request}) +        self.assertTrue(serializer.is_valid()) +        self.assertEqual(serializer.data, expected_data) +        serializer.save() + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}, +            {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, +            {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, +        ] +        self.assertEqual(serializer.data, expected) + +    # reverse foreign keys MUST be read_only +    # In the general case they do not provide .remove() or .clear() +    # and cannot be arbitrarily set. + +    # def test_reverse_foreign_key_update(self): +    #     data = {'id': 1, 'name': 'target-1', 'sources': [1]} +    #     instance = ForeignKeyTarget.objects.get(pk=1) +    #     serializer = ForeignKeyTargetSerializer(instance, data=data) +    #     self.assertTrue(serializer.is_valid()) +    #     self.assertEqual(serializer.data, data) +    #     serializer.save() + +    #     # Ensure target 1 is updated, and everything else is as expected +    #     queryset = ForeignKeyTarget.objects.all() +    #     serializer = ForeignKeyTargetSerializer(queryset, many=True) +    #     expected = [ +    #         {'id': 1, 'name': 'target-1', 'sources': [1]}, +    #         {'id': 2, 'name': 'target-2', 'sources': []}, +    #     ] +    #     self.assertEqual(serializer.data, expected) + + +class HyperlinkedNullableOneToOneTests(TestCase): +    urls = 'tests.test_relations_hyperlink' + +    def setUp(self): +        target = OneToOneTarget(name='target-1') +        target.save() +        new_target = OneToOneTarget(name='target-2') +        new_target.save() +        source = NullableOneToOneSource(name='source-1', target=target) +        source.save() + +    def test_reverse_foreign_key_retrieve_with_null(self): +        queryset = OneToOneTarget.objects.all() +        serializer = NullableOneToOneTargetSerializer(queryset, many=True, context={'request': request}) +        expected = [ +            {'url': 'http://testserver/onetoonetarget/1/', 'name': 'target-1', 'nullable_source': 'http://testserver/nullableonetoonesource/1/'}, +            {'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None}, +        ] +        self.assertEqual(serializer.data, expected) + + +# Regression tests for #694 (`source` attribute on related fields) + +class HyperlinkedRelatedFieldSourceTests(TestCase): +    urls = 'tests.test_relations_hyperlink' + +    def test_related_manager_source(self): +        """ +        Relational fields should be able to use manager-returning methods as their source. +        """ +        BlogPost.objects.create(title='blah') +        field = serializers.HyperlinkedRelatedField( +            many=True, +            source='get_blogposts_manager', +            view_name='dummy-url', +        ) +        field.context = {'request': request} + +        class ClassWithManagerMethod(object): +            def get_blogposts_manager(self): +                return BlogPost.objects + +        obj = ClassWithManagerMethod() +        value = field.field_to_native(obj, 'field_name') +        self.assertEqual(value, ['http://testserver/dummyurl/1/']) + +    def test_related_queryset_source(self): +        """ +        Relational fields should be able to use queryset-returning methods as their source. +        """ +        BlogPost.objects.create(title='blah') +        field = serializers.HyperlinkedRelatedField( +            many=True, +            source='get_blogposts_queryset', +            view_name='dummy-url', +        ) +        field.context = {'request': request} + +        class ClassWithQuerysetMethod(object): +            def get_blogposts_queryset(self): +                return BlogPost.objects.all() + +        obj = ClassWithQuerysetMethod() +        value = field.field_to_native(obj, 'field_name') +        self.assertEqual(value, ['http://testserver/dummyurl/1/']) + +    def test_dotted_source(self): +        """ +        Source argument should support dotted.source notation. +        """ +        BlogPost.objects.create(title='blah') +        field = serializers.HyperlinkedRelatedField( +            many=True, +            source='a.b.c', +            view_name='dummy-url', +        ) +        field.context = {'request': request} + +        class ClassWithQuerysetMethod(object): +            a = { +                'b': { +                    'c': BlogPost.objects.all() +                } +            } + +        obj = ClassWithQuerysetMethod() +        value = field.field_to_native(obj, 'field_name') +        self.assertEqual(value, ['http://testserver/dummyurl/1/']) diff --git a/tests/test_relations_nested.py b/tests/test_relations_nested.py new file mode 100644 index 00000000..4d9da489 --- /dev/null +++ b/tests/test_relations_nested.py @@ -0,0 +1,326 @@ +from __future__ import unicode_literals +from django.db import models +from django.test import TestCase +from rest_framework import serializers + +from .models import OneToOneTarget + + +class OneToOneSource(models.Model): +    name = models.CharField(max_length=100) +    target = models.OneToOneField(OneToOneTarget, related_name='source', +                                  null=True, blank=True) + + +class OneToManyTarget(models.Model): +    name = models.CharField(max_length=100) + + +class OneToManySource(models.Model): +    name = models.CharField(max_length=100) +    target = models.ForeignKey(OneToManyTarget, related_name='sources') + + +class ReverseNestedOneToOneTests(TestCase): +    def setUp(self): +        class OneToOneSourceSerializer(serializers.ModelSerializer): +            class Meta: +                model = OneToOneSource +                fields = ('id', 'name') + +        class OneToOneTargetSerializer(serializers.ModelSerializer): +            source = OneToOneSourceSerializer() + +            class Meta: +                model = OneToOneTarget +                fields = ('id', 'name', 'source') + +        self.Serializer = OneToOneTargetSerializer + +        for idx in range(1, 4): +            target = OneToOneTarget(name='target-%d' % idx) +            target.save() +            source = OneToOneSource(name='source-%d' % idx, target=target) +            source.save() + +    def test_one_to_one_retrieve(self): +        queryset = OneToOneTarget.objects.all() +        serializer = self.Serializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, +            {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, +            {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_one_to_one_create(self): +        data = {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}} +        serializer = self.Serializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'target-4') + +        # Ensure (target 4, target_source 4, source 4) are added, and +        # everything else is as expected. +        queryset = OneToOneTarget.objects.all() +        serializer = self.Serializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, +            {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, +            {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}}, +            {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_one_to_one_create_with_invalid_data(self): +        data = {'id': 4, 'name': 'target-4', 'source': {'id': 4}} +        serializer = self.Serializer(data=data) +        self.assertFalse(serializer.is_valid()) +        self.assertEqual(serializer.errors, {'source': [{'name': ['This field is required.']}]}) + +    def test_one_to_one_update(self): +        data = {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}} +        instance = OneToOneTarget.objects.get(pk=3) +        serializer = self.Serializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'target-3-updated') + +        # Ensure (target 3, target_source 3, source 3) are updated, +        # and everything else is as expected. +        queryset = OneToOneTarget.objects.all() +        serializer = self.Serializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, +            {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, +            {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}} +        ] +        self.assertEqual(serializer.data, expected) + + +class ForwardNestedOneToOneTests(TestCase): +    def setUp(self): +        class OneToOneTargetSerializer(serializers.ModelSerializer): +            class Meta: +                model = OneToOneTarget +                fields = ('id', 'name') + +        class OneToOneSourceSerializer(serializers.ModelSerializer): +            target = OneToOneTargetSerializer() + +            class Meta: +                model = OneToOneSource +                fields = ('id', 'name', 'target') + +        self.Serializer = OneToOneSourceSerializer + +        for idx in range(1, 4): +            target = OneToOneTarget(name='target-%d' % idx) +            target.save() +            source = OneToOneSource(name='source-%d' % idx, target=target) +            source.save() + +    def test_one_to_one_retrieve(self): +        queryset = OneToOneSource.objects.all() +        serializer = self.Serializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, +            {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, +            {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_one_to_one_create(self): +        data = {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}} +        serializer = self.Serializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'source-4') + +        # Ensure (target 4, target_source 4, source 4) are added, and +        # everything else is as expected. +        queryset = OneToOneSource.objects.all() +        serializer = self.Serializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, +            {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, +            {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}}, +            {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_one_to_one_create_with_invalid_data(self): +        data = {'id': 4, 'name': 'source-4', 'target': {'id': 4}} +        serializer = self.Serializer(data=data) +        self.assertFalse(serializer.is_valid()) +        self.assertEqual(serializer.errors, {'target': [{'name': ['This field is required.']}]}) + +    def test_one_to_one_update(self): +        data = {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}} +        instance = OneToOneSource.objects.get(pk=3) +        serializer = self.Serializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'source-3-updated') + +        # Ensure (target 3, target_source 3, source 3) are updated, +        # and everything else is as expected. +        queryset = OneToOneSource.objects.all() +        serializer = self.Serializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, +            {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, +            {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_one_to_one_update_to_null(self): +        data = {'id': 3, 'name': 'source-3-updated', 'target': None} +        instance = OneToOneSource.objects.get(pk=3) +        serializer = self.Serializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() + +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'source-3-updated') +        self.assertEqual(obj.target, None) + +        queryset = OneToOneSource.objects.all() +        serializer = self.Serializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, +            {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, +            {'id': 3, 'name': 'source-3-updated', 'target': None} +        ] +        self.assertEqual(serializer.data, expected) + +    # TODO: Nullable 1-1 tests +    # def test_one_to_one_delete(self): +    #     data = {'id': 3, 'name': 'target-3', 'target_source': None} +    #     instance = OneToOneTarget.objects.get(pk=3) +    #     serializer = self.Serializer(instance, data=data) +    #     self.assertTrue(serializer.is_valid()) +    #     serializer.save() + +    #     # Ensure (target_source 3, source 3) are deleted, +    #     # and everything else is as expected. +    #     queryset = OneToOneTarget.objects.all() +    #     serializer = self.Serializer(queryset) +    #     expected = [ +    #         {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, +    #         {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, +    #         {'id': 3, 'name': 'target-3', 'source': None} +    #     ] +    #     self.assertEqual(serializer.data, expected) + + +class ReverseNestedOneToManyTests(TestCase): +    def setUp(self): +        class OneToManySourceSerializer(serializers.ModelSerializer): +            class Meta: +                model = OneToManySource +                fields = ('id', 'name') + +        class OneToManyTargetSerializer(serializers.ModelSerializer): +            sources = OneToManySourceSerializer(many=True, allow_add_remove=True) + +            class Meta: +                model = OneToManyTarget +                fields = ('id', 'name', 'sources') + +        self.Serializer = OneToManyTargetSerializer + +        target = OneToManyTarget(name='target-1') +        target.save() +        for idx in range(1, 4): +            source = OneToManySource(name='source-%d' % idx, target=target) +            source.save() + +    def test_one_to_many_retrieve(self): +        queryset = OneToManyTarget.objects.all() +        serializer = self.Serializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, +                                                      {'id': 2, 'name': 'source-2'}, +                                                      {'id': 3, 'name': 'source-3'}]}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_one_to_many_create(self): +        data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, +                                                         {'id': 2, 'name': 'source-2'}, +                                                         {'id': 3, 'name': 'source-3'}, +                                                         {'id': 4, 'name': 'source-4'}]} +        instance = OneToManyTarget.objects.get(pk=1) +        serializer = self.Serializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'target-1') + +        # Ensure source 4 is added, and everything else is as +        # expected. +        queryset = OneToManyTarget.objects.all() +        serializer = self.Serializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, +                                                      {'id': 2, 'name': 'source-2'}, +                                                      {'id': 3, 'name': 'source-3'}, +                                                      {'id': 4, 'name': 'source-4'}]} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_one_to_many_create_with_invalid_data(self): +        data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, +                                                         {'id': 2, 'name': 'source-2'}, +                                                         {'id': 3, 'name': 'source-3'}, +                                                         {'id': 4}]} +        serializer = self.Serializer(data=data) +        self.assertFalse(serializer.is_valid()) +        self.assertEqual(serializer.errors, {'sources': [{}, {}, {}, {'name': ['This field is required.']}]}) + +    def test_one_to_many_update(self): +        data = {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'}, +                                                                 {'id': 2, 'name': 'source-2'}, +                                                                 {'id': 3, 'name': 'source-3'}]} +        instance = OneToManyTarget.objects.get(pk=1) +        serializer = self.Serializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'target-1-updated') + +        # Ensure (target 1, source 1) are updated, +        # and everything else is as expected. +        queryset = OneToManyTarget.objects.all() +        serializer = self.Serializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'}, +                                                              {'id': 2, 'name': 'source-2'}, +                                                              {'id': 3, 'name': 'source-3'}]} + +        ] +        self.assertEqual(serializer.data, expected) + +    def test_one_to_many_delete(self): +        data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, +                                                         {'id': 3, 'name': 'source-3'}]} +        instance = OneToManyTarget.objects.get(pk=1) +        serializer = self.Serializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        serializer.save() + +        # Ensure source 2 is deleted, and everything else is as +        # expected. +        queryset = OneToManyTarget.objects.all() +        serializer = self.Serializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, +                                                      {'id': 3, 'name': 'source-3'}]} + +        ] +        self.assertEqual(serializer.data, expected) diff --git a/tests/test_relations_pk.py b/tests/test_relations_pk.py new file mode 100644 index 00000000..ff59b250 --- /dev/null +++ b/tests/test_relations_pk.py @@ -0,0 +1,551 @@ +from __future__ import unicode_literals +from django.db import models +from django.test import TestCase +from rest_framework import serializers +from tests.models import ( +    BlogPost, ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, +    NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource, +) +from rest_framework.compat import six + + +# ManyToMany +class ManyToManyTargetSerializer(serializers.ModelSerializer): +    class Meta: +        model = ManyToManyTarget +        fields = ('id', 'name', 'sources') + + +class ManyToManySourceSerializer(serializers.ModelSerializer): +    class Meta: +        model = ManyToManySource +        fields = ('id', 'name', 'targets') + + +# ForeignKey +class ForeignKeyTargetSerializer(serializers.ModelSerializer): +    class Meta: +        model = ForeignKeyTarget +        fields = ('id', 'name', 'sources') + + +class ForeignKeySourceSerializer(serializers.ModelSerializer): +    class Meta: +        model = ForeignKeySource +        fields = ('id', 'name', 'target') + + +# Nullable ForeignKey +class NullableForeignKeySourceSerializer(serializers.ModelSerializer): +    class Meta: +        model = NullableForeignKeySource +        fields = ('id', 'name', 'target') + + +# Nullable OneToOne +class NullableOneToOneTargetSerializer(serializers.ModelSerializer): +    class Meta: +        model = OneToOneTarget +        fields = ('id', 'name', 'nullable_source') + + +# TODO: Add test that .data cannot be accessed prior to .is_valid + +class PKManyToManyTests(TestCase): +    def setUp(self): +        for idx in range(1, 4): +            target = ManyToManyTarget(name='target-%d' % idx) +            target.save() +            source = ManyToManySource(name='source-%d' % idx) +            source.save() +            for target in ManyToManyTarget.objects.all(): +                source.targets.add(target) + +    def test_many_to_many_retrieve(self): +        queryset = ManyToManySource.objects.all() +        serializer = ManyToManySourceSerializer(queryset, many=True) +        expected = [ +                {'id': 1, 'name': 'source-1', 'targets': [1]}, +                {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, +                {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_reverse_many_to_many_retrieve(self): +        queryset = ManyToManyTarget.objects.all() +        serializer = ManyToManyTargetSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, +            {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, +            {'id': 3, 'name': 'target-3', 'sources': [3]} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_many_to_many_update(self): +        data = {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]} +        instance = ManyToManySource.objects.get(pk=1) +        serializer = ManyToManySourceSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        serializer.save() +        self.assertEqual(serializer.data, data) + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = ManyToManySource.objects.all() +        serializer = ManyToManySourceSerializer(queryset, many=True) +        expected = [ +                {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]}, +                {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, +                {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_reverse_many_to_many_update(self): +        data = {'id': 1, 'name': 'target-1', 'sources': [1]} +        instance = ManyToManyTarget.objects.get(pk=1) +        serializer = ManyToManyTargetSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        serializer.save() +        self.assertEqual(serializer.data, data) + +        # Ensure target 1 is updated, and everything else is as expected +        queryset = ManyToManyTarget.objects.all() +        serializer = ManyToManyTargetSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'sources': [1]}, +            {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, +            {'id': 3, 'name': 'target-3', 'sources': [3]} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_many_to_many_create(self): +        data = {'id': 4, 'name': 'source-4', 'targets': [1, 3]} +        serializer = ManyToManySourceSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'source-4') + +        # Ensure source 4 is added, and everything else is as expected +        queryset = ManyToManySource.objects.all() +        serializer = ManyToManySourceSerializer(queryset, many=True) +        self.assertFalse(serializer.fields['targets'].read_only) +        expected = [ +            {'id': 1, 'name': 'source-1', 'targets': [1]}, +            {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, +            {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}, +            {'id': 4, 'name': 'source-4', 'targets': [1, 3]}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_reverse_many_to_many_create(self): +        data = {'id': 4, 'name': 'target-4', 'sources': [1, 3]} +        serializer = ManyToManyTargetSerializer(data=data) +        self.assertFalse(serializer.fields['sources'].read_only) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'target-4') + +        # Ensure target 4 is added, and everything else is as expected +        queryset = ManyToManyTarget.objects.all() +        serializer = ManyToManyTargetSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, +            {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, +            {'id': 3, 'name': 'target-3', 'sources': [3]}, +            {'id': 4, 'name': 'target-4', 'sources': [1, 3]} +        ] +        self.assertEqual(serializer.data, expected) + + +class PKForeignKeyTests(TestCase): +    def setUp(self): +        target = ForeignKeyTarget(name='target-1') +        target.save() +        new_target = ForeignKeyTarget(name='target-2') +        new_target.save() +        for idx in range(1, 4): +            source = ForeignKeySource(name='source-%d' % idx, target=target) +            source.save() + +    def test_foreign_key_retrieve(self): +        queryset = ForeignKeySource.objects.all() +        serializer = ForeignKeySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': 1}, +            {'id': 2, 'name': 'source-2', 'target': 1}, +            {'id': 3, 'name': 'source-3', 'target': 1} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_reverse_foreign_key_retrieve(self): +        queryset = ForeignKeyTarget.objects.all() +        serializer = ForeignKeyTargetSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, +            {'id': 2, 'name': 'target-2', 'sources': []}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_update(self): +        data = {'id': 1, 'name': 'source-1', 'target': 2} +        instance = ForeignKeySource.objects.get(pk=1) +        serializer = ForeignKeySourceSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        self.assertEqual(serializer.data, data) +        serializer.save() + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = ForeignKeySource.objects.all() +        serializer = ForeignKeySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': 2}, +            {'id': 2, 'name': 'source-2', 'target': 1}, +            {'id': 3, 'name': 'source-3', 'target': 1} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_update_incorrect_type(self): +        data = {'id': 1, 'name': 'source-1', 'target': 'foo'} +        instance = ForeignKeySource.objects.get(pk=1) +        serializer = ForeignKeySourceSerializer(instance, data=data) +        self.assertFalse(serializer.is_valid()) +        self.assertEqual(serializer.errors, {'target': ['Incorrect type.  Expected pk value, received %s.' % six.text_type.__name__]}) + +    def test_reverse_foreign_key_update(self): +        data = {'id': 2, 'name': 'target-2', 'sources': [1, 3]} +        instance = ForeignKeyTarget.objects.get(pk=2) +        serializer = ForeignKeyTargetSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        # We shouldn't have saved anything to the db yet since save +        # hasn't been called. +        queryset = ForeignKeyTarget.objects.all() +        new_serializer = ForeignKeyTargetSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, +            {'id': 2, 'name': 'target-2', 'sources': []}, +        ] +        self.assertEqual(new_serializer.data, expected) + +        serializer.save() +        self.assertEqual(serializer.data, data) + +        # Ensure target 2 is update, and everything else is as expected +        queryset = ForeignKeyTarget.objects.all() +        serializer = ForeignKeyTargetSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'sources': [2]}, +            {'id': 2, 'name': 'target-2', 'sources': [1, 3]}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_create(self): +        data = {'id': 4, 'name': 'source-4', 'target': 2} +        serializer = ForeignKeySourceSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'source-4') + +        # Ensure source 4 is added, and everything else is as expected +        queryset = ForeignKeySource.objects.all() +        serializer = ForeignKeySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': 1}, +            {'id': 2, 'name': 'source-2', 'target': 1}, +            {'id': 3, 'name': 'source-3', 'target': 1}, +            {'id': 4, 'name': 'source-4', 'target': 2}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_reverse_foreign_key_create(self): +        data = {'id': 3, 'name': 'target-3', 'sources': [1, 3]} +        serializer = ForeignKeyTargetSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'target-3') + +        # Ensure target 3 is added, and everything else is as expected +        queryset = ForeignKeyTarget.objects.all() +        serializer = ForeignKeyTargetSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'sources': [2]}, +            {'id': 2, 'name': 'target-2', 'sources': []}, +            {'id': 3, 'name': 'target-3', 'sources': [1, 3]}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_update_with_invalid_null(self): +        data = {'id': 1, 'name': 'source-1', 'target': None} +        instance = ForeignKeySource.objects.get(pk=1) +        serializer = ForeignKeySourceSerializer(instance, data=data) +        self.assertFalse(serializer.is_valid()) +        self.assertEqual(serializer.errors, {'target': ['This field is required.']}) + +    def test_foreign_key_with_empty(self): +        """ +        Regression test for #1072 + +        https://github.com/tomchristie/django-rest-framework/issues/1072 +        """ +        serializer = NullableForeignKeySourceSerializer() +        self.assertEqual(serializer.data['target'], None) + + +class PKNullableForeignKeyTests(TestCase): +    def setUp(self): +        target = ForeignKeyTarget(name='target-1') +        target.save() +        for idx in range(1, 4): +            if idx == 3: +                target = None +            source = NullableForeignKeySource(name='source-%d' % idx, target=target) +            source.save() + +    def test_foreign_key_retrieve_with_null(self): +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': 1}, +            {'id': 2, 'name': 'source-2', 'target': 1}, +            {'id': 3, 'name': 'source-3', 'target': None}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_create_with_valid_null(self): +        data = {'id': 4, 'name': 'source-4', 'target': None} +        serializer = NullableForeignKeySourceSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'source-4') + +        # Ensure source 4 is created, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': 1}, +            {'id': 2, 'name': 'source-2', 'target': 1}, +            {'id': 3, 'name': 'source-3', 'target': None}, +            {'id': 4, 'name': 'source-4', 'target': None} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_create_with_valid_emptystring(self): +        """ +        The emptystring should be interpreted as null in the context +        of relationships. +        """ +        data = {'id': 4, 'name': 'source-4', 'target': ''} +        expected_data = {'id': 4, 'name': 'source-4', 'target': None} +        serializer = NullableForeignKeySourceSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, expected_data) +        self.assertEqual(obj.name, 'source-4') + +        # Ensure source 4 is created, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': 1}, +            {'id': 2, 'name': 'source-2', 'target': 1}, +            {'id': 3, 'name': 'source-3', 'target': None}, +            {'id': 4, 'name': 'source-4', 'target': None} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_update_with_valid_null(self): +        data = {'id': 1, 'name': 'source-1', 'target': None} +        instance = NullableForeignKeySource.objects.get(pk=1) +        serializer = NullableForeignKeySourceSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        self.assertEqual(serializer.data, data) +        serializer.save() + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': None}, +            {'id': 2, 'name': 'source-2', 'target': 1}, +            {'id': 3, 'name': 'source-3', 'target': None} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_update_with_valid_emptystring(self): +        """ +        The emptystring should be interpreted as null in the context +        of relationships. +        """ +        data = {'id': 1, 'name': 'source-1', 'target': ''} +        expected_data = {'id': 1, 'name': 'source-1', 'target': None} +        instance = NullableForeignKeySource.objects.get(pk=1) +        serializer = NullableForeignKeySourceSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        self.assertEqual(serializer.data, expected_data) +        serializer.save() + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': None}, +            {'id': 2, 'name': 'source-2', 'target': 1}, +            {'id': 3, 'name': 'source-3', 'target': None} +        ] +        self.assertEqual(serializer.data, expected) + +    # reverse foreign keys MUST be read_only +    # In the general case they do not provide .remove() or .clear() +    # and cannot be arbitrarily set. + +    # def test_reverse_foreign_key_update(self): +    #     data = {'id': 1, 'name': 'target-1', 'sources': [1]} +    #     instance = ForeignKeyTarget.objects.get(pk=1) +    #     serializer = ForeignKeyTargetSerializer(instance, data=data) +    #     self.assertTrue(serializer.is_valid()) +    #     self.assertEqual(serializer.data, data) +    #     serializer.save() + +    #     # Ensure target 1 is updated, and everything else is as expected +    #     queryset = ForeignKeyTarget.objects.all() +    #     serializer = ForeignKeyTargetSerializer(queryset, many=True) +    #     expected = [ +    #         {'id': 1, 'name': 'target-1', 'sources': [1]}, +    #         {'id': 2, 'name': 'target-2', 'sources': []}, +    #     ] +    #     self.assertEqual(serializer.data, expected) + + +class PKNullableOneToOneTests(TestCase): +    def setUp(self): +        target = OneToOneTarget(name='target-1') +        target.save() +        new_target = OneToOneTarget(name='target-2') +        new_target.save() +        source = NullableOneToOneSource(name='source-1', target=new_target) +        source.save() + +    def test_reverse_foreign_key_retrieve_with_null(self): +        queryset = OneToOneTarget.objects.all() +        serializer = NullableOneToOneTargetSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'nullable_source': None}, +            {'id': 2, 'name': 'target-2', 'nullable_source': 1}, +        ] +        self.assertEqual(serializer.data, expected) + + +# The below models and tests ensure that serializer fields corresponding +# to a ManyToManyField field with a user-specified ``through`` model are +# set to read only + + +class ManyToManyThroughTarget(models.Model): +    name = models.CharField(max_length=100) + + +class ManyToManyThrough(models.Model): +    source = models.ForeignKey('ManyToManyThroughSource') +    target = models.ForeignKey(ManyToManyThroughTarget) + + +class ManyToManyThroughSource(models.Model): +    name = models.CharField(max_length=100) +    targets = models.ManyToManyField(ManyToManyThroughTarget, +                                     related_name='sources', +                                     through='ManyToManyThrough') + + +class ManyToManyThroughTargetSerializer(serializers.ModelSerializer): +    class Meta: +        model = ManyToManyThroughTarget +        fields = ('id', 'name', 'sources') + + +class ManyToManyThroughSourceSerializer(serializers.ModelSerializer): +    class Meta: +        model = ManyToManyThroughSource +        fields = ('id', 'name', 'targets') + + +class PKManyToManyThroughTests(TestCase): +    def setUp(self): +        self.source = ManyToManyThroughSource.objects.create( +            name='through-source-1') +        self.target = ManyToManyThroughTarget.objects.create( +            name='through-target-1') + +    def test_many_to_many_create(self): +        data = {'id': 2, 'name': 'source-2', 'targets': [self.target.pk]} +        serializer = ManyToManyThroughSourceSerializer(data=data) +        self.assertTrue(serializer.fields['targets'].read_only) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(obj.name, 'source-2') +        self.assertEqual(obj.targets.count(), 0) + +    def test_many_to_many_reverse_create(self): +        data = {'id': 2, 'name': 'target-2', 'sources': [self.source.pk]} +        serializer = ManyToManyThroughTargetSerializer(data=data) +        self.assertTrue(serializer.fields['sources'].read_only) +        self.assertTrue(serializer.is_valid()) +        serializer.save() +        obj = serializer.save() +        self.assertEqual(obj.name, 'target-2') +        self.assertEqual(obj.sources.count(), 0) + + +# Regression tests for #694 (`source` attribute on related fields) + + +class PrimaryKeyRelatedFieldSourceTests(TestCase): +    def test_related_manager_source(self): +        """ +        Relational fields should be able to use manager-returning methods as their source. +        """ +        BlogPost.objects.create(title='blah') +        field = serializers.PrimaryKeyRelatedField(many=True, source='get_blogposts_manager') + +        class ClassWithManagerMethod(object): +            def get_blogposts_manager(self): +                return BlogPost.objects + +        obj = ClassWithManagerMethod() +        value = field.field_to_native(obj, 'field_name') +        self.assertEqual(value, [1]) + +    def test_related_queryset_source(self): +        """ +        Relational fields should be able to use queryset-returning methods as their source. +        """ +        BlogPost.objects.create(title='blah') +        field = serializers.PrimaryKeyRelatedField(many=True, source='get_blogposts_queryset') + +        class ClassWithQuerysetMethod(object): +            def get_blogposts_queryset(self): +                return BlogPost.objects.all() + +        obj = ClassWithQuerysetMethod() +        value = field.field_to_native(obj, 'field_name') +        self.assertEqual(value, [1]) + +    def test_dotted_source(self): +        """ +        Source argument should support dotted.source notation. +        """ +        BlogPost.objects.create(title='blah') +        field = serializers.PrimaryKeyRelatedField(many=True, source='a.b.c') + +        class ClassWithQuerysetMethod(object): +            a = { +                'b': { +                    'c': BlogPost.objects.all() +                } +            } + +        obj = ClassWithQuerysetMethod() +        value = field.field_to_native(obj, 'field_name') +        self.assertEqual(value, [1]) diff --git a/tests/test_relations_slug.py b/tests/test_relations_slug.py new file mode 100644 index 00000000..97ebf23a --- /dev/null +++ b/tests/test_relations_slug.py @@ -0,0 +1,257 @@ +from django.test import TestCase +from rest_framework import serializers +from tests.models import NullableForeignKeySource, ForeignKeySource, ForeignKeyTarget + + +class ForeignKeyTargetSerializer(serializers.ModelSerializer): +    sources = serializers.SlugRelatedField(many=True, slug_field='name') + +    class Meta: +        model = ForeignKeyTarget + + +class ForeignKeySourceSerializer(serializers.ModelSerializer): +    target = serializers.SlugRelatedField(slug_field='name') + +    class Meta: +        model = ForeignKeySource + + +class NullableForeignKeySourceSerializer(serializers.ModelSerializer): +    target = serializers.SlugRelatedField(slug_field='name', required=False) + +    class Meta: +        model = NullableForeignKeySource + + +# TODO: M2M Tests, FKTests (Non-nullable), One2One +class SlugForeignKeyTests(TestCase): +    def setUp(self): +        target = ForeignKeyTarget(name='target-1') +        target.save() +        new_target = ForeignKeyTarget(name='target-2') +        new_target.save() +        for idx in range(1, 4): +            source = ForeignKeySource(name='source-%d' % idx, target=target) +            source.save() + +    def test_foreign_key_retrieve(self): +        queryset = ForeignKeySource.objects.all() +        serializer = ForeignKeySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': 'target-1'}, +            {'id': 2, 'name': 'source-2', 'target': 'target-1'}, +            {'id': 3, 'name': 'source-3', 'target': 'target-1'} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_reverse_foreign_key_retrieve(self): +        queryset = ForeignKeyTarget.objects.all() +        serializer = ForeignKeyTargetSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, +            {'id': 2, 'name': 'target-2', 'sources': []}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_update(self): +        data = {'id': 1, 'name': 'source-1', 'target': 'target-2'} +        instance = ForeignKeySource.objects.get(pk=1) +        serializer = ForeignKeySourceSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        self.assertEqual(serializer.data, data) +        serializer.save() + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = ForeignKeySource.objects.all() +        serializer = ForeignKeySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': 'target-2'}, +            {'id': 2, 'name': 'source-2', 'target': 'target-1'}, +            {'id': 3, 'name': 'source-3', 'target': 'target-1'} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_update_incorrect_type(self): +        data = {'id': 1, 'name': 'source-1', 'target': 123} +        instance = ForeignKeySource.objects.get(pk=1) +        serializer = ForeignKeySourceSerializer(instance, data=data) +        self.assertFalse(serializer.is_valid()) +        self.assertEqual(serializer.errors, {'target': ['Object with name=123 does not exist.']}) + +    def test_reverse_foreign_key_update(self): +        data = {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']} +        instance = ForeignKeyTarget.objects.get(pk=2) +        serializer = ForeignKeyTargetSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        # We shouldn't have saved anything to the db yet since save +        # hasn't been called. +        queryset = ForeignKeyTarget.objects.all() +        new_serializer = ForeignKeyTargetSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, +            {'id': 2, 'name': 'target-2', 'sources': []}, +        ] +        self.assertEqual(new_serializer.data, expected) + +        serializer.save() +        self.assertEqual(serializer.data, data) + +        # Ensure target 2 is update, and everything else is as expected +        queryset = ForeignKeyTarget.objects.all() +        serializer = ForeignKeyTargetSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'sources': ['source-2']}, +            {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_create(self): +        data = {'id': 4, 'name': 'source-4', 'target': 'target-2'} +        serializer = ForeignKeySourceSerializer(data=data) +        serializer.is_valid() +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'source-4') + +        # Ensure source 4 is added, and everything else is as expected +        queryset = ForeignKeySource.objects.all() +        serializer = ForeignKeySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': 'target-1'}, +            {'id': 2, 'name': 'source-2', 'target': 'target-1'}, +            {'id': 3, 'name': 'source-3', 'target': 'target-1'}, +            {'id': 4, 'name': 'source-4', 'target': 'target-2'}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_reverse_foreign_key_create(self): +        data = {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']} +        serializer = ForeignKeyTargetSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'target-3') + +        # Ensure target 3 is added, and everything else is as expected +        queryset = ForeignKeyTarget.objects.all() +        serializer = ForeignKeyTargetSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'sources': ['source-2']}, +            {'id': 2, 'name': 'target-2', 'sources': []}, +            {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_update_with_invalid_null(self): +        data = {'id': 1, 'name': 'source-1', 'target': None} +        instance = ForeignKeySource.objects.get(pk=1) +        serializer = ForeignKeySourceSerializer(instance, data=data) +        self.assertFalse(serializer.is_valid()) +        self.assertEqual(serializer.errors, {'target': ['This field is required.']}) + + +class SlugNullableForeignKeyTests(TestCase): +    def setUp(self): +        target = ForeignKeyTarget(name='target-1') +        target.save() +        for idx in range(1, 4): +            if idx == 3: +                target = None +            source = NullableForeignKeySource(name='source-%d' % idx, target=target) +            source.save() + +    def test_foreign_key_retrieve_with_null(self): +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': 'target-1'}, +            {'id': 2, 'name': 'source-2', 'target': 'target-1'}, +            {'id': 3, 'name': 'source-3', 'target': None}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_create_with_valid_null(self): +        data = {'id': 4, 'name': 'source-4', 'target': None} +        serializer = NullableForeignKeySourceSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'source-4') + +        # Ensure source 4 is created, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': 'target-1'}, +            {'id': 2, 'name': 'source-2', 'target': 'target-1'}, +            {'id': 3, 'name': 'source-3', 'target': None}, +            {'id': 4, 'name': 'source-4', 'target': None} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_create_with_valid_emptystring(self): +        """ +        The emptystring should be interpreted as null in the context +        of relationships. +        """ +        data = {'id': 4, 'name': 'source-4', 'target': ''} +        expected_data = {'id': 4, 'name': 'source-4', 'target': None} +        serializer = NullableForeignKeySourceSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, expected_data) +        self.assertEqual(obj.name, 'source-4') + +        # Ensure source 4 is created, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': 'target-1'}, +            {'id': 2, 'name': 'source-2', 'target': 'target-1'}, +            {'id': 3, 'name': 'source-3', 'target': None}, +            {'id': 4, 'name': 'source-4', 'target': None} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_update_with_valid_null(self): +        data = {'id': 1, 'name': 'source-1', 'target': None} +        instance = NullableForeignKeySource.objects.get(pk=1) +        serializer = NullableForeignKeySourceSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        self.assertEqual(serializer.data, data) +        serializer.save() + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': None}, +            {'id': 2, 'name': 'source-2', 'target': 'target-1'}, +            {'id': 3, 'name': 'source-3', 'target': None} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_foreign_key_update_with_valid_emptystring(self): +        """ +        The emptystring should be interpreted as null in the context +        of relationships. +        """ +        data = {'id': 1, 'name': 'source-1', 'target': ''} +        expected_data = {'id': 1, 'name': 'source-1', 'target': None} +        instance = NullableForeignKeySource.objects.get(pk=1) +        serializer = NullableForeignKeySourceSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        self.assertEqual(serializer.data, expected_data) +        serializer.save() + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': None}, +            {'id': 2, 'name': 'source-2', 'target': 'target-1'}, +            {'id': 3, 'name': 'source-3', 'target': None} +        ] +        self.assertEqual(serializer.data, expected) diff --git a/tests/test_renderers.py b/tests/test_renderers.py new file mode 100644 index 00000000..f733d6b6 --- /dev/null +++ b/tests/test_renderers.py @@ -0,0 +1,666 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals + +from decimal import Decimal +from django.core.cache import cache +from django.db import models +from django.test import TestCase +from django.utils import unittest +from django.utils.translation import ugettext_lazy as _ +from rest_framework import status, permissions +from rest_framework.compat import yaml, etree, patterns, url, include, six, StringIO +from rest_framework.response import Response +from rest_framework.views import APIView +from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \ +    XMLRenderer, JSONPRenderer, BrowsableAPIRenderer, UnicodeJSONRenderer, UnicodeYAMLRenderer +from rest_framework.parsers import YAMLParser, XMLParser +from rest_framework.settings import api_settings +from rest_framework.test import APIRequestFactory +from collections import MutableMapping +import datetime +import json +import pickle +import re + + +DUMMYSTATUS = status.HTTP_200_OK +DUMMYCONTENT = 'dummycontent' + +RENDERER_A_SERIALIZER = lambda x: ('Renderer A: %s' % x).encode('ascii') +RENDERER_B_SERIALIZER = lambda x: ('Renderer B: %s' % x).encode('ascii') + + +expected_results = [ +    ((elem for elem in [1, 2, 3]), JSONRenderer, b'[1, 2, 3]')  # Generator +] + + +class DummyTestModel(models.Model): +    name = models.CharField(max_length=42, default='') + + +class BasicRendererTests(TestCase): +    def test_expected_results(self): +        for value, renderer_cls, expected in expected_results: +            output = renderer_cls().render(value) +            self.assertEqual(output, expected) + + +class RendererA(BaseRenderer): +    media_type = 'mock/renderera' +    format = "formata" + +    def render(self, data, media_type=None, renderer_context=None): +        return RENDERER_A_SERIALIZER(data) + + +class RendererB(BaseRenderer): +    media_type = 'mock/rendererb' +    format = "formatb" + +    def render(self, data, media_type=None, renderer_context=None): +        return RENDERER_B_SERIALIZER(data) + + +class MockView(APIView): +    renderer_classes = (RendererA, RendererB) + +    def get(self, request, **kwargs): +        response = Response(DUMMYCONTENT, status=DUMMYSTATUS) +        return response + + +class MockGETView(APIView): +    def get(self, request, **kwargs): +        return Response({'foo': ['bar', 'baz']}) + + + +class MockPOSTView(APIView): +    def post(self, request, **kwargs): +        return Response({'foo': request.DATA}) + + +class EmptyGETView(APIView): +    renderer_classes = (JSONRenderer,) + +    def get(self, request, **kwargs): +        return Response(status=status.HTTP_204_NO_CONTENT) + + +class HTMLView(APIView): +    renderer_classes = (BrowsableAPIRenderer, ) + +    def get(self, request, **kwargs): +        return Response('text') + + +class HTMLView1(APIView): +    renderer_classes = (BrowsableAPIRenderer, JSONRenderer) + +    def get(self, request, **kwargs): +        return Response('text') + +urlpatterns = patterns('', +    url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])), +    url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])), +    url(r'^cache$', MockGETView.as_view()), +    url(r'^jsonp/jsonrenderer$', MockGETView.as_view(renderer_classes=[JSONRenderer, JSONPRenderer])), +    url(r'^jsonp/nojsonrenderer$', MockGETView.as_view(renderer_classes=[JSONPRenderer])), +    url(r'^parseerror$', MockPOSTView.as_view(renderer_classes=[JSONRenderer, BrowsableAPIRenderer])), +    url(r'^html$', HTMLView.as_view()), +    url(r'^html1$', HTMLView1.as_view()), +    url(r'^empty$', EmptyGETView.as_view()), +    url(r'^api', include('rest_framework.urls', namespace='rest_framework')) +) + + +class POSTDeniedPermission(permissions.BasePermission): +    def has_permission(self, request, view): +        return request.method != 'POST' + + +class POSTDeniedView(APIView): +    renderer_classes = (BrowsableAPIRenderer,) +    permission_classes = (POSTDeniedPermission,) + +    def get(self, request): +        return Response() + +    def post(self, request): +        return Response() + +    def put(self, request): +        return Response() + +    def patch(self, request): +        return Response() + + +class DocumentingRendererTests(TestCase): +    def test_only_permitted_forms_are_displayed(self): +        view = POSTDeniedView.as_view() +        request = APIRequestFactory().get('/') +        response = view(request).render() +        self.assertNotContains(response, '>POST<') +        self.assertContains(response, '>PUT<') +        self.assertContains(response, '>PATCH<') + + +class RendererEndToEndTests(TestCase): +    """ +    End-to-end testing of renderers using an RendererMixin on a generic view. +    """ + +    urls = 'tests.test_renderers' + +    def test_default_renderer_serializes_content(self): +        """If the Accept header is not set the default renderer should serialize the response.""" +        resp = self.client.get('/') +        self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) +        self.assertEqual(resp.status_code, DUMMYSTATUS) + +    def test_head_method_serializes_no_content(self): +        """No response must be included in HEAD requests.""" +        resp = self.client.head('/') +        self.assertEqual(resp.status_code, DUMMYSTATUS) +        self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, six.b('')) + +    def test_default_renderer_serializes_content_on_accept_any(self): +        """If the Accept header is set to */* the default renderer should serialize the response.""" +        resp = self.client.get('/', HTTP_ACCEPT='*/*') +        self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) +        self.assertEqual(resp.status_code, DUMMYSTATUS) + +    def test_specified_renderer_serializes_content_default_case(self): +        """If the Accept header is set the specified renderer should serialize the response. +        (In this case we check that works for the default renderer)""" +        resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type) +        self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) +        self.assertEqual(resp.status_code, DUMMYSTATUS) + +    def test_specified_renderer_serializes_content_non_default_case(self): +        """If the Accept header is set the specified renderer should serialize the response. +        (In this case we check that works for a non-default renderer)""" +        resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type) +        self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) +        self.assertEqual(resp.status_code, DUMMYSTATUS) + +    def test_specified_renderer_serializes_content_on_accept_query(self): +        """The '_accept' query string should behave in the same way as the Accept header.""" +        param = '?%s=%s' % ( +            api_settings.URL_ACCEPT_OVERRIDE, +            RendererB.media_type +        ) +        resp = self.client.get('/' + param) +        self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) +        self.assertEqual(resp.status_code, DUMMYSTATUS) + +    def test_unsatisfiable_accept_header_on_request_returns_406_status(self): +        """If the Accept header is unsatisfiable we should return a 406 Not Acceptable response.""" +        resp = self.client.get('/', HTTP_ACCEPT='foo/bar') +        self.assertEqual(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE) + +    def test_specified_renderer_serializes_content_on_format_query(self): +        """If a 'format' query is specified, the renderer with the matching +        format attribute should serialize the response.""" +        param = '?%s=%s' % ( +            api_settings.URL_FORMAT_OVERRIDE, +            RendererB.format +        ) +        resp = self.client.get('/' + param) +        self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) +        self.assertEqual(resp.status_code, DUMMYSTATUS) + +    def test_specified_renderer_serializes_content_on_format_kwargs(self): +        """If a 'format' keyword arg is specified, the renderer with the matching +        format attribute should serialize the response.""" +        resp = self.client.get('/something.formatb') +        self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) +        self.assertEqual(resp.status_code, DUMMYSTATUS) + +    def test_specified_renderer_is_used_on_format_query_with_matching_accept(self): +        """If both a 'format' query and a matching Accept header specified, +        the renderer with the matching format attribute should serialize the response.""" +        param = '?%s=%s' % ( +            api_settings.URL_FORMAT_OVERRIDE, +            RendererB.format +        ) +        resp = self.client.get('/' + param, +                               HTTP_ACCEPT=RendererB.media_type) +        self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) +        self.assertEqual(resp.status_code, DUMMYSTATUS) + +    def test_parse_error_renderers_browsable_api(self): +        """Invalid data should still render the browsable API correctly.""" +        resp = self.client.post('/parseerror', data='foobar', content_type='application/json', HTTP_ACCEPT='text/html') +        self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') +        self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST) + +    def test_204_no_content_responses_have_no_content_type_set(self): +        """ +        Regression test for #1196 + +        https://github.com/tomchristie/django-rest-framework/issues/1196 +        """ +        resp = self.client.get('/empty') +        self.assertEqual(resp.get('Content-Type', None), None) +        self.assertEqual(resp.status_code, status.HTTP_204_NO_CONTENT) + +    def test_contains_headers_of_api_response(self): +        """ +        Issue #1437 + +        Test we display the headers of the API response and not those from the +        HTML response +        """ +        resp = self.client.get('/html1') +        self.assertContains(resp, '>GET, HEAD, OPTIONS<') +        self.assertContains(resp, '>application/json<') +        self.assertNotContains(resp, '>text/html; charset=utf-8<') + + +_flat_repr = '{"foo": ["bar", "baz"]}' +_indented_repr = '{\n  "foo": [\n    "bar",\n    "baz"\n  ]\n}' + + +def strip_trailing_whitespace(content): +    """ +    Seems to be some inconsistencies re. trailing whitespace with +    different versions of the json lib. +    """ +    return re.sub(' +\n', '\n', content) + + +class JSONRendererTests(TestCase): +    """ +    Tests specific to the JSON Renderer +    """ + +    def test_render_lazy_strings(self): +        """ +        JSONRenderer should deal with lazy translated strings. +        """ +        ret = JSONRenderer().render(_('test')) +        self.assertEqual(ret, b'"test"') + +    def test_render_queryset_values(self): +        o = DummyTestModel.objects.create(name='dummy') +        qs = DummyTestModel.objects.values('id', 'name') +        ret = JSONRenderer().render(qs) +        data = json.loads(ret.decode('utf-8')) +        self.assertEquals(data, [{'id': o.id, 'name': o.name}]) + +    def test_render_queryset_values_list(self): +        o = DummyTestModel.objects.create(name='dummy') +        qs = DummyTestModel.objects.values_list('id', 'name') +        ret = JSONRenderer().render(qs) +        data = json.loads(ret.decode('utf-8')) +        self.assertEquals(data, [[o.id, o.name]]) + +    def test_render_dict_abc_obj(self): +        class Dict(MutableMapping): +            def __init__(self): +                self._dict = dict() +            def __getitem__(self, key): +                return self._dict.__getitem__(key) +            def __setitem__(self, key, value): +                return self._dict.__setitem__(key, value) +            def __delitem__(self, key): +                return self._dict.__delitem__(key) +            def __iter__(self): +                return self._dict.__iter__() +            def __len__(self): +                return self._dict.__len__() +            def keys(self): +                return self._dict.keys() + +        x = Dict() +        x['key'] = 'string value' +        x[2] = 3 +        ret = JSONRenderer().render(x) +        data = json.loads(ret.decode('utf-8')) +        self.assertEquals(data, {'key': 'string value', '2': 3})     + +    def test_render_obj_with_getitem(self): +        class DictLike(object): +            def __init__(self): +                self._dict = {} +            def set(self, value): +                self._dict = dict(value) +            def __getitem__(self, key): +                return self._dict[key] +             +        x = DictLike() +        x.set({'a': 1, 'b': 'string'}) +        with self.assertRaises(TypeError): +            JSONRenderer().render(x) +         +    def test_without_content_type_args(self): +        """ +        Test basic JSON rendering. +        """ +        obj = {'foo': ['bar', 'baz']} +        renderer = JSONRenderer() +        content = renderer.render(obj, 'application/json') +        # Fix failing test case which depends on version of JSON library. +        self.assertEqual(content.decode('utf-8'), _flat_repr) + +    def test_with_content_type_args(self): +        """ +        Test JSON rendering with additional content type arguments supplied. +        """ +        obj = {'foo': ['bar', 'baz']} +        renderer = JSONRenderer() +        content = renderer.render(obj, 'application/json; indent=2') +        self.assertEqual(strip_trailing_whitespace(content.decode('utf-8')), _indented_repr) + +    def test_check_ascii(self): +        obj = {'countries': ['United Kingdom', 'France', 'España']} +        renderer = JSONRenderer() +        content = renderer.render(obj, 'application/json') +        self.assertEqual(content, '{"countries": ["United Kingdom", "France", "Espa\\u00f1a"]}'.encode('utf-8')) + + +class UnicodeJSONRendererTests(TestCase): +    """ +    Tests specific for the Unicode JSON Renderer +    """ +    def test_proper_encoding(self): +        obj = {'countries': ['United Kingdom', 'France', 'España']} +        renderer = UnicodeJSONRenderer() +        content = renderer.render(obj, 'application/json') +        self.assertEqual(content, '{"countries": ["United Kingdom", "France", "España"]}'.encode('utf-8')) + + +class JSONPRendererTests(TestCase): +    """ +    Tests specific to the JSONP Renderer +    """ + +    urls = 'tests.test_renderers' + +    def test_without_callback_with_json_renderer(self): +        """ +        Test JSONP rendering with View JSON Renderer. +        """ +        resp = self.client.get('/jsonp/jsonrenderer', +                               HTTP_ACCEPT='application/javascript') +        self.assertEqual(resp.status_code, status.HTTP_200_OK) +        self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8') +        self.assertEqual(resp.content, +            ('callback(%s);' % _flat_repr).encode('ascii')) + +    def test_without_callback_without_json_renderer(self): +        """ +        Test JSONP rendering without View JSON Renderer. +        """ +        resp = self.client.get('/jsonp/nojsonrenderer', +                               HTTP_ACCEPT='application/javascript') +        self.assertEqual(resp.status_code, status.HTTP_200_OK) +        self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8') +        self.assertEqual(resp.content, +            ('callback(%s);' % _flat_repr).encode('ascii')) + +    def test_with_callback(self): +        """ +        Test JSONP rendering with callback function name. +        """ +        callback_func = 'myjsonpcallback' +        resp = self.client.get('/jsonp/nojsonrenderer?callback=' + callback_func, +                               HTTP_ACCEPT='application/javascript') +        self.assertEqual(resp.status_code, status.HTTP_200_OK) +        self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8') +        self.assertEqual(resp.content, +            ('%s(%s);' % (callback_func, _flat_repr)).encode('ascii')) + + +if yaml: +    _yaml_repr = 'foo: [bar, baz]\n' + +    class YAMLRendererTests(TestCase): +        """ +        Tests specific to the YAML Renderer +        """ + +        def test_render(self): +            """ +            Test basic YAML rendering. +            """ +            obj = {'foo': ['bar', 'baz']} +            renderer = YAMLRenderer() +            content = renderer.render(obj, 'application/yaml') +            self.assertEqual(content, _yaml_repr) + +        def test_render_and_parse(self): +            """ +            Test rendering and then parsing returns the original object. +            IE obj -> render -> parse -> obj. +            """ +            obj = {'foo': ['bar', 'baz']} + +            renderer = YAMLRenderer() +            parser = YAMLParser() + +            content = renderer.render(obj, 'application/yaml') +            data = parser.parse(StringIO(content)) +            self.assertEqual(obj, data) + +        def test_render_decimal(self): +            """ +            Test YAML decimal rendering. +            """ +            renderer = YAMLRenderer() +            content = renderer.render({'field': Decimal('111.2')}, 'application/yaml') +            self.assertYAMLContains(content, "field: '111.2'") + +        def assertYAMLContains(self, content, string): +            self.assertTrue(string in content, '%r not in %r' % (string, content)) + + +    class UnicodeYAMLRendererTests(TestCase): +        """ +        Tests specific for the Unicode YAML Renderer +        """ +        def test_proper_encoding(self): +            obj = {'countries': ['United Kingdom', 'France', 'España']} +            renderer = UnicodeYAMLRenderer() +            content = renderer.render(obj, 'application/yaml') +            self.assertEqual(content.strip(), 'countries: [United Kingdom, France, España]'.encode('utf-8')) + + +class XMLRendererTestCase(TestCase): +    """ +    Tests specific to the XML Renderer +    """ + +    _complex_data = { +        "creation_date": datetime.datetime(2011, 12, 25, 12, 45, 00), +        "name": "name", +        "sub_data_list": [ +            { +                "sub_id": 1, +                "sub_name": "first" +            }, +            { +                "sub_id": 2, +                "sub_name": "second" +            } +        ] +    } + +    def test_render_string(self): +        """ +        Test XML rendering. +        """ +        renderer = XMLRenderer() +        content = renderer.render({'field': 'astring'}, 'application/xml') +        self.assertXMLContains(content, '<field>astring</field>') + +    def test_render_integer(self): +        """ +        Test XML rendering. +        """ +        renderer = XMLRenderer() +        content = renderer.render({'field': 111}, 'application/xml') +        self.assertXMLContains(content, '<field>111</field>') + +    def test_render_datetime(self): +        """ +        Test XML rendering. +        """ +        renderer = XMLRenderer() +        content = renderer.render({ +            'field': datetime.datetime(2011, 12, 25, 12, 45, 00) +        }, 'application/xml') +        self.assertXMLContains(content, '<field>2011-12-25 12:45:00</field>') + +    def test_render_float(self): +        """ +        Test XML rendering. +        """ +        renderer = XMLRenderer() +        content = renderer.render({'field': 123.4}, 'application/xml') +        self.assertXMLContains(content, '<field>123.4</field>') + +    def test_render_decimal(self): +        """ +        Test XML rendering. +        """ +        renderer = XMLRenderer() +        content = renderer.render({'field': Decimal('111.2')}, 'application/xml') +        self.assertXMLContains(content, '<field>111.2</field>') + +    def test_render_none(self): +        """ +        Test XML rendering. +        """ +        renderer = XMLRenderer() +        content = renderer.render({'field': None}, 'application/xml') +        self.assertXMLContains(content, '<field></field>') + +    def test_render_complex_data(self): +        """ +        Test XML rendering. +        """ +        renderer = XMLRenderer() +        content = renderer.render(self._complex_data, 'application/xml') +        self.assertXMLContains(content, '<sub_name>first</sub_name>') +        self.assertXMLContains(content, '<sub_name>second</sub_name>') + +    @unittest.skipUnless(etree, 'defusedxml not installed') +    def test_render_and_parse_complex_data(self): +        """ +        Test XML rendering. +        """ +        renderer = XMLRenderer() +        content = StringIO(renderer.render(self._complex_data, 'application/xml')) + +        parser = XMLParser() +        complex_data_out = parser.parse(content) +        error_msg = "complex data differs!IN:\n %s \n\n OUT:\n %s" % (repr(self._complex_data), repr(complex_data_out)) +        self.assertEqual(self._complex_data, complex_data_out, error_msg) + +    def assertXMLContains(self, xml, string): +        self.assertTrue(xml.startswith('<?xml version="1.0" encoding="utf-8"?>\n<root>')) +        self.assertTrue(xml.endswith('</root>')) +        self.assertTrue(string in xml, '%r not in %r' % (string, xml)) + + +# Tests for caching issue, #346 +class CacheRenderTest(TestCase): +    """ +    Tests specific to caching responses +    """ + +    urls = 'tests.test_renderers' + +    cache_key = 'just_a_cache_key' + +    @classmethod +    def _get_pickling_errors(cls, obj, seen=None): +        """ Return any errors that would be raised if `obj' is pickled +        Courtesy of koffie @ http://stackoverflow.com/a/7218986/109897 +        """ +        if seen == None: +            seen = [] +        try: +            state = obj.__getstate__() +        except AttributeError: +            return +        if state == None: +            return +        if isinstance(state, tuple): +            if not isinstance(state[0], dict): +                state = state[1] +            else: +                state = state[0].update(state[1]) +        result = {} +        for i in state: +            try: +                pickle.dumps(state[i], protocol=2) +            except pickle.PicklingError: +                if not state[i] in seen: +                    seen.append(state[i]) +                    result[i] = cls._get_pickling_errors(state[i], seen) +        return result + +    def http_resp(self, http_method, url): +        """ +        Simple wrapper for Client http requests +        Removes the `client' and `request' attributes from as they are +        added by django.test.client.Client and not part of caching +        responses outside of tests. +        """ +        method = getattr(self.client, http_method) +        resp = method(url) +        del resp.client, resp.request +        try: +            del resp.wsgi_request +        except AttributeError: +            pass +        return resp + +    def test_obj_pickling(self): +        """ +        Test that responses are properly pickled +        """ +        resp = self.http_resp('get', '/cache') + +        # Make sure that no pickling errors occurred +        self.assertEqual(self._get_pickling_errors(resp), {}) + +        # Unfortunately LocMem backend doesn't raise PickleErrors but returns +        # None instead. +        cache.set(self.cache_key, resp) +        self.assertTrue(cache.get(self.cache_key) is not None) + +    def test_head_caching(self): +        """ +        Test caching of HEAD requests +        """ +        resp = self.http_resp('head', '/cache') +        cache.set(self.cache_key, resp) + +        cached_resp = cache.get(self.cache_key) +        self.assertIsInstance(cached_resp, Response) + +    def test_get_caching(self): +        """ +        Test caching of GET requests +        """ +        resp = self.http_resp('get', '/cache') +        cache.set(self.cache_key, resp) + +        cached_resp = cache.get(self.cache_key) +        self.assertIsInstance(cached_resp, Response) +        self.assertEqual(cached_resp.content, resp.content) diff --git a/tests/test_request.py b/tests/test_request.py new file mode 100644 index 00000000..0a9355f0 --- /dev/null +++ b/tests/test_request.py @@ -0,0 +1,347 @@ +""" +Tests for content parsing, and form-overloaded content parsing. +""" +from __future__ import unicode_literals +from django.contrib.auth.models import User +from django.contrib.auth import authenticate, login, logout +from django.contrib.sessions.middleware import SessionMiddleware +from django.core.handlers.wsgi import WSGIRequest +from django.test import TestCase +from rest_framework import status +from rest_framework.authentication import SessionAuthentication +from rest_framework.compat import patterns +from rest_framework.parsers import ( +    BaseParser, +    FormParser, +    MultiPartParser, +    JSONParser +) +from rest_framework.request import Request, Empty +from rest_framework.response import Response +from rest_framework.settings import api_settings +from rest_framework.test import APIRequestFactory, APIClient +from rest_framework.views import APIView +from rest_framework.compat import six +from io import BytesIO +import json + + +factory = APIRequestFactory() + + +class PlainTextParser(BaseParser): +    media_type = 'text/plain' + +    def parse(self, stream, media_type=None, parser_context=None): +        """ +        Returns a 2-tuple of `(data, files)`. + +        `data` will simply be a string representing the body of the request. +        `files` will always be `None`. +        """ +        return stream.read() + + +class TestMethodOverloading(TestCase): +    def test_method(self): +        """ +        Request methods should be same as underlying request. +        """ +        request = Request(factory.get('/')) +        self.assertEqual(request.method, 'GET') +        request = Request(factory.post('/')) +        self.assertEqual(request.method, 'POST') + +    def test_overloaded_method(self): +        """ +        POST requests can be overloaded to another method by setting a +        reserved form field +        """ +        request = Request(factory.post('/', {api_settings.FORM_METHOD_OVERRIDE: 'DELETE'})) +        self.assertEqual(request.method, 'DELETE') + +    def test_x_http_method_override_header(self): +        """ +        POST requests can also be overloaded to another method by setting +        the X-HTTP-Method-Override header. +        """ +        request = Request(factory.post('/', {'foo': 'bar'}, HTTP_X_HTTP_METHOD_OVERRIDE='DELETE')) +        self.assertEqual(request.method, 'DELETE') + +        request = Request(factory.get('/', {'foo': 'bar'}, HTTP_X_HTTP_METHOD_OVERRIDE='DELETE')) +        self.assertEqual(request.method, 'DELETE') + + +class TestContentParsing(TestCase): +    def test_standard_behaviour_determines_no_content_GET(self): +        """ +        Ensure request.DATA returns empty QueryDict for GET request. +        """ +        request = Request(factory.get('/')) +        self.assertEqual(request.DATA, {}) + +    def test_standard_behaviour_determines_no_content_HEAD(self): +        """ +        Ensure request.DATA returns empty QueryDict for HEAD request. +        """ +        request = Request(factory.head('/')) +        self.assertEqual(request.DATA, {}) + +    def test_request_DATA_with_form_content(self): +        """ +        Ensure request.DATA returns content for POST request with form content. +        """ +        data = {'qwerty': 'uiop'} +        request = Request(factory.post('/', data)) +        request.parsers = (FormParser(), MultiPartParser()) +        self.assertEqual(list(request.DATA.items()), list(data.items())) + +    def test_request_DATA_with_text_content(self): +        """ +        Ensure request.DATA returns content for POST request with +        non-form content. +        """ +        content = six.b('qwerty') +        content_type = 'text/plain' +        request = Request(factory.post('/', content, content_type=content_type)) +        request.parsers = (PlainTextParser(),) +        self.assertEqual(request.DATA, content) + +    def test_request_POST_with_form_content(self): +        """ +        Ensure request.POST returns content for POST request with form content. +        """ +        data = {'qwerty': 'uiop'} +        request = Request(factory.post('/', data)) +        request.parsers = (FormParser(), MultiPartParser()) +        self.assertEqual(list(request.POST.items()), list(data.items())) + +    def test_standard_behaviour_determines_form_content_PUT(self): +        """ +        Ensure request.DATA returns content for PUT request with form content. +        """ +        data = {'qwerty': 'uiop'} +        request = Request(factory.put('/', data)) +        request.parsers = (FormParser(), MultiPartParser()) +        self.assertEqual(list(request.DATA.items()), list(data.items())) + +    def test_standard_behaviour_determines_non_form_content_PUT(self): +        """ +        Ensure request.DATA returns content for PUT request with +        non-form content. +        """ +        content = six.b('qwerty') +        content_type = 'text/plain' +        request = Request(factory.put('/', content, content_type=content_type)) +        request.parsers = (PlainTextParser(), ) +        self.assertEqual(request.DATA, content) + +    def test_overloaded_behaviour_allows_content_tunnelling(self): +        """ +        Ensure request.DATA returns content for overloaded POST request. +        """ +        json_data = {'foobar': 'qwerty'} +        content = json.dumps(json_data) +        content_type = 'application/json' +        form_data = { +            api_settings.FORM_CONTENT_OVERRIDE: content, +            api_settings.FORM_CONTENTTYPE_OVERRIDE: content_type +        } +        request = Request(factory.post('/', form_data)) +        request.parsers = (JSONParser(), ) +        self.assertEqual(request.DATA, json_data) + +    def test_form_POST_unicode(self): +        """ +        JSON POST via default web interface with unicode data +        """ +        # Note: environ and other variables here have simplified content compared to real Request +        CONTENT = b'_content_type=application%2Fjson&_content=%7B%22request%22%3A+4%2C+%22firm%22%3A+1%2C+%22text%22%3A+%22%D0%9F%D1%80%D0%B8%D0%B2%D0%B5%D1%82%21%22%7D' +        environ = { +            'REQUEST_METHOD': 'POST', +            'CONTENT_TYPE': 'application/x-www-form-urlencoded', +            'CONTENT_LENGTH': len(CONTENT), +            'wsgi.input': BytesIO(CONTENT), +        } +        wsgi_request = WSGIRequest(environ=environ) +        wsgi_request._load_post_and_files() +        parsers = (JSONParser(), FormParser(), MultiPartParser()) +        parser_context = { +            'encoding': 'utf-8', +            'kwargs': {}, +            'args': (), +        } +        request = Request(wsgi_request, parsers=parsers, parser_context=parser_context) +        method = request.method +        self.assertEqual(method, 'POST') +        self.assertEqual(request._content_type, 'application/json') +        self.assertEqual(request._stream.getvalue(), b'{"request": 4, "firm": 1, "text": "\xd0\x9f\xd1\x80\xd0\xb8\xd0\xb2\xd0\xb5\xd1\x82!"}') +        self.assertEqual(request._data, Empty) +        self.assertEqual(request._files, Empty) + +    # def test_accessing_post_after_data_form(self): +    #     """ +    #     Ensures request.POST can be accessed after request.DATA in +    #     form request. +    #     """ +    #     data = {'qwerty': 'uiop'} +    #     request = factory.post('/', data=data) +    #     self.assertEqual(request.DATA.items(), data.items()) +    #     self.assertEqual(request.POST.items(), data.items()) + +    # def test_accessing_post_after_data_for_json(self): +    #     """ +    #     Ensures request.POST can be accessed after request.DATA in +    #     json request. +    #     """ +    #     data = {'qwerty': 'uiop'} +    #     content = json.dumps(data) +    #     content_type = 'application/json' +    #     parsers = (JSONParser, ) + +    #     request = factory.post('/', content, content_type=content_type, +    #                            parsers=parsers) +    #     self.assertEqual(request.DATA.items(), data.items()) +    #     self.assertEqual(request.POST.items(), []) + +    # def test_accessing_post_after_data_for_overloaded_json(self): +    #     """ +    #     Ensures request.POST can be accessed after request.DATA in overloaded +    #     json request. +    #     """ +    #     data = {'qwerty': 'uiop'} +    #     content = json.dumps(data) +    #     content_type = 'application/json' +    #     parsers = (JSONParser, ) +    #     form_data = {Request._CONTENT_PARAM: content, +    #                  Request._CONTENTTYPE_PARAM: content_type} + +    #     request = factory.post('/', form_data, parsers=parsers) +    #     self.assertEqual(request.DATA.items(), data.items()) +    #     self.assertEqual(request.POST.items(), form_data.items()) + +    # def test_accessing_data_after_post_form(self): +    #     """ +    #     Ensures request.DATA can be accessed after request.POST in +    #     form request. +    #     """ +    #     data = {'qwerty': 'uiop'} +    #     parsers = (FormParser, MultiPartParser) +    #     request = factory.post('/', data, parsers=parsers) + +    #     self.assertEqual(request.POST.items(), data.items()) +    #     self.assertEqual(request.DATA.items(), data.items()) + +    # def test_accessing_data_after_post_for_json(self): +    #     """ +    #     Ensures request.DATA can be accessed after request.POST in +    #     json request. +    #     """ +    #     data = {'qwerty': 'uiop'} +    #     content = json.dumps(data) +    #     content_type = 'application/json' +    #     parsers = (JSONParser, ) +    #     request = factory.post('/', content, content_type=content_type, +    #                            parsers=parsers) +    #     self.assertEqual(request.POST.items(), []) +    #     self.assertEqual(request.DATA.items(), data.items()) + +    # def test_accessing_data_after_post_for_overloaded_json(self): +    #     """ +    #     Ensures request.DATA can be accessed after request.POST in overloaded +    #     json request +    #     """ +    #     data = {'qwerty': 'uiop'} +    #     content = json.dumps(data) +    #     content_type = 'application/json' +    #     parsers = (JSONParser, ) +    #     form_data = {Request._CONTENT_PARAM: content, +    #                  Request._CONTENTTYPE_PARAM: content_type} + +    #     request = factory.post('/', form_data, parsers=parsers) +    #     self.assertEqual(request.POST.items(), form_data.items()) +    #     self.assertEqual(request.DATA.items(), data.items()) + + +class MockView(APIView): +    authentication_classes = (SessionAuthentication,) + +    def post(self, request): +        if request.POST.get('example') is not None: +            return Response(status=status.HTTP_200_OK) + +        return Response(status=status.INTERNAL_SERVER_ERROR) + +urlpatterns = patterns('', +    (r'^$', MockView.as_view()), +) + + +class TestContentParsingWithAuthentication(TestCase): +    urls = 'tests.test_request' + +    def setUp(self): +        self.csrf_client = APIClient(enforce_csrf_checks=True) +        self.username = 'john' +        self.email = 'lennon@thebeatles.com' +        self.password = 'password' +        self.user = User.objects.create_user(self.username, self.email, self.password) + +    def test_user_logged_in_authentication_has_POST_when_not_logged_in(self): +        """ +        Ensures request.POST exists after SessionAuthentication when user +        doesn't log in. +        """ +        content = {'example': 'example'} + +        response = self.client.post('/', content) +        self.assertEqual(status.HTTP_200_OK, response.status_code) + +        response = self.csrf_client.post('/', content) +        self.assertEqual(status.HTTP_200_OK, response.status_code) + +    # def test_user_logged_in_authentication_has_post_when_logged_in(self): +    #     """Ensures request.POST exists after UserLoggedInAuthentication when user does log in""" +    #     self.client.login(username='john', password='password') +    #     self.csrf_client.login(username='john', password='password') +    #     content = {'example': 'example'} + +    #     response = self.client.post('/', content) +    #     self.assertEqual(status.OK, response.status_code, "POST data is malformed") + +    #     response = self.csrf_client.post('/', content) +    #     self.assertEqual(status.OK, response.status_code, "POST data is malformed") + + +class TestUserSetter(TestCase): + +    def setUp(self): +        # Pass request object through session middleware so session is +        # available to login and logout functions +        self.request = Request(factory.get('/')) +        SessionMiddleware().process_request(self.request) + +        User.objects.create_user('ringo', 'starr@thebeatles.com', 'yellow') +        self.user = authenticate(username='ringo', password='yellow') + +    def test_user_can_be_set(self): +        self.request.user = self.user +        self.assertEqual(self.request.user, self.user) + +    def test_user_can_login(self): +        login(self.request, self.user) +        self.assertEqual(self.request.user, self.user) + +    def test_user_can_logout(self): +        self.request.user = self.user +        self.assertFalse(self.request.user.is_anonymous()) +        logout(self.request) +        self.assertTrue(self.request.user.is_anonymous()) + + +class TestAuthSetter(TestCase): + +    def test_auth_can_be_set(self): +        request = Request(factory.get('/')) +        request.auth = 'DUMMY' +        self.assertEqual(request.auth, 'DUMMY') diff --git a/tests/test_response.py b/tests/test_response.py new file mode 100644 index 00000000..41c0f49d --- /dev/null +++ b/tests/test_response.py @@ -0,0 +1,278 @@ +from __future__ import unicode_literals +from django.test import TestCase +from tests.models import BasicModel, BasicModelSerializer +from rest_framework.compat import patterns, url, include +from rest_framework.response import Response +from rest_framework.views import APIView +from rest_framework import generics +from rest_framework import routers +from rest_framework import status +from rest_framework.renderers import ( +    BaseRenderer, +    JSONRenderer, +    BrowsableAPIRenderer +) +from rest_framework import viewsets +from rest_framework.settings import api_settings +from rest_framework.compat import six + + +class MockPickleRenderer(BaseRenderer): +    media_type = 'application/pickle' + + +class MockJsonRenderer(BaseRenderer): +    media_type = 'application/json' + + +class MockTextMediaRenderer(BaseRenderer): +    media_type = 'text/html' + +DUMMYSTATUS = status.HTTP_200_OK +DUMMYCONTENT = 'dummycontent' + +RENDERER_A_SERIALIZER = lambda x: ('Renderer A: %s' % x).encode('ascii') +RENDERER_B_SERIALIZER = lambda x: ('Renderer B: %s' % x).encode('ascii') + + +class RendererA(BaseRenderer): +    media_type = 'mock/renderera' +    format = "formata" + +    def render(self, data, media_type=None, renderer_context=None): +        return RENDERER_A_SERIALIZER(data) + + +class RendererB(BaseRenderer): +    media_type = 'mock/rendererb' +    format = "formatb" + +    def render(self, data, media_type=None, renderer_context=None): +        return RENDERER_B_SERIALIZER(data) + + +class RendererC(RendererB): +    media_type = 'mock/rendererc' +    format = 'formatc' +    charset = "rendererc" + + +class MockView(APIView): +    renderer_classes = (RendererA, RendererB, RendererC) + +    def get(self, request, **kwargs): +        return Response(DUMMYCONTENT, status=DUMMYSTATUS) + + +class MockViewSettingContentType(APIView): +    renderer_classes = (RendererA, RendererB, RendererC) + +    def get(self, request, **kwargs): +        return Response(DUMMYCONTENT, status=DUMMYSTATUS, content_type='setbyview') + + +class HTMLView(APIView): +    renderer_classes = (BrowsableAPIRenderer, ) + +    def get(self, request, **kwargs): +        return Response('text') + + +class HTMLView1(APIView): +    renderer_classes = (BrowsableAPIRenderer, JSONRenderer) + +    def get(self, request, **kwargs): +        return Response('text') + + +class HTMLNewModelViewSet(viewsets.ModelViewSet): +    model = BasicModel + + +class HTMLNewModelView(generics.ListCreateAPIView): +    renderer_classes = (BrowsableAPIRenderer,) +    permission_classes = [] +    serializer_class = BasicModelSerializer +    model = BasicModel + + +new_model_viewset_router = routers.DefaultRouter() +new_model_viewset_router.register(r'', HTMLNewModelViewSet) + + +urlpatterns = patterns('', +    url(r'^setbyview$', MockViewSettingContentType.as_view(renderer_classes=[RendererA, RendererB, RendererC])), +    url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])), +    url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])), +    url(r'^html$', HTMLView.as_view()), +    url(r'^html1$', HTMLView1.as_view()), +    url(r'^html_new_model$', HTMLNewModelView.as_view()), +    url(r'^html_new_model_viewset', include(new_model_viewset_router.urls)), +    url(r'^restframework', include('rest_framework.urls', namespace='rest_framework')) +) + + +# TODO: Clean tests bellow - remove duplicates with above, better unit testing, ... +class RendererIntegrationTests(TestCase): +    """ +    End-to-end testing of renderers using an ResponseMixin on a generic view. +    """ + +    urls = 'tests.test_response' + +    def test_default_renderer_serializes_content(self): +        """If the Accept header is not set the default renderer should serialize the response.""" +        resp = self.client.get('/') +        self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) +        self.assertEqual(resp.status_code, DUMMYSTATUS) + +    def test_head_method_serializes_no_content(self): +        """No response must be included in HEAD requests.""" +        resp = self.client.head('/') +        self.assertEqual(resp.status_code, DUMMYSTATUS) +        self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, six.b('')) + +    def test_default_renderer_serializes_content_on_accept_any(self): +        """If the Accept header is set to */* the default renderer should serialize the response.""" +        resp = self.client.get('/', HTTP_ACCEPT='*/*') +        self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) +        self.assertEqual(resp.status_code, DUMMYSTATUS) + +    def test_specified_renderer_serializes_content_default_case(self): +        """If the Accept header is set the specified renderer should serialize the response. +        (In this case we check that works for the default renderer)""" +        resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type) +        self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) +        self.assertEqual(resp.status_code, DUMMYSTATUS) + +    def test_specified_renderer_serializes_content_non_default_case(self): +        """If the Accept header is set the specified renderer should serialize the response. +        (In this case we check that works for a non-default renderer)""" +        resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type) +        self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) +        self.assertEqual(resp.status_code, DUMMYSTATUS) + +    def test_specified_renderer_serializes_content_on_accept_query(self): +        """The '_accept' query string should behave in the same way as the Accept header.""" +        param = '?%s=%s' % ( +            api_settings.URL_ACCEPT_OVERRIDE, +            RendererB.media_type +        ) +        resp = self.client.get('/' + param) +        self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) +        self.assertEqual(resp.status_code, DUMMYSTATUS) + +    def test_specified_renderer_serializes_content_on_format_query(self): +        """If a 'format' query is specified, the renderer with the matching +        format attribute should serialize the response.""" +        resp = self.client.get('/?format=%s' % RendererB.format) +        self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) +        self.assertEqual(resp.status_code, DUMMYSTATUS) + +    def test_specified_renderer_serializes_content_on_format_kwargs(self): +        """If a 'format' keyword arg is specified, the renderer with the matching +        format attribute should serialize the response.""" +        resp = self.client.get('/something.formatb') +        self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) +        self.assertEqual(resp.status_code, DUMMYSTATUS) + +    def test_specified_renderer_is_used_on_format_query_with_matching_accept(self): +        """If both a 'format' query and a matching Accept header specified, +        the renderer with the matching format attribute should serialize the response.""" +        resp = self.client.get('/?format=%s' % RendererB.format, +                               HTTP_ACCEPT=RendererB.media_type) +        self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') +        self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) +        self.assertEqual(resp.status_code, DUMMYSTATUS) + + +class Issue122Tests(TestCase): +    """ +    Tests that covers #122. +    """ +    urls = 'tests.test_response' + +    def test_only_html_renderer(self): +        """ +        Test if no infinite recursion occurs. +        """ +        self.client.get('/html') + +    def test_html_renderer_is_first(self): +        """ +        Test if no infinite recursion occurs. +        """ +        self.client.get('/html1') + + +class Issue467Tests(TestCase): +    """ +    Tests for #467 +    """ + +    urls = 'tests.test_response' + +    def test_form_has_label_and_help_text(self): +        resp = self.client.get('/html_new_model') +        self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') +        self.assertContains(resp, 'Text comes here') +        self.assertContains(resp, 'Text description.') + + +class Issue807Tests(TestCase): +    """ +    Covers #807 +    """ + +    urls = 'tests.test_response' + +    def test_does_not_append_charset_by_default(self): +        """ +        Renderers don't include a charset unless set explicitly. +        """ +        headers = {"HTTP_ACCEPT": RendererA.media_type} +        resp = self.client.get('/', **headers) +        expected = "{0}; charset={1}".format(RendererA.media_type, 'utf-8') +        self.assertEqual(expected, resp['Content-Type']) + +    def test_if_there_is_charset_specified_on_renderer_it_gets_appended(self): +        """ +        If renderer class has charset attribute declared, it gets appended +        to Response's Content-Type +        """ +        headers = {"HTTP_ACCEPT": RendererC.media_type} +        resp = self.client.get('/', **headers) +        expected = "{0}; charset={1}".format(RendererC.media_type, RendererC.charset) +        self.assertEqual(expected, resp['Content-Type']) + +    def test_content_type_set_explictly_on_response(self): +        """ +        The content type may be set explictly on the response. +        """ +        headers = {"HTTP_ACCEPT": RendererC.media_type} +        resp = self.client.get('/setbyview', **headers) +        self.assertEqual('setbyview', resp['Content-Type']) + +    def test_viewset_label_help_text(self): +        param = '?%s=%s' % ( +            api_settings.URL_ACCEPT_OVERRIDE, +            'text/html' +        ) +        resp = self.client.get('/html_new_model_viewset/' + param) +        self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') +        self.assertContains(resp, 'Text comes here') +        self.assertContains(resp, 'Text description.') + +    def test_form_has_label_and_help_text(self): +        resp = self.client.get('/html_new_model') +        self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') +        self.assertContains(resp, 'Text comes here') +        self.assertContains(resp, 'Text description.') diff --git a/tests/test_reverse.py b/tests/test_reverse.py new file mode 100644 index 00000000..3d14a28f --- /dev/null +++ b/tests/test_reverse.py @@ -0,0 +1,27 @@ +from __future__ import unicode_literals +from django.test import TestCase +from rest_framework.compat import patterns, url +from rest_framework.reverse import reverse +from rest_framework.test import APIRequestFactory + +factory = APIRequestFactory() + + +def null_view(request): +    pass + +urlpatterns = patterns('', +    url(r'^view$', null_view, name='view'), +) + + +class ReverseTests(TestCase): +    """ +    Tests for fully qualified URLs when using `reverse`. +    """ +    urls = 'tests.test_reverse' + +    def test_reversed_urls_are_fully_qualified(self): +        request = factory.get('/view') +        url = reverse('view', request=request) +        self.assertEqual(url, 'http://testserver/view') diff --git a/tests/test_routers.py b/tests/test_routers.py new file mode 100644 index 00000000..084c0e27 --- /dev/null +++ b/tests/test_routers.py @@ -0,0 +1,216 @@ +from __future__ import unicode_literals +from django.db import models +from django.test import TestCase +from django.core.exceptions import ImproperlyConfigured +from rest_framework import serializers, viewsets, permissions +from rest_framework.compat import include, patterns, url +from rest_framework.decorators import link, action +from rest_framework.response import Response +from rest_framework.routers import SimpleRouter, DefaultRouter +from rest_framework.test import APIRequestFactory + +factory = APIRequestFactory() + +urlpatterns = patterns('',) + + +class BasicViewSet(viewsets.ViewSet): +    def list(self, request, *args, **kwargs): +        return Response({'method': 'list'}) + +    @action() +    def action1(self, request, *args, **kwargs): +        return Response({'method': 'action1'}) + +    @action() +    def action2(self, request, *args, **kwargs): +        return Response({'method': 'action2'}) + +    @action(methods=['post', 'delete']) +    def action3(self, request, *args, **kwargs): +        return Response({'method': 'action2'}) + +    @link() +    def link1(self, request, *args, **kwargs): +        return Response({'method': 'link1'}) + +    @link() +    def link2(self, request, *args, **kwargs): +        return Response({'method': 'link2'}) + + +class TestSimpleRouter(TestCase): +    def setUp(self): +        self.router = SimpleRouter() + +    def test_link_and_action_decorator(self): +        routes = self.router.get_routes(BasicViewSet) +        decorator_routes = routes[2:] +        # Make sure all these endpoints exist and none have been clobbered +        for i, endpoint in enumerate(['action1', 'action2', 'action3', 'link1', 'link2']): +            route = decorator_routes[i] +            # check url listing +            self.assertEqual(route.url, +                             '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(endpoint)) +            # check method to function mapping +            if endpoint == 'action3': +                methods_map = ['post', 'delete'] +            elif endpoint.startswith('action'): +                methods_map = ['post'] +            else: +                methods_map = ['get'] +            for method in methods_map: +                self.assertEqual(route.mapping[method], endpoint) + + +class RouterTestModel(models.Model): +    uuid = models.CharField(max_length=20) +    text = models.CharField(max_length=200) + + +class TestCustomLookupFields(TestCase): +    """ +    Ensure that custom lookup fields are correctly routed. +    """ +    urls = 'tests.test_routers' + +    def setUp(self): +        class NoteSerializer(serializers.HyperlinkedModelSerializer): +            class Meta: +                model = RouterTestModel +                lookup_field = 'uuid' +                fields = ('url', 'uuid', 'text') + +        class NoteViewSet(viewsets.ModelViewSet): +            queryset = RouterTestModel.objects.all() +            serializer_class = NoteSerializer +            lookup_field = 'uuid' + +        RouterTestModel.objects.create(uuid='123', text='foo bar') + +        self.router = SimpleRouter() +        self.router.register(r'notes', NoteViewSet) + +        from tests import test_routers +        urls = getattr(test_routers, 'urlpatterns') +        urls += patterns('', +            url(r'^', include(self.router.urls)), +        ) + +    def test_custom_lookup_field_route(self): +        detail_route = self.router.urls[-1] +        detail_url_pattern = detail_route.regex.pattern +        self.assertIn('<uuid>', detail_url_pattern) + +    def test_retrieve_lookup_field_list_view(self): +        response = self.client.get('/notes/') +        self.assertEqual(response.data, +            [{ +                "url": "http://testserver/notes/123/", +                "uuid": "123", "text": "foo bar" +            }] +        ) + +    def test_retrieve_lookup_field_detail_view(self): +        response = self.client.get('/notes/123/') +        self.assertEqual(response.data, +            { +                "url": "http://testserver/notes/123/", +                "uuid": "123", "text": "foo bar" +            } +        ) + + +class TestTrailingSlashIncluded(TestCase): +    def setUp(self): +        class NoteViewSet(viewsets.ModelViewSet): +            model = RouterTestModel + +        self.router = SimpleRouter() +        self.router.register(r'notes', NoteViewSet) +        self.urls = self.router.urls + +    def test_urls_have_trailing_slash_by_default(self): +        expected = ['^notes/$', '^notes/(?P<pk>[^/]+)/$'] +        for idx in range(len(expected)): +            self.assertEqual(expected[idx], self.urls[idx].regex.pattern) + + +class TestTrailingSlashRemoved(TestCase): +    def setUp(self): +        class NoteViewSet(viewsets.ModelViewSet): +            model = RouterTestModel + +        self.router = SimpleRouter(trailing_slash=False) +        self.router.register(r'notes', NoteViewSet) +        self.urls = self.router.urls + +    def test_urls_can_have_trailing_slash_removed(self): +        expected = ['^notes$', '^notes/(?P<pk>[^/.]+)$'] +        for idx in range(len(expected)): +            self.assertEqual(expected[idx], self.urls[idx].regex.pattern) + + +class TestNameableRoot(TestCase): +    def setUp(self): +        class NoteViewSet(viewsets.ModelViewSet): +            model = RouterTestModel +        self.router = DefaultRouter() +        self.router.root_view_name = 'nameable-root' +        self.router.register(r'notes', NoteViewSet) +        self.urls = self.router.urls + +    def test_router_has_custom_name(self): +        expected = 'nameable-root' +        self.assertEqual(expected, self.urls[0].name) + + +class TestActionKeywordArgs(TestCase): +    """ +    Ensure keyword arguments passed in the `@action` decorator +    are properly handled.  Refs #940. +    """ + +    def setUp(self): +        class TestViewSet(viewsets.ModelViewSet): +            permission_classes = [] + +            @action(permission_classes=[permissions.AllowAny]) +            def custom(self, request, *args, **kwargs): +                return Response({ +                    'permission_classes': self.permission_classes +                }) + +        self.router = SimpleRouter() +        self.router.register(r'test', TestViewSet, base_name='test') +        self.view = self.router.urls[-1].callback + +    def test_action_kwargs(self): +        request = factory.post('/test/0/custom/') +        response = self.view(request) +        self.assertEqual( +            response.data, +            {'permission_classes': [permissions.AllowAny]} +        ) + + +class TestActionAppliedToExistingRoute(TestCase): +    """ +    Ensure `@action` decorator raises an except when applied +    to an existing route +    """ + +    def test_exception_raised_when_action_applied_to_existing_route(self): +        class TestViewSet(viewsets.ModelViewSet): + +            @action() +            def retrieve(self, request, *args, **kwargs): +                return Response({ +                    'hello': 'world' +                }) + +        self.router = SimpleRouter() +        self.router.register(r'test', TestViewSet, base_name='test') + +        with self.assertRaises(ImproperlyConfigured): +            self.router.urls diff --git a/tests/test_serializer.py b/tests/test_serializer.py new file mode 100644 index 00000000..73eb5c79 --- /dev/null +++ b/tests/test_serializer.py @@ -0,0 +1,1973 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals +from django.db import models +from django.db.models.fields import BLANK_CHOICE_DASH +from django.test import TestCase +from django.utils import unittest +from django.utils.datastructures import MultiValueDict +from django.utils.translation import ugettext_lazy as _ +from rest_framework import serializers, fields, relations +from tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel, +    BlankFieldModel, BlogPost, BlogPostComment, Book, CallableDefaultValueModel, DefaultValueModel, +    ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo, RESTFrameworkModel, +    ForeignKeySource, ManyToManySource) +from tests.models import BasicModelSerializer +import datetime +import pickle +try: +    import PIL +except: +    PIL = None + + +if PIL is not None: +    class AMOAFModel(RESTFrameworkModel): +        char_field = models.CharField(max_length=1024, blank=True) +        comma_separated_integer_field = models.CommaSeparatedIntegerField(max_length=1024, blank=True) +        decimal_field = models.DecimalField(max_digits=64, decimal_places=32, blank=True) +        email_field = models.EmailField(max_length=1024, blank=True) +        file_field = models.FileField(upload_to='test', max_length=1024, blank=True) +        image_field = models.ImageField(upload_to='test', max_length=1024, blank=True) +        slug_field = models.SlugField(max_length=1024, blank=True) +        url_field = models.URLField(max_length=1024, blank=True) + +    class DVOAFModel(RESTFrameworkModel): +        positive_integer_field = models.PositiveIntegerField(blank=True) +        positive_small_integer_field = models.PositiveSmallIntegerField(blank=True) +        email_field = models.EmailField(blank=True) +        file_field = models.FileField(upload_to='test', blank=True) +        image_field = models.ImageField(upload_to='test', blank=True) +        slug_field = models.SlugField(blank=True) +        url_field = models.URLField(blank=True) + + +class SubComment(object): +    def __init__(self, sub_comment): +        self.sub_comment = sub_comment + + +class Comment(object): +    def __init__(self, email, content, created): +        self.email = email +        self.content = content +        self.created = created or datetime.datetime.now() + +    def __eq__(self, other): +        return all([getattr(self, attr) == getattr(other, attr) +                    for attr in ('email', 'content', 'created')]) + +    def get_sub_comment(self): +        sub_comment = SubComment('And Merry Christmas!') +        return sub_comment + + +class CommentSerializer(serializers.Serializer): +    email = serializers.EmailField() +    content = serializers.CharField(max_length=1000) +    created = serializers.DateTimeField() +    sub_comment = serializers.Field(source='get_sub_comment.sub_comment') + +    def restore_object(self, data, instance=None): +        if instance is None: +            return Comment(**data) +        for key, val in data.items(): +            setattr(instance, key, val) +        return instance + + +class NamesSerializer(serializers.Serializer): +    first = serializers.CharField() +    last = serializers.CharField(required=False, default='') +    initials = serializers.CharField(required=False, default='') + + +class PersonIdentifierSerializer(serializers.Serializer): +    ssn = serializers.CharField() +    names = NamesSerializer(source='names', required=False) + + +class BookSerializer(serializers.ModelSerializer): +    isbn = serializers.RegexField(regex=r'^[0-9]{13}$', error_messages={'invalid': 'isbn has to be exact 13 numbers'}) + +    class Meta: +        model = Book + + +class ActionItemSerializer(serializers.ModelSerializer): + +    class Meta: +        model = ActionItem + +class ActionItemSerializerOptionalFields(serializers.ModelSerializer): +    """ +    Intended to test that fields with `required=False` are excluded from validation. +    """ +    title = serializers.CharField(required=False) + +    class Meta: +        model = ActionItem +        fields = ('title',) + +class ActionItemSerializerCustomRestore(serializers.ModelSerializer): + +    class Meta: +        model = ActionItem + +    def restore_object(self, data, instance=None): +        if instance is None: +            return ActionItem(**data) +        for key, val in data.items(): +            setattr(instance, key, val) +        return instance + + +class PersonSerializer(serializers.ModelSerializer): +    info = serializers.Field(source='info') + +    class Meta: +        model = Person +        fields = ('name', 'age', 'info') +        read_only_fields = ('age',) + + +class NestedSerializer(serializers.Serializer): +    info = serializers.Field() + + +class ModelSerializerWithNestedSerializer(serializers.ModelSerializer): +    nested = NestedSerializer(source='*') + +    class Meta: +        model = Person + + +class NestedSerializerWithRenamedField(serializers.Serializer): +    renamed_info = serializers.Field(source='info') + + +class ModelSerializerWithNestedSerializerWithRenamedField(serializers.ModelSerializer): +    nested = NestedSerializerWithRenamedField(source='*') + +    class Meta: +        model = Person + + +class PersonSerializerInvalidReadOnly(serializers.ModelSerializer): +    """ +    Testing for #652. +    """ +    info = serializers.Field(source='info') + +    class Meta: +        model = Person +        fields = ('name', 'age', 'info') +        read_only_fields = ('age', 'info') + + +class AlbumsSerializer(serializers.ModelSerializer): + +    class Meta: +        model = Album +        fields = ['title', 'ref']  # lists are also valid options + + +class PositiveIntegerAsChoiceSerializer(serializers.ModelSerializer): +    class Meta: +        model = HasPositiveIntegerAsChoice +        fields = ['some_integer'] + + +class ForeignKeySourceSerializer(serializers.ModelSerializer): +    class Meta: +        model = ForeignKeySource + + +class HyperlinkedForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): +    class Meta: +        model = ForeignKeySource + + +class BasicTests(TestCase): +    def setUp(self): +        self.comment = Comment( +            'tom@example.com', +            'Happy new year!', +            datetime.datetime(2012, 1, 1) +        ) +        self.actionitem = ActionItem(title='Some to do item',) +        self.data = { +            'email': 'tom@example.com', +            'content': 'Happy new year!', +            'created': datetime.datetime(2012, 1, 1), +            'sub_comment': 'This wont change' +        } +        self.expected = { +            'email': 'tom@example.com', +            'content': 'Happy new year!', +            'created': datetime.datetime(2012, 1, 1), +            'sub_comment': 'And Merry Christmas!' +        } +        self.person_data = {'name': 'dwight', 'age': 35} +        self.person = Person(**self.person_data) +        self.person.save() + +    def test_empty(self): +        serializer = CommentSerializer() +        expected = { +            'email': '', +            'content': '', +            'created': None +        } +        self.assertEqual(serializer.data, expected) + +    def test_retrieve(self): +        serializer = CommentSerializer(self.comment) +        self.assertEqual(serializer.data, self.expected) + +    def test_create(self): +        serializer = CommentSerializer(data=self.data) +        expected = self.comment +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(serializer.object, expected) +        self.assertFalse(serializer.object is expected) +        self.assertEqual(serializer.data['sub_comment'], 'And Merry Christmas!') + +    def test_create_nested(self): +        """Test a serializer with nested data.""" +        names = {'first': 'John', 'last': 'Doe', 'initials': 'jd'} +        data = {'ssn': '1234567890', 'names': names} +        serializer = PersonIdentifierSerializer(data=data) + +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(serializer.object, data) +        self.assertFalse(serializer.object is data) +        self.assertEqual(serializer.data['names'], names) + +    def test_create_partial_nested(self): +        """Test a serializer with nested data which has missing fields.""" +        names = {'first': 'John'} +        data = {'ssn': '1234567890', 'names': names} +        serializer = PersonIdentifierSerializer(data=data) + +        expected_names = {'first': 'John', 'last': '', 'initials': ''} +        data['names'] = expected_names + +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(serializer.object, data) +        self.assertFalse(serializer.object is expected_names) +        self.assertEqual(serializer.data['names'], expected_names) + +    def test_null_nested(self): +        """Test a serializer with a nonexistent nested field""" +        data = {'ssn': '1234567890'} +        serializer = PersonIdentifierSerializer(data=data) + +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(serializer.object, data) +        self.assertFalse(serializer.object is data) +        expected = {'ssn': '1234567890', 'names': None} +        self.assertEqual(serializer.data, expected) + +    def test_update(self): +        serializer = CommentSerializer(self.comment, data=self.data) +        expected = self.comment +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(serializer.object, expected) +        self.assertTrue(serializer.object is expected) +        self.assertEqual(serializer.data['sub_comment'], 'And Merry Christmas!') + +    def test_partial_update(self): +        msg = 'Merry New Year!' +        partial_data = {'content': msg} +        serializer = CommentSerializer(self.comment, data=partial_data) +        self.assertEqual(serializer.is_valid(), False) +        serializer = CommentSerializer(self.comment, data=partial_data, partial=True) +        expected = self.comment +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(serializer.object, expected) +        self.assertTrue(serializer.object is expected) +        self.assertEqual(serializer.data['content'], msg) + +    def test_model_fields_as_expected(self): +        """ +        Make sure that the fields returned are the same as defined +        in the Meta data +        """ +        serializer = PersonSerializer(self.person) +        self.assertEqual(set(serializer.data.keys()), +                          set(['name', 'age', 'info'])) + +    def test_field_with_dictionary(self): +        """ +        Make sure that dictionaries from fields are left intact +        """ +        serializer = PersonSerializer(self.person) +        expected = self.person_data +        self.assertEqual(serializer.data['info'], expected) + +    def test_read_only_fields(self): +        """ +        Attempting to update fields set as read_only should have no effect. +        """ +        serializer = PersonSerializer(self.person, data={'name': 'dwight', 'age': 99}) +        self.assertEqual(serializer.is_valid(), True) +        instance = serializer.save() +        self.assertEqual(serializer.errors, {}) +        # Assert age is unchanged (35) +        self.assertEqual(instance.age, self.person_data['age']) + +    def test_invalid_read_only_fields(self): +        """ +        Regression test for #652. +        """ +        self.assertRaises(AssertionError, PersonSerializerInvalidReadOnly, []) + +    def test_serializer_data_is_cleared_on_save(self): +        """ +        Check _data attribute is cleared on `save()` + +        Regression test for #1116 +            — id field is not populated if `data` is accessed prior to `save()` +        """ +        serializer = ActionItemSerializer(self.actionitem) +        self.assertIsNone(serializer.data.get('id',None), 'New instance. `id` should not be set.') +        serializer.save() +        self.assertIsNotNone(serializer.data.get('id',None), 'Model is saved. `id` should be set.') + +    def test_fields_marked_as_not_required_are_excluded_from_validation(self): +        """ +        Check that fields with `required=False` are included in list of exclusions. +        """ +        serializer = ActionItemSerializerOptionalFields(self.actionitem) +        exclusions = serializer.get_validation_exclusions() +        self.assertTrue('title' in exclusions, '`title` field was marked `required=False` and should be excluded') + + +class DictStyleSerializer(serializers.Serializer): +    """ +    Note that we don't have any `restore_object` method, so the default +    case of simply returning a dict will apply. +    """ +    email = serializers.EmailField() + + +class DictStyleSerializerTests(TestCase): +    def test_dict_style_deserialize(self): +        """ +        Ensure serializers can deserialize into a dict. +        """ +        data = {'email': 'foo@example.com'} +        serializer = DictStyleSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        self.assertEqual(serializer.data, data) + +    def test_dict_style_serialize(self): +        """ +        Ensure serializers can serialize dict objects. +        """ +        data = {'email': 'foo@example.com'} +        serializer = DictStyleSerializer(data) +        self.assertEqual(serializer.data, data) + + +class ValidationTests(TestCase): +    def setUp(self): +        self.comment = Comment( +            'tom@example.com', +            'Happy new year!', +            datetime.datetime(2012, 1, 1) +        ) +        self.data = { +            'email': 'tom@example.com', +            'content': 'x' * 1001, +            'created': datetime.datetime(2012, 1, 1) +        } +        self.actionitem = ActionItem(title='Some to do item',) + +    def test_create(self): +        serializer = CommentSerializer(data=self.data) +        self.assertEqual(serializer.is_valid(), False) +        self.assertEqual(serializer.errors, {'content': ['Ensure this value has at most 1000 characters (it has 1001).']}) + +    def test_update(self): +        serializer = CommentSerializer(self.comment, data=self.data) +        self.assertEqual(serializer.is_valid(), False) +        self.assertEqual(serializer.errors, {'content': ['Ensure this value has at most 1000 characters (it has 1001).']}) + +    def test_update_missing_field(self): +        data = { +            'content': 'xxx', +            'created': datetime.datetime(2012, 1, 1) +        } +        serializer = CommentSerializer(self.comment, data=data) +        self.assertEqual(serializer.is_valid(), False) +        self.assertEqual(serializer.errors, {'email': ['This field is required.']}) + +    def test_missing_bool_with_default(self): +        """Make sure that a boolean value with a 'False' value is not +        mistaken for not having a default.""" +        data = { +            'title': 'Some action item', +            #No 'done' value. +        } +        serializer = ActionItemSerializer(self.actionitem, data=data) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(serializer.errors, {}) + +    def test_cross_field_validation(self): + +        class CommentSerializerWithCrossFieldValidator(CommentSerializer): + +            def validate(self, attrs): +                if attrs["email"] not in attrs["content"]: +                    raise serializers.ValidationError("Email address not in content") +                return attrs + +        data = { +            'email': 'tom@example.com', +            'content': 'A comment from tom@example.com', +            'created': datetime.datetime(2012, 1, 1) +        } + +        serializer = CommentSerializerWithCrossFieldValidator(data=data) +        self.assertTrue(serializer.is_valid()) + +        data['content'] = 'A comment from foo@bar.com' + +        serializer = CommentSerializerWithCrossFieldValidator(data=data) +        self.assertFalse(serializer.is_valid()) +        self.assertEqual(serializer.errors, {'non_field_errors': ['Email address not in content']}) + +    def test_null_is_true_fields(self): +        """ +        Omitting a value for null-field should validate. +        """ +        serializer = PersonSerializer(data={'name': 'marko'}) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(serializer.errors, {}) + +    def test_modelserializer_max_length_exceeded(self): +        data = { +            'title': 'x' * 201, +        } +        serializer = ActionItemSerializer(data=data) +        self.assertEqual(serializer.is_valid(), False) +        self.assertEqual(serializer.errors, {'title': ['Ensure this value has at most 200 characters (it has 201).']}) + +    def test_modelserializer_max_length_exceeded_with_custom_restore(self): +        """ +        When overriding ModelSerializer.restore_object, validation tests should still apply. +        Regression test for #623. + +        https://github.com/tomchristie/django-rest-framework/pull/623 +        """ +        data = { +            'title': 'x' * 201, +        } +        serializer = ActionItemSerializerCustomRestore(data=data) +        self.assertEqual(serializer.is_valid(), False) +        self.assertEqual(serializer.errors, {'title': ['Ensure this value has at most 200 characters (it has 201).']}) + +    def test_default_modelfield_max_length_exceeded(self): +        data = { +            'title': 'Testing "info" field...', +            'info': 'x' * 13, +        } +        serializer = ActionItemSerializer(data=data) +        self.assertEqual(serializer.is_valid(), False) +        self.assertEqual(serializer.errors, {'info': ['Ensure this value has at most 12 characters (it has 13).']}) + +    def test_datetime_validation_failure(self): +        """ +        Test DateTimeField validation errors on non-str values. +        Regression test for #669. + +        https://github.com/tomchristie/django-rest-framework/issues/669 +        """ +        data = self.data +        data['created'] = 0 + +        serializer = CommentSerializer(data=data) +        self.assertEqual(serializer.is_valid(), False) + +        self.assertIn('created', serializer.errors) + +    def test_missing_model_field_exception_msg(self): +        """ +        Assert that a meaningful exception message is outputted when the model +        field is missing (e.g. when mistyping ``model``). +        """ +        class BrokenModelSerializer(serializers.ModelSerializer): +            class Meta: +                fields = ['some_field'] + +        try: +            BrokenModelSerializer() +        except AssertionError as e: +            self.assertEqual(e.args[0], "Serializer class 'BrokenModelSerializer' is missing 'model' Meta option") +        except: +            self.fail('Wrong exception type thrown.') + +    def test_writable_star_source_on_nested_serializer(self): +        """ +        Assert that a nested serializer instantiated with source='*' correctly +        expands the data into the outer serializer. +        """ +        serializer = ModelSerializerWithNestedSerializer(data={ +            'name': 'marko', +            'nested': {'info': 'hi'}}, +        ) +        self.assertEqual(serializer.is_valid(), True) + +    def test_writable_star_source_on_nested_serializer_with_parent_object(self): +        class TitleSerializer(serializers.Serializer): +            title = serializers.WritableField(source='title') + +        class AlbumSerializer(serializers.ModelSerializer): +            nested = TitleSerializer(source='*') + +            class Meta: +                model = Album +                fields = ('nested',) + +        class PhotoSerializer(serializers.ModelSerializer): +            album = AlbumSerializer(source='album') + +            class Meta: +                model = Photo +                fields = ('album', ) + +        photo = Photo(album=Album()) + +        data = {'album': {'nested': {'title': 'test'}}} + +        serializer = PhotoSerializer(photo, data=data) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(serializer.data, data) + +    def test_writable_star_source_with_inner_source_fields(self): +        """ +        Tests that a serializer with source="*" correctly expands the +        it's fields into the outer serializer even if they have their +        own 'source' parameters. +        """ + +        serializer = ModelSerializerWithNestedSerializerWithRenamedField(data={ +            'name': 'marko', +            'nested': {'renamed_info': 'hi'}}, +        ) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(serializer.errors, {}) + + +class CustomValidationTests(TestCase): +    class CommentSerializerWithFieldValidator(CommentSerializer): + +        def validate_email(self, attrs, source): +            attrs[source] +            return attrs + +        def validate_content(self, attrs, source): +            value = attrs[source] +            if "test" not in value: +                raise serializers.ValidationError("Test not in value") +            return attrs + +    def test_field_validation(self): +        data = { +            'email': 'tom@example.com', +            'content': 'A test comment', +            'created': datetime.datetime(2012, 1, 1) +        } + +        serializer = self.CommentSerializerWithFieldValidator(data=data) +        self.assertTrue(serializer.is_valid()) + +        data['content'] = 'This should not validate' + +        serializer = self.CommentSerializerWithFieldValidator(data=data) +        self.assertFalse(serializer.is_valid()) +        self.assertEqual(serializer.errors, {'content': ['Test not in value']}) + +    def test_missing_data(self): +        """ +        Make sure that validate_content isn't called if the field is missing +        """ +        incomplete_data = { +            'email': 'tom@example.com', +            'created': datetime.datetime(2012, 1, 1) +        } +        serializer = self.CommentSerializerWithFieldValidator(data=incomplete_data) +        self.assertFalse(serializer.is_valid()) +        self.assertEqual(serializer.errors, {'content': ['This field is required.']}) + +    def test_wrong_data(self): +        """ +        Make sure that validate_content isn't called if the field input is wrong +        """ +        wrong_data = { +            'email': 'not an email', +            'content': 'A test comment', +            'created': datetime.datetime(2012, 1, 1) +        } +        serializer = self.CommentSerializerWithFieldValidator(data=wrong_data) +        self.assertFalse(serializer.is_valid()) +        self.assertEqual(serializer.errors, {'email': ['Enter a valid email address.']}) + +    def test_partial_update(self): +        """ +        Make sure that validate_email isn't called when partial=True and email +        isn't found in data. +        """ +        initial_data = { +            'email': 'tom@example.com', +            'content': 'A test comment', +            'created': datetime.datetime(2012, 1, 1) +        } + +        serializer = self.CommentSerializerWithFieldValidator(data=initial_data) +        self.assertEqual(serializer.is_valid(), True) +        instance = serializer.object + +        new_content = 'An *updated* test comment' +        partial_data = { +            'content': new_content +        } + +        serializer = self.CommentSerializerWithFieldValidator(instance=instance, +                                                              data=partial_data, +                                                              partial=True) +        self.assertEqual(serializer.is_valid(), True) +        instance = serializer.object +        self.assertEqual(instance.content, new_content) + + +class PositiveIntegerAsChoiceTests(TestCase): +    def test_positive_integer_in_json_is_correctly_parsed(self): +        data = {'some_integer': 1} +        serializer = PositiveIntegerAsChoiceSerializer(data=data) +        self.assertEqual(serializer.is_valid(), True) + + +class ModelValidationTests(TestCase): +    def test_validate_unique(self): +        """ +        Just check if serializers.ModelSerializer handles unique checks via .full_clean() +        """ +        serializer = AlbumsSerializer(data={'title': 'a', 'ref': '1'}) +        serializer.is_valid() +        serializer.save() +        second_serializer = AlbumsSerializer(data={'title': 'a'}) +        self.assertFalse(second_serializer.is_valid()) +        self.assertEqual(second_serializer.errors,  {'title': ['Album with this Title already exists.'],}) +        third_serializer = AlbumsSerializer(data=[{'title': 'b', 'ref': '1'}, {'title': 'c'}]) +        self.assertFalse(third_serializer.is_valid()) +        self.assertEqual(third_serializer.errors,  [{'ref': ['Album with this Ref already exists.']}, {}]) + +    def test_foreign_key_is_null_with_partial(self): +        """ +        Test ModelSerializer validation with partial=True + +        Specifically test that a null foreign key does not pass validation +        """ +        album = Album(title='test') +        album.save() + +        class PhotoSerializer(serializers.ModelSerializer): +            class Meta: +                model = Photo + +        photo_serializer = PhotoSerializer(data={'description': 'test', 'album': album.pk}) +        self.assertTrue(photo_serializer.is_valid()) +        photo = photo_serializer.save() + +        # Updating only the album (foreign key) +        photo_serializer = PhotoSerializer(instance=photo, data={'album': ''}, partial=True) +        self.assertFalse(photo_serializer.is_valid()) +        self.assertTrue('album' in photo_serializer.errors) +        self.assertEqual(photo_serializer.errors['album'], photo_serializer.error_messages['required']) + +    def test_foreign_key_with_partial(self): +        """ +        Test ModelSerializer validation with partial=True + +        Specifically test foreign key validation. +        """ + +        album = Album(title='test') +        album.save() + +        class PhotoSerializer(serializers.ModelSerializer): +            class Meta: +                model = Photo + +        photo_serializer = PhotoSerializer(data={'description': 'test', 'album': album.pk}) +        self.assertTrue(photo_serializer.is_valid()) +        photo = photo_serializer.save() + +        # Updating only the album (foreign key) +        photo_serializer = PhotoSerializer(instance=photo, data={'album': album.pk}, partial=True) +        self.assertTrue(photo_serializer.is_valid()) +        self.assertTrue(photo_serializer.save()) + +        # Updating only the description +        photo_serializer = PhotoSerializer(instance=photo, +                                           data={'description': 'new'}, +                                           partial=True) + +        self.assertTrue(photo_serializer.is_valid()) +        self.assertTrue(photo_serializer.save()) + + +class RegexValidationTest(TestCase): +    def test_create_failed(self): +        serializer = BookSerializer(data={'isbn': '1234567890'}) +        self.assertFalse(serializer.is_valid()) +        self.assertEqual(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']}) + +        serializer = BookSerializer(data={'isbn': '12345678901234'}) +        self.assertFalse(serializer.is_valid()) +        self.assertEqual(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']}) + +        serializer = BookSerializer(data={'isbn': 'abcdefghijklm'}) +        self.assertFalse(serializer.is_valid()) +        self.assertEqual(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']}) + +    def test_create_success(self): +        serializer = BookSerializer(data={'isbn': '1234567890123'}) +        self.assertTrue(serializer.is_valid()) + + +class MetadataTests(TestCase): +    def test_empty(self): +        serializer = CommentSerializer() +        expected = { +            'email': serializers.CharField, +            'content': serializers.CharField, +            'created': serializers.DateTimeField +        } +        for field_name, field in expected.items(): +            self.assertTrue(isinstance(serializer.data.fields[field_name], field)) + + +class ManyToManyTests(TestCase): +    def setUp(self): +        class ManyToManySerializer(serializers.ModelSerializer): +            class Meta: +                model = ManyToManyModel + +        self.serializer_class = ManyToManySerializer + +        # An anchor instance to use for the relationship +        self.anchor = Anchor() +        self.anchor.save() + +        # A model instance with a many to many relationship to the anchor +        self.instance = ManyToManyModel() +        self.instance.save() +        self.instance.rel.add(self.anchor) + +        # A serialized representation of the model instance +        self.data = {'id': 1, 'rel': [self.anchor.id]} + +    def test_retrieve(self): +        """ +        Serialize an instance of a model with a ManyToMany relationship. +        """ +        serializer = self.serializer_class(instance=self.instance) +        expected = self.data +        self.assertEqual(serializer.data, expected) + +    def test_create(self): +        """ +        Create an instance of a model with a ManyToMany relationship. +        """ +        data = {'rel': [self.anchor.id]} +        serializer = self.serializer_class(data=data) +        self.assertEqual(serializer.is_valid(), True) +        instance = serializer.save() +        self.assertEqual(len(ManyToManyModel.objects.all()), 2) +        self.assertEqual(instance.pk, 2) +        self.assertEqual(list(instance.rel.all()), [self.anchor]) + +    def test_update(self): +        """ +        Update an instance of a model with a ManyToMany relationship. +        """ +        new_anchor = Anchor() +        new_anchor.save() +        data = {'rel': [self.anchor.id, new_anchor.id]} +        serializer = self.serializer_class(self.instance, data=data) +        self.assertEqual(serializer.is_valid(), True) +        instance = serializer.save() +        self.assertEqual(len(ManyToManyModel.objects.all()), 1) +        self.assertEqual(instance.pk, 1) +        self.assertEqual(list(instance.rel.all()), [self.anchor, new_anchor]) + +    def test_create_empty_relationship(self): +        """ +        Create an instance of a model with a ManyToMany relationship, +        containing no items. +        """ +        data = {'rel': []} +        serializer = self.serializer_class(data=data) +        self.assertEqual(serializer.is_valid(), True) +        instance = serializer.save() +        self.assertEqual(len(ManyToManyModel.objects.all()), 2) +        self.assertEqual(instance.pk, 2) +        self.assertEqual(list(instance.rel.all()), []) + +    def test_update_empty_relationship(self): +        """ +        Update an instance of a model with a ManyToMany relationship, +        containing no items. +        """ +        new_anchor = Anchor() +        new_anchor.save() +        data = {'rel': []} +        serializer = self.serializer_class(self.instance, data=data) +        self.assertEqual(serializer.is_valid(), True) +        instance = serializer.save() +        self.assertEqual(len(ManyToManyModel.objects.all()), 1) +        self.assertEqual(instance.pk, 1) +        self.assertEqual(list(instance.rel.all()), []) + +    def test_create_empty_relationship_flat_data(self): +        """ +        Create an instance of a model with a ManyToMany relationship, +        containing no items, using a representation that does not support +        lists (eg form data). +        """ +        data = MultiValueDict() +        data.setlist('rel', ['']) +        serializer = self.serializer_class(data=data) +        self.assertEqual(serializer.is_valid(), True) +        instance = serializer.save() +        self.assertEqual(len(ManyToManyModel.objects.all()), 2) +        self.assertEqual(instance.pk, 2) +        self.assertEqual(list(instance.rel.all()), []) + + +class ReadOnlyManyToManyTests(TestCase): +    def setUp(self): +        class ReadOnlyManyToManySerializer(serializers.ModelSerializer): +            rel = serializers.RelatedField(many=True, read_only=True) + +            class Meta: +                model = ReadOnlyManyToManyModel + +        self.serializer_class = ReadOnlyManyToManySerializer + +        # An anchor instance to use for the relationship +        self.anchor = Anchor() +        self.anchor.save() + +        # A model instance with a many to many relationship to the anchor +        self.instance = ReadOnlyManyToManyModel() +        self.instance.save() +        self.instance.rel.add(self.anchor) + +        # A serialized representation of the model instance +        self.data = {'rel': [self.anchor.id], 'id': 1, 'text': 'anchor'} + +    def test_update(self): +        """ +        Attempt to update an instance of a model with a ManyToMany +        relationship.  Not updated due to read_only=True +        """ +        new_anchor = Anchor() +        new_anchor.save() +        data = {'rel': [self.anchor.id, new_anchor.id]} +        serializer = self.serializer_class(self.instance, data=data) +        self.assertEqual(serializer.is_valid(), True) +        instance = serializer.save() +        self.assertEqual(len(ReadOnlyManyToManyModel.objects.all()), 1) +        self.assertEqual(instance.pk, 1) +        # rel is still as original (1 entry) +        self.assertEqual(list(instance.rel.all()), [self.anchor]) + +    def test_update_without_relationship(self): +        """ +        Attempt to update an instance of a model where many to ManyToMany +        relationship is not supplied.  Not updated due to read_only=True +        """ +        new_anchor = Anchor() +        new_anchor.save() +        data = {} +        serializer = self.serializer_class(self.instance, data=data) +        self.assertEqual(serializer.is_valid(), True) +        instance = serializer.save() +        self.assertEqual(len(ReadOnlyManyToManyModel.objects.all()), 1) +        self.assertEqual(instance.pk, 1) +        # rel is still as original (1 entry) +        self.assertEqual(list(instance.rel.all()), [self.anchor]) + + +class DefaultValueTests(TestCase): +    def setUp(self): +        class DefaultValueSerializer(serializers.ModelSerializer): +            class Meta: +                model = DefaultValueModel + +        self.serializer_class = DefaultValueSerializer +        self.objects = DefaultValueModel.objects + +    def test_create_using_default(self): +        data = {} +        serializer = self.serializer_class(data=data) +        self.assertEqual(serializer.is_valid(), True) +        instance = serializer.save() +        self.assertEqual(len(self.objects.all()), 1) +        self.assertEqual(instance.pk, 1) +        self.assertEqual(instance.text, 'foobar') + +    def test_create_overriding_default(self): +        data = {'text': 'overridden'} +        serializer = self.serializer_class(data=data) +        self.assertEqual(serializer.is_valid(), True) +        instance = serializer.save() +        self.assertEqual(len(self.objects.all()), 1) +        self.assertEqual(instance.pk, 1) +        self.assertEqual(instance.text, 'overridden') + +    def test_partial_update_default(self): +        """ Regression test for issue #532 """ +        data = {'text': 'overridden'} +        serializer = self.serializer_class(data=data, partial=True) +        self.assertEqual(serializer.is_valid(), True) +        instance = serializer.save() + +        data = {'extra': 'extra_value'} +        serializer = self.serializer_class(instance=instance, data=data, partial=True) +        self.assertEqual(serializer.is_valid(), True) +        instance = serializer.save() + +        self.assertEqual(instance.extra, 'extra_value') +        self.assertEqual(instance.text, 'overridden') + + +class WritableFieldDefaultValueTests(TestCase): + +    def setUp(self): +        self.expected = {'default': 'value'} +        self.create_field = fields.WritableField + +    def test_get_default_value_with_noncallable(self): +        field = self.create_field(default=self.expected) +        got = field.get_default_value() +        self.assertEqual(got, self.expected) + +    def test_get_default_value_with_callable(self): +        field = self.create_field(default=lambda : self.expected) +        got = field.get_default_value() +        self.assertEqual(got, self.expected) + +    def test_get_default_value_when_not_required(self): +        field = self.create_field(default=self.expected, required=False) +        got = field.get_default_value() +        self.assertEqual(got, self.expected) + +    def test_get_default_value_returns_None(self): +        field = self.create_field() +        got = field.get_default_value() +        self.assertIsNone(got) + +    def test_get_default_value_returns_non_True_values(self): +        values = [None, '', False, 0, [], (), {}] # values that assumed as 'False' in the 'if' clause +        for expected in values: +            field = self.create_field(default=expected) +            got = field.get_default_value() +            self.assertEqual(got, expected) + + +class RelatedFieldDefaultValueTests(WritableFieldDefaultValueTests): + +    def setUp(self): +        self.expected = {'foo': 'bar'} +        self.create_field = relations.RelatedField + +    def test_get_default_value_returns_empty_list(self): +        field = self.create_field(many=True) +        got = field.get_default_value() +        self.assertListEqual(got, []) + +    def test_get_default_value_returns_expected(self): +        expected = [1, 2, 3] +        field = self.create_field(many=True, default=expected) +        got = field.get_default_value() +        self.assertListEqual(got, expected) + + +class CallableDefaultValueTests(TestCase): +    def setUp(self): +        class CallableDefaultValueSerializer(serializers.ModelSerializer): +            class Meta: +                model = CallableDefaultValueModel + +        self.serializer_class = CallableDefaultValueSerializer +        self.objects = CallableDefaultValueModel.objects + +    def test_create_using_default(self): +        data = {} +        serializer = self.serializer_class(data=data) +        self.assertEqual(serializer.is_valid(), True) +        instance = serializer.save() +        self.assertEqual(len(self.objects.all()), 1) +        self.assertEqual(instance.pk, 1) +        self.assertEqual(instance.text, 'foobar') + +    def test_create_overriding_default(self): +        data = {'text': 'overridden'} +        serializer = self.serializer_class(data=data) +        self.assertEqual(serializer.is_valid(), True) +        instance = serializer.save() +        self.assertEqual(len(self.objects.all()), 1) +        self.assertEqual(instance.pk, 1) +        self.assertEqual(instance.text, 'overridden') + + +class ManyRelatedTests(TestCase): +    def test_reverse_relations(self): +        post = BlogPost.objects.create(title="Test blog post") +        post.blogpostcomment_set.create(text="I hate this blog post") +        post.blogpostcomment_set.create(text="I love this blog post") + +        class BlogPostCommentSerializer(serializers.Serializer): +            text = serializers.CharField() + +        class BlogPostSerializer(serializers.Serializer): +            title = serializers.CharField() +            comments = BlogPostCommentSerializer(source='blogpostcomment_set') + +        serializer = BlogPostSerializer(instance=post) +        expected = { +            'title': 'Test blog post', +            'comments': [ +                {'text': 'I hate this blog post'}, +                {'text': 'I love this blog post'} +            ] +        } + +        self.assertEqual(serializer.data, expected) + +    def test_include_reverse_relations(self): +        post = BlogPost.objects.create(title="Test blog post") +        post.blogpostcomment_set.create(text="I hate this blog post") +        post.blogpostcomment_set.create(text="I love this blog post") + +        class BlogPostSerializer(serializers.ModelSerializer): +            class Meta: +                model = BlogPost +                fields = ('id', 'title', 'blogpostcomment_set') + +        serializer = BlogPostSerializer(instance=post) +        expected = { +            'id': 1, 'title': 'Test blog post', 'blogpostcomment_set': [1, 2] +        } +        self.assertEqual(serializer.data, expected) + +    def test_depth_include_reverse_relations(self): +        post = BlogPost.objects.create(title="Test blog post") +        post.blogpostcomment_set.create(text="I hate this blog post") +        post.blogpostcomment_set.create(text="I love this blog post") + +        class BlogPostSerializer(serializers.ModelSerializer): +            class Meta: +                model = BlogPost +                fields = ('id', 'title', 'blogpostcomment_set') +                depth = 1 + +        serializer = BlogPostSerializer(instance=post) +        expected = { +            'id': 1, 'title': 'Test blog post', +            'blogpostcomment_set': [ +                {'id': 1, 'text': 'I hate this blog post', 'blog_post': 1}, +                {'id': 2, 'text': 'I love this blog post', 'blog_post': 1} +            ] +        } +        self.assertEqual(serializer.data, expected) + +    def test_callable_source(self): +        post = BlogPost.objects.create(title="Test blog post") +        post.blogpostcomment_set.create(text="I love this blog post") + +        class BlogPostCommentSerializer(serializers.Serializer): +            text = serializers.CharField() + +        class BlogPostSerializer(serializers.Serializer): +            title = serializers.CharField() +            first_comment = BlogPostCommentSerializer(source='get_first_comment') + +        serializer = BlogPostSerializer(post) + +        expected = { +            'title': 'Test blog post', +            'first_comment': {'text': 'I love this blog post'} +        } +        self.assertEqual(serializer.data, expected) + + +class RelatedTraversalTest(TestCase): +    def test_nested_traversal(self): +        """ +        Source argument should support dotted.source notation. +        """ +        user = Person.objects.create(name="django") +        post = BlogPost.objects.create(title="Test blog post", writer=user) +        post.blogpostcomment_set.create(text="I love this blog post") + +        class PersonSerializer(serializers.ModelSerializer): +            class Meta: +                model = Person +                fields = ("name", "age") + +        class BlogPostCommentSerializer(serializers.ModelSerializer): +            class Meta: +                model = BlogPostComment +                fields = ("text", "post_owner") + +            text = serializers.CharField() +            post_owner = PersonSerializer(source='blog_post.writer') + +        class BlogPostSerializer(serializers.Serializer): +            title = serializers.CharField() +            comments = BlogPostCommentSerializer(source='blogpostcomment_set') + +        serializer = BlogPostSerializer(instance=post) + +        expected = { +            'title': 'Test blog post', +            'comments': [{ +                'text': 'I love this blog post', +                'post_owner': { +                    "name": "django", +                    "age": None +                } +            }] +        } + +        self.assertEqual(serializer.data, expected) + +    def test_nested_traversal_with_none(self): +        """ +        If a component of the dotted.source is None, return None for the field. +        """ +        from tests.models import NullableForeignKeySource +        instance = NullableForeignKeySource.objects.create(name='Source with null FK') + +        class NullableSourceSerializer(serializers.Serializer): +            target_name = serializers.Field(source='target.name') + +        serializer = NullableSourceSerializer(instance=instance) + +        expected = { +            'target_name': None, +        } + +        self.assertEqual(serializer.data, expected) + + +class SerializerMethodFieldTests(TestCase): +    def setUp(self): + +        class BoopSerializer(serializers.Serializer): +            beep = serializers.SerializerMethodField('get_beep') +            boop = serializers.Field() +            boop_count = serializers.SerializerMethodField('get_boop_count') + +            def get_beep(self, obj): +                return 'hello!' + +            def get_boop_count(self, obj): +                return len(obj.boop) + +        self.serializer_class = BoopSerializer + +    def test_serializer_method_field(self): + +        class MyModel(object): +            boop = ['a', 'b', 'c'] + +        source_data = MyModel() + +        serializer = self.serializer_class(source_data) + +        expected = { +            'beep': 'hello!', +            'boop': ['a', 'b', 'c'], +            'boop_count': 3, +        } + +        self.assertEqual(serializer.data, expected) + + +# Test for issue #324 +class BlankFieldTests(TestCase): +    def setUp(self): + +        class BlankFieldModelSerializer(serializers.ModelSerializer): +            class Meta: +                model = BlankFieldModel + +        class BlankFieldSerializer(serializers.Serializer): +            title = serializers.CharField(required=False) + +        class NotBlankFieldModelSerializer(serializers.ModelSerializer): +            class Meta: +                model = BasicModel + +        class NotBlankFieldSerializer(serializers.Serializer): +            title = serializers.CharField() + +        self.model_serializer_class = BlankFieldModelSerializer +        self.serializer_class = BlankFieldSerializer +        self.not_blank_model_serializer_class = NotBlankFieldModelSerializer +        self.not_blank_serializer_class = NotBlankFieldSerializer +        self.data = {'title': ''} + +    def test_create_blank_field(self): +        serializer = self.serializer_class(data=self.data) +        self.assertEqual(serializer.is_valid(), True) + +    def test_create_model_blank_field(self): +        serializer = self.model_serializer_class(data=self.data) +        self.assertEqual(serializer.is_valid(), True) + +    def test_create_model_null_field(self): +        serializer = self.model_serializer_class(data={'title': None}) +        self.assertEqual(serializer.is_valid(), True) + +    def test_create_not_blank_field(self): +        """ +        Test to ensure blank data in a field not marked as blank=True +        is considered invalid in a non-model serializer +        """ +        serializer = self.not_blank_serializer_class(data=self.data) +        self.assertEqual(serializer.is_valid(), False) + +    def test_create_model_not_blank_field(self): +        """ +        Test to ensure blank data in a field not marked as blank=True +        is considered invalid in a model serializer +        """ +        serializer = self.not_blank_model_serializer_class(data=self.data) +        self.assertEqual(serializer.is_valid(), False) + +    def test_create_model_empty_field(self): +        serializer = self.model_serializer_class(data={}) +        self.assertEqual(serializer.is_valid(), True) + + +#test for issue #460 +class SerializerPickleTests(TestCase): +    """ +    Test pickleability of the output of Serializers +    """ +    def test_pickle_simple_model_serializer_data(self): +        """ +        Test simple serializer +        """ +        pickle.dumps(PersonSerializer(Person(name="Methusela", age=969)).data) + +    def test_pickle_inner_serializer(self): +        """ +        Test pickling a serializer whose resulting .data (a SortedDictWithMetadata) will +        have unpickleable meta data--in order to make sure metadata doesn't get pulled into the pickle. +        See DictWithMetadata.__getstate__ +        """ +        class InnerPersonSerializer(serializers.ModelSerializer): +            class Meta: +                model = Person +                fields = ('name', 'age') +        pickle.dumps(InnerPersonSerializer(Person(name="Noah", age=950)).data, 0) + +    def test_getstate_method_should_not_return_none(self): +        """ +        Regression test for #645. +        """ +        data = serializers.DictWithMetadata({1: 1}) +        self.assertEqual(data.__getstate__(), serializers.SortedDict({1: 1})) + +    def test_serializer_data_is_pickleable(self): +        """ +        Another regression test for #645. +        """ +        data = serializers.SortedDictWithMetadata({1: 1}) +        repr(pickle.loads(pickle.dumps(data, 0))) + + +# test for issue #725 +class SeveralChoicesModel(models.Model): +    color = models.CharField( +        max_length=10, +        choices=[('red', 'Red'), ('green', 'Green'), ('blue', 'Blue')], +        blank=False +    ) +    drink = models.CharField( +        max_length=10, +        choices=[('beer', 'Beer'), ('wine', 'Wine'), ('cider', 'Cider')], +        blank=False, +        default='beer' +    ) +    os = models.CharField( +        max_length=10, +        choices=[('linux', 'Linux'), ('osx', 'OSX'), ('windows', 'Windows')], +        blank=True +    ) +    music_genre = models.CharField( +        max_length=10, +        choices=[('rock', 'Rock'), ('metal', 'Metal'), ('grunge', 'Grunge')], +        blank=True, +        default='metal' +    ) + + +class SerializerChoiceFields(TestCase): + +    def setUp(self): +        super(SerializerChoiceFields, self).setUp() + +        class SeveralChoicesSerializer(serializers.ModelSerializer): +            class Meta: +                model = SeveralChoicesModel +                fields = ('color', 'drink', 'os', 'music_genre') + +        self.several_choices_serializer = SeveralChoicesSerializer + +    def test_choices_blank_false_not_default(self): +        serializer = self.several_choices_serializer() +        self.assertEqual( +            serializer.fields['color'].choices, +            [('red', 'Red'), ('green', 'Green'), ('blue', 'Blue')] +        ) + +    def test_choices_blank_false_with_default(self): +        serializer = self.several_choices_serializer() +        self.assertEqual( +            serializer.fields['drink'].choices, +            [('beer', 'Beer'), ('wine', 'Wine'), ('cider', 'Cider')] +        ) + +    def test_choices_blank_true_not_default(self): +        serializer = self.several_choices_serializer() +        self.assertEqual( +            serializer.fields['os'].choices, +            BLANK_CHOICE_DASH + [('linux', 'Linux'), ('osx', 'OSX'), ('windows', 'Windows')] +        ) + +    def test_choices_blank_true_with_default(self): +        serializer = self.several_choices_serializer() +        self.assertEqual( +            serializer.fields['music_genre'].choices, +            BLANK_CHOICE_DASH + [('rock', 'Rock'), ('metal', 'Metal'), ('grunge', 'Grunge')] +        ) + + +# Regression tests for #675 +class Ticket(models.Model): +    assigned = models.ForeignKey( +        Person, related_name='assigned_tickets') +    reviewer = models.ForeignKey( +        Person, blank=True, null=True, related_name='reviewed_tickets') + + +class SerializerRelatedChoicesTest(TestCase): + +    def setUp(self): +        super(SerializerRelatedChoicesTest, self).setUp() + +        class RelatedChoicesSerializer(serializers.ModelSerializer): +            class Meta: +                model = Ticket +                fields = ('assigned', 'reviewer') + +        self.related_fields_serializer = RelatedChoicesSerializer + +    def test_empty_queryset_required(self): +        serializer = self.related_fields_serializer() +        self.assertEqual(serializer.fields['assigned'].queryset.count(), 0) +        self.assertEqual( +            [x for x in serializer.fields['assigned'].widget.choices], +            [] +        ) + +    def test_empty_queryset_not_required(self): +        serializer = self.related_fields_serializer() +        self.assertEqual(serializer.fields['reviewer'].queryset.count(), 0) +        self.assertEqual( +            [x for x in serializer.fields['reviewer'].widget.choices], +            [('', '---------')] +        ) + +    def test_with_some_persons_required(self): +        Person.objects.create(name="Lionel Messi") +        Person.objects.create(name="Xavi Hernandez") +        serializer = self.related_fields_serializer() +        self.assertEqual(serializer.fields['assigned'].queryset.count(), 2) +        self.assertEqual( +            [x for x in serializer.fields['assigned'].widget.choices], +            [(1, 'Person object - 1'), (2, 'Person object - 2')] +        ) + +    def test_with_some_persons_not_required(self): +        Person.objects.create(name="Lionel Messi") +        Person.objects.create(name="Xavi Hernandez") +        serializer = self.related_fields_serializer() +        self.assertEqual(serializer.fields['reviewer'].queryset.count(), 2) +        self.assertEqual( +            [x for x in serializer.fields['reviewer'].widget.choices], +            [('', '---------'), (1, 'Person object - 1'), (2, 'Person object - 2')] +        ) + + +class DepthTest(TestCase): +    def test_implicit_nesting(self): + +        writer = Person.objects.create(name="django", age=1) +        post = BlogPost.objects.create(title="Test blog post", writer=writer) +        comment = BlogPostComment.objects.create(text="Test blog post comment", blog_post=post) + +        class BlogPostCommentSerializer(serializers.ModelSerializer): +            class Meta: +                model = BlogPostComment +                depth = 2 + +        serializer = BlogPostCommentSerializer(instance=comment) +        expected = {'id': 1, 'text': 'Test blog post comment', 'blog_post': {'id': 1, 'title': 'Test blog post', +                    'writer': {'id': 1, 'name': 'django', 'age': 1}}} + +        self.assertEqual(serializer.data, expected) + +    def test_explicit_nesting(self): +        writer = Person.objects.create(name="django", age=1) +        post = BlogPost.objects.create(title="Test blog post", writer=writer) +        comment = BlogPostComment.objects.create(text="Test blog post comment", blog_post=post) + +        class PersonSerializer(serializers.ModelSerializer): +            class Meta: +                model = Person + +        class BlogPostSerializer(serializers.ModelSerializer): +            writer = PersonSerializer() + +            class Meta: +                model = BlogPost + +        class BlogPostCommentSerializer(serializers.ModelSerializer): +            blog_post = BlogPostSerializer() + +            class Meta: +                model = BlogPostComment + +        serializer = BlogPostCommentSerializer(instance=comment) +        expected = {'id': 1, 'text': 'Test blog post comment', 'blog_post': {'id': 1, 'title': 'Test blog post', +                    'writer': {'id': 1, 'name': 'django', 'age': 1}}} + +        self.assertEqual(serializer.data, expected) + + +class NestedSerializerContextTests(TestCase): + +    def test_nested_serializer_context(self): +        """ +        Regression for #497 + +        https://github.com/tomchristie/django-rest-framework/issues/497 +        """ +        class PhotoSerializer(serializers.ModelSerializer): +            class Meta: +                model = Photo +                fields = ("description", "callable") + +            callable = serializers.SerializerMethodField('_callable') + +            def _callable(self, instance): +                if not 'context_item' in self.context: +                    raise RuntimeError("context isn't getting passed into 2nd level nested serializer") +                return "success" + +        class AlbumSerializer(serializers.ModelSerializer): +            class Meta: +                model = Album +                fields = ("photo_set", "callable") + +            photo_set = PhotoSerializer(source="photo_set") +            callable = serializers.SerializerMethodField("_callable") + +            def _callable(self, instance): +                if not 'context_item' in self.context: +                    raise RuntimeError("context isn't getting passed into 1st level nested serializer") +                return "success" + +        class AlbumCollection(object): +            albums = None + +        class AlbumCollectionSerializer(serializers.Serializer): +            albums = AlbumSerializer(source="albums") + +        album1 = Album.objects.create(title="album 1") +        album2 = Album.objects.create(title="album 2") +        Photo.objects.create(description="Bigfoot", album=album1) +        Photo.objects.create(description="Unicorn", album=album1) +        Photo.objects.create(description="Yeti", album=album2) +        Photo.objects.create(description="Sasquatch", album=album2) +        album_collection = AlbumCollection() +        album_collection.albums = [album1, album2] + +        # This will raise RuntimeError if context doesn't get passed correctly to the nested Serializers +        AlbumCollectionSerializer(album_collection, context={'context_item': 'album context'}).data + + +class DeserializeListTestCase(TestCase): + +    def setUp(self): +        self.data = { +            'email': 'nobody@nowhere.com', +            'content': 'This is some test content', +            'created': datetime.datetime(2013, 3, 7), +        } + +    def test_no_errors(self): +        data = [self.data.copy() for x in range(0, 3)] +        serializer = CommentSerializer(data=data, many=True) +        self.assertTrue(serializer.is_valid()) +        self.assertTrue(isinstance(serializer.object, list)) +        self.assertTrue( +            all((isinstance(item, Comment) for item in serializer.object)) +        ) + +    def test_errors_return_as_list(self): +        invalid_item = self.data.copy() +        invalid_item['email'] = '' +        data = [self.data.copy(), invalid_item, self.data.copy()] + +        serializer = CommentSerializer(data=data, many=True) +        self.assertFalse(serializer.is_valid()) +        expected = [{}, {'email': ['This field is required.']}, {}] +        self.assertEqual(serializer.errors, expected) + + +# Test for issue 747 + +class LazyStringModel(object): +    def __init__(self, lazystring): +        self.lazystring = lazystring + + +class LazyStringSerializer(serializers.Serializer): +    lazystring = serializers.Field() + +    def restore_object(self, attrs, instance=None): +        if instance is not None: +            instance.lazystring = attrs.get('lazystring', instance.lazystring) +            return instance +        return LazyStringModel(**attrs) + + +class LazyStringsTestCase(TestCase): +    def setUp(self): +        self.model = LazyStringModel(lazystring=_('lazystring')) + +    def test_lazy_strings_are_translated(self): +        serializer = LazyStringSerializer(self.model) +        self.assertEqual(type(serializer.data['lazystring']), +                         type('lazystring')) + + +# Test for issue #467 + +class FieldLabelTest(TestCase): +    def setUp(self): +        self.serializer_class = BasicModelSerializer + +    def test_label_from_model(self): +        """ +        Validates that label and help_text are correctly copied from the model class. +        """ +        serializer = self.serializer_class() +        text_field = serializer.fields['text'] + +        self.assertEqual('Text comes here', text_field.label) +        self.assertEqual('Text description.', text_field.help_text) + +    def test_field_ctor(self): +        """ +        This is check that ctor supports both label and help_text. +        """ +        self.assertEqual('Label', fields.Field(label='Label', help_text='Help').label) +        self.assertEqual('Help', fields.CharField(label='Label', help_text='Help').help_text) +        self.assertEqual('Label', relations.HyperlinkedRelatedField(view_name='fake', label='Label', help_text='Help', many=True).label) + + +# Test for issue #961 + +class ManyFieldHelpTextTest(TestCase): +    def test_help_text_no_hold_down_control_msg(self): +        """ +        Validate that help_text doesn't contain the 'Hold down "Control" ...' +        message that Django appends to choice fields. +        """ +        rel_field = fields.Field(help_text=ManyToManyModel._meta.get_field('rel').help_text) +        self.assertEqual('Some help text.', rel_field.help_text) + + +class AttributeMappingOnAutogeneratedRelatedFields(TestCase): + +    def test_primary_key_related_field(self): +        serializer = ForeignKeySourceSerializer() +        self.assertEqual(serializer.fields['target'].help_text, 'Target') +        self.assertEqual(serializer.fields['target'].label, 'Target') + +    def test_hyperlinked_related_field(self): +        serializer = HyperlinkedForeignKeySourceSerializer() +        self.assertEqual(serializer.fields['target'].help_text, 'Target') +        self.assertEqual(serializer.fields['target'].label, 'Target') + + +@unittest.skipUnless(PIL is not None, 'PIL is not installed') +class AttributeMappingOnAutogeneratedFieldsTests(TestCase): + +    def setUp(self): + +        class AMOAFSerializer(serializers.ModelSerializer): +            class Meta: +                model = AMOAFModel + +        self.serializer_class = AMOAFSerializer +        self.fields_attributes = { +            'char_field': [ +                ('max_length', 1024), +            ], +            'comma_separated_integer_field': [ +                ('max_length', 1024), +            ], +            'decimal_field': [ +                ('max_digits', 64), +                ('decimal_places', 32), +            ], +            'email_field': [ +                ('max_length', 1024), +            ], +            'file_field': [ +                ('max_length', 1024), +            ], +            'image_field': [ +                ('max_length', 1024), +            ], +            'slug_field': [ +                ('max_length', 1024), +            ], +            'url_field': [ +                ('max_length', 1024), +            ], +        } + +    def field_test(self, field): +        serializer = self.serializer_class(data={}) +        self.assertEqual(serializer.is_valid(), True) + +        for attribute in self.fields_attributes[field]: +            self.assertEqual( +                getattr(serializer.fields[field], attribute[0]), +                attribute[1] +            ) + +    def test_char_field(self): +        self.field_test('char_field') + +    def test_comma_separated_integer_field(self): +        self.field_test('comma_separated_integer_field') + +    def test_decimal_field(self): +        self.field_test('decimal_field') + +    def test_email_field(self): +        self.field_test('email_field') + +    def test_file_field(self): +        self.field_test('file_field') + +    def test_image_field(self): +        self.field_test('image_field') + +    def test_slug_field(self): +        self.field_test('slug_field') + +    def test_url_field(self): +        self.field_test('url_field') + + +@unittest.skipUnless(PIL is not None, 'PIL is not installed') +class DefaultValuesOnAutogeneratedFieldsTests(TestCase): + +    def setUp(self): + +        class DVOAFSerializer(serializers.ModelSerializer): +            class Meta: +                model = DVOAFModel + +        self.serializer_class = DVOAFSerializer +        self.fields_attributes = { +            'positive_integer_field': [ +                ('min_value', 0), +            ], +            'positive_small_integer_field': [ +                ('min_value', 0), +            ], +            'email_field': [ +                ('max_length', 75), +            ], +            'file_field': [ +                ('max_length', 100), +            ], +            'image_field': [ +                ('max_length', 100), +            ], +            'slug_field': [ +                ('max_length', 50), +            ], +            'url_field': [ +                ('max_length', 200), +            ], +        } + +    def field_test(self, field): +        serializer = self.serializer_class(data={}) +        self.assertEqual(serializer.is_valid(), True) + +        for attribute in self.fields_attributes[field]: +            self.assertEqual( +                getattr(serializer.fields[field], attribute[0]), +                attribute[1] +            ) + +    def test_positive_integer_field(self): +        self.field_test('positive_integer_field') + +    def test_positive_small_integer_field(self): +        self.field_test('positive_small_integer_field') + +    def test_email_field(self): +        self.field_test('email_field') + +    def test_file_field(self): +        self.field_test('file_field') + +    def test_image_field(self): +        self.field_test('image_field') + +    def test_slug_field(self): +        self.field_test('slug_field') + +    def test_url_field(self): +        self.field_test('url_field') + + +class MetadataSerializer(serializers.Serializer): +    field1 = serializers.CharField(3, required=True) +    field2 = serializers.CharField(10, required=False) + + +class MetadataSerializerTestCase(TestCase): +    def setUp(self): +        self.serializer = MetadataSerializer() + +    def test_serializer_metadata(self): +        metadata = self.serializer.metadata() +        expected = { +            'field1': { +                'required': True, +                'max_length': 3, +                'type': 'string', +                'read_only': False +            }, +            'field2': { +                'required': False, +                'max_length': 10, +                'type': 'string', +                'read_only': False +            } +        } +        self.assertEqual(expected, metadata) + + +### Regression test for #840 + +class SimpleModel(models.Model): +    text = models.CharField(max_length=100) + + +class SimpleModelSerializer(serializers.ModelSerializer): +    text = serializers.CharField() +    other = serializers.CharField() + +    class Meta: +        model = SimpleModel + +    def validate_other(self, attrs, source): +        del attrs['other'] +        return attrs + + +class FieldValidationRemovingAttr(TestCase): +    def test_removing_non_model_field_in_validation(self): +        """ +        Removing an attr during field valiation should ensure that it is not +        passed through when restoring the object. + +        This allows additional non-model fields to be supported. + +        Regression test for #840. +        """ +        serializer = SimpleModelSerializer(data={'text': 'foo', 'other': 'bar'}) +        self.assertTrue(serializer.is_valid()) +        serializer.save() +        self.assertEqual(serializer.object.text, 'foo') + + +### Regression test for #878 + +class SimpleTargetModel(models.Model): +    text = models.CharField(max_length=100) + + +class SimplePKSourceModelSerializer(serializers.Serializer): +    targets = serializers.PrimaryKeyRelatedField(queryset=SimpleTargetModel.objects.all(), many=True) +    text = serializers.CharField() + + +class SimpleSlugSourceModelSerializer(serializers.Serializer): +    targets = serializers.SlugRelatedField(queryset=SimpleTargetModel.objects.all(), many=True, slug_field='pk') +    text = serializers.CharField() + + +class SerializerSupportsManyRelationships(TestCase): +    def setUp(self): +        SimpleTargetModel.objects.create(text='foo') +        SimpleTargetModel.objects.create(text='bar') + +    def test_serializer_supports_pk_many_relationships(self): +        """ +        Regression test for #878. + +        Note that pk behavior has a different code path to usual cases, +        for performance reasons. +        """ +        serializer = SimplePKSourceModelSerializer(data={'text': 'foo', 'targets': [1, 2]}) +        self.assertTrue(serializer.is_valid()) +        self.assertEqual(serializer.data, {'text': 'foo', 'targets': [1, 2]}) + +    def test_serializer_supports_slug_many_relationships(self): +        """ +        Regression test for #878. +        """ +        serializer = SimpleSlugSourceModelSerializer(data={'text': 'foo', 'targets': [1, 2]}) +        self.assertTrue(serializer.is_valid()) +        self.assertEqual(serializer.data, {'text': 'foo', 'targets': [1, 2]}) + + +class TransformMethodsSerializer(serializers.Serializer): +    a = serializers.CharField() +    b_renamed = serializers.CharField(source='b') + +    def transform_a(self, obj, value): +        return value.lower() + +    def transform_b_renamed(self, obj, value): +        if value is not None: +            return 'and ' + value + + +class TestSerializerTransformMethods(TestCase): +    def setUp(self): +        self.s = TransformMethodsSerializer() + +    def test_transform_methods(self): +        self.assertEqual( +            self.s.to_native({'a': 'GREEN EGGS', 'b': 'HAM'}), +            { +                'a': 'green eggs', +                'b_renamed': 'and HAM', +            } +        ) + +    def test_missing_fields(self): +        self.assertEqual( +            self.s.to_native({'a': 'GREEN EGGS'}), +            { +                'a': 'green eggs', +                'b_renamed': None, +            } +        ) + + +class DefaultTrueBooleanModel(models.Model): +    cat = models.BooleanField(default=True) +    dog = models.BooleanField(default=False) + + +class SerializerDefaultTrueBoolean(TestCase): + +    def setUp(self): +        super(SerializerDefaultTrueBoolean, self).setUp() + +        class DefaultTrueBooleanSerializer(serializers.ModelSerializer): +            class Meta: +                model = DefaultTrueBooleanModel +                fields = ('cat', 'dog') + +        self.default_true_boolean_serializer = DefaultTrueBooleanSerializer + +    def test_enabled_as_false(self): +        serializer = self.default_true_boolean_serializer(data={'cat': False, +                                                                'dog': False}) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(serializer.data['cat'], False) +        self.assertEqual(serializer.data['dog'], False) + +    def test_enabled_as_true(self): +        serializer = self.default_true_boolean_serializer(data={'cat': True, +                                                                'dog': True}) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(serializer.data['cat'], True) +        self.assertEqual(serializer.data['dog'], True) + +    def test_enabled_partial(self): +        serializer = self.default_true_boolean_serializer(data={'cat': False}, +                                                          partial=True) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(serializer.data['cat'], False) +        self.assertEqual(serializer.data['dog'], False) + + +class BoolenFieldTypeTest(TestCase): +    ''' +    Ensure the various Boolean based model fields are rendered as the proper +    field type + +    ''' + +    def setUp(self): +        ''' +        Setup an ActionItemSerializer for BooleanTesting +        ''' +        data = { +            'title': 'b' * 201, +        } +        self.serializer = ActionItemSerializer(data=data) + +    def test_booleanfield_type(self): +        ''' +        Test that BooleanField is infered from models.BooleanField +        ''' +        bfield = self.serializer.get_fields()['done'] +        self.assertEqual(type(bfield), fields.BooleanField) + +    def test_nullbooleanfield_type(self): +        ''' +        Test that BooleanField is infered from models.NullBooleanField + +        https://groups.google.com/forum/#!topic/django-rest-framework/D9mXEftpuQ8 +        ''' +        bfield = self.serializer.get_fields()['started'] +        self.assertEqual(type(bfield), fields.BooleanField) diff --git a/tests/test_serializer_bulk_update.py b/tests/test_serializer_bulk_update.py new file mode 100644 index 00000000..8b0ded1a --- /dev/null +++ b/tests/test_serializer_bulk_update.py @@ -0,0 +1,278 @@ +""" +Tests to cover bulk create and update using serializers. +""" +from __future__ import unicode_literals +from django.test import TestCase +from rest_framework import serializers + + +class BulkCreateSerializerTests(TestCase): +    """ +    Creating multiple instances using serializers. +    """ + +    def setUp(self): +        class BookSerializer(serializers.Serializer): +            id = serializers.IntegerField() +            title = serializers.CharField(max_length=100) +            author = serializers.CharField(max_length=100) + +        self.BookSerializer = BookSerializer + +    def test_bulk_create_success(self): +        """ +        Correct bulk update serialization should return the input data. +        """ + +        data = [ +            { +                'id': 0, +                'title': 'The electric kool-aid acid test', +                'author': 'Tom Wolfe' +            }, { +                'id': 1, +                'title': 'If this is a man', +                'author': 'Primo Levi' +            }, { +                'id': 2, +                'title': 'The wind-up bird chronicle', +                'author': 'Haruki Murakami' +            } +        ] + +        serializer = self.BookSerializer(data=data, many=True) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(serializer.object, data) + +    def test_bulk_create_errors(self): +        """ +        Correct bulk update serialization should return the input data. +        """ + +        data = [ +            { +                'id': 0, +                'title': 'The electric kool-aid acid test', +                'author': 'Tom Wolfe' +            }, { +                'id': 1, +                'title': 'If this is a man', +                'author': 'Primo Levi' +            }, { +                'id': 'foo', +                'title': 'The wind-up bird chronicle', +                'author': 'Haruki Murakami' +            } +        ] +        expected_errors = [ +            {}, +            {}, +            {'id': ['Enter a whole number.']} +        ] + +        serializer = self.BookSerializer(data=data, many=True) +        self.assertEqual(serializer.is_valid(), False) +        self.assertEqual(serializer.errors, expected_errors) + +    def test_invalid_list_datatype(self): +        """ +        Data containing list of incorrect data type should return errors. +        """ +        data = ['foo', 'bar', 'baz'] +        serializer = self.BookSerializer(data=data, many=True) +        self.assertEqual(serializer.is_valid(), False) + +        expected_errors = [ +                {'non_field_errors': ['Invalid data']}, +                {'non_field_errors': ['Invalid data']}, +                {'non_field_errors': ['Invalid data']} +        ] + +        self.assertEqual(serializer.errors, expected_errors) + +    def test_invalid_single_datatype(self): +        """ +        Data containing a single incorrect data type should return errors. +        """ +        data = 123 +        serializer = self.BookSerializer(data=data, many=True) +        self.assertEqual(serializer.is_valid(), False) + +        expected_errors = {'non_field_errors': ['Expected a list of items.']} + +        self.assertEqual(serializer.errors, expected_errors) + +    def test_invalid_single_object(self): +        """ +        Data containing only a single object, instead of a list of objects +        should return errors. +        """ +        data = { +            'id': 0, +            'title': 'The electric kool-aid acid test', +            'author': 'Tom Wolfe' +        } +        serializer = self.BookSerializer(data=data, many=True) +        self.assertEqual(serializer.is_valid(), False) + +        expected_errors = {'non_field_errors': ['Expected a list of items.']} + +        self.assertEqual(serializer.errors, expected_errors) + + +class BulkUpdateSerializerTests(TestCase): +    """ +    Updating multiple instances using serializers. +    """ + +    def setUp(self): +        class Book(object): +            """ +            A data type that can be persisted to a mock storage backend +            with `.save()` and `.delete()`. +            """ +            object_map = {} + +            def __init__(self, id, title, author): +                self.id = id +                self.title = title +                self.author = author + +            def save(self): +                Book.object_map[self.id] = self + +            def delete(self): +                del Book.object_map[self.id] + +        class BookSerializer(serializers.Serializer): +            id = serializers.IntegerField() +            title = serializers.CharField(max_length=100) +            author = serializers.CharField(max_length=100) + +            def restore_object(self, attrs, instance=None): +                if instance: +                    instance.id = attrs['id'] +                    instance.title = attrs['title'] +                    instance.author = attrs['author'] +                    return instance +                return Book(**attrs) + +        self.Book = Book +        self.BookSerializer = BookSerializer + +        data = [ +            { +                'id': 0, +                'title': 'The electric kool-aid acid test', +                'author': 'Tom Wolfe' +            }, { +                'id': 1, +                'title': 'If this is a man', +                'author': 'Primo Levi' +            }, { +                'id': 2, +                'title': 'The wind-up bird chronicle', +                'author': 'Haruki Murakami' +            } +        ] + +        for item in data: +            book = Book(item['id'], item['title'], item['author']) +            book.save() + +    def books(self): +        """ +        Return all the objects in the mock storage backend. +        """ +        return self.Book.object_map.values() + +    def test_bulk_update_success(self): +        """ +        Correct bulk update serialization should return the input data. +        """ +        data = [ +            { +                'id': 0, +                'title': 'The electric kool-aid acid test', +                'author': 'Tom Wolfe' +            }, { +                'id': 2, +                'title': 'Kafka on the shore', +                'author': 'Haruki Murakami' +            } +        ] +        serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(serializer.data, data) +        serializer.save() +        new_data = self.BookSerializer(self.books(), many=True).data + +        self.assertEqual(data, new_data) + +    def test_bulk_update_and_create(self): +        """ +        Bulk update serialization may also include created items. +        """ +        data = [ +            { +                'id': 0, +                'title': 'The electric kool-aid acid test', +                'author': 'Tom Wolfe' +            }, { +                'id': 3, +                'title': 'Kafka on the shore', +                'author': 'Haruki Murakami' +            } +        ] +        serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(serializer.data, data) +        serializer.save() +        new_data = self.BookSerializer(self.books(), many=True).data +        self.assertEqual(data, new_data) + +    def test_bulk_update_invalid_create(self): +        """ +        Bulk update serialization without allow_add_remove may not create items. +        """ +        data = [ +            { +                'id': 0, +                'title': 'The electric kool-aid acid test', +                'author': 'Tom Wolfe' +            }, { +                'id': 3, +                'title': 'Kafka on the shore', +                'author': 'Haruki Murakami' +            } +        ] +        expected_errors = [ +            {}, +            {'non_field_errors': ['Cannot create a new item, only existing items may be updated.']} +        ] +        serializer = self.BookSerializer(self.books(), data=data, many=True) +        self.assertEqual(serializer.is_valid(), False) +        self.assertEqual(serializer.errors, expected_errors) + +    def test_bulk_update_error(self): +        """ +        Incorrect bulk update serialization should return error data. +        """ +        data = [ +            { +                'id': 0, +                'title': 'The electric kool-aid acid test', +                'author': 'Tom Wolfe' +            }, { +                'id': 'foo', +                'title': 'Kafka on the shore', +                'author': 'Haruki Murakami' +            } +        ] +        expected_errors = [ +            {}, +            {'id': ['Enter a whole number.']} +        ] +        serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) +        self.assertEqual(serializer.is_valid(), False) +        self.assertEqual(serializer.errors, expected_errors) diff --git a/tests/test_serializer_empty.py b/tests/test_serializer_empty.py new file mode 100644 index 00000000..30cff361 --- /dev/null +++ b/tests/test_serializer_empty.py @@ -0,0 +1,15 @@ +from django.test import TestCase +from rest_framework import serializers + + +class EmptySerializerTestCase(TestCase): +    def test_empty_serializer(self): +        class FooBarSerializer(serializers.Serializer): +            foo = serializers.IntegerField() +            bar = serializers.SerializerMethodField('get_bar') + +            def get_bar(self, obj): +                return 'bar' + +        serializer = FooBarSerializer() +        self.assertEquals(serializer.data, {'foo': 0}) diff --git a/tests/test_serializer_import.py b/tests/test_serializer_import.py new file mode 100644 index 00000000..3b8ff4b3 --- /dev/null +++ b/tests/test_serializer_import.py @@ -0,0 +1,19 @@ +from django.test import TestCase + +from rest_framework import serializers +from tests.accounts.serializers import AccountSerializer + + +class ImportingModelSerializerTests(TestCase): +    """ +    In some situations like, GH #1225, it is possible, especially in +    testing, to import a serializer who's related models have not yet +    been resolved by Django. `AccountSerializer` is an example of such +    a serializer (imported at the top of this file). +    """ +    def test_import_model_serializer(self): +        """ +        The serializer at the top of this file should have been +        imported successfully, and we should be able to instantiate it. +        """ +        self.assertIsInstance(AccountSerializer(), serializers.ModelSerializer) diff --git a/tests/test_serializer_nested.py b/tests/test_serializer_nested.py new file mode 100644 index 00000000..6d69ffbd --- /dev/null +++ b/tests/test_serializer_nested.py @@ -0,0 +1,347 @@ +""" +Tests to cover nested serializers. + +Doesn't cover model serializers. +""" +from __future__ import unicode_literals +from django.test import TestCase +from rest_framework import serializers +from . import models + + +class WritableNestedSerializerBasicTests(TestCase): +    """ +    Tests for deserializing nested entities. +    Basic tests that use serializers that simply restore to dicts. +    """ + +    def setUp(self): +        class TrackSerializer(serializers.Serializer): +            order = serializers.IntegerField() +            title = serializers.CharField(max_length=100) +            duration = serializers.IntegerField() + +        class AlbumSerializer(serializers.Serializer): +            album_name = serializers.CharField(max_length=100) +            artist = serializers.CharField(max_length=100) +            tracks = TrackSerializer(many=True) + +        self.AlbumSerializer = AlbumSerializer + +    def test_nested_validation_success(self): +        """ +        Correct nested serialization should return the input data. +        """ + +        data = { +            'album_name': 'Discovery', +            'artist': 'Daft Punk', +            'tracks': [ +                {'order': 1, 'title': 'One More Time', 'duration': 235}, +                {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, +                {'order': 3, 'title': 'Digital Love', 'duration': 239} +            ] +        } + +        serializer = self.AlbumSerializer(data=data) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(serializer.object, data) + +    def test_nested_validation_error(self): +        """ +        Incorrect nested serialization should return appropriate error data. +        """ + +        data = { +            'album_name': 'Discovery', +            'artist': 'Daft Punk', +            'tracks': [ +                {'order': 1, 'title': 'One More Time', 'duration': 235}, +                {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, +                {'order': 3, 'title': 'Digital Love', 'duration': 'foobar'} +            ] +        } +        expected_errors = { +            'tracks': [ +                {}, +                {}, +                {'duration': ['Enter a whole number.']} +            ] +        } + +        serializer = self.AlbumSerializer(data=data) +        self.assertEqual(serializer.is_valid(), False) +        self.assertEqual(serializer.errors, expected_errors) + +    def test_many_nested_validation_error(self): +        """ +        Incorrect nested serialization should return appropriate error data +        when multiple entities are being deserialized. +        """ + +        data = [ +            { +                'album_name': 'Russian Red', +                'artist': 'I Love Your Glasses', +                'tracks': [ +                    {'order': 1, 'title': 'Cigarettes', 'duration': 121}, +                    {'order': 2, 'title': 'No Past Land', 'duration': 198}, +                    {'order': 3, 'title': 'They Don\'t Believe', 'duration': 191} +                ] +            }, +            { +                'album_name': 'Discovery', +                'artist': 'Daft Punk', +                'tracks': [ +                    {'order': 1, 'title': 'One More Time', 'duration': 235}, +                    {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, +                    {'order': 3, 'title': 'Digital Love', 'duration': 'foobar'} +                ] +            } +        ] +        expected_errors = [ +            {}, +            { +                'tracks': [ +                    {}, +                    {}, +                    {'duration': ['Enter a whole number.']} +                ] +            } +        ] + +        serializer = self.AlbumSerializer(data=data, many=True) +        self.assertEqual(serializer.is_valid(), False) +        self.assertEqual(serializer.errors, expected_errors) + + +class WritableNestedSerializerObjectTests(TestCase): +    """ +    Tests for deserializing nested entities. +    These tests use serializers that restore to concrete objects. +    """ + +    def setUp(self): +        # Couple of concrete objects that we're going to deserialize into +        class Track(object): +            def __init__(self, order, title, duration): +                self.order, self.title, self.duration = order, title, duration + +            def __eq__(self, other): +                return ( +                    self.order == other.order and +                    self.title == other.title and +                    self.duration == other.duration +                ) + +        class Album(object): +            def __init__(self, album_name, artist, tracks): +                self.album_name, self.artist, self.tracks = album_name, artist, tracks + +            def __eq__(self, other): +                return ( +                    self.album_name == other.album_name and +                    self.artist == other.artist and +                    self.tracks == other.tracks +                ) + +        # And their corresponding serializers +        class TrackSerializer(serializers.Serializer): +            order = serializers.IntegerField() +            title = serializers.CharField(max_length=100) +            duration = serializers.IntegerField() + +            def restore_object(self, attrs, instance=None): +                return Track(attrs['order'], attrs['title'], attrs['duration']) + +        class AlbumSerializer(serializers.Serializer): +            album_name = serializers.CharField(max_length=100) +            artist = serializers.CharField(max_length=100) +            tracks = TrackSerializer(many=True) + +            def restore_object(self, attrs, instance=None): +                return Album(attrs['album_name'], attrs['artist'], attrs['tracks']) + +        self.Album, self.Track = Album, Track +        self.AlbumSerializer = AlbumSerializer + +    def test_nested_validation_success(self): +        """ +        Correct nested serialization should return a restored object +        that corresponds to the input data. +        """ + +        data = { +            'album_name': 'Discovery', +            'artist': 'Daft Punk', +            'tracks': [ +                {'order': 1, 'title': 'One More Time', 'duration': 235}, +                {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, +                {'order': 3, 'title': 'Digital Love', 'duration': 239} +            ] +        } +        expected_object = self.Album( +            album_name='Discovery', +            artist='Daft Punk', +            tracks=[ +                self.Track(order=1, title='One More Time', duration=235), +                self.Track(order=2, title='Aerodynamic', duration=184), +                self.Track(order=3, title='Digital Love', duration=239), +            ] +        ) + +        serializer = self.AlbumSerializer(data=data) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(serializer.object, expected_object) + +    def test_many_nested_validation_success(self): +        """ +        Correct nested serialization should return multiple restored objects +        that corresponds to the input data when multiple objects are +        being deserialized. +        """ + +        data = [ +            { +                'album_name': 'Russian Red', +                'artist': 'I Love Your Glasses', +                'tracks': [ +                    {'order': 1, 'title': 'Cigarettes', 'duration': 121}, +                    {'order': 2, 'title': 'No Past Land', 'duration': 198}, +                    {'order': 3, 'title': 'They Don\'t Believe', 'duration': 191} +                ] +            }, +            { +                'album_name': 'Discovery', +                'artist': 'Daft Punk', +                'tracks': [ +                    {'order': 1, 'title': 'One More Time', 'duration': 235}, +                    {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, +                    {'order': 3, 'title': 'Digital Love', 'duration': 239} +                ] +            } +        ] +        expected_object = [ +            self.Album( +                album_name='Russian Red', +                artist='I Love Your Glasses', +                tracks=[ +                    self.Track(order=1, title='Cigarettes', duration=121), +                    self.Track(order=2, title='No Past Land', duration=198), +                    self.Track(order=3, title='They Don\'t Believe', duration=191), +                ] +            ), +            self.Album( +                album_name='Discovery', +                artist='Daft Punk', +                tracks=[ +                    self.Track(order=1, title='One More Time', duration=235), +                    self.Track(order=2, title='Aerodynamic', duration=184), +                    self.Track(order=3, title='Digital Love', duration=239), +                ] +            ) +        ] + +        serializer = self.AlbumSerializer(data=data, many=True) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(serializer.object, expected_object) + + +class ForeignKeyNestedSerializerUpdateTests(TestCase): +    def setUp(self): +        class Artist(object): +            def __init__(self, name): +                self.name = name + +            def __eq__(self, other): +                return self.name == other.name + +        class Album(object): +            def __init__(self, name, artist): +                self.name, self.artist = name, artist + +            def __eq__(self, other): +                return self.name == other.name and self.artist == other.artist + +        class ArtistSerializer(serializers.Serializer): +            name = serializers.CharField() + +            def restore_object(self, attrs, instance=None): +                if instance: +                    instance.name = attrs['name'] +                else: +                    instance = Artist(attrs['name']) +                return instance + +        class AlbumSerializer(serializers.Serializer): +            name = serializers.CharField() +            by = ArtistSerializer(source='artist') + +            def restore_object(self, attrs, instance=None): +                if instance: +                    instance.name = attrs['name'] +                    instance.artist = attrs['artist'] +                else: +                    instance = Album(attrs['name'], attrs['artist']) +                return instance + +        self.Artist = Artist +        self.Album = Album +        self.AlbumSerializer = AlbumSerializer + +    def test_create_via_foreign_key_with_source(self): +        """ +        Check that we can both *create* and *update* into objects across +        ForeignKeys that have a `source` specified. +        Regression test for #1170 +        """ +        data = { +            'name': 'Discovery', +            'by': {'name': 'Daft Punk'}, +        } + +        expected = self.Album(artist=self.Artist('Daft Punk'), name='Discovery') + +        # create +        serializer = self.AlbumSerializer(data=data) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(serializer.object, expected) + +        # update +        original = self.Album(artist=self.Artist('The Bats'), name='Free All the Monsters') +        serializer = self.AlbumSerializer(instance=original, data=data) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(serializer.object, expected) + + +class NestedModelSerializerUpdateTests(TestCase): +    def test_second_nested_level(self): +        john = models.Person.objects.create(name="john") + +        post = john.blogpost_set.create(title="Test blog post") +        post.blogpostcomment_set.create(text="I hate this blog post") +        post.blogpostcomment_set.create(text="I love this blog post") + +        class BlogPostCommentSerializer(serializers.ModelSerializer): +            class Meta: +                model = models.BlogPostComment + +        class BlogPostSerializer(serializers.ModelSerializer): +            comments = BlogPostCommentSerializer(many=True, source='blogpostcomment_set') +            class Meta: +                model = models.BlogPost +                fields = ('id', 'title', 'comments') + +        class PersonSerializer(serializers.ModelSerializer): +            posts = BlogPostSerializer(many=True, source='blogpost_set') +            class Meta: +                model = models.Person +                fields = ('id', 'name', 'age', 'posts') + +        serialize = PersonSerializer(instance=john) +        deserialize = PersonSerializer(data=serialize.data, instance=john) +        self.assertTrue(deserialize.is_valid()) + +        result = deserialize.object +        result.save() +        self.assertEqual(result.id, john.id) diff --git a/tests/test_serializers.py b/tests/test_serializers.py new file mode 100644 index 00000000..67547783 --- /dev/null +++ b/tests/test_serializers.py @@ -0,0 +1,28 @@ +from django.db import models +from django.test import TestCase + +from rest_framework.serializers import _resolve_model +from tests.models import BasicModel + + +class ResolveModelTests(TestCase): +    """ +    `_resolve_model` should return a Django model class given the +    provided argument is a Django model class itself, or a properly +    formatted string representation of one. +    """ +    def test_resolve_django_model(self): +        resolved_model = _resolve_model(BasicModel) +        self.assertEqual(resolved_model, BasicModel) + +    def test_resolve_string_representation(self): +        resolved_model = _resolve_model('tests.BasicModel') +        self.assertEqual(resolved_model, BasicModel) + +    def test_resolve_non_django_model(self): +        with self.assertRaises(ValueError): +            _resolve_model(TestCase) + +    def test_resolve_improper_string_representation(self): +        with self.assertRaises(ValueError): +            _resolve_model('BasicModel') diff --git a/tests/test_settings.py b/tests/test_settings.py new file mode 100644 index 00000000..e29fc34a --- /dev/null +++ b/tests/test_settings.py @@ -0,0 +1,22 @@ +"""Tests for the settings module""" +from __future__ import unicode_literals +from django.test import TestCase + +from rest_framework.settings import APISettings, DEFAULTS, IMPORT_STRINGS + + +class TestSettings(TestCase): +    """Tests relating to the api settings""" + +    def test_non_import_errors(self): +        """Make sure other errors aren't suppressed.""" +        settings = APISettings({'DEFAULT_MODEL_SERIALIZER_CLASS': 'tests.extras.bad_import.ModelSerializer'}, DEFAULTS, IMPORT_STRINGS) +        with self.assertRaises(ValueError): +            settings.DEFAULT_MODEL_SERIALIZER_CLASS + +    def test_import_error_message_maintained(self): +        """Make sure real import errors are captured and raised sensibly.""" +        settings = APISettings({'DEFAULT_MODEL_SERIALIZER_CLASS': 'tests.extras.not_here.ModelSerializer'}, DEFAULTS, IMPORT_STRINGS) +        with self.assertRaises(ImportError) as cm: +            settings.DEFAULT_MODEL_SERIALIZER_CLASS +        self.assertTrue('ImportError' in str(cm.exception)) diff --git a/tests/test_status.py b/tests/test_status.py new file mode 100644 index 00000000..7b1bdae3 --- /dev/null +++ b/tests/test_status.py @@ -0,0 +1,33 @@ +from __future__ import unicode_literals +from django.test import TestCase +from rest_framework.status import ( +    is_informational, is_success, is_redirect, is_client_error, is_server_error +) + + +class TestStatus(TestCase): +    def test_status_categories(self): +        self.assertFalse(is_informational(99)) +        self.assertTrue(is_informational(100)) +        self.assertTrue(is_informational(199)) +        self.assertFalse(is_informational(200)) + +        self.assertFalse(is_success(199)) +        self.assertTrue(is_success(200)) +        self.assertTrue(is_success(299)) +        self.assertFalse(is_success(300)) + +        self.assertFalse(is_redirect(299)) +        self.assertTrue(is_redirect(300)) +        self.assertTrue(is_redirect(399)) +        self.assertFalse(is_redirect(400)) + +        self.assertFalse(is_client_error(399)) +        self.assertTrue(is_client_error(400)) +        self.assertTrue(is_client_error(499)) +        self.assertFalse(is_client_error(500)) + +        self.assertFalse(is_server_error(499)) +        self.assertTrue(is_server_error(500)) +        self.assertTrue(is_server_error(599)) +        self.assertFalse(is_server_error(600))
\ No newline at end of file diff --git a/tests/test_templatetags.py b/tests/test_templatetags.py new file mode 100644 index 00000000..d4da0c23 --- /dev/null +++ b/tests/test_templatetags.py @@ -0,0 +1,51 @@ +# encoding: utf-8 +from __future__ import unicode_literals +from django.test import TestCase +from rest_framework.test import APIRequestFactory +from rest_framework.templatetags.rest_framework import add_query_param, urlize_quoted_links + +factory = APIRequestFactory() + + +class TemplateTagTests(TestCase): + +    def test_add_query_param_with_non_latin_charactor(self): +        # Ensure we don't double-escape non-latin characters +        # that are present in the querystring. +        # See #1314. +        request = factory.get("/", {'q': '查询'}) +        json_url = add_query_param(request, "format", "json") +        self.assertIn("q=%E6%9F%A5%E8%AF%A2", json_url) +        self.assertIn("format=json", json_url) + + +class Issue1386Tests(TestCase): +    """ +    Covers #1386 +    """ + +    def test_issue_1386(self): +        """ +        Test function urlize_quoted_links with different args +        """ +        correct_urls = [ +            "asdf.com", +            "asdf.net", +            "www.as_df.org", +            "as.d8f.ghj8.gov", +        ] +        for i in correct_urls: +            res = urlize_quoted_links(i) +            self.assertNotEqual(res, i) +            self.assertIn(i, res) + +        incorrect_urls = [ +            "mailto://asdf@fdf.com", +            "asdf.netnet", +        ] +        for i in incorrect_urls: +            res = urlize_quoted_links(i) +            self.assertEqual(i, res) + +        # example from issue #1386, this shouldn't raise an exception +        _ = urlize_quoted_links("asdf:[/p]zxcv.com") diff --git a/tests/test_testing.py b/tests/test_testing.py new file mode 100644 index 00000000..bd3e1329 --- /dev/null +++ b/tests/test_testing.py @@ -0,0 +1,164 @@ +# -- coding: utf-8 -- + +from __future__ import unicode_literals +from io import BytesIO + +from django.contrib.auth.models import User +from django.test import TestCase +from rest_framework.compat import patterns, url +from rest_framework.decorators import api_view +from rest_framework.response import Response +from rest_framework.test import APIClient, APIRequestFactory, force_authenticate + + +@api_view(['GET', 'POST']) +def view(request): +    return Response({ +        'auth': request.META.get('HTTP_AUTHORIZATION', b''), +        'user': request.user.username +    }) + + +@api_view(['GET', 'POST']) +def session_view(request): +    active_session = request.session.get('active_session', False) +    request.session['active_session'] = True +    return Response({ +        'active_session': active_session +    }) + + +urlpatterns = patterns('', +    url(r'^view/$', view), +    url(r'^session-view/$', session_view), +) + + +class TestAPITestClient(TestCase): +    urls = 'tests.test_testing' + +    def setUp(self): +        self.client = APIClient() + +    def test_credentials(self): +        """ +        Setting `.credentials()` adds the required headers to each request. +        """ +        self.client.credentials(HTTP_AUTHORIZATION='example') +        for _ in range(0, 3): +            response = self.client.get('/view/') +            self.assertEqual(response.data['auth'], 'example') + +    def test_force_authenticate(self): +        """ +        Setting `.force_authenticate()` forcibly authenticates each request. +        """ +        user = User.objects.create_user('example', 'example@example.com') +        self.client.force_authenticate(user) +        response = self.client.get('/view/') +        self.assertEqual(response.data['user'], 'example') + +    def test_force_authenticate_with_sessions(self): +        """ +        Setting `.force_authenticate()` forcibly authenticates each request. +        """ +        user = User.objects.create_user('example', 'example@example.com') +        self.client.force_authenticate(user) + +        # First request does not yet have an active session +        response = self.client.get('/session-view/') +        self.assertEqual(response.data['active_session'], False) + +        # Subsequant requests have an active session +        response = self.client.get('/session-view/') +        self.assertEqual(response.data['active_session'], True) + +        # Force authenticating as `None` should also logout the user session. +        self.client.force_authenticate(None) +        response = self.client.get('/session-view/') +        self.assertEqual(response.data['active_session'], False) + +    def test_csrf_exempt_by_default(self): +        """ +        By default, the test client is CSRF exempt. +        """ +        User.objects.create_user('example', 'example@example.com', 'password') +        self.client.login(username='example', password='password') +        response = self.client.post('/view/') +        self.assertEqual(response.status_code, 200) + +    def test_explicitly_enforce_csrf_checks(self): +        """ +        The test client can enforce CSRF checks. +        """ +        client = APIClient(enforce_csrf_checks=True) +        User.objects.create_user('example', 'example@example.com', 'password') +        client.login(username='example', password='password') +        response = client.post('/view/') +        expected = {'detail': 'CSRF Failed: CSRF cookie not set.'} +        self.assertEqual(response.status_code, 403) +        self.assertEqual(response.data, expected) + + +class TestAPIRequestFactory(TestCase): +    def test_csrf_exempt_by_default(self): +        """ +        By default, the test client is CSRF exempt. +        """ +        user = User.objects.create_user('example', 'example@example.com', 'password') +        factory = APIRequestFactory() +        request = factory.post('/view/') +        request.user = user +        response = view(request) +        self.assertEqual(response.status_code, 200) + +    def test_explicitly_enforce_csrf_checks(self): +        """ +        The test client can enforce CSRF checks. +        """ +        user = User.objects.create_user('example', 'example@example.com', 'password') +        factory = APIRequestFactory(enforce_csrf_checks=True) +        request = factory.post('/view/') +        request.user = user +        response = view(request) +        expected = {'detail': 'CSRF Failed: CSRF cookie not set.'} +        self.assertEqual(response.status_code, 403) +        self.assertEqual(response.data, expected) + +    def test_invalid_format(self): +        """ +        Attempting to use a format that is not configured will raise an +        assertion error. +        """ +        factory = APIRequestFactory() +        self.assertRaises(AssertionError, factory.post, +            path='/view/', data={'example': 1}, format='xml' +        ) + +    def test_force_authenticate(self): +        """ +        Setting `force_authenticate()` forcibly authenticates the request. +        """ +        user = User.objects.create_user('example', 'example@example.com') +        factory = APIRequestFactory() +        request = factory.get('/view') +        force_authenticate(request, user=user) +        response = view(request) +        self.assertEqual(response.data['user'], 'example') + +    def test_upload_file(self): +        # This is a 1x1 black png +        simple_png = BytesIO(b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x06\x00\x00\x00\x1f\x15\xc4\x89\x00\x00\x00\rIDATx\x9cc````\x00\x00\x00\x05\x00\x01\xa5\xf6E@\x00\x00\x00\x00IEND\xaeB`\x82') +        simple_png.name = 'test.png' +        factory = APIRequestFactory() +        factory.post('/', data={'image': simple_png}) + +    def test_request_factory_url_arguments(self): +        """ +        This is a non regression test against #1461 +        """ +        factory = APIRequestFactory() +        request = factory.get('/view/?demo=test') +        self.assertEqual(dict(request.GET), {'demo': ['test']}) +        request = factory.get('/view/', {'demo': 'test'}) +        self.assertEqual(dict(request.GET), {'demo': ['test']}) diff --git a/tests/test_throttling.py b/tests/test_throttling.py new file mode 100644 index 00000000..41bff692 --- /dev/null +++ b/tests/test_throttling.py @@ -0,0 +1,277 @@ +""" +Tests for the throttling implementations in the permissions module. +""" +from __future__ import unicode_literals +from django.test import TestCase +from django.contrib.auth.models import User +from django.core.cache import cache +from rest_framework.test import APIRequestFactory +from rest_framework.views import APIView +from rest_framework.throttling import BaseThrottle, UserRateThrottle, ScopedRateThrottle +from rest_framework.response import Response + + +class User3SecRateThrottle(UserRateThrottle): +    rate = '3/sec' +    scope = 'seconds' + + +class User3MinRateThrottle(UserRateThrottle): +    rate = '3/min' +    scope = 'minutes' + + +class NonTimeThrottle(BaseThrottle): +    def allow_request(self, request, view): +        if not hasattr(self.__class__, 'called'): +            self.__class__.called = True +            return True +        return False  + + +class MockView(APIView): +    throttle_classes = (User3SecRateThrottle,) + +    def get(self, request): +        return Response('foo') + + +class MockView_MinuteThrottling(APIView): +    throttle_classes = (User3MinRateThrottle,) + +    def get(self, request): +        return Response('foo') + + +class MockView_NonTimeThrottling(APIView): +    throttle_classes = (NonTimeThrottle,) + +    def get(self, request): +        return Response('foo') + + +class ThrottlingTests(TestCase): +    def setUp(self): +        """ +        Reset the cache so that no throttles will be active +        """ +        cache.clear() +        self.factory = APIRequestFactory() + +    def test_requests_are_throttled(self): +        """ +        Ensure request rate is limited +        """ +        request = self.factory.get('/') +        for dummy in range(4): +            response = MockView.as_view()(request) +        self.assertEqual(429, response.status_code) + +    def set_throttle_timer(self, view, value): +        """ +        Explicitly set the timer, overriding time.time() +        """ +        view.throttle_classes[0].timer = lambda self: value + +    def test_request_throttling_expires(self): +        """ +        Ensure request rate is limited for a limited duration only +        """ +        self.set_throttle_timer(MockView, 0) + +        request = self.factory.get('/') +        for dummy in range(4): +            response = MockView.as_view()(request) +        self.assertEqual(429, response.status_code) + +        # Advance the timer by one second +        self.set_throttle_timer(MockView, 1) + +        response = MockView.as_view()(request) +        self.assertEqual(200, response.status_code) + +    def ensure_is_throttled(self, view, expect): +        request = self.factory.get('/') +        request.user = User.objects.create(username='a') +        for dummy in range(3): +            view.as_view()(request) +        request.user = User.objects.create(username='b') +        response = view.as_view()(request) +        self.assertEqual(expect, response.status_code) + +    def test_request_throttling_is_per_user(self): +        """ +        Ensure request rate is only limited per user, not globally for +        PerUserThrottles +        """ +        self.ensure_is_throttled(MockView, 200) + +    def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers): +        """ +        Ensure the response returns an X-Throttle field with status and next attributes +        set properly. +        """ +        request = self.factory.get('/') +        for timer, expect in expected_headers: +            self.set_throttle_timer(view, timer) +            response = view.as_view()(request) +            if expect is not None: +                self.assertEqual(response['X-Throttle-Wait-Seconds'], expect) +            else: +                self.assertFalse('X-Throttle-Wait-Seconds' in response) + +    def test_seconds_fields(self): +        """ +        Ensure for second based throttles. +        """ +        self.ensure_response_header_contains_proper_throttle_field(MockView, +         ((0, None), +          (0, None), +          (0, None), +          (0, '1') +         )) + +    def test_minutes_fields(self): +        """ +        Ensure for minute based throttles. +        """ +        self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling, +         ((0, None), +          (0, None), +          (0, None), +          (0, '60') +         )) + +    def test_next_rate_remains_constant_if_followed(self): +        """ +        If a client follows the recommended next request rate, +        the throttling rate should stay constant. +        """ +        self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling, +         ((0, None), +          (20, None), +          (40, None), +          (60, None), +          (80, None) +         )) + +    def test_non_time_throttle(self): +        """ +        Ensure for second based throttles. +        """ +        request = self.factory.get('/') + +        self.assertFalse(hasattr(MockView_NonTimeThrottling.throttle_classes[0], 'called')) + +        response = MockView_NonTimeThrottling.as_view()(request) +        self.assertFalse('X-Throttle-Wait-Seconds' in response) + +        self.assertTrue(MockView_NonTimeThrottling.throttle_classes[0].called) + +        response = MockView_NonTimeThrottling.as_view()(request) +        self.assertFalse('X-Throttle-Wait-Seconds' in response)  + + +class ScopedRateThrottleTests(TestCase): +    """ +    Tests for ScopedRateThrottle. +    """ + +    def setUp(self): +        class XYScopedRateThrottle(ScopedRateThrottle): +            TIMER_SECONDS = 0 +            THROTTLE_RATES = {'x': '3/min', 'y': '1/min'} +            timer = lambda self: self.TIMER_SECONDS + +        class XView(APIView): +            throttle_classes = (XYScopedRateThrottle,) +            throttle_scope = 'x' + +            def get(self, request): +                return Response('x') + +        class YView(APIView): +            throttle_classes = (XYScopedRateThrottle,) +            throttle_scope = 'y' + +            def get(self, request): +                return Response('y') + +        class UnscopedView(APIView): +            throttle_classes = (XYScopedRateThrottle,) + +            def get(self, request): +                return Response('y') + +        self.throttle_class = XYScopedRateThrottle +        self.factory = APIRequestFactory() +        self.x_view = XView.as_view() +        self.y_view = YView.as_view() +        self.unscoped_view = UnscopedView.as_view() + +    def increment_timer(self, seconds=1): +        self.throttle_class.TIMER_SECONDS += seconds + +    def test_scoped_rate_throttle(self): +        request = self.factory.get('/') + +        # Should be able to hit x view 3 times per minute. +        response = self.x_view(request) +        self.assertEqual(200, response.status_code) + +        self.increment_timer() +        response = self.x_view(request) +        self.assertEqual(200, response.status_code) + +        self.increment_timer() +        response = self.x_view(request) +        self.assertEqual(200, response.status_code) + +        self.increment_timer() +        response = self.x_view(request) +        self.assertEqual(429, response.status_code) + +        # Should be able to hit y view 1 time per minute. +        self.increment_timer() +        response = self.y_view(request) +        self.assertEqual(200, response.status_code) + +        self.increment_timer() +        response = self.y_view(request) +        self.assertEqual(429, response.status_code) + +        # Ensure throttles properly reset by advancing the rest of the minute +        self.increment_timer(55) + +        # Should still be able to hit x view 3 times per minute. +        response = self.x_view(request) +        self.assertEqual(200, response.status_code) + +        self.increment_timer() +        response = self.x_view(request) +        self.assertEqual(200, response.status_code) + +        self.increment_timer() +        response = self.x_view(request) +        self.assertEqual(200, response.status_code) + +        self.increment_timer() +        response = self.x_view(request) +        self.assertEqual(429, response.status_code) + +        # Should still be able to hit y view 1 time per minute. +        self.increment_timer() +        response = self.y_view(request) +        self.assertEqual(200, response.status_code) + +        self.increment_timer() +        response = self.y_view(request) +        self.assertEqual(429, response.status_code) + +    def test_unscoped_view_not_throttled(self): +        request = self.factory.get('/') + +        for idx in range(10): +            self.increment_timer() +            response = self.unscoped_view(request) +            self.assertEqual(200, response.status_code) diff --git a/tests/test_urlizer.py b/tests/test_urlizer.py new file mode 100644 index 00000000..3dc8e8fe --- /dev/null +++ b/tests/test_urlizer.py @@ -0,0 +1,38 @@ +from __future__ import unicode_literals +from django.test import TestCase +from rest_framework.templatetags.rest_framework import urlize_quoted_links +import sys + + +class URLizerTests(TestCase): +    """ +    Test if both JSON and YAML URLs are transformed into links well +    """ +    def _urlize_dict_check(self, data): +        """ +        For all items in dict test assert that the value is urlized key +        """ +        for original, urlized in data.items(): +            assert urlize_quoted_links(original, nofollow=False) == urlized + +    def test_json_with_url(self): +        """ +        Test if JSON URLs are transformed into links well +        """ +        data = {} +        data['"url": "http://api/users/1/", '] = \ +            '"url": "<a href="http://api/users/1/">http://api/users/1/</a>", ' +        data['"foo_set": [\n    "http://api/foos/1/"\n], '] = \ +            '"foo_set": [\n    "<a href="http://api/foos/1/">http://api/foos/1/</a>"\n], ' +        self._urlize_dict_check(data) + +    def test_yaml_with_url(self): +        """ +        Test if YAML URLs are transformed into links well +        """ +        data = {} +        data['''{users: 'http://api/users/'}'''] = \ +            '''{users: '<a href="http://api/users/">http://api/users/</a>'}''' +        data['''foo_set: ['http://api/foos/1/']'''] = \ +            '''foo_set: ['<a href="http://api/foos/1/">http://api/foos/1/</a>']''' +        self._urlize_dict_check(data) diff --git a/tests/test_urlpatterns.py b/tests/test_urlpatterns.py new file mode 100644 index 00000000..8132ec4c --- /dev/null +++ b/tests/test_urlpatterns.py @@ -0,0 +1,76 @@ +from __future__ import unicode_literals +from collections import namedtuple +from django.core import urlresolvers +from django.test import TestCase +from rest_framework.test import APIRequestFactory +from rest_framework.compat import patterns, url, include +from rest_framework.urlpatterns import format_suffix_patterns + + +# A container class for test paths for the test case +URLTestPath = namedtuple('URLTestPath', ['path', 'args', 'kwargs']) + + +def dummy_view(request, *args, **kwargs): +    pass + + +class FormatSuffixTests(TestCase): +    """ +    Tests `format_suffix_patterns` against different URLPatterns to ensure the URLs still resolve properly, including any captured parameters. +    """ +    def _resolve_urlpatterns(self, urlpatterns, test_paths): +        factory = APIRequestFactory() +        try: +            urlpatterns = format_suffix_patterns(urlpatterns) +        except Exception: +            self.fail("Failed to apply `format_suffix_patterns` on  the supplied urlpatterns") +        resolver = urlresolvers.RegexURLResolver(r'^/', urlpatterns) +        for test_path in test_paths: +            request = factory.get(test_path.path) +            try: +                callback, callback_args, callback_kwargs = resolver.resolve(request.path_info) +            except Exception: +                self.fail("Failed to resolve URL: %s" % request.path_info) +            self.assertEqual(callback_args, test_path.args) +            self.assertEqual(callback_kwargs, test_path.kwargs) + +    def test_format_suffix(self): +        urlpatterns = patterns( +            '', +            url(r'^test$', dummy_view), +        ) +        test_paths = [ +            URLTestPath('/test', (), {}), +            URLTestPath('/test.api', (), {'format': 'api'}), +            URLTestPath('/test.asdf', (), {'format': 'asdf'}), +        ] +        self._resolve_urlpatterns(urlpatterns, test_paths) + +    def test_default_args(self): +        urlpatterns = patterns( +            '', +            url(r'^test$', dummy_view, {'foo': 'bar'}), +        ) +        test_paths = [ +            URLTestPath('/test', (), {'foo': 'bar', }), +            URLTestPath('/test.api', (), {'foo': 'bar', 'format': 'api'}), +            URLTestPath('/test.asdf', (), {'foo': 'bar', 'format': 'asdf'}), +        ] +        self._resolve_urlpatterns(urlpatterns, test_paths) + +    def test_included_urls(self): +        nested_patterns = patterns( +            '', +            url(r'^path$', dummy_view) +        ) +        urlpatterns = patterns( +            '', +            url(r'^test/', include(nested_patterns), {'foo': 'bar'}), +        ) +        test_paths = [ +            URLTestPath('/test/path', (), {'foo': 'bar', }), +            URLTestPath('/test/path.api', (), {'foo': 'bar', 'format': 'api'}), +            URLTestPath('/test/path.asdf', (), {'foo': 'bar', 'format': 'asdf'}), +        ] +        self._resolve_urlpatterns(urlpatterns, test_paths) diff --git a/tests/test_validation.py b/tests/test_validation.py new file mode 100644 index 00000000..e13e4078 --- /dev/null +++ b/tests/test_validation.py @@ -0,0 +1,148 @@ +from __future__ import unicode_literals +from django.core.validators import MaxValueValidator +from django.db import models +from django.test import TestCase +from rest_framework import generics, serializers, status +from rest_framework.test import APIRequestFactory + +factory = APIRequestFactory() + + +# Regression for #666 + +class ValidationModel(models.Model): +    blank_validated_field = models.CharField(max_length=255) + + +class ValidationModelSerializer(serializers.ModelSerializer): +    class Meta: +        model = ValidationModel +        fields = ('blank_validated_field',) +        read_only_fields = ('blank_validated_field',) + + +class UpdateValidationModel(generics.RetrieveUpdateDestroyAPIView): +    model = ValidationModel +    serializer_class = ValidationModelSerializer + + +class TestPreSaveValidationExclusions(TestCase): +    def test_pre_save_validation_exclusions(self): +        """ +        Somewhat weird test case to ensure that we don't perform model +        validation on read only fields. +        """ +        obj = ValidationModel.objects.create(blank_validated_field='') +        request = factory.put('/', {}, format='json') +        view = UpdateValidationModel().as_view() +        response = view(request, pk=obj.pk).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) + + +# Regression for #653 + +class ShouldValidateModel(models.Model): +    should_validate_field = models.CharField(max_length=255) + + +class ShouldValidateModelSerializer(serializers.ModelSerializer): +    renamed = serializers.CharField(source='should_validate_field', required=False) + +    def validate_renamed(self, attrs, source): +        value = attrs[source] +        if len(value) < 3: +            raise serializers.ValidationError('Minimum 3 characters.') +        return attrs + +    class Meta: +        model = ShouldValidateModel +        fields = ('renamed',) + + +class TestPreSaveValidationExclusionsSerializer(TestCase): +    def test_renamed_fields_are_model_validated(self): +        """ +        Ensure fields with 'source' applied do get still get model validation. +        """ +        # We've set `required=False` on the serializer, but the model +        # does not have `blank=True`, so this serializer should not validate. +        serializer = ShouldValidateModelSerializer(data={'renamed': ''}) +        self.assertEqual(serializer.is_valid(), False) +        self.assertIn('renamed', serializer.errors) +        self.assertNotIn('should_validate_field', serializer.errors) + + +class TestCustomValidationMethods(TestCase): +    def test_custom_validation_method_is_executed(self): +        serializer = ShouldValidateModelSerializer(data={'renamed': 'fo'}) +        self.assertFalse(serializer.is_valid()) +        self.assertIn('renamed', serializer.errors) + +    def test_custom_validation_method_passing(self): +        serializer = ShouldValidateModelSerializer(data={'renamed': 'foo'}) +        self.assertTrue(serializer.is_valid()) + + +class ValidationSerializer(serializers.Serializer): +    foo = serializers.CharField() + +    def validate_foo(self, attrs, source): +        raise serializers.ValidationError("foo invalid") + +    def validate(self, attrs): +        raise serializers.ValidationError("serializer invalid") + + +class TestAvoidValidation(TestCase): +    """ +    If serializer was initialized with invalid data (None or non dict-like), it +    should avoid validation layer (validate_<field> and validate methods) +    """ +    def test_serializer_errors_has_only_invalid_data_error(self): +        serializer = ValidationSerializer(data='invalid data') +        self.assertFalse(serializer.is_valid()) +        self.assertDictEqual(serializer.errors, +                             {'non_field_errors': ['Invalid data']}) + + +# regression tests for issue: 1493 + +class ValidationMaxValueValidatorModel(models.Model): +    number_value = models.PositiveIntegerField(validators=[MaxValueValidator(100)]) + + +class ValidationMaxValueValidatorModelSerializer(serializers.ModelSerializer): +    class Meta: +        model = ValidationMaxValueValidatorModel + + +class UpdateMaxValueValidationModel(generics.RetrieveUpdateDestroyAPIView): +    model = ValidationMaxValueValidatorModel +    serializer_class = ValidationMaxValueValidatorModelSerializer + + +class TestMaxValueValidatorValidation(TestCase): + +    def test_max_value_validation_serializer_success(self): +        serializer = ValidationMaxValueValidatorModelSerializer(data={'number_value': 99}) +        self.assertTrue(serializer.is_valid()) + +    def test_max_value_validation_serializer_fails(self): +        serializer = ValidationMaxValueValidatorModelSerializer(data={'number_value': 101}) +        self.assertFalse(serializer.is_valid()) +        self.assertDictEqual({'number_value': ['Ensure this value is less than or equal to 100.']}, serializer.errors) + +    def test_max_value_validation_success(self): +        obj = ValidationMaxValueValidatorModel.objects.create(number_value=100) +        request = factory.patch('/{0}'.format(obj.pk), {'number_value': 98}, format='json') +        view = UpdateMaxValueValidationModel().as_view() +        response = view(request, pk=obj.pk).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) + +    def test_max_value_validation_fail(self): +        obj = ValidationMaxValueValidatorModel.objects.create(number_value=100) +        request = factory.patch('/{0}'.format(obj.pk), {'number_value': 101}, format='json') +        view = UpdateMaxValueValidationModel().as_view() +        response = view(request, pk=obj.pk).render() +        self.assertEqual(response.content, b'{"number_value": ["Ensure this value is less than or equal to 100."]}') +        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) diff --git a/tests/test_views.py b/tests/test_views.py new file mode 100644 index 00000000..65c7e50e --- /dev/null +++ b/tests/test_views.py @@ -0,0 +1,142 @@ +from __future__ import unicode_literals + +import copy +from django.test import TestCase +from rest_framework import status +from rest_framework.decorators import api_view +from rest_framework.response import Response +from rest_framework.settings import api_settings +from rest_framework.test import APIRequestFactory +from rest_framework.views import APIView + +factory = APIRequestFactory() + + +class BasicView(APIView): +    def get(self, request, *args, **kwargs): +        return Response({'method': 'GET'}) + +    def post(self, request, *args, **kwargs): +        return Response({'method': 'POST', 'data': request.DATA}) + + +@api_view(['GET', 'POST', 'PUT', 'PATCH']) +def basic_view(request): +    if request.method == 'GET': +        return {'method': 'GET'} +    elif request.method == 'POST': +        return {'method': 'POST', 'data': request.DATA} +    elif request.method == 'PUT': +        return {'method': 'PUT', 'data': request.DATA} +    elif request.method == 'PATCH': +        return {'method': 'PATCH', 'data': request.DATA} + + +class ErrorView(APIView): +    def get(self, request, *args, **kwargs): +        raise Exception + + +@api_view(['GET']) +def error_view(request): +    raise Exception + + +def sanitise_json_error(error_dict): +    """ +    Exact contents of JSON error messages depend on the installed version +    of json. +    """ +    ret = copy.copy(error_dict) +    chop = len('JSON parse error - No JSON object could be decoded') +    ret['detail'] = ret['detail'][:chop] +    return ret + + +class ClassBasedViewIntegrationTests(TestCase): +    def setUp(self): +        self.view = BasicView.as_view() + +    def test_400_parse_error(self): +        request = factory.post('/', 'f00bar', content_type='application/json') +        response = self.view(request) +        expected = { +            'detail': 'JSON parse error - No JSON object could be decoded' +        } +        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) +        self.assertEqual(sanitise_json_error(response.data), expected) + +    def test_400_parse_error_tunneled_content(self): +        content = 'f00bar' +        content_type = 'application/json' +        form_data = { +            api_settings.FORM_CONTENT_OVERRIDE: content, +            api_settings.FORM_CONTENTTYPE_OVERRIDE: content_type +        } +        request = factory.post('/', form_data) +        response = self.view(request) +        expected = { +            'detail': 'JSON parse error - No JSON object could be decoded' +        } +        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) +        self.assertEqual(sanitise_json_error(response.data), expected) + + +class FunctionBasedViewIntegrationTests(TestCase): +    def setUp(self): +        self.view = basic_view + +    def test_400_parse_error(self): +        request = factory.post('/', 'f00bar', content_type='application/json') +        response = self.view(request) +        expected = { +            'detail': 'JSON parse error - No JSON object could be decoded' +        } +        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) +        self.assertEqual(sanitise_json_error(response.data), expected) + +    def test_400_parse_error_tunneled_content(self): +        content = 'f00bar' +        content_type = 'application/json' +        form_data = { +            api_settings.FORM_CONTENT_OVERRIDE: content, +            api_settings.FORM_CONTENTTYPE_OVERRIDE: content_type +        } +        request = factory.post('/', form_data) +        response = self.view(request) +        expected = { +            'detail': 'JSON parse error - No JSON object could be decoded' +        } +        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) +        self.assertEqual(sanitise_json_error(response.data), expected) + + +class TestCustomExceptionHandler(TestCase): +    def setUp(self): +        self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER + +        def exception_handler(exc): +            return Response('Error!', status=status.HTTP_400_BAD_REQUEST) + +        api_settings.EXCEPTION_HANDLER = exception_handler + +    def tearDown(self): +        api_settings.EXCEPTION_HANDLER = self.DEFAULT_HANDLER + +    def test_class_based_view_exception_handler(self): +        view = ErrorView.as_view() + +        request = factory.get('/', content_type='application/json') +        response = view(request) +        expected = 'Error!' +        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) +        self.assertEqual(response.data, expected) + +    def test_function_based_view_exception_handler(self): +        view = error_view + +        request = factory.get('/', content_type='application/json') +        response = view(request) +        expected = 'Error!' +        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) +        self.assertEqual(response.data, expected) diff --git a/tests/test_write_only_fields.py b/tests/test_write_only_fields.py new file mode 100644 index 00000000..aabb18d6 --- /dev/null +++ b/tests/test_write_only_fields.py @@ -0,0 +1,42 @@ +from django.db import models +from django.test import TestCase +from rest_framework import serializers + + +class ExampleModel(models.Model): +    email = models.EmailField(max_length=100) +    password = models.CharField(max_length=100) + + +class WriteOnlyFieldTests(TestCase): +    def test_write_only_fields(self): +        class ExampleSerializer(serializers.Serializer): +            email = serializers.EmailField() +            password = serializers.CharField(write_only=True) + +        data = { +            'email': 'foo@example.com', +            'password': '123' +        } +        serializer = ExampleSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        self.assertEquals(serializer.object, data) +        self.assertEquals(serializer.data, {'email': 'foo@example.com'}) + +    def test_write_only_fields_meta(self): +        class ExampleSerializer(serializers.ModelSerializer): +            class Meta: +                model = ExampleModel +                fields = ('email', 'password') +                write_only_fields = ('password',) + +        data = { +            'email': 'foo@example.com', +            'password': '123' +        } +        serializer = ExampleSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        self.assertTrue(isinstance(serializer.object, ExampleModel)) +        self.assertEquals(serializer.object.email, data['email']) +        self.assertEquals(serializer.object.password, data['password']) +        self.assertEquals(serializer.data, {'email': 'foo@example.com'}) diff --git a/tests/urls.py b/tests/urls.py new file mode 100644 index 00000000..62cad339 --- /dev/null +++ b/tests/urls.py @@ -0,0 +1,6 @@ +""" +Blank URLConf just to keep the test suite happy +""" +from rest_framework.compat import patterns + +urlpatterns = patterns('') diff --git a/tests/users/__init__.py b/tests/users/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/tests/users/__init__.py diff --git a/tests/users/models.py b/tests/users/models.py new file mode 100644 index 00000000..128bac90 --- /dev/null +++ b/tests/users/models.py @@ -0,0 +1,6 @@ +from django.db import models + + +class User(models.Model): +    account = models.ForeignKey('accounts.Account', blank=True, null=True, related_name='users') +    active_record = models.ForeignKey('records.Record', blank=True, null=True) diff --git a/tests/users/serializers.py b/tests/users/serializers.py new file mode 100644 index 00000000..4893ddb3 --- /dev/null +++ b/tests/users/serializers.py @@ -0,0 +1,8 @@ +from rest_framework import serializers + +from tests.users.models import User + + +class UserSerializer(serializers.ModelSerializer): +    class Meta: +        model = User diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 00000000..a8f2eb0b --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,25 @@ +from contextlib import contextmanager +from rest_framework.compat import six +from rest_framework.settings import api_settings + + +@contextmanager +def temporary_setting(setting, value, module=None): +    """ +    Temporarily change value of setting for test. + +    Optionally reload given module, useful when module uses value of setting on +    import. +    """ +    original_value = getattr(api_settings, setting) +    setattr(api_settings, setting, value) + +    if module is not None: +        six.moves.reload_module(module) + +    yield + +    setattr(api_settings, setting, original_value) + +    if module is not None: +        six.moves.reload_module(module) diff --git a/tests/views.py b/tests/views.py new file mode 100644 index 00000000..55935e92 --- /dev/null +++ b/tests/views.py @@ -0,0 +1,8 @@ +from rest_framework import generics +from .models import NullableForeignKeySource +from .serializers import NullableFKSourceSerializer + + +class NullableFKSourceDetail(generics.RetrieveUpdateDestroyAPIView): +    model = NullableForeignKeySource +    model_serializer_class = NullableFKSourceSerializer | 
