diff options
Diffstat (limited to 'tests')
59 files changed, 3365 insertions, 5641 deletions
diff --git a/tests/accounts/__init__.py b/tests/accounts/__init__.py deleted file mode 100644 index e69de29b..00000000 --- a/tests/accounts/__init__.py +++ /dev/null diff --git a/tests/accounts/models.py b/tests/accounts/models.py deleted file mode 100644 index 3bf4a0c3..00000000 --- a/tests/accounts/models.py +++ /dev/null @@ -1,8 +0,0 @@ -from django.db import models - -from tests.users.models import User - - -class Account(models.Model): - owner = models.ForeignKey(User, related_name='accounts_owned') - admins = models.ManyToManyField(User, blank=True, null=True, related_name='accounts_administered') diff --git a/tests/accounts/serializers.py b/tests/accounts/serializers.py deleted file mode 100644 index 57a91b92..00000000 --- a/tests/accounts/serializers.py +++ /dev/null @@ -1,11 +0,0 @@ -from rest_framework import serializers - -from tests.accounts.models import Account -from tests.users.serializers import UserSerializer - - -class AccountSerializer(serializers.ModelSerializer): - admins = UserSerializer(many=True) - - class Meta: - model = Account diff --git a/tests/conftest.py b/tests/conftest.py index 4b33e19c..31142eaf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -33,9 +33,6 @@ def pytest_configure(): 'rest_framework', 'rest_framework.authtoken', 'tests', - 'tests.accounts', - 'tests.records', - 'tests.users', ), PASSWORD_HASHERS=( 'django.contrib.auth.hashers.SHA1PasswordHasher', diff --git a/tests/extras/__init__.py b/tests/extras/__init__.py deleted file mode 100644 index e69de29b..00000000 --- a/tests/extras/__init__.py +++ /dev/null diff --git a/tests/extras/bad_import.py b/tests/extras/bad_import.py deleted file mode 100644 index 68263d94..00000000 --- a/tests/extras/bad_import.py +++ /dev/null @@ -1 +0,0 @@ -raise ValueError diff --git a/tests/models.py b/tests/models.py index fe064b46..456b0a0b 100644 --- a/tests/models.py +++ b/tests/models.py @@ -1,65 +1,22 @@ from __future__ import unicode_literals from django.db import models from django.utils.translation import ugettext_lazy as _ -from rest_framework import serializers - - -def foobar(): - return 'foobar' - - -class CustomField(models.CharField): - - def __init__(self, *args, **kwargs): - kwargs['max_length'] = 12 - super(CustomField, self).__init__(*args, **kwargs) class RESTFrameworkModel(models.Model): """ Base for test models that sets app_label, so they play nicely. """ + class Meta: app_label = 'tests' abstract = True -class HasPositiveIntegerAsChoice(RESTFrameworkModel): - some_choices = ((1, 'A'), (2, 'B'), (3, 'C')) - some_integer = models.PositiveIntegerField(choices=some_choices) - - -class Anchor(RESTFrameworkModel): - text = models.CharField(max_length=100, default='anchor') - - class BasicModel(RESTFrameworkModel): text = models.CharField(max_length=100, verbose_name=_("Text comes here"), help_text=_("Text description.")) -class SlugBasedModel(RESTFrameworkModel): - text = models.CharField(max_length=100) - slug = models.SlugField(max_length=32) - - -class DefaultValueModel(RESTFrameworkModel): - text = models.CharField(default='foobar', max_length=100) - extra = models.CharField(blank=True, null=True, max_length=100) - - -class CallableDefaultValueModel(RESTFrameworkModel): - text = models.CharField(default=foobar, max_length=100) - - -class ManyToManyModel(RESTFrameworkModel): - rel = models.ManyToManyField(Anchor, help_text='Some help text.') - - -class ReadOnlyManyToManyModel(RESTFrameworkModel): - text = models.CharField(max_length=100, default='anchor') - rel = models.ManyToManyField(Anchor) - - class BaseFilterableItem(RESTFrameworkModel): text = models.CharField(max_length=100) @@ -72,73 +29,6 @@ class FilterableItem(BaseFilterableItem): date = models.DateField() -# Model for regression test for #285 - -class Comment(RESTFrameworkModel): - email = models.EmailField() - content = models.CharField(max_length=200) - created = models.DateTimeField(auto_now_add=True) - - -class ActionItem(RESTFrameworkModel): - title = models.CharField(max_length=200) - started = models.NullBooleanField(default=False) - done = models.BooleanField(default=False) - info = CustomField(default='---', max_length=12) - - -# Models for reverse relations -class Person(RESTFrameworkModel): - name = models.CharField(max_length=10) - age = models.IntegerField(null=True, blank=True) - - @property - def info(self): - return { - 'name': self.name, - 'age': self.age, - } - - -class BlogPost(RESTFrameworkModel): - title = models.CharField(max_length=100) - writer = models.ForeignKey(Person, null=True, blank=True) - - def get_first_comment(self): - return self.blogpostcomment_set.all()[0] - - -class BlogPostComment(RESTFrameworkModel): - text = models.TextField() - blog_post = models.ForeignKey(BlogPost) - - -class Album(RESTFrameworkModel): - title = models.CharField(max_length=100, unique=True) - ref = models.CharField(max_length=10, unique=True, null=True, blank=True) - - -class Photo(RESTFrameworkModel): - description = models.TextField() - album = models.ForeignKey(Album) - - -# Model for issue #324 -class BlankFieldModel(RESTFrameworkModel): - title = models.CharField(max_length=100, blank=True, null=False, - default="title") - - -# Model for issue #380 -class OptionalRelationModel(RESTFrameworkModel): - other = models.ForeignKey('OptionalRelationModel', blank=True, null=True) - - -# Model for RegexField -class Book(RESTFrameworkModel): - isbn = models.CharField(max_length=13) - - # Models for relations tests # ManyToMany class ManyToManyTarget(RESTFrameworkModel): @@ -178,9 +68,3 @@ class NullableOneToOneSource(RESTFrameworkModel): name = models.CharField(max_length=100) target = models.OneToOneField(OneToOneTarget, null=True, blank=True, related_name='nullable_source') - - -# Serializer used to test BasicModel -class BasicModelSerializer(serializers.ModelSerializer): - class Meta: - model = BasicModel diff --git a/tests/records/__init__.py b/tests/records/__init__.py deleted file mode 100644 index e69de29b..00000000 --- a/tests/records/__init__.py +++ /dev/null diff --git a/tests/records/models.py b/tests/records/models.py deleted file mode 100644 index 76954807..00000000 --- a/tests/records/models.py +++ /dev/null @@ -1,6 +0,0 @@ -from django.db import models - - -class Record(models.Model): - account = models.ForeignKey('accounts.Account', blank=True, null=True) - owner = models.ForeignKey('users.User', blank=True, null=True) diff --git a/tests/serializers.py b/tests/serializers.py deleted file mode 100644 index be7b3772..00000000 --- a/tests/serializers.py +++ /dev/null @@ -1,7 +0,0 @@ -from rest_framework import serializers -from tests.models import NullableForeignKeySource - - -class NullableFKSourceSerializer(serializers.ModelSerializer): - class Meta: - model = NullableForeignKeySource diff --git a/tests/settings.py b/tests/settings.py deleted file mode 100644 index 91c9ed09..00000000 --- a/tests/settings.py +++ /dev/null @@ -1,165 +0,0 @@ -# Django settings for testproject project. - -DEBUG = True -TEMPLATE_DEBUG = DEBUG -DEBUG_PROPAGATE_EXCEPTIONS = True - -ALLOWED_HOSTS = ['*'] - -ADMINS = ( - # ('Your Name', 'your_email@domain.com'), -) - -MANAGERS = ADMINS - -DATABASES = { - 'default': { - 'ENGINE': 'django.db.backends.sqlite3', # Add 'postgresql_psycopg2', 'postgresql', 'mysql', 'sqlite3' or 'oracle'. - 'NAME': 'sqlite.db', # Or path to database file if using sqlite3. - 'USER': '', # Not used with sqlite3. - 'PASSWORD': '', # Not used with sqlite3. - 'HOST': '', # Set to empty string for localhost. Not used with sqlite3. - 'PORT': '', # Set to empty string for default. Not used with sqlite3. - } -} - -CACHES = { - 'default': { - 'BACKEND': 'django.core.cache.backends.locmem.LocMemCache', - } -} - -# Local time zone for this installation. Choices can be found here: -# http://en.wikipedia.org/wiki/List_of_tz_zones_by_name -# although not all choices may be available on all operating systems. -# On Unix systems, a value of None will cause Django to use the same -# timezone as the operating system. -# If running in a Windows environment this must be set to the same as your -# system time zone. -TIME_ZONE = 'Europe/London' - -# Language code for this installation. All choices can be found here: -# http://www.i18nguy.com/unicode/language-identifiers.html -LANGUAGE_CODE = 'en-uk' - -SITE_ID = 1 - -# If you set this to False, Django will make some optimizations so as not -# to load the internationalization machinery. -USE_I18N = True - -# If you set this to False, Django will not format dates, numbers and -# calendars according to the current locale -USE_L10N = True - -# Absolute filesystem path to the directory that will hold user-uploaded files. -# Example: "/home/media/media.lawrence.com/" -MEDIA_ROOT = '' - -# URL that handles the media served from MEDIA_ROOT. Make sure to use a -# trailing slash if there is a path component (optional in other cases). -# Examples: "http://media.lawrence.com", "http://example.com/media/" -MEDIA_URL = '' - -# Make this unique, and don't share it with anybody. -SECRET_KEY = 'u@x-aj9(hoh#rb-^ymf#g2jx_hp0vj7u5#b@ag1n^seu9e!%cy' - -# List of callables that know how to import templates from various sources. -TEMPLATE_LOADERS = ( - 'django.template.loaders.filesystem.Loader', - 'django.template.loaders.app_directories.Loader', -) - -MIDDLEWARE_CLASSES = ( - 'django.middleware.common.CommonMiddleware', - 'django.contrib.sessions.middleware.SessionMiddleware', - 'django.middleware.csrf.CsrfViewMiddleware', - 'django.contrib.auth.middleware.AuthenticationMiddleware', - 'django.contrib.messages.middleware.MessageMiddleware', -) - -ROOT_URLCONF = 'tests.urls' - -TEMPLATE_DIRS = ( - # Put strings here, like "/home/html/django_templates" or "C:/www/django/templates". - # Always use forward slashes, even on Windows. - # Don't forget to use absolute paths, not relative paths. -) - -INSTALLED_APPS = ( - 'django.contrib.auth', - 'django.contrib.contenttypes', - 'django.contrib.sessions', - 'django.contrib.sites', - 'django.contrib.messages', - 'django.contrib.staticfiles', - 'rest_framework', - 'rest_framework.authtoken', - 'tests', - 'tests.accounts', - 'tests.records', - 'tests.users', -) - -# OAuth is optional and won't work if there is no oauth_provider & oauth2 -try: - import oauth_provider # NOQA - import oauth2 # NOQA -except ImportError: - pass -else: - INSTALLED_APPS += ( - 'oauth_provider', - ) - -try: - import provider # NOQA -except ImportError: - pass -else: - INSTALLED_APPS += ( - 'provider', - 'provider.oauth2', - ) - -# guardian is optional -try: - import guardian # NOQA -except ImportError: - pass -else: - ANONYMOUS_USER_ID = -1 - AUTHENTICATION_BACKENDS = ( - 'django.contrib.auth.backends.ModelBackend', # default - 'guardian.backends.ObjectPermissionBackend', - ) - INSTALLED_APPS += ( - 'guardian', - ) - -STATIC_URL = '/static/' - -PASSWORD_HASHERS = ( - 'django.contrib.auth.hashers.SHA1PasswordHasher', - 'django.contrib.auth.hashers.PBKDF2PasswordHasher', - 'django.contrib.auth.hashers.PBKDF2SHA1PasswordHasher', - 'django.contrib.auth.hashers.BCryptPasswordHasher', - 'django.contrib.auth.hashers.MD5PasswordHasher', - 'django.contrib.auth.hashers.CryptPasswordHasher', -) - -AUTH_USER_MODEL = 'auth.User' - -import django - -if django.VERSION < (1, 3): - INSTALLED_APPS += ('staticfiles',) - - -# If we're running on the Jenkins server we want to archive the coverage reports as XML. -import os -if os.environ.get('HUDSON_URL', None): - TEST_RUNNER = 'xmlrunner.extra.djangotestrunner.XMLTestRunner' - TEST_OUTPUT_VERBOSE = True - TEST_OUTPUT_DESCRIPTIONS = True - TEST_OUTPUT_DIR = 'xmlrunner' diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 32041f9c..44837c4e 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -142,7 +142,7 @@ class SessionAuthTests(TestCase): cf. [#1810](https://github.com/tomchristie/django-rest-framework/pull/1810) """ response = self.csrf_client.get('/auth/login/') - self.assertContains(response, '<Label class="span4">Username:</label>') + self.assertContains(response, '<label for="id_username">Username:</label>') def test_post_form_session_auth_failing_csrf(self): """ diff --git a/tests/test_bound_fields.py b/tests/test_bound_fields.py new file mode 100644 index 00000000..bfc54b23 --- /dev/null +++ b/tests/test_bound_fields.py @@ -0,0 +1,69 @@ +from rest_framework import serializers + + +class TestSimpleBoundField: + def test_empty_bound_field(self): + class ExampleSerializer(serializers.Serializer): + text = serializers.CharField(max_length=100) + amount = serializers.IntegerField() + + serializer = ExampleSerializer() + + assert serializer['text'].value == '' + assert serializer['text'].errors is None + assert serializer['text'].name == 'text' + assert serializer['amount'].value is None + assert serializer['amount'].errors is None + assert serializer['amount'].name == 'amount' + + def test_populated_bound_field(self): + class ExampleSerializer(serializers.Serializer): + text = serializers.CharField(max_length=100) + amount = serializers.IntegerField() + + serializer = ExampleSerializer(data={'text': 'abc', 'amount': 123}) + assert serializer.is_valid() + assert serializer['text'].value == 'abc' + assert serializer['text'].errors is None + assert serializer['text'].name == 'text' + assert serializer['amount'].value is 123 + assert serializer['amount'].errors is None + assert serializer['amount'].name == 'amount' + + def test_error_bound_field(self): + class ExampleSerializer(serializers.Serializer): + text = serializers.CharField(max_length=100) + amount = serializers.IntegerField() + + serializer = ExampleSerializer(data={'text': 'x' * 1000, 'amount': 123}) + serializer.is_valid() + + assert serializer['text'].value == 'x' * 1000 + assert serializer['text'].errors == ['Ensure this field has no more than 100 characters.'] + assert serializer['text'].name == 'text' + assert serializer['amount'].value is 123 + assert serializer['amount'].errors is None + assert serializer['amount'].name == 'amount' + + +class TestNestedBoundField: + def test_nested_empty_bound_field(self): + class Nested(serializers.Serializer): + more_text = serializers.CharField(max_length=100) + amount = serializers.IntegerField() + + class ExampleSerializer(serializers.Serializer): + text = serializers.CharField(max_length=100) + nested = Nested() + + serializer = ExampleSerializer() + + assert serializer['text'].value == '' + assert serializer['text'].errors is None + assert serializer['text'].name == 'text' + assert serializer['nested']['more_text'].value == '' + assert serializer['nested']['more_text'].errors is None + assert serializer['nested']['more_text'].name == 'nested.more_text' + assert serializer['nested']['amount'].value is None + assert serializer['nested']['amount'].errors is None + assert serializer['nested']['amount'].name == 'nested.amount' diff --git a/tests/test_description.py b/tests/test_description.py index 0675d209..78ce2350 100644 --- a/tests/test_description.py +++ b/tests/test_description.py @@ -2,7 +2,8 @@ from __future__ import unicode_literals from django.test import TestCase -from rest_framework.compat import apply_markdown, smart_text +from django.utils.encoding import python_2_unicode_compatible, smart_text +from rest_framework.compat import apply_markdown from rest_framework.views import APIView from .description import ViewWithNonASCIICharactersInDocstring from .description import UTF8_TEST_DOCSTRING @@ -107,6 +108,7 @@ class TestViewNamesAndDescriptions(TestCase): """ # use a mock object instead of gettext_lazy to ensure that we can't end # up with a test case string in our l10n catalog + @python_2_unicode_compatible class MockLazyStr(object): def __init__(self, string): self.s = string @@ -114,9 +116,6 @@ class TestViewNamesAndDescriptions(TestCase): def __str__(self): return self.s - def __unicode__(self): - return self.s - class MockView(APIView): __doc__ = MockLazyStr("a gettext string") diff --git a/tests/test_fields.py b/tests/test_fields.py index 0ddbe48b..04c721d3 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -1,1034 +1,1085 @@ -""" -General serializer field tests. -""" -from __future__ import unicode_literals - -import datetime -import re from decimal import Decimal -from uuid import uuid4 -from django.core import validators -from django.db import models -from django.test import TestCase -from django.utils.datastructures import SortedDict +from django.utils import timezone from rest_framework import serializers -from tests.models import RESTFrameworkModel - - -class TimestampedModel(models.Model): - added = models.DateTimeField(auto_now_add=True) - updated = models.DateTimeField(auto_now=True) - - -class CharPrimaryKeyModel(models.Model): - id = models.CharField(max_length=20, primary_key=True) - - -class TimestampedModelSerializer(serializers.ModelSerializer): - class Meta: - model = TimestampedModel - - -class CharPrimaryKeyModelSerializer(serializers.ModelSerializer): - class Meta: - model = CharPrimaryKeyModel - - -class TimeFieldModel(models.Model): - clock = models.TimeField() - - -class TimeFieldModelSerializer(serializers.ModelSerializer): - class Meta: - model = TimeFieldModel - - -SAMPLE_CHOICES = [ - ('red', 'Red'), - ('green', 'Green'), - ('blue', 'Blue'), -] - - -class ChoiceFieldModel(models.Model): - choice = models.CharField(choices=SAMPLE_CHOICES, blank=True, max_length=255) - - -class ChoiceFieldModelSerializer(serializers.ModelSerializer): - class Meta: - model = ChoiceFieldModel - - -class ChoiceFieldModelWithNull(models.Model): - choice = models.CharField(choices=SAMPLE_CHOICES, blank=True, null=True, max_length=255) - - -class ChoiceFieldModelWithNullSerializer(serializers.ModelSerializer): - class Meta: - model = ChoiceFieldModelWithNull - - -class BasicFieldTests(TestCase): - def test_auto_now_fields_read_only(self): - """ - auto_now and auto_now_add fields should be read_only by default. - """ - serializer = TimestampedModelSerializer() - self.assertEqual(serializer.fields['added'].read_only, True) +import datetime +import django +import pytest - def test_auto_pk_fields_read_only(self): - """ - AutoField fields should be read_only by default. - """ - serializer = TimestampedModelSerializer() - self.assertEqual(serializer.fields['id'].read_only, True) - def test_non_auto_pk_fields_not_read_only(self): - """ - PK fields other than AutoField fields should not be read_only by default. - """ - serializer = CharPrimaryKeyModelSerializer() - self.assertEqual(serializer.fields['id'].read_only, False) +# Tests for field keyword arguments and core functionality. +# --------------------------------------------------------- - def test_dict_field_ordering(self): - """ - Field should preserve dictionary ordering, if it exists. - See: https://github.com/tomchristie/django-rest-framework/issues/832 - """ - ret = SortedDict() - ret['c'] = 1 - ret['b'] = 1 - ret['a'] = 1 - ret['z'] = 1 - field = serializers.Field() - keys = list(field.to_native(ret).keys()) - self.assertEqual(keys, ['c', 'b', 'a', 'z']) - - def test_widget_html_attributes(self): - """ - Make sure widget_html() renders the correct attributes - """ - r = re.compile('(\S+)=["\']?((?:.(?!["\']?\s+(?:\S+)=|[>"\']))+.)["\']?') - form = TimeFieldModelSerializer().data - attributes = r.findall(form.fields['clock'].widget_html()) - self.assertIn(('name', 'clock'), attributes) - self.assertIn(('id', 'clock'), attributes) - - -class DateFieldTest(TestCase): +class TestEmpty: """ - Tests for the DateFieldTest from_native() and to_native() behavior + Tests for `required`, `allow_null`, `allow_blank`, `default`. """ - - def test_from_native_string(self): - """ - Make sure from_native() accepts default iso input formats. - """ - f = serializers.DateField() - result_1 = f.from_native('1984-07-31') - - self.assertEqual(datetime.date(1984, 7, 31), result_1) - - def test_from_native_datetime_date(self): - """ - Make sure from_native() accepts a datetime.date instance. - """ - f = serializers.DateField() - result_1 = f.from_native(datetime.date(1984, 7, 31)) - - self.assertEqual(result_1, datetime.date(1984, 7, 31)) - - def test_from_native_custom_format(self): - """ - Make sure from_native() accepts custom input formats. - """ - f = serializers.DateField(input_formats=['%Y -- %d']) - result = f.from_native('1984 -- 31') - - self.assertEqual(datetime.date(1984, 1, 31), result) - - def test_from_native_invalid_default_on_custom_format(self): - """ - Make sure from_native() don't accept default formats if custom format is preset - """ - f = serializers.DateField(input_formats=['%Y -- %d']) - - try: - f.from_native('1984-07-31') - except validators.ValidationError as e: - self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY -- DD"]) - else: - self.fail("ValidationError was not properly raised") - - def test_from_native_empty(self): + def test_required(self): """ - Make sure from_native() returns None on empty param. + By default a field must be included in the input. """ - f = serializers.DateField() - result = f.from_native('') + field = serializers.IntegerField() + with pytest.raises(serializers.ValidationError) as exc_info: + field.run_validation() + assert exc_info.value.detail == ['This field is required.'] - self.assertEqual(result, None) - - def test_from_native_none(self): + def test_not_required(self): """ - Make sure from_native() returns None on None param. + If `required=False` then a field may be omitted from the input. """ - f = serializers.DateField() - result = f.from_native(None) - - self.assertEqual(result, None) + field = serializers.IntegerField(required=False) + with pytest.raises(serializers.SkipField): + field.run_validation() - def test_from_native_invalid_date(self): + def test_disallow_null(self): """ - Make sure from_native() raises a ValidationError on passing an invalid date. + By default `None` is not a valid input. """ - f = serializers.DateField() + field = serializers.IntegerField() + with pytest.raises(serializers.ValidationError) as exc_info: + field.run_validation(None) + assert exc_info.value.detail == ['This field may not be null.'] - try: - f.from_native('1984-13-31') - except validators.ValidationError as e: - self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]"]) - else: - self.fail("ValidationError was not properly raised") - - def test_from_native_invalid_format(self): + def test_allow_null(self): """ - Make sure from_native() raises a ValidationError on passing an invalid format. + If `allow_null=True` then `None` is a valid input. """ - f = serializers.DateField() - - try: - f.from_native('1984 -- 31') - except validators.ValidationError as e: - self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]"]) - else: - self.fail("ValidationError was not properly raised") + field = serializers.IntegerField(allow_null=True) + output = field.run_validation(None) + assert output is None - def test_to_native(self): + def test_disallow_blank(self): """ - Make sure to_native() returns datetime as default. + By default '' is not a valid input. """ - f = serializers.DateField() - - result_1 = f.to_native(datetime.date(1984, 7, 31)) - - self.assertEqual(datetime.date(1984, 7, 31), result_1) + field = serializers.CharField() + with pytest.raises(serializers.ValidationError) as exc_info: + field.run_validation('') + assert exc_info.value.detail == ['This field may not be blank.'] - def test_to_native_iso(self): + def test_allow_blank(self): """ - Make sure to_native() with 'iso-8601' returns iso formated date. + If `allow_blank=True` then '' is a valid input. """ - f = serializers.DateField(format='iso-8601') + field = serializers.CharField(allow_blank=True) + output = field.run_validation('') + assert output == '' - result_1 = f.to_native(datetime.date(1984, 7, 31)) - - self.assertEqual('1984-07-31', result_1) - - def test_to_native_custom_format(self): + def test_default(self): """ - Make sure to_native() returns correct custom format. + If `default` is set, then omitted values get the default input. """ - f = serializers.DateField(format="%Y - %m.%d") - - result_1 = f.to_native(datetime.date(1984, 7, 31)) - - self.assertEqual('1984 - 07.31', result_1) + field = serializers.IntegerField(default=123) + output = field.run_validation() + assert output is 123 - def test_to_native_none(self): - """ - Make sure from_native() returns None on None param. - """ - f = serializers.DateField(required=False) - self.assertEqual(None, f.to_native(None)) +class TestSource: + def test_source(self): + class ExampleSerializer(serializers.Serializer): + example_field = serializers.CharField(source='other') + serializer = ExampleSerializer(data={'example_field': 'abc'}) + assert serializer.is_valid() + assert serializer.validated_data == {'other': 'abc'} -class DateTimeFieldTest(TestCase): - """ - Tests for the DateTimeField from_native() and to_native() behavior - """ + def test_redundant_source(self): + class ExampleSerializer(serializers.Serializer): + example_field = serializers.CharField(source='example_field') + with pytest.raises(AssertionError) as exc_info: + ExampleSerializer().fields + assert str(exc_info.value) == ( + "It is redundant to specify `source='example_field'` on field " + "'CharField' in serializer 'ExampleSerializer', because it is the " + "same as the field name. Remove the `source` keyword argument." + ) - def test_from_native_string(self): - """ - Make sure from_native() accepts default iso input formats. - """ - f = serializers.DateTimeField() - result_1 = f.from_native('1984-07-31 04:31') - result_2 = f.from_native('1984-07-31 04:31:59') - result_3 = f.from_native('1984-07-31 04:31:59.000200') - self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_1) - self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_2) - self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_3) +class TestReadOnly: + def setup(self): + class TestSerializer(serializers.Serializer): + read_only = serializers.ReadOnlyField() + writable = serializers.IntegerField() + self.Serializer = TestSerializer - def test_from_native_datetime_datetime(self): + def test_validate_read_only(self): """ - Make sure from_native() accepts a datetime.datetime instance. + Read-only serializers.should not be included in validation. """ - f = serializers.DateTimeField() - result_1 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31)) - result_2 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31, 59)) - result_3 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) + data = {'read_only': 123, 'writable': 456} + serializer = self.Serializer(data=data) + assert serializer.is_valid() + assert serializer.validated_data == {'writable': 456} - self.assertEqual(result_1, datetime.datetime(1984, 7, 31, 4, 31)) - self.assertEqual(result_2, datetime.datetime(1984, 7, 31, 4, 31, 59)) - self.assertEqual(result_3, datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) - - def test_from_native_custom_format(self): + def test_serialize_read_only(self): """ - Make sure from_native() accepts custom input formats. + Read-only serializers.should be serialized. """ - f = serializers.DateTimeField(input_formats=['%Y -- %H:%M']) - result = f.from_native('1984 -- 04:59') - - self.assertEqual(datetime.datetime(1984, 1, 1, 4, 59), result) - - def test_from_native_invalid_default_on_custom_format(self): - """ - Make sure from_native() don't accept default formats if custom format is preset - """ - f = serializers.DateTimeField(input_formats=['%Y -- %H:%M']) - - try: - f.from_native('1984-07-31 04:31:59') - except validators.ValidationError as e: - self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: YYYY -- hh:mm"]) - else: - self.fail("ValidationError was not properly raised") + instance = {'read_only': 123, 'writable': 456} + serializer = self.Serializer(instance) + assert serializer.data == {'read_only': 123, 'writable': 456} - def test_from_native_empty(self): - """ - Make sure from_native() returns None on empty param. - """ - f = serializers.DateTimeField() - result = f.from_native('') - self.assertEqual(result, None) +class TestWriteOnly: + def setup(self): + class TestSerializer(serializers.Serializer): + write_only = serializers.IntegerField(write_only=True) + readable = serializers.IntegerField() + self.Serializer = TestSerializer - def test_from_native_none(self): + def test_validate_write_only(self): """ - Make sure from_native() returns None on None param. + Write-only serializers.should be included in validation. """ - f = serializers.DateTimeField() - result = f.from_native(None) - - self.assertEqual(result, None) + data = {'write_only': 123, 'readable': 456} + serializer = self.Serializer(data=data) + assert serializer.is_valid() + assert serializer.validated_data == {'write_only': 123, 'readable': 456} - def test_from_native_invalid_datetime(self): + def test_serialize_write_only(self): """ - Make sure from_native() raises a ValidationError on passing an invalid datetime. + Write-only serializers.should not be serialized. """ - f = serializers.DateTimeField() + instance = {'write_only': 123, 'readable': 456} + serializer = self.Serializer(instance) + assert serializer.data == {'readable': 456} - try: - f.from_native('04:61:59') - except validators.ValidationError as e: - self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: " - "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]"]) - else: - self.fail("ValidationError was not properly raised") - def test_from_native_invalid_format(self): - """ - Make sure from_native() raises a ValidationError on passing an invalid format. - """ - f = serializers.DateTimeField() +class TestInitial: + def setup(self): + class TestSerializer(serializers.Serializer): + initial_field = serializers.IntegerField(initial=123) + blank_field = serializers.IntegerField() + self.serializer = TestSerializer() - try: - f.from_native('04 -- 31') - except validators.ValidationError as e: - self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: " - "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]"]) - else: - self.fail("ValidationError was not properly raised") - - def test_to_native(self): + def test_initial(self): """ - Make sure to_native() returns isoformat as default. + Initial values should be included when serializing a new representation. """ - f = serializers.DateTimeField() + assert self.serializer.data == { + 'initial_field': 123, + 'blank_field': None + } - result_1 = f.to_native(datetime.datetime(1984, 7, 31)) - result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31)) - result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59)) - result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) - self.assertEqual(datetime.datetime(1984, 7, 31), result_1) - self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_2) - self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_3) - self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_4) +class TestLabel: + def setup(self): + class TestSerializer(serializers.Serializer): + labeled = serializers.IntegerField(label='My label') + self.serializer = TestSerializer() - def test_to_native_iso(self): + def test_label(self): """ - Make sure to_native() with format=iso-8601 returns iso formatted datetime. + A field's label may be set with the `label` argument. """ - f = serializers.DateTimeField(format='iso-8601') + fields = self.serializer.fields + assert fields['labeled'].label == 'My label' - result_1 = f.to_native(datetime.datetime(1984, 7, 31)) - result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31)) - result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59)) - result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) - self.assertEqual('1984-07-31T00:00:00', result_1) - self.assertEqual('1984-07-31T04:31:00', result_2) - self.assertEqual('1984-07-31T04:31:59', result_3) - self.assertEqual('1984-07-31T04:31:59.000200', result_4) +class TestInvalidErrorKey: + def setup(self): + class ExampleField(serializers.Field): + def to_native(self, data): + self.fail('incorrect') + self.field = ExampleField() - def test_to_native_custom_format(self): + def test_invalid_error_key(self): """ - Make sure to_native() returns correct custom format. + If a field raises a validation error, but does not have a corresponding + error message, then raise an appropriate assertion error. """ - f = serializers.DateTimeField(format="%Y - %H:%M") + with pytest.raises(AssertionError) as exc_info: + self.field.to_native(123) + expected = ( + 'ValidationError raised by `ExampleField`, but error key ' + '`incorrect` does not exist in the `error_messages` dictionary.' + ) + assert str(exc_info.value) == expected - result_1 = f.to_native(datetime.datetime(1984, 7, 31)) - result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31)) - result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59)) - result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) - self.assertEqual('1984 - 00:00', result_1) - self.assertEqual('1984 - 04:31', result_2) - self.assertEqual('1984 - 04:31', result_3) - self.assertEqual('1984 - 04:31', result_4) +class TestBooleanHTMLInput: + def setup(self): + class TestSerializer(serializers.Serializer): + archived = serializers.BooleanField() + self.Serializer = TestSerializer - def test_to_native_none(self): + def test_empty_html_checkbox(self): """ - Make sure from_native() returns None on None param. + HTML checkboxes do not send any value, but should be treated + as `False` by BooleanField. """ - f = serializers.DateTimeField(required=False) - self.assertEqual(None, f.to_native(None)) + # This class mocks up a dictionary like object, that behaves + # as if it was returned for multipart or urlencoded data. + class MockHTMLDict(dict): + getlist = None + serializer = self.Serializer(data=MockHTMLDict()) + assert serializer.is_valid() + assert serializer.validated_data == {'archived': False} -class TimeFieldTest(TestCase): +class MockHTMLDict(dict): """ - Tests for the TimeField from_native() and to_native() behavior + This class mocks up a dictionary like object, that behaves + as if it was returned for multipart or urlencoded data. """ + getlist = None - def test_from_native_string(self): - """ - Make sure from_native() accepts default iso input formats. - """ - f = serializers.TimeField() - result_1 = f.from_native('04:31') - result_2 = f.from_native('04:31:59') - result_3 = f.from_native('04:31:59.000200') - self.assertEqual(datetime.time(4, 31), result_1) - self.assertEqual(datetime.time(4, 31, 59), result_2) - self.assertEqual(datetime.time(4, 31, 59, 200), result_3) +class TestCharHTMLInput: + def test_empty_html_checkbox(self): + class TestSerializer(serializers.Serializer): + message = serializers.CharField(default='happy') - def test_from_native_datetime_time(self): - """ - Make sure from_native() accepts a datetime.time instance. - """ - f = serializers.TimeField() - result_1 = f.from_native(datetime.time(4, 31)) - result_2 = f.from_native(datetime.time(4, 31, 59)) - result_3 = f.from_native(datetime.time(4, 31, 59, 200)) + serializer = TestSerializer(data=MockHTMLDict()) + assert serializer.is_valid() + assert serializer.validated_data == {'message': 'happy'} - self.assertEqual(result_1, datetime.time(4, 31)) - self.assertEqual(result_2, datetime.time(4, 31, 59)) - self.assertEqual(result_3, datetime.time(4, 31, 59, 200)) + def test_empty_html_checkbox_allow_null(self): + class TestSerializer(serializers.Serializer): + message = serializers.CharField(allow_null=True) - def test_from_native_custom_format(self): - """ - Make sure from_native() accepts custom input formats. - """ - f = serializers.TimeField(input_formats=['%H -- %M']) - result = f.from_native('04 -- 31') + serializer = TestSerializer(data=MockHTMLDict()) + assert serializer.is_valid() + assert serializer.validated_data == {'message': None} - self.assertEqual(datetime.time(4, 31), result) + def test_empty_html_checkbox_allow_null_allow_blank(self): + class TestSerializer(serializers.Serializer): + message = serializers.CharField(allow_null=True, allow_blank=True) - def test_from_native_invalid_default_on_custom_format(self): - """ - Make sure from_native() don't accept default formats if custom format is preset - """ - f = serializers.TimeField(input_formats=['%H -- %M']) + serializer = TestSerializer(data=MockHTMLDict({})) + assert serializer.is_valid() + assert serializer.validated_data == {'message': ''} - try: - f.from_native('04:31:59') - except validators.ValidationError as e: - self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: hh -- mm"]) - else: - self.fail("ValidationError was not properly raised") - - def test_from_native_empty(self): - """ - Make sure from_native() returns None on empty param. - """ - f = serializers.TimeField() - result = f.from_native('') - - self.assertEqual(result, None) - - def test_from_native_none(self): - """ - Make sure from_native() returns None on None param. - """ - f = serializers.TimeField() - result = f.from_native(None) - - self.assertEqual(result, None) - - def test_from_native_invalid_time(self): - """ - Make sure from_native() raises a ValidationError on passing an invalid time. - """ - f = serializers.TimeField() + def test_empty_html_required_false(self): + class TestSerializer(serializers.Serializer): + message = serializers.CharField(required=False) - try: - f.from_native('04:61:59') - except validators.ValidationError as e: - self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: " - "hh:mm[:ss[.uuuuuu]]"]) - else: - self.fail("ValidationError was not properly raised") + serializer = TestSerializer(data=MockHTMLDict()) + assert serializer.is_valid() + assert serializer.validated_data == {} - def test_from_native_invalid_format(self): - """ - Make sure from_native() raises a ValidationError on passing an invalid format. - """ - f = serializers.TimeField() - try: - f.from_native('04 -- 31') - except validators.ValidationError as e: - self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: " - "hh:mm[:ss[.uuuuuu]]"]) - else: - self.fail("ValidationError was not properly raised") +class TestCreateOnlyDefault: + def setup(self): + default = serializers.CreateOnlyDefault('2001-01-01') - def test_to_native(self): - """ - Make sure to_native() returns time object as default. - """ - f = serializers.TimeField() - result_1 = f.to_native(datetime.time(4, 31)) - result_2 = f.to_native(datetime.time(4, 31, 59)) - result_3 = f.to_native(datetime.time(4, 31, 59, 200)) + class TestSerializer(serializers.Serializer): + published = serializers.HiddenField(default=default) + text = serializers.CharField() + self.Serializer = TestSerializer - self.assertEqual(datetime.time(4, 31), result_1) - self.assertEqual(datetime.time(4, 31, 59), result_2) - self.assertEqual(datetime.time(4, 31, 59, 200), result_3) + def test_create_only_default_is_provided(self): + serializer = self.Serializer(data={'text': 'example'}) + assert serializer.is_valid() + assert serializer.validated_data == { + 'text': 'example', 'published': '2001-01-01' + } - def test_to_native_iso(self): - """ - Make sure to_native() with format='iso-8601' returns iso formatted time. - """ - f = serializers.TimeField(format='iso-8601') - result_1 = f.to_native(datetime.time(4, 31)) - result_2 = f.to_native(datetime.time(4, 31, 59)) - result_3 = f.to_native(datetime.time(4, 31, 59, 200)) + def test_create_only_default_is_not_provided_on_update(self): + instance = { + 'text': 'example', 'published': '2001-01-01' + } + serializer = self.Serializer(instance, data={'text': 'example'}) + assert serializer.is_valid() + assert serializer.validated_data == { + 'text': 'example', + } - self.assertEqual('04:31:00', result_1) - self.assertEqual('04:31:59', result_2) - self.assertEqual('04:31:59.000200', result_3) - def test_to_native_custom_format(self): - """ - Make sure to_native() returns correct custom format. - """ - f = serializers.TimeField(format="%H - %S [%f]") - result_1 = f.to_native(datetime.time(4, 31)) - result_2 = f.to_native(datetime.time(4, 31, 59)) - result_3 = f.to_native(datetime.time(4, 31, 59, 200)) +# Tests for field input and output values. +# ---------------------------------------- - self.assertEqual('04 - 00 [000000]', result_1) - self.assertEqual('04 - 59 [000000]', result_2) - self.assertEqual('04 - 59 [000200]', result_3) +def get_items(mapping_or_list_of_two_tuples): + # Tests accept either lists of two tuples, or dictionaries. + if isinstance(mapping_or_list_of_two_tuples, dict): + # {value: expected} + return mapping_or_list_of_two_tuples.items() + # [(value, expected), ...] + return mapping_or_list_of_two_tuples -class DecimalFieldTest(TestCase): +class FieldValues: """ - Tests for the DecimalField from_native() and to_native() behavior + Base class for testing valid and invalid input values. """ - - def test_from_native_string(self): + def test_valid_inputs(self): """ - Make sure from_native() accepts string values + Ensure that valid values return the expected validated data. """ - f = serializers.DecimalField() - result_1 = f.from_native('9000') - result_2 = f.from_native('1.00000001') - - self.assertEqual(Decimal('9000'), result_1) - self.assertEqual(Decimal('1.00000001'), result_2) + for input_value, expected_output in get_items(self.valid_inputs): + assert self.field.run_validation(input_value) == expected_output - def test_from_native_invalid_string(self): + def test_invalid_inputs(self): """ - Make sure from_native() raises ValidationError on passing invalid string + Ensure that invalid values raise the expected validation error. """ - f = serializers.DecimalField() + for input_value, expected_failure in get_items(self.invalid_inputs): + with pytest.raises(serializers.ValidationError) as exc_info: + self.field.run_validation(input_value) + assert exc_info.value.detail == expected_failure - try: - f.from_native('123.45.6') - except validators.ValidationError as e: - self.assertEqual(e.messages, ["Enter a number."]) - else: - self.fail("ValidationError was not properly raised") + def test_outputs(self): + for output_value, expected_output in get_items(self.outputs): + assert self.field.to_representation(output_value) == expected_output - def test_from_native_integer(self): - """ - Make sure from_native() accepts integer values - """ - f = serializers.DecimalField() - result = f.from_native(9000) - self.assertEqual(Decimal('9000'), result) +# Boolean types... - def test_from_native_float(self): - """ - Make sure from_native() accepts float values - """ - f = serializers.DecimalField() - result = f.from_native(1.00000001) +class TestBooleanField(FieldValues): + """ + Valid and invalid values for `BooleanField`. + """ + valid_inputs = { + 'true': True, + 'false': False, + '1': True, + '0': False, + 1: True, + 0: False, + True: True, + False: False, + } + invalid_inputs = { + 'foo': ['`foo` is not a valid boolean.'], + None: ['This field may not be null.'] + } + outputs = { + 'true': True, + 'false': False, + '1': True, + '0': False, + 1: True, + 0: False, + True: True, + False: False, + 'other': True + } + field = serializers.BooleanField() + + +class TestNullBooleanField(FieldValues): + """ + Valid and invalid values for `BooleanField`. + """ + valid_inputs = { + 'true': True, + 'false': False, + 'null': None, + True: True, + False: False, + None: None + } + invalid_inputs = { + 'foo': ['`foo` is not a valid boolean.'], + } + outputs = { + 'true': True, + 'false': False, + 'null': None, + True: True, + False: False, + None: None, + 'other': True + } + field = serializers.NullBooleanField() + + +# String types... + +class TestCharField(FieldValues): + """ + Valid and invalid values for `CharField`. + """ + valid_inputs = { + 1: '1', + 'abc': 'abc' + } + invalid_inputs = { + '': ['This field may not be blank.'] + } + outputs = { + 1: '1', + 'abc': 'abc' + } + field = serializers.CharField() + + +class TestEmailField(FieldValues): + """ + Valid and invalid values for `EmailField`. + """ + valid_inputs = { + 'example@example.com': 'example@example.com', + ' example@example.com ': 'example@example.com', + } + invalid_inputs = { + 'examplecom': ['Enter a valid email address.'] + } + outputs = {} + field = serializers.EmailField() + + +class TestRegexField(FieldValues): + """ + Valid and invalid values for `RegexField`. + """ + valid_inputs = { + 'a9': 'a9', + } + invalid_inputs = { + 'A9': ["This value does not match the required pattern."] + } + outputs = {} + field = serializers.RegexField(regex='[a-z][0-9]') - self.assertEqual(Decimal('1.00000001'), result) - def test_from_native_empty(self): - """ - Make sure from_native() returns None on empty param. - """ - f = serializers.DecimalField() - result = f.from_native('') +class TestSlugField(FieldValues): + """ + Valid and invalid values for `SlugField`. + """ + valid_inputs = { + 'slug-99': 'slug-99', + } + invalid_inputs = { + 'slug 99': ["Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens."] + } + outputs = {} + field = serializers.SlugField() - self.assertEqual(result, None) - def test_from_native_none(self): - """ - Make sure from_native() returns None on None param. - """ - f = serializers.DecimalField() - result = f.from_native(None) +class TestURLField(FieldValues): + """ + Valid and invalid values for `URLField`. + """ + valid_inputs = { + 'http://example.com': 'http://example.com', + } + invalid_inputs = { + 'example.com': ['Enter a valid URL.'] + } + outputs = {} + field = serializers.URLField() - self.assertEqual(result, None) - def test_to_native(self): - """ - Make sure to_native() returns Decimal as string. - """ - f = serializers.DecimalField() +# Number types... - result_1 = f.to_native(Decimal('9000')) - result_2 = f.to_native(Decimal('1.00000001')) - - self.assertEqual(Decimal('9000'), result_1) - self.assertEqual(Decimal('1.00000001'), result_2) +class TestIntegerField(FieldValues): + """ + Valid and invalid values for `IntegerField`. + """ + valid_inputs = { + '1': 1, + '0': 0, + 1: 1, + 0: 0, + 1.0: 1, + 0.0: 0 + } + invalid_inputs = { + 'abc': ['A valid integer is required.'] + } + outputs = { + '1': 1, + '0': 0, + 1: 1, + 0: 0, + 1.0: 1, + 0.0: 0 + } + field = serializers.IntegerField() + + +class TestMinMaxIntegerField(FieldValues): + """ + Valid and invalid values for `IntegerField` with min and max limits. + """ + valid_inputs = { + '1': 1, + '3': 3, + 1: 1, + 3: 3, + } + invalid_inputs = { + 0: ['Ensure this value is greater than or equal to 1.'], + 4: ['Ensure this value is less than or equal to 3.'], + '0': ['Ensure this value is greater than or equal to 1.'], + '4': ['Ensure this value is less than or equal to 3.'], + } + outputs = {} + field = serializers.IntegerField(min_value=1, max_value=3) + + +class TestFloatField(FieldValues): + """ + Valid and invalid values for `FloatField`. + """ + valid_inputs = { + '1': 1.0, + '0': 0.0, + 1: 1.0, + 0: 0.0, + 1.0: 1.0, + 0.0: 0.0, + } + invalid_inputs = { + 'abc': ["A valid number is required."] + } + outputs = { + '1': 1.0, + '0': 0.0, + 1: 1.0, + 0: 0.0, + 1.0: 1.0, + 0.0: 0.0, + } + field = serializers.FloatField() + + +class TestMinMaxFloatField(FieldValues): + """ + Valid and invalid values for `FloatField` with min and max limits. + """ + valid_inputs = { + '1': 1, + '3': 3, + 1: 1, + 3: 3, + 1.0: 1.0, + 3.0: 3.0, + } + invalid_inputs = { + 0.9: ['Ensure this value is greater than or equal to 1.'], + 3.1: ['Ensure this value is less than or equal to 3.'], + '0.0': ['Ensure this value is greater than or equal to 1.'], + '3.1': ['Ensure this value is less than or equal to 3.'], + } + outputs = {} + field = serializers.FloatField(min_value=1, max_value=3) + + +class TestDecimalField(FieldValues): + """ + Valid and invalid values for `DecimalField`. + """ + valid_inputs = { + '12.3': Decimal('12.3'), + '0.1': Decimal('0.1'), + 10: Decimal('10'), + 0: Decimal('0'), + 12.3: Decimal('12.3'), + 0.1: Decimal('0.1'), + } + invalid_inputs = ( + ('abc', ["A valid number is required."]), + (Decimal('Nan'), ["A valid number is required."]), + (Decimal('Inf'), ["A valid number is required."]), + ('12.345', ["Ensure that there are no more than 3 digits in total."]), + ('0.01', ["Ensure that there are no more than 1 decimal places."]), + (123, ["Ensure that there are no more than 2 digits before the decimal point."]) + ) + outputs = { + '1': '1.0', + '0': '0.0', + '1.09': '1.1', + '0.04': '0.0', + 1: '1.0', + 0: '0.0', + Decimal('1.0'): '1.0', + Decimal('0.0'): '0.0', + Decimal('1.09'): '1.1', + Decimal('0.04'): '0.0' + } + field = serializers.DecimalField(max_digits=3, decimal_places=1) + + +class TestMinMaxDecimalField(FieldValues): + """ + Valid and invalid values for `DecimalField` with min and max limits. + """ + valid_inputs = { + '10.0': Decimal('10.0'), + '20.0': Decimal('20.0'), + } + invalid_inputs = { + '9.9': ['Ensure this value is greater than or equal to 10.'], + '20.1': ['Ensure this value is less than or equal to 20.'], + } + outputs = {} + field = serializers.DecimalField( + max_digits=3, decimal_places=1, + min_value=10, max_value=20 + ) + + +class TestNoStringCoercionDecimalField(FieldValues): + """ + Output values for `DecimalField` with `coerce_to_string=False`. + """ + valid_inputs = {} + invalid_inputs = {} + outputs = { + 1.09: Decimal('1.1'), + 0.04: Decimal('0.0'), + '1.09': Decimal('1.1'), + '0.04': Decimal('0.0'), + Decimal('1.09'): Decimal('1.1'), + Decimal('0.04'): Decimal('0.0'), + } + field = serializers.DecimalField( + max_digits=3, decimal_places=1, + coerce_to_string=False + ) + + +# Date & time serializers... + +class TestDateField(FieldValues): + """ + Valid and invalid values for `DateField`. + """ + valid_inputs = { + '2001-01-01': datetime.date(2001, 1, 1), + datetime.date(2001, 1, 1): datetime.date(2001, 1, 1), + } + invalid_inputs = { + 'abc': ['Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]'], + '2001-99-99': ['Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]'], + datetime.datetime(2001, 1, 1, 12, 00): ['Expected a date but got a datetime.'], + } + outputs = { + datetime.date(2001, 1, 1): '2001-01-01' + } + field = serializers.DateField() + + +class TestCustomInputFormatDateField(FieldValues): + """ + Valid and invalid values for `DateField` with a cutom input format. + """ + valid_inputs = { + '1 Jan 2001': datetime.date(2001, 1, 1), + } + invalid_inputs = { + '2001-01-01': ['Date has wrong format. Use one of these formats instead: DD [Jan-Dec] YYYY'] + } + outputs = {} + field = serializers.DateField(input_formats=['%d %b %Y']) - def test_to_native_none(self): - """ - Make sure from_native() returns None on None param. - """ - f = serializers.DecimalField(required=False) - self.assertEqual(None, f.to_native(None)) - def test_valid_serialization(self): - """ - Make sure the serializer works correctly - """ - class DecimalSerializer(serializers.Serializer): - decimal_field = serializers.DecimalField(max_value=9010, - min_value=9000, - max_digits=6, - decimal_places=2) +class TestCustomOutputFormatDateField(FieldValues): + """ + Values for `DateField` with a custom output format. + """ + valid_inputs = {} + invalid_inputs = {} + outputs = { + datetime.date(2001, 1, 1): '01 Jan 2001' + } + field = serializers.DateField(format='%d %b %Y') - self.assertTrue(DecimalSerializer(data={'decimal_field': '9001'}).is_valid()) - self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.2'}).is_valid()) - self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.23'}).is_valid()) - self.assertFalse(DecimalSerializer(data={'decimal_field': '8000'}).is_valid()) - self.assertFalse(DecimalSerializer(data={'decimal_field': '9900'}).is_valid()) - self.assertFalse(DecimalSerializer(data={'decimal_field': '9001.234'}).is_valid()) +class TestNoOutputFormatDateField(FieldValues): + """ + Values for `DateField` with no output format. + """ + valid_inputs = {} + invalid_inputs = {} + outputs = { + datetime.date(2001, 1, 1): datetime.date(2001, 1, 1) + } + field = serializers.DateField(format=None) - def test_raise_max_value(self): - """ - Make sure max_value violations raises ValidationError - """ - class DecimalSerializer(serializers.Serializer): - decimal_field = serializers.DecimalField(max_value=100) - s = DecimalSerializer(data={'decimal_field': '123'}) +class TestDateTimeField(FieldValues): + """ + Valid and invalid values for `DateTimeField`. + """ + valid_inputs = { + '2001-01-01 13:00': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()), + '2001-01-01T13:00': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()), + '2001-01-01T13:00Z': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()), + datetime.datetime(2001, 1, 1, 13, 00): datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()), + datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()): datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()), + # Django 1.4 does not support timezone string parsing. + '2001-01-01T14:00+01:00' if (django.VERSION > (1, 4)) else '2001-01-01T13:00Z': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()) + } + invalid_inputs = { + 'abc': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]'], + '2001-99-99T99:00': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]'], + datetime.date(2001, 1, 1): ['Expected a datetime but got a date.'], + } + outputs = { + datetime.datetime(2001, 1, 1, 13, 00): '2001-01-01T13:00:00', + datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()): '2001-01-01T13:00:00Z' + } + field = serializers.DateTimeField(default_timezone=timezone.UTC()) + + +class TestCustomInputFormatDateTimeField(FieldValues): + """ + Valid and invalid values for `DateTimeField` with a cutom input format. + """ + valid_inputs = { + '1:35pm, 1 Jan 2001': datetime.datetime(2001, 1, 1, 13, 35, tzinfo=timezone.UTC()), + } + invalid_inputs = { + '2001-01-01T20:50': ['Datetime has wrong format. Use one of these formats instead: hh:mm[AM|PM], DD [Jan-Dec] YYYY'] + } + outputs = {} + field = serializers.DateTimeField(default_timezone=timezone.UTC(), input_formats=['%I:%M%p, %d %b %Y']) - self.assertFalse(s.is_valid()) - self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is less than or equal to 100.']}) - def test_raise_min_value(self): - """ - Make sure min_value violations raises ValidationError - """ - class DecimalSerializer(serializers.Serializer): - decimal_field = serializers.DecimalField(min_value=100) +class TestCustomOutputFormatDateTimeField(FieldValues): + """ + Values for `DateTimeField` with a custom output format. + """ + valid_inputs = {} + invalid_inputs = {} + outputs = { + datetime.datetime(2001, 1, 1, 13, 00): '01:00PM, 01 Jan 2001', + } + field = serializers.DateTimeField(format='%I:%M%p, %d %b %Y') - s = DecimalSerializer(data={'decimal_field': '99'}) - self.assertFalse(s.is_valid()) - self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is greater than or equal to 100.']}) +class TestNoOutputFormatDateTimeField(FieldValues): + """ + Values for `DateTimeField` with no output format. + """ + valid_inputs = {} + invalid_inputs = {} + outputs = { + datetime.datetime(2001, 1, 1, 13, 00): datetime.datetime(2001, 1, 1, 13, 00), + } + field = serializers.DateTimeField(format=None) - def test_raise_max_digits(self): - """ - Make sure max_digits violations raises ValidationError - """ - class DecimalSerializer(serializers.Serializer): - decimal_field = serializers.DecimalField(max_digits=5) - s = DecimalSerializer(data={'decimal_field': '123.456'}) +class TestNaiveDateTimeField(FieldValues): + """ + Valid and invalid values for `DateTimeField` with naive datetimes. + """ + valid_inputs = { + datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()): datetime.datetime(2001, 1, 1, 13, 00), + '2001-01-01 13:00': datetime.datetime(2001, 1, 1, 13, 00), + } + invalid_inputs = {} + outputs = {} + field = serializers.DateTimeField(default_timezone=None) - self.assertFalse(s.is_valid()) - self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 5 digits in total.']}) - def test_raise_max_decimal_places(self): - """ - Make sure max_decimal_places violations raises ValidationError - """ - class DecimalSerializer(serializers.Serializer): - decimal_field = serializers.DecimalField(decimal_places=3) +class TestTimeField(FieldValues): + """ + Valid and invalid values for `TimeField`. + """ + valid_inputs = { + '13:00': datetime.time(13, 00), + datetime.time(13, 00): datetime.time(13, 00), + } + invalid_inputs = { + 'abc': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]]'], + '99:99': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]]'], + } + outputs = { + datetime.time(13, 00): '13:00:00' + } + field = serializers.TimeField() + + +class TestCustomInputFormatTimeField(FieldValues): + """ + Valid and invalid values for `TimeField` with a custom input format. + """ + valid_inputs = { + '1:00pm': datetime.time(13, 00), + } + invalid_inputs = { + '13:00': ['Time has wrong format. Use one of these formats instead: hh:mm[AM|PM]'], + } + outputs = {} + field = serializers.TimeField(input_formats=['%I:%M%p']) - s = DecimalSerializer(data={'decimal_field': '123.4567'}) - self.assertFalse(s.is_valid()) - self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 3 decimal places.']}) +class TestCustomOutputFormatTimeField(FieldValues): + """ + Values for `TimeField` with a custom output format. + """ + valid_inputs = {} + invalid_inputs = {} + outputs = { + datetime.time(13, 00): '01:00PM' + } + field = serializers.TimeField(format='%I:%M%p') - def test_raise_max_whole_digits(self): - """ - Make sure max_whole_digits violations raises ValidationError - """ - class DecimalSerializer(serializers.Serializer): - decimal_field = serializers.DecimalField(max_digits=4, decimal_places=3) - s = DecimalSerializer(data={'decimal_field': '12345.6'}) +class TestNoOutputFormatTimeField(FieldValues): + """ + Values for `TimeField` with a no output format. + """ + valid_inputs = {} + invalid_inputs = {} + outputs = { + datetime.time(13, 00): datetime.time(13, 00) + } + field = serializers.TimeField(format=None) - self.assertFalse(s.is_valid()) - self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']}) +# Choice types... -class ChoiceFieldTests(TestCase): +class TestChoiceField(FieldValues): """ - Tests for the ChoiceField options generator + Valid and invalid values for `ChoiceField`. """ - def test_choices_required(self): - """ - Make sure proper choices are rendered if field is required - """ - f = serializers.ChoiceField(required=True, choices=SAMPLE_CHOICES) - self.assertEqual(f.choices, SAMPLE_CHOICES) - - def test_choices_not_required(self): - """ - Make sure proper choices (plus blank) are rendered if the field isn't required - """ - f = serializers.ChoiceField(required=False, choices=SAMPLE_CHOICES) - self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + SAMPLE_CHOICES) - - def test_blank_choice_display(self): - blank = 'No Preference' - f = serializers.ChoiceField( - required=False, - choices=SAMPLE_CHOICES, - blank_display_value=blank, + valid_inputs = { + 'poor': 'poor', + 'medium': 'medium', + 'good': 'good', + } + invalid_inputs = { + 'amazing': ['`amazing` is not a valid choice.'] + } + outputs = { + 'good': 'good', + '': '' + } + field = serializers.ChoiceField( + choices=[ + ('poor', 'Poor quality'), + ('medium', 'Medium quality'), + ('good', 'Good quality'), + ] + ) + + def test_allow_blank(self): + """ + If `allow_blank=True` then '' is a valid input. + """ + field = serializers.ChoiceField( + allow_blank=True, + choices=[ + ('poor', 'Poor quality'), + ('medium', 'Medium quality'), + ('good', 'Good quality'), + ] ) - self.assertEqual(f.choices, [('', blank)] + SAMPLE_CHOICES) - - def test_invalid_choice_model(self): - s = ChoiceFieldModelSerializer(data={'choice': 'wrong_value'}) - self.assertFalse(s.is_valid()) - self.assertEqual(s.errors, {'choice': ['Select a valid choice. wrong_value is not one of the available choices.']}) - self.assertEqual(s.data['choice'], '') - - def test_empty_choice_model(self): - """ - Test that the 'empty' value is correctly passed and used depending on - the 'null' property on the model field. - """ - s = ChoiceFieldModelSerializer(data={'choice': ''}) - self.assertTrue(s.is_valid()) - self.assertEqual(s.data['choice'], '') - - s = ChoiceFieldModelWithNullSerializer(data={'choice': ''}) - self.assertTrue(s.is_valid()) - self.assertEqual(s.data['choice'], None) - - def test_from_native_empty(self): - """ - Make sure from_native() returns an empty string on empty param by default. - """ - f = serializers.ChoiceField(choices=SAMPLE_CHOICES) - self.assertEqual(f.from_native(''), '') - self.assertEqual(f.from_native(None), '') - - def test_from_native_empty_override(self): - """ - Make sure you can override from_native() behavior regarding empty values. - """ - f = serializers.ChoiceField(choices=SAMPLE_CHOICES, empty=None) - self.assertEqual(f.from_native(''), None) - self.assertEqual(f.from_native(None), None) + output = field.run_validation('') + assert output == '' - def test_metadata_choices(self): - """ - Make sure proper choices are included in the field's metadata. - """ - choices = [{'value': v, 'display_name': n} for v, n in SAMPLE_CHOICES] - f = serializers.ChoiceField(choices=SAMPLE_CHOICES) - self.assertEqual(f.metadata()['choices'], choices) - def test_metadata_choices_not_required(self): - """ - Make sure proper choices are included in the field's metadata. - """ - choices = [{'value': v, 'display_name': n} - for v, n in models.fields.BLANK_CHOICE_DASH + SAMPLE_CHOICES] - f = serializers.ChoiceField(required=False, choices=SAMPLE_CHOICES) - self.assertEqual(f.metadata()['choices'], choices) - - -class EmailFieldTests(TestCase): +class TestChoiceFieldWithType(FieldValues): """ - Tests for EmailField attribute values + Valid and invalid values for a `Choice` field that uses an integer type, + instead of a char type. """ - - class EmailFieldModel(RESTFrameworkModel): - email_field = models.EmailField(blank=True) - - class EmailFieldWithGivenMaxLengthModel(RESTFrameworkModel): - email_field = models.EmailField(max_length=150, blank=True) - - def test_default_model_value(self): - class EmailFieldSerializer(serializers.ModelSerializer): - class Meta: - model = self.EmailFieldModel - - serializer = EmailFieldSerializer(data={}) - self.assertEqual(serializer.is_valid(), True) - self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 75) - - def test_given_model_value(self): - class EmailFieldSerializer(serializers.ModelSerializer): - class Meta: - model = self.EmailFieldWithGivenMaxLengthModel - - serializer = EmailFieldSerializer(data={}) - self.assertEqual(serializer.is_valid(), True) - self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 150) - - def test_given_serializer_value(self): - class EmailFieldSerializer(serializers.ModelSerializer): - email_field = serializers.EmailField(source='email_field', max_length=20, required=False) - - class Meta: - model = self.EmailFieldModel - - serializer = EmailFieldSerializer(data={}) - self.assertEqual(serializer.is_valid(), True) - self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 20) - - -class SlugFieldTests(TestCase): + valid_inputs = { + '1': 1, + 3: 3, + } + invalid_inputs = { + 5: ['`5` is not a valid choice.'], + 'abc': ['`abc` is not a valid choice.'] + } + outputs = { + '1': 1, + 1: 1 + } + field = serializers.ChoiceField( + choices=[ + (1, 'Poor quality'), + (2, 'Medium quality'), + (3, 'Good quality'), + ] + ) + + +class TestChoiceFieldWithListChoices(FieldValues): """ - Tests for SlugField attribute values + Valid and invalid values for a `Choice` field that uses a flat list for the + choices, rather than a list of pairs of (`value`, `description`). """ - - class SlugFieldModel(RESTFrameworkModel): - slug_field = models.SlugField(blank=True) - - class SlugFieldWithGivenMaxLengthModel(RESTFrameworkModel): - slug_field = models.SlugField(max_length=84, blank=True) - - def test_default_model_value(self): - class SlugFieldSerializer(serializers.ModelSerializer): - class Meta: - model = self.SlugFieldModel - - serializer = SlugFieldSerializer(data={}) - self.assertEqual(serializer.is_valid(), True) - self.assertEqual(getattr(serializer.fields['slug_field'], 'max_length'), 50) - - def test_given_model_value(self): - class SlugFieldSerializer(serializers.ModelSerializer): - class Meta: - model = self.SlugFieldWithGivenMaxLengthModel - - serializer = SlugFieldSerializer(data={}) - self.assertEqual(serializer.is_valid(), True) - self.assertEqual(getattr(serializer.fields['slug_field'], 'max_length'), 84) - - def test_given_serializer_value(self): - class SlugFieldSerializer(serializers.ModelSerializer): - slug_field = serializers.SlugField(source='slug_field', - max_length=20, required=False) - - class Meta: - model = self.SlugFieldModel - - serializer = SlugFieldSerializer(data={}) - self.assertEqual(serializer.is_valid(), True) - self.assertEqual(getattr(serializer.fields['slug_field'], - 'max_length'), 20) - - def test_invalid_slug(self): - """ - Make sure an invalid slug raises ValidationError - """ - class SlugFieldSerializer(serializers.ModelSerializer): - slug_field = serializers.SlugField(source='slug_field', max_length=20, required=True) - - class Meta: - model = self.SlugFieldModel - - s = SlugFieldSerializer(data={'slug_field': 'a b'}) - - self.assertEqual(s.is_valid(), False) - self.assertEqual(s.errors, {'slug_field': ["Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens."]}) - - -class URLFieldTests(TestCase): + valid_inputs = { + 'poor': 'poor', + 'medium': 'medium', + 'good': 'good', + } + invalid_inputs = { + 'awful': ['`awful` is not a valid choice.'] + } + outputs = { + 'good': 'good' + } + field = serializers.ChoiceField(choices=('poor', 'medium', 'good')) + + +class TestMultipleChoiceField(FieldValues): """ - Tests for URLField attribute values. - - (Includes test for #1210, checking that validators can be overridden.) + Valid and invalid values for `MultipleChoiceField`. """ - - class URLFieldModel(RESTFrameworkModel): - url_field = models.URLField(blank=True) - - class URLFieldWithGivenMaxLengthModel(RESTFrameworkModel): - url_field = models.URLField(max_length=128, blank=True) - - def test_default_model_value(self): - class URLFieldSerializer(serializers.ModelSerializer): - class Meta: - model = self.URLFieldModel - - serializer = URLFieldSerializer(data={}) - self.assertEqual(serializer.is_valid(), True) - self.assertEqual(getattr(serializer.fields['url_field'], - 'max_length'), 200) - - def test_given_model_value(self): - class URLFieldSerializer(serializers.ModelSerializer): - class Meta: - model = self.URLFieldWithGivenMaxLengthModel - - serializer = URLFieldSerializer(data={}) - self.assertEqual(serializer.is_valid(), True) - self.assertEqual(getattr(serializer.fields['url_field'], - 'max_length'), 128) - - def test_given_serializer_value(self): - class URLFieldSerializer(serializers.ModelSerializer): - url_field = serializers.URLField(source='url_field', - max_length=20, required=False) - - class Meta: - model = self.URLFieldWithGivenMaxLengthModel - - serializer = URLFieldSerializer(data={}) - self.assertEqual(serializer.is_valid(), True) - self.assertEqual(getattr(serializer.fields['url_field'], - 'max_length'), 20) - - def test_validators_can_be_overridden(self): - url_field = serializers.URLField(validators=[]) - validators = url_field.validators - self.assertEqual([], validators, 'Passing `validators` kwarg should have overridden default validators') - - -class FieldMetadata(TestCase): - def setUp(self): - self.required_field = serializers.Field() - self.required_field.label = uuid4().hex - self.required_field.required = True - - self.optional_field = serializers.Field() - self.optional_field.label = uuid4().hex - self.optional_field.required = False - - def test_required(self): - self.assertEqual(self.required_field.metadata()['required'], True) - - def test_optional(self): - self.assertEqual(self.optional_field.metadata()['required'], False) - - def test_label(self): - for field in (self.required_field, self.optional_field): - self.assertEqual(field.metadata()['label'], field.label) - - -class FieldCallableDefault(TestCase): - def setUp(self): - self.simple_callable = lambda: 'foo bar' - - def test_default_can_be_simple_callable(self): - """ - Ensure that the 'default' argument can also be a simple callable. - """ - field = serializers.WritableField(default=self.simple_callable) - into = {} - field.field_from_native({}, {}, 'field', into) - self.assertEqual(into, {'field': 'foo bar'}) + valid_inputs = { + (): set(), + ('aircon',): set(['aircon']), + ('aircon', 'manual'): set(['aircon', 'manual']), + } + invalid_inputs = { + 'abc': ['Expected a list of items but got type `str`.'], + ('aircon', 'incorrect'): ['`incorrect` is not a valid choice.'] + } + outputs = [ + (['aircon', 'manual'], set(['aircon', 'manual'])) + ] + field = serializers.MultipleChoiceField( + choices=[ + ('aircon', 'AirCon'), + ('manual', 'Manual drive'), + ('diesel', 'Diesel'), + ] + ) + + +# File serializers... + +class MockFile: + def __init__(self, name='', size=0, url=''): + self.name = name + self.size = size + self.url = url + + def __eq__(self, other): + return ( + isinstance(other, MockFile) and + self.name == other.name and + self.size == other.size and + self.url == other.url + ) -class CustomIntegerField(TestCase): +class TestFileField(FieldValues): """ - Test that custom fields apply min_value and max_value constraints + Values for `FileField`. """ - def test_custom_fields_can_be_validated_for_value(self): - - class MoneyField(models.PositiveIntegerField): - pass - - class EntryModel(models.Model): - bank = MoneyField(validators=[validators.MaxValueValidator(100)]) - - class EntrySerializer(serializers.ModelSerializer): - class Meta: - model = EntryModel + valid_inputs = [ + (MockFile(name='example', size=10), MockFile(name='example', size=10)) + ] + invalid_inputs = [ + ('invalid', ['The submitted data was not a file. Check the encoding type on the form.']), + (MockFile(name='example.txt', size=0), ['The submitted file is empty.']), + (MockFile(name='', size=10), ['No filename could be determined.']), + (MockFile(name='x' * 100, size=10), ['Ensure this filename has at most 10 characters (it has 100).']) + ] + outputs = [ + (MockFile(name='example.txt', url='/example.txt'), '/example.txt'), + ('', None) + ] + field = serializers.FileField(max_length=10) + + +class TestFieldFieldWithName(FieldValues): + """ + Values for `FileField` with a filename output instead of URLs. + """ + valid_inputs = {} + invalid_inputs = {} + outputs = [ + (MockFile(name='example.txt', url='/example.txt'), 'example.txt') + ] + field = serializers.FileField(use_url=False) - entry = EntryModel(bank=1) - serializer = EntrySerializer(entry, data={"bank": 11}) - self.assertTrue(serializer.is_valid()) +# Stub out mock Django `forms.ImageField` class so we don't *actually* +# call into it's regular validation, or require PIL for testing. +class FailImageValidation(object): + def to_python(self, value): + raise serializers.ValidationError(self.error_messages['invalid_image']) - serializer = EntrySerializer(entry, data={"bank": -1}) - self.assertFalse(serializer.is_valid()) - serializer = EntrySerializer(entry, data={"bank": 101}) - self.assertFalse(serializer.is_valid()) +class PassImageValidation(object): + def to_python(self, value): + return value -class BooleanField(TestCase): +class TestInvalidImageField(FieldValues): """ - Tests for BooleanField + Values for an invalid `ImageField`. """ - def test_boolean_required(self): - class BooleanRequiredSerializer(serializers.Serializer): - bool_field = serializers.BooleanField(required=True) - - self.assertFalse(BooleanRequiredSerializer(data={}).is_valid()) + valid_inputs = {} + invalid_inputs = [ + (MockFile(name='example.txt', size=10), ['Upload a valid image. The file you uploaded was either not an image or a corrupted image.']) + ] + outputs = {} + field = serializers.ImageField(_DjangoImageField=FailImageValidation) -class ModelCharField(TestCase): +class TestValidImageField(FieldValues): """ - Tests for CharField + Values for an valid `ImageField`. """ - def test_none_serializing(self): - class CharFieldSerializer(serializers.Serializer): - char = serializers.CharField(allow_none=True, required=False) - serializer = CharFieldSerializer(data={'char': None}) - self.assertTrue(serializer.is_valid()) - self.assertIsNone(serializer.object['char']) + valid_inputs = [ + (MockFile(name='example.txt', size=10), MockFile(name='example.txt', size=10)) + ] + invalid_inputs = {} + outputs = {} + field = serializers.ImageField(_DjangoImageField=PassImageValidation) -class SerializerMethodFieldTest(TestCase): +# Composite serializers... + +class TestListField(FieldValues): """ - Tests for the SerializerMethodField field_to_native() behavior + Values for `ListField`. """ - class SerializerTest(serializers.Serializer): - def get_my_test(self, obj): - return obj.my_test[0:5] - - class Example(): - my_test = 'Hey, this is a test !' - - def test_field_to_native(self): - s = serializers.SerializerMethodField('get_my_test') - s.initialize(self.SerializerTest(), 'name') - result = s.field_to_native(self.Example(), None) - self.assertEqual(result, 'Hey, ') + valid_inputs = [ + ([1, 2, 3], [1, 2, 3]), + (['1', '2', '3'], [1, 2, 3]) + ] + invalid_inputs = [ + ('not a list', ['Expected a list of items but got type `str`']), + ([1, 2, 'error'], ['A valid integer is required.']) + ] + outputs = [ + ([1, 2, 3], [1, 2, 3]), + (['1', '2', '3'], [1, 2, 3]) + ] + field = serializers.ListField(child=serializers.IntegerField()) + + +# Tests for FieldField. +# --------------------- + +class MockRequest: + def build_absolute_uri(self, value): + return 'http://example.com' + value + + +class TestFileFieldContext: + def test_fully_qualified_when_request_in_context(self): + field = serializers.FileField(max_length=10) + field._context = {'request': MockRequest()} + obj = MockFile(name='example.txt', url='/example.txt') + value = field.to_representation(obj) + assert value == 'http://example.com/example.txt' + + +# Tests for SerializerMethodField. +# -------------------------------- + +class TestSerializerMethodField: + def test_serializer_method_field(self): + class ExampleSerializer(serializers.Serializer): + example_field = serializers.SerializerMethodField() + + def get_example_field(self, obj): + return 'ran get_example_field(%d)' % obj['example_field'] + + serializer = ExampleSerializer({'example_field': 123}) + assert serializer.data == { + 'example_field': 'ran get_example_field(123)' + } + + def test_redundant_method_name(self): + class ExampleSerializer(serializers.Serializer): + example_field = serializers.SerializerMethodField('get_example_field') + + with pytest.raises(AssertionError) as exc_info: + ExampleSerializer().fields + assert str(exc_info.value) == ( + "It is redundant to specify `get_example_field` on " + "SerializerMethodField 'example_field' in serializer " + "'ExampleSerializer', because it is the same as the default " + "method name. Remove the `method_name` argument." + ) diff --git a/tests/test_files.py b/tests/test_files.py deleted file mode 100644 index de4f71d1..00000000 --- a/tests/test_files.py +++ /dev/null @@ -1,92 +0,0 @@ -from __future__ import unicode_literals -from django.test import TestCase -from django.utils import six -from rest_framework import serializers -from rest_framework.compat import BytesIO -import datetime - - -class UploadedFile(object): - def __init__(self, file=None, created=None): - self.file = file - self.created = created or datetime.datetime.now() - - -class UploadedFileSerializer(serializers.Serializer): - file = serializers.FileField(required=False) - created = serializers.DateTimeField() - - def restore_object(self, attrs, instance=None): - if instance: - instance.file = attrs['file'] - instance.created = attrs['created'] - return instance - return UploadedFile(**attrs) - - -class FileSerializerTests(TestCase): - def test_create(self): - now = datetime.datetime.now() - file = BytesIO(six.b('stuff')) - file.name = 'stuff.txt' - file.size = len(file.getvalue()) - serializer = UploadedFileSerializer(data={'created': now}, files={'file': file}) - uploaded_file = UploadedFile(file=file, created=now) - self.assertTrue(serializer.is_valid()) - self.assertEqual(serializer.object.created, uploaded_file.created) - self.assertEqual(serializer.object.file, uploaded_file.file) - self.assertFalse(serializer.object is uploaded_file) - - def test_creation_failure(self): - """ - Passing files=None should result in an ValidationError - - Regression test for: - https://github.com/tomchristie/django-rest-framework/issues/542 - """ - now = datetime.datetime.now() - - serializer = UploadedFileSerializer(data={'created': now}) - self.assertTrue(serializer.is_valid()) - self.assertEqual(serializer.object.created, now) - self.assertIsNone(serializer.object.file) - - def test_remove_with_empty_string(self): - """ - Passing empty string as data should cause file to be removed - - Test for: - https://github.com/tomchristie/django-rest-framework/issues/937 - """ - now = datetime.datetime.now() - file = BytesIO(six.b('stuff')) - file.name = 'stuff.txt' - file.size = len(file.getvalue()) - - uploaded_file = UploadedFile(file=file, created=now) - - serializer = UploadedFileSerializer(instance=uploaded_file, data={'created': now, 'file': ''}) - self.assertTrue(serializer.is_valid()) - self.assertEqual(serializer.object.created, uploaded_file.created) - self.assertIsNone(serializer.object.file) - - def test_validation_error_with_non_file(self): - """ - Passing non-files should raise a validation error. - """ - now = datetime.datetime.now() - errmsg = 'No file was submitted. Check the encoding type on the form.' - - serializer = UploadedFileSerializer(data={'created': now, 'file': 'abc'}) - self.assertFalse(serializer.is_valid()) - self.assertEqual(serializer.errors, {'file': [errmsg]}) - - def test_validation_with_no_data(self): - """ - Validation should still function when no data dictionary is provided. - """ - uploaded_file = BytesIO(six.b('stuff')) - uploaded_file.name = 'stuff.txt' - uploaded_file.size = len(uploaded_file.getvalue()) - serializer = UploadedFileSerializer(files={'file': uploaded_file}) - self.assertFalse(serializer.is_valid()) diff --git a/tests/test_filters.py b/tests/test_filters.py index 5722fd7c..dc84dcbd 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -2,10 +2,11 @@ from __future__ import unicode_literals import datetime from decimal import Decimal from django.db import models +from django.conf.urls import patterns, url from django.core.urlresolvers import reverse from django.test import TestCase from django.utils import unittest -from django.conf.urls import patterns, url +from django.utils.dateparse import parse_date from rest_framework import generics, serializers, status, filters from rest_framework.compat import django_filters from rest_framework.test import APIRequestFactory @@ -16,9 +17,14 @@ factory = APIRequestFactory() if django_filters: + class FilterableItemSerializer(serializers.ModelSerializer): + class Meta: + model = FilterableItem + # Basic filter on a list view. class FilterFieldsRootView(generics.ListCreateAPIView): - model = FilterableItem + queryset = FilterableItem.objects.all() + serializer_class = FilterableItemSerializer filter_fields = ['decimal', 'date'] filter_backends = (filters.DjangoFilterBackend,) @@ -33,7 +39,8 @@ if django_filters: fields = ['text', 'decimal', 'date'] class FilterClassRootView(generics.ListCreateAPIView): - model = FilterableItem + queryset = FilterableItem.objects.all() + serializer_class = FilterableItemSerializer filter_class = SeveralFieldsFilter filter_backends = (filters.DjangoFilterBackend,) @@ -46,12 +53,14 @@ if django_filters: fields = ['text'] class IncorrectlyConfiguredRootView(generics.ListCreateAPIView): - model = FilterableItem + queryset = FilterableItem.objects.all() + serializer_class = FilterableItemSerializer filter_class = MisconfiguredFilter filter_backends = (filters.DjangoFilterBackend,) class FilterClassDetailView(generics.RetrieveAPIView): - model = FilterableItem + queryset = FilterableItem.objects.all() + serializer_class = FilterableItemSerializer filter_class = SeveralFieldsFilter filter_backends = (filters.DjangoFilterBackend,) @@ -63,15 +72,12 @@ if django_filters: model = BaseFilterableItem class BaseFilterableItemFilterRootView(generics.ListCreateAPIView): - model = FilterableItem + queryset = FilterableItem.objects.all() + serializer_class = FilterableItemSerializer filter_class = BaseFilterableItemFilter filter_backends = (filters.DjangoFilterBackend,) # Regression test for #814 - class FilterableItemSerializer(serializers.ModelSerializer): - class Meta: - model = FilterableItem - class FilterFieldsQuerysetView(generics.ListCreateAPIView): queryset = FilterableItem.objects.all() serializer_class = FilterableItemSerializer @@ -97,7 +103,7 @@ if django_filters: class CommonFilteringTestCase(TestCase): def _serialize_object(self, obj): - return {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} + return {'id': obj.id, 'text': obj.text, 'decimal': str(obj.decimal), 'date': obj.date.isoformat()} def setUp(self): """ @@ -140,7 +146,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): request = factory.get('/', {'decimal': '%s' % search_decimal}) response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) - expected_data = [f for f in self.data if f['decimal'] == search_decimal] + expected_data = [f for f in self.data if Decimal(f['decimal']) == search_decimal] self.assertEqual(response.data, expected_data) # Tests that the date filter works. @@ -148,7 +154,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): request = factory.get('/', {'date': '%s' % search_date}) # search_date str: '2012-09-22' response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) - expected_data = [f for f in self.data if f['date'] == search_date] + expected_data = [f for f in self.data if parse_date(f['date']) == search_date] self.assertEqual(response.data, expected_data) @unittest.skipUnless(django_filters, 'django-filter not installed') @@ -163,7 +169,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): request = factory.get('/', {'decimal': '%s' % search_decimal}) response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) - expected_data = [f for f in self.data if f['decimal'] == search_decimal] + expected_data = [f for f in self.data if Decimal(f['decimal']) == search_decimal] self.assertEqual(response.data, expected_data) @unittest.skipUnless(django_filters, 'django-filter not installed') @@ -196,7 +202,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): request = factory.get('/', {'decimal': '%s' % search_decimal}) response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) - expected_data = [f for f in self.data if f['decimal'] < search_decimal] + expected_data = [f for f in self.data if Decimal(f['decimal']) < search_decimal] self.assertEqual(response.data, expected_data) # Tests that the date filter set with 'gt' in the filter class works. @@ -204,7 +210,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase): request = factory.get('/', {'date': '%s' % search_date}) # search_date str: '2012-10-02' response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) - expected_data = [f for f in self.data if f['date'] > search_date] + expected_data = [f for f in self.data if parse_date(f['date']) > search_date] self.assertEqual(response.data, expected_data) # Tests that the text filter set with 'icontains' in the filter class works. @@ -224,8 +230,8 @@ class IntegrationTestFiltering(CommonFilteringTestCase): }) response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) - expected_data = [f for f in self.data if f['date'] > search_date and - f['decimal'] < search_decimal] + expected_data = [f for f in self.data if parse_date(f['date']) > search_date and + Decimal(f['decimal']) < search_decimal] self.assertEqual(response.data, expected_data) @unittest.skipUnless(django_filters, 'django-filter not installed') @@ -323,6 +329,11 @@ class SearchFilterModel(models.Model): text = models.CharField(max_length=100) +class SearchFilterSerializer(serializers.ModelSerializer): + class Meta: + model = SearchFilterModel + + class SearchFilterTests(TestCase): def setUp(self): # Sequence of title/text is: @@ -342,7 +353,8 @@ class SearchFilterTests(TestCase): def test_search(self): class SearchListView(generics.ListAPIView): - model = SearchFilterModel + queryset = SearchFilterModel.objects.all() + serializer_class = SearchFilterSerializer filter_backends = (filters.SearchFilter,) search_fields = ('title', 'text') @@ -359,7 +371,8 @@ class SearchFilterTests(TestCase): def test_exact_search(self): class SearchListView(generics.ListAPIView): - model = SearchFilterModel + queryset = SearchFilterModel.objects.all() + serializer_class = SearchFilterSerializer filter_backends = (filters.SearchFilter,) search_fields = ('=title', 'text') @@ -375,7 +388,8 @@ class SearchFilterTests(TestCase): def test_startswith_search(self): class SearchListView(generics.ListAPIView): - model = SearchFilterModel + queryset = SearchFilterModel.objects.all() + serializer_class = SearchFilterSerializer filter_backends = (filters.SearchFilter,) search_fields = ('title', '^text') @@ -392,7 +406,8 @@ class SearchFilterTests(TestCase): def test_search_with_nonstandard_search_param(self): with temporary_setting('SEARCH_PARAM', 'query', module=filters): class SearchListView(generics.ListAPIView): - model = SearchFilterModel + queryset = SearchFilterModel.objects.all() + serializer_class = SearchFilterSerializer filter_backends = (filters.SearchFilter,) search_fields = ('title', 'text') @@ -418,6 +433,11 @@ class OrderingFilterRelatedModel(models.Model): related_name="relateds") +class OrderingFilterSerializer(serializers.ModelSerializer): + class Meta: + model = OrderingFilterModel + + class DjangoFilterOrderingModel(models.Model): date = models.DateField() text = models.CharField(max_length=10) @@ -426,6 +446,11 @@ class DjangoFilterOrderingModel(models.Model): ordering = ['-date'] +class DjangoFilterOrderingSerializer(serializers.ModelSerializer): + class Meta: + model = DjangoFilterOrderingModel + + class DjangoFilterOrderingTests(TestCase): def setUp(self): data = [{ @@ -444,7 +469,8 @@ class DjangoFilterOrderingTests(TestCase): def test_default_ordering(self): class DjangoFilterOrderingView(generics.ListAPIView): - model = DjangoFilterOrderingModel + serializer_class = DjangoFilterOrderingSerializer + queryset = DjangoFilterOrderingModel.objects.all() filter_backends = (filters.DjangoFilterBackend,) filter_fields = ['text'] ordering = ('-date',) @@ -456,9 +482,9 @@ class DjangoFilterOrderingTests(TestCase): self.assertEqual( response.data, [ - {'id': 3, 'date': datetime.date(2014, 10, 8), 'text': 'cde'}, - {'id': 2, 'date': datetime.date(2013, 10, 8), 'text': 'bcd'}, - {'id': 1, 'date': datetime.date(2012, 10, 8), 'text': 'abc'} + {'id': 3, 'date': '2014-10-08', 'text': 'cde'}, + {'id': 2, 'date': '2013-10-08', 'text': 'bcd'}, + {'id': 1, 'date': '2012-10-08', 'text': 'abc'} ] ) @@ -485,7 +511,8 @@ class OrderingFilterTests(TestCase): def test_ordering(self): class OrderingListView(generics.ListAPIView): - model = OrderingFilterModel + queryset = OrderingFilterModel.objects.all() + serializer_class = OrderingFilterSerializer filter_backends = (filters.OrderingFilter,) ordering = ('title',) ordering_fields = ('text',) @@ -504,7 +531,8 @@ class OrderingFilterTests(TestCase): def test_reverse_ordering(self): class OrderingListView(generics.ListAPIView): - model = OrderingFilterModel + queryset = OrderingFilterModel.objects.all() + serializer_class = OrderingFilterSerializer filter_backends = (filters.OrderingFilter,) ordering = ('title',) ordering_fields = ('text',) @@ -523,7 +551,8 @@ class OrderingFilterTests(TestCase): def test_incorrectfield_ordering(self): class OrderingListView(generics.ListAPIView): - model = OrderingFilterModel + queryset = OrderingFilterModel.objects.all() + serializer_class = OrderingFilterSerializer filter_backends = (filters.OrderingFilter,) ordering = ('title',) ordering_fields = ('text',) @@ -542,7 +571,8 @@ class OrderingFilterTests(TestCase): def test_default_ordering(self): class OrderingListView(generics.ListAPIView): - model = OrderingFilterModel + queryset = OrderingFilterModel.objects.all() + serializer_class = OrderingFilterSerializer filter_backends = (filters.OrderingFilter,) ordering = ('title',) oredering_fields = ('text',) @@ -561,7 +591,8 @@ class OrderingFilterTests(TestCase): def test_default_ordering_using_string(self): class OrderingListView(generics.ListAPIView): - model = OrderingFilterModel + queryset = OrderingFilterModel.objects.all() + serializer_class = OrderingFilterSerializer filter_backends = (filters.OrderingFilter,) ordering = 'title' ordering_fields = ('text',) @@ -590,7 +621,7 @@ class OrderingFilterTests(TestCase): new_related.save() class OrderingListView(generics.ListAPIView): - model = OrderingFilterModel + serializer_class = OrderingFilterSerializer filter_backends = (filters.OrderingFilter,) ordering = 'title' ordering_fields = '__all__' @@ -612,7 +643,8 @@ class OrderingFilterTests(TestCase): def test_ordering_with_nonstandard_ordering_param(self): with temporary_setting('ORDERING_PARAM', 'order', filters): class OrderingListView(generics.ListAPIView): - model = OrderingFilterModel + queryset = OrderingFilterModel.objects.all() + serializer_class = OrderingFilterSerializer filter_backends = (filters.OrderingFilter,) ordering = ('title',) ordering_fields = ('text',) diff --git a/tests/test_generics.py b/tests/test_generics.py index 97116349..94023c30 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -1,58 +1,75 @@ from __future__ import unicode_literals +import django from django.db import models from django.shortcuts import get_object_or_404 from django.test import TestCase from django.utils import six from rest_framework import generics, renderers, serializers, status from rest_framework.test import APIRequestFactory -from tests.models import BasicModel, Comment, SlugBasedModel +from tests.models import BasicModel, RESTFrameworkModel from tests.models import ForeignKeySource, ForeignKeyTarget factory = APIRequestFactory() -class RootView(generics.ListCreateAPIView): - """ - Example description for OPTIONS. - """ - model = BasicModel +# Models +class SlugBasedModel(RESTFrameworkModel): + text = models.CharField(max_length=100) + slug = models.SlugField(max_length=32) -class InstanceView(generics.RetrieveUpdateDestroyAPIView): - """ - Example description for OPTIONS. - """ - model = BasicModel +# Model for regression test for #285 +class Comment(RESTFrameworkModel): + email = models.EmailField() + content = models.CharField(max_length=200) + created = models.DateTimeField(auto_now_add=True) - def get_queryset(self): - queryset = super(InstanceView, self).get_queryset() - return queryset.exclude(text='filtered out') +# Serializers +class BasicSerializer(serializers.ModelSerializer): + class Meta: + model = BasicModel -class FKInstanceView(generics.RetrieveUpdateDestroyAPIView): - """ - FK: example description for OPTIONS. - """ - model = ForeignKeySource + +class ForeignKeySerializer(serializers.ModelSerializer): + class Meta: + model = ForeignKeySource class SlugSerializer(serializers.ModelSerializer): - slug = serializers.Field() # read only + slug = serializers.ReadOnlyField() class Meta: model = SlugBasedModel - exclude = ('id',) + fields = ('text', 'slug') + + +# Views +class RootView(generics.ListCreateAPIView): + queryset = BasicModel.objects.all() + serializer_class = BasicSerializer + + +class InstanceView(generics.RetrieveUpdateDestroyAPIView): + queryset = BasicModel.objects.exclude(text='filtered out') + serializer_class = BasicSerializer + + +class FKInstanceView(generics.RetrieveUpdateDestroyAPIView): + queryset = ForeignKeySource.objects.all() + serializer_class = ForeignKeySerializer class SlugBasedInstanceView(InstanceView): """ A model with a slug-field. """ - model = SlugBasedModel + queryset = SlugBasedModel.objects.all() serializer_class = SlugSerializer lookup_field = 'slug' +# Tests class TestRootView(TestCase): def setUp(self): """ @@ -112,47 +129,6 @@ class TestRootView(TestCase): self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) self.assertEqual(response.data, {"detail": "Method 'DELETE' not allowed."}) - def test_options_root_view(self): - """ - OPTIONS requests to ListCreateAPIView should return metadata - """ - request = factory.options('/') - with self.assertNumQueries(0): - response = self.view(request).render() - expected = { - 'parses': [ - 'application/json', - 'application/x-www-form-urlencoded', - 'multipart/form-data' - ], - 'renders': [ - 'application/json', - 'text/html' - ], - 'name': 'Root', - 'description': 'Example description for OPTIONS.', - 'actions': { - 'POST': { - 'text': { - 'max_length': 100, - 'read_only': False, - 'required': True, - 'type': 'string', - "label": "Text comes here", - "help_text": "Text description." - }, - 'id': { - 'read_only': True, - 'required': False, - 'type': 'integer', - 'label': 'ID', - }, - } - } - } - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, expected) - def test_post_cannot_set_id(self): """ POST requests to create a new object should not be able to set the id. @@ -167,10 +143,13 @@ class TestRootView(TestCase): self.assertEqual(created.text, 'foobar') +EXPECTED_QUERIES_FOR_PUT = 3 if django.VERSION < (1, 6) else 2 + + class TestInstanceView(TestCase): def setUp(self): """ - Create 3 BasicModel intances. + Create 3 BasicModel instances. """ items = ['foo', 'bar', 'baz', 'filtered out'] for item in items: @@ -210,10 +189,10 @@ class TestInstanceView(TestCase): """ data = {'text': 'foobar'} request = factory.put('/1', data, format='json') - with self.assertNumQueries(2): + with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT): response = self.view(request, pk='1').render() self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) + self.assertEqual(dict(response.data), {'id': 1, 'text': 'foobar'}) updated = self.objects.get(id=1) self.assertEqual(updated.text, 'foobar') @@ -224,7 +203,7 @@ class TestInstanceView(TestCase): data = {'text': 'foobar'} request = factory.patch('/1', data, format='json') - with self.assertNumQueries(2): + with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT): response = self.view(request, pk=1).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) @@ -243,89 +222,6 @@ class TestInstanceView(TestCase): ids = [obj.id for obj in self.objects.all()] self.assertEqual(ids, [2, 3]) - def test_options_instance_view(self): - """ - OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata - """ - request = factory.options('/1') - with self.assertNumQueries(1): - response = self.view(request, pk=1).render() - expected = { - 'parses': [ - 'application/json', - 'application/x-www-form-urlencoded', - 'multipart/form-data' - ], - 'renders': [ - 'application/json', - 'text/html' - ], - 'name': 'Instance', - 'description': 'Example description for OPTIONS.', - 'actions': { - 'PUT': { - 'text': { - 'max_length': 100, - 'read_only': False, - 'required': True, - 'type': 'string', - 'label': 'Text comes here', - 'help_text': 'Text description.' - }, - 'id': { - 'read_only': True, - 'required': False, - 'type': 'integer', - 'label': 'ID', - }, - } - } - } - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, expected) - - def test_options_before_instance_create(self): - """ - OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata - before the instance has been created - """ - request = factory.options('/999') - with self.assertNumQueries(1): - response = self.view(request, pk=999).render() - expected = { - 'parses': [ - 'application/json', - 'application/x-www-form-urlencoded', - 'multipart/form-data' - ], - 'renders': [ - 'application/json', - 'text/html' - ], - 'name': 'Instance', - 'description': 'Example description for OPTIONS.', - 'actions': { - 'PUT': { - 'text': { - 'max_length': 100, - 'read_only': False, - 'required': True, - 'type': 'string', - 'label': 'Text comes here', - 'help_text': 'Text description.' - }, - 'id': { - 'read_only': True, - 'required': False, - 'type': 'integer', - 'label': 'ID', - }, - } - } - } - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, expected) - def test_get_instance_view_incorrect_arg(self): """ GET requests with an incorrect pk type, should raise 404, not 500. @@ -342,7 +238,7 @@ class TestInstanceView(TestCase): """ data = {'id': 999, 'text': 'foobar'} request = factory.put('/1', data, format='json') - with self.assertNumQueries(2): + with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT): response = self.view(request, pk=1).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) @@ -351,18 +247,15 @@ class TestInstanceView(TestCase): def test_put_to_deleted_instance(self): """ - PUT requests to RetrieveUpdateDestroyAPIView should create an object - if it does not currently exist. + PUT requests to RetrieveUpdateDestroyAPIView should return 404 if + an object does not currently exist. """ self.objects.get(id=1).delete() data = {'text': 'foobar'} request = factory.put('/1', data, format='json') - with self.assertNumQueries(3): + with self.assertNumQueries(1): response = self.view(request, pk=1).render() - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - self.assertEqual(response.data, {'id': 1, 'text': 'foobar'}) - updated = self.objects.get(id=1) - self.assertEqual(updated.text, 'foobar') + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) def test_put_to_filtered_out_instance(self): """ @@ -373,35 +266,7 @@ class TestInstanceView(TestCase): filtered_out_pk = BasicModel.objects.filter(text='filtered out')[0].pk request = factory.put('/{0}'.format(filtered_out_pk), data, format='json') response = self.view(request, pk=filtered_out_pk).render() - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - - def test_put_as_create_on_id_based_url(self): - """ - PUT requests to RetrieveUpdateDestroyAPIView should create an object - at the requested url if it doesn't exist. - """ - data = {'text': 'foobar'} - # pk fields can not be created on demand, only the database can set the pk for a new object - request = factory.put('/5', data, format='json') - with self.assertNumQueries(3): - response = self.view(request, pk=5).render() - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - new_obj = self.objects.get(pk=5) - self.assertEqual(new_obj.text, 'foobar') - - def test_put_as_create_on_slug_based_url(self): - """ - PUT requests to RetrieveUpdateDestroyAPIView should create an object - at the requested url if possible, else return HTTP_403_FORBIDDEN error-response. - """ - data = {'text': 'foobar'} - request = factory.put('/test_slug', data, format='json') - with self.assertNumQueries(2): - response = self.slug_based_view(request, slug='test_slug').render() - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - self.assertEqual(response.data, {'slug': 'test_slug', 'text': 'foobar'}) - new_obj = SlugBasedModel.objects.get(slug='test_slug') - self.assertEqual(new_obj.text, 'foobar') + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) def test_patch_cannot_create_an_object(self): """ @@ -433,62 +298,16 @@ class TestFKInstanceView(TestCase): ] self.view = FKInstanceView.as_view() - def test_options_root_view(self): - """ - OPTIONS requests to ListCreateAPIView should return metadata - """ - request = factory.options('/999') - with self.assertNumQueries(1): - response = self.view(request, pk=999).render() - expected = { - 'name': 'Fk Instance', - 'description': 'FK: example description for OPTIONS.', - 'renders': [ - 'application/json', - 'text/html' - ], - 'parses': [ - 'application/json', - 'application/x-www-form-urlencoded', - 'multipart/form-data' - ], - 'actions': { - 'PUT': { - 'id': { - 'type': 'integer', - 'required': False, - 'read_only': True, - 'label': 'ID' - }, - 'name': { - 'type': 'string', - 'required': True, - 'read_only': False, - 'label': 'name', - 'max_length': 100 - }, - 'target': { - 'type': 'field', - 'required': True, - 'read_only': False, - 'label': 'Target', - 'help_text': 'Target' - } - } - } - } - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, expected) - class TestOverriddenGetObject(TestCase): """ Test cases for a RetrieveUpdateDestroyAPIView that does NOT use the queryset/model mechanism but instead overrides get_object() """ + def setUp(self): """ - Create 3 BasicModel intances. + Create 3 BasicModel instances. """ items = ['foo', 'bar', 'baz'] for item in items: @@ -503,7 +322,7 @@ class TestOverriddenGetObject(TestCase): """ Example detail view for override of get_object(). """ - model = BasicModel + serializer_class = BasicSerializer def get_object(self): pk = int(self.kwargs['pk']) @@ -561,11 +380,13 @@ class ClassB(models.Model): class ClassA(models.Model): name = models.CharField(max_length=255) - childs = models.ManyToManyField(ClassB, blank=True, null=True) + children = models.ManyToManyField(ClassB, blank=True, null=True) class ClassASerializer(serializers.ModelSerializer): - childs = serializers.PrimaryKeyRelatedField(many=True, source='childs') + children = serializers.PrimaryKeyRelatedField( + many=True, queryset=ClassB.objects.all() + ) class Meta: model = ClassA @@ -573,11 +394,11 @@ class ClassASerializer(serializers.ModelSerializer): class ExampleView(generics.ListCreateAPIView): serializer_class = ClassASerializer - model = ClassA + queryset = ClassA.objects.all() -class TestM2MBrowseableAPI(TestCase): - def test_m2m_in_browseable_api(self): +class TestM2MBrowsableAPI(TestCase): + def test_m2m_in_browsable_api(self): """ Test for particularly ugly regression with m2m in browsable API """ @@ -603,7 +424,7 @@ class TwoFieldModel(models.Model): class DynamicSerializerView(generics.ListCreateAPIView): - model = TwoFieldModel + queryset = TwoFieldModel.objects.all() renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer) def get_serializer_class(self): @@ -612,12 +433,14 @@ class DynamicSerializerView(generics.ListCreateAPIView): class Meta: model = TwoFieldModel fields = ('field_b',) - return DynamicSerializer - return super(DynamicSerializerView, self).get_serializer_class() + else: + class DynamicSerializer(serializers.ModelSerializer): + class Meta: + model = TwoFieldModel + return DynamicSerializer class TestFilterBackendAppliedToViews(TestCase): - def setUp(self): """ Create 3 BasicModel instances to filter on. @@ -681,42 +504,3 @@ class TestFilterBackendAppliedToViews(TestCase): response = view(request).render() self.assertContains(response, 'field_b') self.assertNotContains(response, 'field_a') - - def test_options_with_dynamic_serializer(self): - """ - Ensure that OPTIONS returns correct POST json schema: - DynamicSerializer with single field 'field_b' - """ - request = factory.options('/') - view = DynamicSerializerView.as_view() - - with self.assertNumQueries(0): - response = view(request).render() - - expected = { - 'name': 'Dynamic Serializer', - 'description': '', - 'renders': [ - 'text/html', - 'application/json' - ], - 'parses': [ - 'application/json', - 'application/x-www-form-urlencoded', - 'multipart/form-data' - ], - 'actions': { - 'POST': { - 'field_b': { - 'type': 'string', - 'required': True, - 'read_only': False, - 'label': 'field b', - 'max_length': 100 - } - } - } - } - - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, expected) diff --git a/tests/test_hyperlinkedserializers.py b/tests/test_hyperlinkedserializers.py deleted file mode 100644 index d4548539..00000000 --- a/tests/test_hyperlinkedserializers.py +++ /dev/null @@ -1,380 +0,0 @@ -from __future__ import unicode_literals -import json -from django.test import TestCase -from rest_framework import generics, status, serializers -from django.conf.urls import patterns, url -from rest_framework.settings import api_settings -from rest_framework.test import APIRequestFactory -from tests.models import ( - Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, - Album, Photo, OptionalRelationModel -) - -factory = APIRequestFactory() - - -class BlogPostCommentSerializer(serializers.ModelSerializer): - url = serializers.HyperlinkedIdentityField(view_name='blogpostcomment-detail') - text = serializers.CharField() - blog_post_url = serializers.HyperlinkedRelatedField(source='blog_post', view_name='blogpost-detail') - - class Meta: - model = BlogPostComment - fields = ('text', 'blog_post_url', 'url') - - -class PhotoSerializer(serializers.Serializer): - description = serializers.CharField() - album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', queryset=Album.objects.all(), lookup_field='title') - - def restore_object(self, attrs, instance=None): - return Photo(**attrs) - - -class AlbumSerializer(serializers.ModelSerializer): - url = serializers.HyperlinkedIdentityField(view_name='album-detail', lookup_field='title') - - class Meta: - model = Album - fields = ('title', 'url') - - -class BasicList(generics.ListCreateAPIView): - model = BasicModel - model_serializer_class = serializers.HyperlinkedModelSerializer - - -class BasicDetail(generics.RetrieveUpdateDestroyAPIView): - model = BasicModel - model_serializer_class = serializers.HyperlinkedModelSerializer - - -class AnchorDetail(generics.RetrieveAPIView): - model = Anchor - model_serializer_class = serializers.HyperlinkedModelSerializer - - -class ManyToManyList(generics.ListAPIView): - model = ManyToManyModel - model_serializer_class = serializers.HyperlinkedModelSerializer - - -class ManyToManyDetail(generics.RetrieveAPIView): - model = ManyToManyModel - model_serializer_class = serializers.HyperlinkedModelSerializer - - -class BlogPostCommentListCreate(generics.ListCreateAPIView): - model = BlogPostComment - serializer_class = BlogPostCommentSerializer - - -class BlogPostCommentDetail(generics.RetrieveAPIView): - model = BlogPostComment - serializer_class = BlogPostCommentSerializer - - -class BlogPostDetail(generics.RetrieveAPIView): - model = BlogPost - - -class PhotoListCreate(generics.ListCreateAPIView): - model = Photo - model_serializer_class = PhotoSerializer - - -class AlbumDetail(generics.RetrieveAPIView): - model = Album - serializer_class = AlbumSerializer - lookup_field = 'title' - - -class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView): - model = OptionalRelationModel - model_serializer_class = serializers.HyperlinkedModelSerializer - - -urlpatterns = patterns( - '', - url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'), - url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'), - url(r'^anchor/(?P<pk>\d+)/$', AnchorDetail.as_view(), name='anchor-detail'), - url(r'^manytomany/$', ManyToManyList.as_view(), name='manytomanymodel-list'), - url(r'^manytomany/(?P<pk>\d+)/$', ManyToManyDetail.as_view(), name='manytomanymodel-detail'), - url(r'^posts/(?P<pk>\d+)/$', BlogPostDetail.as_view(), name='blogpost-detail'), - url(r'^comments/$', BlogPostCommentListCreate.as_view(), name='blogpostcomment-list'), - url(r'^comments/(?P<pk>\d+)/$', BlogPostCommentDetail.as_view(), name='blogpostcomment-detail'), - url(r'^albums/(?P<title>\w[\w-]*)/$', AlbumDetail.as_view(), name='album-detail'), - url(r'^photos/$', PhotoListCreate.as_view(), name='photo-list'), - url(r'^optionalrelation/(?P<pk>\d+)/$', OptionalRelationDetail.as_view(), name='optionalrelationmodel-detail'), -) - - -class TestBasicHyperlinkedView(TestCase): - urls = 'tests.test_hyperlinkedserializers' - - def setUp(self): - """ - Create 3 BasicModel instances. - """ - items = ['foo', 'bar', 'baz'] - for item in items: - BasicModel(text=item).save() - self.objects = BasicModel.objects - self.data = [ - {'url': 'http://testserver/basic/%d/' % obj.id, 'text': obj.text} - for obj in self.objects.all() - ] - self.list_view = BasicList.as_view() - self.detail_view = BasicDetail.as_view() - - def test_get_list_view(self): - """ - GET requests to ListCreateAPIView should return list of objects. - """ - request = factory.get('/basic/') - response = self.list_view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, self.data) - - def test_get_detail_view(self): - """ - GET requests to ListCreateAPIView should return list of objects. - """ - request = factory.get('/basic/1') - response = self.detail_view(request, pk=1).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, self.data[0]) - - -class TestManyToManyHyperlinkedView(TestCase): - urls = 'tests.test_hyperlinkedserializers' - - def setUp(self): - """ - Create 3 BasicModel instances. - """ - items = ['foo', 'bar', 'baz'] - anchors = [] - for item in items: - anchor = Anchor(text=item) - anchor.save() - anchors.append(anchor) - - manytomany = ManyToManyModel() - manytomany.save() - manytomany.rel.add(*anchors) - - self.data = [{ - 'url': 'http://testserver/manytomany/1/', - 'rel': [ - 'http://testserver/anchor/1/', - 'http://testserver/anchor/2/', - 'http://testserver/anchor/3/', - ] - }] - self.list_view = ManyToManyList.as_view() - self.detail_view = ManyToManyDetail.as_view() - - def test_get_list_view(self): - """ - GET requests to ListCreateAPIView should return list of objects. - """ - request = factory.get('/manytomany/') - response = self.list_view(request) - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, self.data) - - def test_get_detail_view(self): - """ - GET requests to ListCreateAPIView should return list of objects. - """ - request = factory.get('/manytomany/1/') - response = self.detail_view(request, pk=1) - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, self.data[0]) - - -class TestHyperlinkedIdentityFieldLookup(TestCase): - urls = 'tests.test_hyperlinkedserializers' - - def setUp(self): - """ - Create 3 Album instances. - """ - titles = ['foo', 'bar', 'baz'] - for title in titles: - album = Album(title=title) - album.save() - self.detail_view = AlbumDetail.as_view() - self.data = { - 'foo': {'title': 'foo', 'url': 'http://testserver/albums/foo/'}, - 'bar': {'title': 'bar', 'url': 'http://testserver/albums/bar/'}, - 'baz': {'title': 'baz', 'url': 'http://testserver/albums/baz/'} - } - - def test_lookup_field(self): - """ - GET requests to AlbumDetail view should return serialized Albums - with a url field keyed by `title`. - """ - for album in Album.objects.all(): - request = factory.get('/albums/{0}/'.format(album.title)) - response = self.detail_view(request, title=album.title) - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, self.data[album.title]) - - -class TestCreateWithForeignKeys(TestCase): - urls = 'tests.test_hyperlinkedserializers' - - def setUp(self): - """ - Create a blog post - """ - self.post = BlogPost.objects.create(title="Test post") - self.create_view = BlogPostCommentListCreate.as_view() - - def test_create_comment(self): - - data = { - 'text': 'A test comment', - 'blog_post_url': 'http://testserver/posts/1/' - } - - request = factory.post('/comments/', data=data) - response = self.create_view(request) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - self.assertEqual(response['Location'], 'http://testserver/comments/1/') - self.assertEqual(self.post.blogpostcomment_set.count(), 1) - self.assertEqual(self.post.blogpostcomment_set.all()[0].text, 'A test comment') - - -class TestCreateWithForeignKeysAndCustomSlug(TestCase): - urls = 'tests.test_hyperlinkedserializers' - - def setUp(self): - """ - Create an Album - """ - self.post = Album.objects.create(title='test-album') - self.list_create_view = PhotoListCreate.as_view() - - def test_create_photo(self): - - data = { - 'description': 'A test photo', - 'album_url': 'http://testserver/albums/test-album/' - } - - request = factory.post('/photos/', data=data) - response = self.list_create_view(request) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) - self.assertNotIn('Location', response, msg='Location should only be included if there is a "url" field on the serializer') - self.assertEqual(self.post.photo_set.count(), 1) - self.assertEqual(self.post.photo_set.all()[0].description, 'A test photo') - - -class TestOptionalRelationHyperlinkedView(TestCase): - urls = 'tests.test_hyperlinkedserializers' - - def setUp(self): - """ - Create 1 OptionalRelationModel instances. - """ - OptionalRelationModel().save() - self.objects = OptionalRelationModel.objects - self.detail_view = OptionalRelationDetail.as_view() - self.data = {"url": "http://testserver/optionalrelation/1/", "other": None} - - def test_get_detail_view(self): - """ - GET requests to RetrieveAPIView with optional relations should return None - for non existing relations. - """ - request = factory.get('/optionalrelationmodel-detail/1') - response = self.detail_view(request, pk=1) - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data, self.data) - - def test_put_detail_view(self): - """ - PUT requests to RetrieveUpdateDestroyAPIView with optional relations - should accept None for non existing relations. - """ - response = self.client.put('/optionalrelation/1/', - data=json.dumps(self.data), - content_type='application/json') - self.assertEqual(response.status_code, status.HTTP_200_OK) - - -class TestOverriddenURLField(TestCase): - def setUp(self): - class OverriddenURLSerializer(serializers.HyperlinkedModelSerializer): - url = serializers.SerializerMethodField('get_url') - - class Meta: - model = BlogPost - fields = ('title', 'url') - - def get_url(self, obj): - return 'foo bar' - - self.Serializer = OverriddenURLSerializer - self.obj = BlogPost.objects.create(title='New blog post') - - def test_overridden_url_field(self): - """ - The 'url' field should respect overriding. - Regression test for #936. - """ - serializer = self.Serializer(self.obj) - self.assertEqual( - serializer.data, - {'title': 'New blog post', 'url': 'foo bar'} - ) - - -class TestURLFieldNameBySettings(TestCase): - urls = 'tests.test_hyperlinkedserializers' - - def setUp(self): - self.saved_url_field_name = api_settings.URL_FIELD_NAME - api_settings.URL_FIELD_NAME = 'global_url_field' - - class Serializer(serializers.HyperlinkedModelSerializer): - - class Meta: - model = BlogPost - fields = ('title', api_settings.URL_FIELD_NAME) - - self.Serializer = Serializer - self.obj = BlogPost.objects.create(title="New blog post") - - def tearDown(self): - api_settings.URL_FIELD_NAME = self.saved_url_field_name - - def test_overridden_url_field_name(self): - request = factory.get('/posts/') - serializer = self.Serializer(self.obj, context={'request': request}) - self.assertIn(api_settings.URL_FIELD_NAME, serializer.data) - - -class TestURLFieldNameByOptions(TestCase): - urls = 'tests.test_hyperlinkedserializers' - - def setUp(self): - class Serializer(serializers.HyperlinkedModelSerializer): - - class Meta: - model = BlogPost - fields = ('title', 'serializer_url_field') - url_field_name = 'serializer_url_field' - - self.Serializer = Serializer - self.obj = BlogPost.objects.create(title="New blog post") - - def test_overridden_url_field_name(self): - request = factory.get('/posts/') - serializer = self.Serializer(self.obj, context={'request': request}) - self.assertIn(self.Serializer.Meta.url_field_name, serializer.data) diff --git a/tests/test_metadata.py b/tests/test_metadata.py new file mode 100644 index 00000000..5ff59c72 --- /dev/null +++ b/tests/test_metadata.py @@ -0,0 +1,166 @@ +from __future__ import unicode_literals + +from rest_framework import exceptions, serializers, views +from rest_framework.request import Request +from rest_framework.test import APIRequestFactory +import pytest + +request = Request(APIRequestFactory().options('/')) + + +class TestMetadata: + def test_metadata(self): + """ + OPTIONS requests to views should return a valid 200 response. + """ + class ExampleView(views.APIView): + """Example view.""" + pass + + response = ExampleView().options(request=request) + expected = { + 'name': 'Example', + 'description': 'Example view.', + 'renders': [ + 'application/json', + 'text/html' + ], + 'parses': [ + 'application/json', + 'application/x-www-form-urlencoded', + 'multipart/form-data' + ] + } + assert response.status_code == 200 + assert response.data == expected + + def test_none_metadata(self): + """ + OPTIONS requests to views where `metadata_class = None` should raise + a MethodNotAllowed exception, which will result in an HTTP 405 response. + """ + class ExampleView(views.APIView): + metadata_class = None + + with pytest.raises(exceptions.MethodNotAllowed): + ExampleView().options(request=request) + + def test_actions(self): + """ + On generic views OPTIONS should return an 'actions' key with metadata + on the fields that may be supplied to PUT and POST requests. + """ + class ExampleSerializer(serializers.Serializer): + choice_field = serializers.ChoiceField(['red', 'green', 'blue']) + integer_field = serializers.IntegerField(max_value=10) + char_field = serializers.CharField(required=False) + + class ExampleView(views.APIView): + """Example view.""" + def post(self, request): + pass + + def get_serializer(self): + return ExampleSerializer() + + response = ExampleView().options(request=request) + expected = { + 'name': 'Example', + 'description': 'Example view.', + 'renders': [ + 'application/json', + 'text/html' + ], + 'parses': [ + 'application/json', + 'application/x-www-form-urlencoded', + 'multipart/form-data' + ], + 'actions': { + 'POST': { + 'choice_field': { + 'type': 'choice', + 'required': True, + 'read_only': False, + 'label': 'Choice field', + 'choices': [ + {'display_name': 'red', 'value': 'red'}, + {'display_name': 'green', 'value': 'green'}, + {'display_name': 'blue', 'value': 'blue'} + ] + }, + 'integer_field': { + 'type': 'integer', + 'required': True, + 'read_only': False, + 'label': 'Integer field' + }, + 'char_field': { + 'type': 'string', + 'required': False, + 'read_only': False, + 'label': 'Char field' + } + } + } + } + assert response.status_code == 200 + assert response.data == expected + + def test_global_permissions(self): + """ + If a user does not have global permissions on an action, then any + metadata associated with it should not be included in OPTION responses. + """ + class ExampleSerializer(serializers.Serializer): + choice_field = serializers.ChoiceField(['red', 'green', 'blue']) + integer_field = serializers.IntegerField(max_value=10) + char_field = serializers.CharField(required=False) + + class ExampleView(views.APIView): + """Example view.""" + def post(self, request): + pass + + def put(self, request): + pass + + def get_serializer(self): + return ExampleSerializer() + + def check_permissions(self, request): + if request.method == 'POST': + raise exceptions.PermissionDenied() + + response = ExampleView().options(request=request) + assert response.status_code == 200 + assert list(response.data['actions'].keys()) == ['PUT'] + + def test_object_permissions(self): + """ + If a user does not have object permissions on an action, then any + metadata associated with it should not be included in OPTION responses. + """ + class ExampleSerializer(serializers.Serializer): + choice_field = serializers.ChoiceField(['red', 'green', 'blue']) + integer_field = serializers.IntegerField(max_value=10) + char_field = serializers.CharField(required=False) + + class ExampleView(views.APIView): + """Example view.""" + def post(self, request): + pass + + def put(self, request): + pass + + def get_serializer(self): + return ExampleSerializer() + + def get_object(self): + if self.request.method == 'PUT': + raise exceptions.PermissionDenied() + + response = ExampleView().options(request=request) + assert response.status_code == 200 + assert list(response.data['actions'].keys()) == ['POST'] diff --git a/tests/test_middleware.py b/tests/test_middleware.py new file mode 100644 index 00000000..4c099fca --- /dev/null +++ b/tests/test_middleware.py @@ -0,0 +1,37 @@ + +from django.conf.urls import patterns, url +from django.contrib.auth.models import User +from rest_framework.authentication import TokenAuthentication +from rest_framework.authtoken.models import Token +from rest_framework.test import APITestCase +from rest_framework.views import APIView + + +urlpatterns = patterns( + '', + url(r'^$', APIView.as_view(authentication_classes=(TokenAuthentication,))), +) + + +class MyMiddleware(object): + + def process_response(self, request, response): + assert hasattr(request, 'user'), '`user` is not set on request' + assert request.user.is_authenticated(), '`user` is not authenticated' + return response + + +class TestMiddleware(APITestCase): + + urls = 'tests.test_middleware' + + def test_middleware_can_access_user_when_processing_response(self): + user = User.objects.create_user('john', 'john@example.com', 'password') + key = 'abcd1234' + Token.objects.create(key=key, user=user) + + with self.settings( + MIDDLEWARE_CLASSES=('tests.test_middleware.MyMiddleware',) + ): + auth = 'Token ' + key + self.client.get('/', HTTP_AUTHORIZATION=auth) diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py new file mode 100644 index 00000000..da79164a --- /dev/null +++ b/tests/test_model_serializer.py @@ -0,0 +1,611 @@ +""" +The `ModelSerializer` and `HyperlinkedModelSerializer` classes are essentially +shortcuts for automatically creating serializers based on a given model class. + +These tests deal with ensuring that we correctly map the model fields onto +an appropriate set of serializer fields for each case. +""" +from django.core.exceptions import ImproperlyConfigured +from django.core.validators import MaxValueValidator, MinValueValidator, MinLengthValidator +from django.db import models +from django.test import TestCase +from rest_framework import serializers + + +def dedent(blocktext): + return '\n'.join([line[12:] for line in blocktext.splitlines()[1:-1]]) + + +# Tests for regular field mappings. +# --------------------------------- + +class CustomField(models.Field): + """ + A custom model field simply for testing purposes. + """ + pass + + +class OneFieldModel(models.Model): + char_field = models.CharField(max_length=100) + + +class RegularFieldsModel(models.Model): + """ + A model class for testing regular flat fields. + """ + auto_field = models.AutoField(primary_key=True) + big_integer_field = models.BigIntegerField() + boolean_field = models.BooleanField(default=False) + char_field = models.CharField(max_length=100) + comma_separated_integer_field = models.CommaSeparatedIntegerField(max_length=100) + date_field = models.DateField() + datetime_field = models.DateTimeField() + decimal_field = models.DecimalField(max_digits=3, decimal_places=1) + email_field = models.EmailField(max_length=100) + float_field = models.FloatField() + integer_field = models.IntegerField() + null_boolean_field = models.NullBooleanField() + positive_integer_field = models.PositiveIntegerField() + positive_small_integer_field = models.PositiveSmallIntegerField() + slug_field = models.SlugField(max_length=100) + small_integer_field = models.SmallIntegerField() + text_field = models.TextField() + time_field = models.TimeField() + url_field = models.URLField(max_length=100) + custom_field = CustomField() + + def method(self): + return 'method' + + +COLOR_CHOICES = (('red', 'Red'), ('blue', 'Blue'), ('green', 'Green')) + + +class FieldOptionsModel(models.Model): + value_limit_field = models.IntegerField(validators=[MinValueValidator(1), MaxValueValidator(10)]) + length_limit_field = models.CharField(validators=[MinLengthValidator(3)], max_length=12) + blank_field = models.CharField(blank=True, max_length=10) + null_field = models.IntegerField(null=True) + default_field = models.IntegerField(default=0) + descriptive_field = models.IntegerField(help_text='Some help text', verbose_name='A label') + choices_field = models.CharField(max_length=100, choices=COLOR_CHOICES) + + +class TestModelSerializer(TestCase): + def test_create_method(self): + class TestSerializer(serializers.ModelSerializer): + non_model_field = serializers.CharField() + + class Meta: + model = OneFieldModel + fields = ('char_field', 'non_model_field') + + serializer = TestSerializer(data={ + 'char_field': 'foo', + 'non_model_field': 'bar', + }) + serializer.is_valid() + with self.assertRaises(TypeError) as excinfo: + serializer.save() + msginitial = 'Got a `TypeError` when calling `OneFieldModel.objects.create()`.' + assert str(excinfo.exception).startswith(msginitial) + + +class TestRegularFieldMappings(TestCase): + def test_regular_fields(self): + """ + Model fields should map to their equivelent serializer fields. + """ + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = RegularFieldsModel + + expected = dedent(""" + TestSerializer(): + auto_field = IntegerField(read_only=True) + big_integer_field = IntegerField() + boolean_field = BooleanField(required=False) + char_field = CharField(max_length=100) + comma_separated_integer_field = CharField(max_length=100, validators=[<django.core.validators.RegexValidator object>]) + date_field = DateField() + datetime_field = DateTimeField() + decimal_field = DecimalField(decimal_places=1, max_digits=3) + email_field = EmailField(max_length=100) + float_field = FloatField() + integer_field = IntegerField() + null_boolean_field = NullBooleanField(required=False) + positive_integer_field = IntegerField() + positive_small_integer_field = IntegerField() + slug_field = SlugField(max_length=100) + small_integer_field = IntegerField() + text_field = CharField(style={'type': 'textarea'}) + time_field = TimeField() + url_field = URLField(max_length=100) + custom_field = ModelField(model_field=<tests.test_model_serializer.CustomField: custom_field>) + """) + self.assertEqual(repr(TestSerializer()), expected) + + def test_field_options(self): + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = FieldOptionsModel + + expected = dedent(""" + TestSerializer(): + id = IntegerField(label='ID', read_only=True) + value_limit_field = IntegerField(max_value=10, min_value=1) + length_limit_field = CharField(max_length=12, min_length=3) + blank_field = CharField(allow_blank=True, max_length=10, required=False) + null_field = IntegerField(allow_null=True, required=False) + default_field = IntegerField(required=False) + descriptive_field = IntegerField(help_text='Some help text', label='A label') + choices_field = ChoiceField(choices=[('red', 'Red'), ('blue', 'Blue'), ('green', 'Green')]) + """) + self.assertEqual(repr(TestSerializer()), expected) + + def test_method_field(self): + """ + Properties and methods on the model should be allowed as `Meta.fields` + values, and should map to `ReadOnlyField`. + """ + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = RegularFieldsModel + fields = ('auto_field', 'method') + + expected = dedent(""" + TestSerializer(): + auto_field = IntegerField(read_only=True) + method = ReadOnlyField() + """) + self.assertEqual(repr(TestSerializer()), expected) + + def test_pk_fields(self): + """ + Both `pk` and the actual primary key name are valid in `Meta.fields`. + """ + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = RegularFieldsModel + fields = ('pk', 'auto_field') + + expected = dedent(""" + TestSerializer(): + pk = IntegerField(label='Auto field', read_only=True) + auto_field = IntegerField(read_only=True) + """) + self.assertEqual(repr(TestSerializer()), expected) + + def test_extra_field_kwargs(self): + """ + Ensure `extra_kwargs` are passed to generated fields. + """ + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = RegularFieldsModel + fields = ('auto_field', 'char_field') + extra_kwargs = {'char_field': {'default': 'extra'}} + + expected = dedent(""" + TestSerializer(): + auto_field = IntegerField(read_only=True) + char_field = CharField(default='extra', max_length=100) + """) + self.assertEqual(repr(TestSerializer()), expected) + + def test_invalid_field(self): + """ + Field names that do not map to a model field or relationship should + raise a configuration errror. + """ + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = RegularFieldsModel + fields = ('auto_field', 'invalid') + + with self.assertRaises(ImproperlyConfigured) as excinfo: + TestSerializer().fields + expected = 'Field name `invalid` is not valid for model `ModelBase`.' + assert str(excinfo.exception) == expected + + def test_missing_field(self): + """ + Fields that have been declared on the serializer class must be included + in the `Meta.fields` if it exists. + """ + class TestSerializer(serializers.ModelSerializer): + missing = serializers.ReadOnlyField() + + class Meta: + model = RegularFieldsModel + fields = ('auto_field',) + + with self.assertRaises(ImproperlyConfigured) as excinfo: + TestSerializer().fields + expected = ( + 'Field `missing` has been declared on serializer ' + '`TestSerializer`, but is missing from `Meta.fields`.' + ) + assert str(excinfo.exception) == expected + + +# Tests for relational field mappings. +# ------------------------------------ + +class ForeignKeyTargetModel(models.Model): + name = models.CharField(max_length=100) + + +class ManyToManyTargetModel(models.Model): + name = models.CharField(max_length=100) + + +class OneToOneTargetModel(models.Model): + name = models.CharField(max_length=100) + + +class ThroughTargetModel(models.Model): + name = models.CharField(max_length=100) + + +class Supplementary(models.Model): + extra = models.IntegerField() + forwards = models.ForeignKey('ThroughTargetModel') + backwards = models.ForeignKey('RelationalModel') + + +class RelationalModel(models.Model): + foreign_key = models.ForeignKey(ForeignKeyTargetModel, related_name='reverse_foreign_key') + many_to_many = models.ManyToManyField(ManyToManyTargetModel, related_name='reverse_many_to_many') + one_to_one = models.OneToOneField(OneToOneTargetModel, related_name='reverse_one_to_one') + through = models.ManyToManyField(ThroughTargetModel, through=Supplementary, related_name='reverse_through') + + +class TestRelationalFieldMappings(TestCase): + def test_pk_relations(self): + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = RelationalModel + + expected = dedent(""" + TestSerializer(): + id = IntegerField(label='ID', read_only=True) + foreign_key = PrimaryKeyRelatedField(queryset=ForeignKeyTargetModel.objects.all()) + one_to_one = PrimaryKeyRelatedField(queryset=OneToOneTargetModel.objects.all(), validators=[<UniqueValidator(queryset=RelationalModel.objects.all())>]) + many_to_many = PrimaryKeyRelatedField(many=True, queryset=ManyToManyTargetModel.objects.all()) + through = PrimaryKeyRelatedField(many=True, read_only=True) + """) + self.assertEqual(repr(TestSerializer()), expected) + + def test_nested_relations(self): + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = RelationalModel + depth = 1 + + expected = dedent(""" + TestSerializer(): + id = IntegerField(label='ID', read_only=True) + foreign_key = NestedSerializer(read_only=True): + id = IntegerField(label='ID', read_only=True) + name = CharField(max_length=100) + one_to_one = NestedSerializer(read_only=True): + id = IntegerField(label='ID', read_only=True) + name = CharField(max_length=100) + many_to_many = NestedSerializer(many=True, read_only=True): + id = IntegerField(label='ID', read_only=True) + name = CharField(max_length=100) + through = NestedSerializer(many=True, read_only=True): + id = IntegerField(label='ID', read_only=True) + name = CharField(max_length=100) + """) + self.assertEqual(repr(TestSerializer()), expected) + + def test_hyperlinked_relations(self): + class TestSerializer(serializers.HyperlinkedModelSerializer): + class Meta: + model = RelationalModel + + expected = dedent(""" + TestSerializer(): + url = HyperlinkedIdentityField(view_name='relationalmodel-detail') + foreign_key = HyperlinkedRelatedField(queryset=ForeignKeyTargetModel.objects.all(), view_name='foreignkeytargetmodel-detail') + one_to_one = HyperlinkedRelatedField(queryset=OneToOneTargetModel.objects.all(), validators=[<UniqueValidator(queryset=RelationalModel.objects.all())>], view_name='onetoonetargetmodel-detail') + many_to_many = HyperlinkedRelatedField(many=True, queryset=ManyToManyTargetModel.objects.all(), view_name='manytomanytargetmodel-detail') + through = HyperlinkedRelatedField(many=True, read_only=True, view_name='throughtargetmodel-detail') + """) + self.assertEqual(repr(TestSerializer()), expected) + + def test_nested_hyperlinked_relations(self): + class TestSerializer(serializers.HyperlinkedModelSerializer): + class Meta: + model = RelationalModel + depth = 1 + + expected = dedent(""" + TestSerializer(): + url = HyperlinkedIdentityField(view_name='relationalmodel-detail') + foreign_key = NestedSerializer(read_only=True): + url = HyperlinkedIdentityField(view_name='foreignkeytargetmodel-detail') + name = CharField(max_length=100) + one_to_one = NestedSerializer(read_only=True): + url = HyperlinkedIdentityField(view_name='onetoonetargetmodel-detail') + name = CharField(max_length=100) + many_to_many = NestedSerializer(many=True, read_only=True): + url = HyperlinkedIdentityField(view_name='manytomanytargetmodel-detail') + name = CharField(max_length=100) + through = NestedSerializer(many=True, read_only=True): + url = HyperlinkedIdentityField(view_name='throughtargetmodel-detail') + name = CharField(max_length=100) + """) + self.assertEqual(repr(TestSerializer()), expected) + + def test_pk_reverse_foreign_key(self): + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = ForeignKeyTargetModel + fields = ('id', 'name', 'reverse_foreign_key') + + expected = dedent(""" + TestSerializer(): + id = IntegerField(label='ID', read_only=True) + name = CharField(max_length=100) + reverse_foreign_key = PrimaryKeyRelatedField(many=True, queryset=RelationalModel.objects.all()) + """) + self.assertEqual(repr(TestSerializer()), expected) + + def test_pk_reverse_one_to_one(self): + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = OneToOneTargetModel + fields = ('id', 'name', 'reverse_one_to_one') + + expected = dedent(""" + TestSerializer(): + id = IntegerField(label='ID', read_only=True) + name = CharField(max_length=100) + reverse_one_to_one = PrimaryKeyRelatedField(queryset=RelationalModel.objects.all()) + """) + self.assertEqual(repr(TestSerializer()), expected) + + def test_pk_reverse_many_to_many(self): + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = ManyToManyTargetModel + fields = ('id', 'name', 'reverse_many_to_many') + + expected = dedent(""" + TestSerializer(): + id = IntegerField(label='ID', read_only=True) + name = CharField(max_length=100) + reverse_many_to_many = PrimaryKeyRelatedField(many=True, queryset=RelationalModel.objects.all()) + """) + self.assertEqual(repr(TestSerializer()), expected) + + def test_pk_reverse_through(self): + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = ThroughTargetModel + fields = ('id', 'name', 'reverse_through') + + expected = dedent(""" + TestSerializer(): + id = IntegerField(label='ID', read_only=True) + name = CharField(max_length=100) + reverse_through = PrimaryKeyRelatedField(many=True, read_only=True) + """) + self.assertEqual(repr(TestSerializer()), expected) + + +class TestIntegration(TestCase): + def setUp(self): + self.foreign_key_target = ForeignKeyTargetModel.objects.create( + name='foreign_key' + ) + self.one_to_one_target = OneToOneTargetModel.objects.create( + name='one_to_one' + ) + self.many_to_many_targets = [ + ManyToManyTargetModel.objects.create( + name='many_to_many (%d)' % idx + ) for idx in range(3) + ] + self.instance = RelationalModel.objects.create( + foreign_key=self.foreign_key_target, + one_to_one=self.one_to_one_target, + ) + self.instance.many_to_many = self.many_to_many_targets + self.instance.save() + + def test_pk_retrival(self): + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = RelationalModel + + serializer = TestSerializer(self.instance) + expected = { + 'id': self.instance.pk, + 'foreign_key': self.foreign_key_target.pk, + 'one_to_one': self.one_to_one_target.pk, + 'many_to_many': [item.pk for item in self.many_to_many_targets], + 'through': [] + } + self.assertEqual(serializer.data, expected) + + def test_pk_create(self): + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = RelationalModel + + new_foreign_key = ForeignKeyTargetModel.objects.create( + name='foreign_key' + ) + new_one_to_one = OneToOneTargetModel.objects.create( + name='one_to_one' + ) + new_many_to_many = [ + ManyToManyTargetModel.objects.create( + name='new many_to_many (%d)' % idx + ) for idx in range(3) + ] + data = { + 'foreign_key': new_foreign_key.pk, + 'one_to_one': new_one_to_one.pk, + 'many_to_many': [item.pk for item in new_many_to_many], + } + + # Serializer should validate okay. + serializer = TestSerializer(data=data) + assert serializer.is_valid() + + # Creating the instance, relationship attributes should be set. + instance = serializer.save() + assert instance.foreign_key.pk == new_foreign_key.pk + assert instance.one_to_one.pk == new_one_to_one.pk + assert [ + item.pk for item in instance.many_to_many.all() + ] == [ + item.pk for item in new_many_to_many + ] + assert list(instance.through.all()) == [] + + # Representation should be correct. + expected = { + 'id': instance.pk, + 'foreign_key': new_foreign_key.pk, + 'one_to_one': new_one_to_one.pk, + 'many_to_many': [item.pk for item in new_many_to_many], + 'through': [] + } + self.assertEqual(serializer.data, expected) + + def test_pk_update(self): + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = RelationalModel + + new_foreign_key = ForeignKeyTargetModel.objects.create( + name='foreign_key' + ) + new_one_to_one = OneToOneTargetModel.objects.create( + name='one_to_one' + ) + new_many_to_many = [ + ManyToManyTargetModel.objects.create( + name='new many_to_many (%d)' % idx + ) for idx in range(3) + ] + data = { + 'foreign_key': new_foreign_key.pk, + 'one_to_one': new_one_to_one.pk, + 'many_to_many': [item.pk for item in new_many_to_many], + } + + # Serializer should validate okay. + serializer = TestSerializer(self.instance, data=data) + assert serializer.is_valid() + + # Creating the instance, relationship attributes should be set. + instance = serializer.save() + assert instance.foreign_key.pk == new_foreign_key.pk + assert instance.one_to_one.pk == new_one_to_one.pk + assert [ + item.pk for item in instance.many_to_many.all() + ] == [ + item.pk for item in new_many_to_many + ] + assert list(instance.through.all()) == [] + + # Representation should be correct. + expected = { + 'id': self.instance.pk, + 'foreign_key': new_foreign_key.pk, + 'one_to_one': new_one_to_one.pk, + 'many_to_many': [item.pk for item in new_many_to_many], + 'through': [] + } + self.assertEqual(serializer.data, expected) + + +# Tests for bulk create using `ListSerializer`. + +class BulkCreateModel(models.Model): + name = models.CharField(max_length=10) + + +class TestBulkCreate(TestCase): + def test_bulk_create(self): + class BasicModelSerializer(serializers.ModelSerializer): + class Meta: + model = BulkCreateModel + fields = ('name',) + + class BulkCreateSerializer(serializers.ListSerializer): + child = BasicModelSerializer() + + data = [{'name': 'a'}, {'name': 'b'}, {'name': 'c'}] + serializer = BulkCreateSerializer(data=data) + assert serializer.is_valid() + + # Objects are returned by save(). + instances = serializer.save() + assert len(instances) == 3 + assert [item.name for item in instances] == ['a', 'b', 'c'] + + # Objects have been created in the database. + assert BulkCreateModel.objects.count() == 3 + assert list(BulkCreateModel.objects.values_list('name', flat=True)) == ['a', 'b', 'c'] + + # Serializer returns correct data. + assert serializer.data == data + + +class TestMetaClassModel(models.Model): + text = models.CharField(max_length=100) + + +class TestSerializerMetaClass(TestCase): + def test_meta_class_fields_option(self): + class ExampleSerializer(serializers.ModelSerializer): + class Meta: + model = TestMetaClassModel + fields = 'text' + + with self.assertRaises(TypeError) as result: + ExampleSerializer().fields + + exception = result.exception + assert str(exception).startswith( + "The `fields` option must be a list or tuple" + ) + + def test_meta_class_exclude_option(self): + class ExampleSerializer(serializers.ModelSerializer): + class Meta: + model = TestMetaClassModel + exclude = 'text' + + with self.assertRaises(TypeError) as result: + ExampleSerializer().fields + + exception = result.exception + assert str(exception).startswith( + "The `exclude` option must be a list or tuple" + ) + + def test_meta_class_fields_and_exclude_options(self): + class ExampleSerializer(serializers.ModelSerializer): + class Meta: + model = TestMetaClassModel + fields = ('text',) + exclude = ('text',) + + with self.assertRaises(AssertionError) as result: + ExampleSerializer().fields + + exception = result.exception + self.assertEqual( + str(exception), + "Cannot set both 'fields' and 'exclude'." + ) diff --git a/tests/test_multitable_inheritance.py b/tests/test_multitable_inheritance.py index ce1bf3ea..e1b40cc7 100644 --- a/tests/test_multitable_inheritance.py +++ b/tests/test_multitable_inheritance.py @@ -31,7 +31,7 @@ class AssociatedModelSerializer(serializers.ModelSerializer): # Tests -class IneritedModelSerializationTests(TestCase): +class InheritedModelSerializationTests(TestCase): def test_multitable_inherited_model_fields_as_expected(self): """ diff --git a/tests/test_nullable_fields.py b/tests/test_nullable_fields.py deleted file mode 100644 index 0c133fc2..00000000 --- a/tests/test_nullable_fields.py +++ /dev/null @@ -1,30 +0,0 @@ -from django.core.urlresolvers import reverse - -from django.conf.urls import patterns, url -from rest_framework.test import APITestCase -from tests.models import NullableForeignKeySource -from tests.serializers import NullableFKSourceSerializer -from tests.views import NullableFKSourceDetail - - -urlpatterns = patterns( - '', - url(r'^objects/(?P<pk>\d+)/$', NullableFKSourceDetail.as_view(), name='object-detail'), -) - - -class NullableForeignKeyTests(APITestCase): - """ - DRF should be able to handle nullable foreign keys when a test - Client POST/PUT request is made with its own serialized object. - """ - urls = 'tests.test_nullable_fields' - - def test_updating_object_with_null_fk(self): - obj = NullableForeignKeySource(name='example', target=None) - obj.save() - serialized_data = NullableFKSourceSerializer(obj).data - - response = self.client.put(reverse('object-detail', args=[obj.pk]), serialized_data) - - self.assertEqual(response.data, serialized_data) diff --git a/tests/test_pagination.py b/tests/test_pagination.py index e1c2528b..1fd9cf9c 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -4,7 +4,7 @@ from decimal import Decimal from django.core.paginator import Paginator from django.test import TestCase from django.utils import unittest -from rest_framework import generics, status, pagination, filters, serializers +from rest_framework import generics, serializers, status, pagination, filters from rest_framework.compat import django_filters from rest_framework.test import APIRequestFactory from .models import BasicModel, FilterableItem @@ -22,11 +22,22 @@ def split_arguments_from_url(url): return path, args +class BasicSerializer(serializers.ModelSerializer): + class Meta: + model = BasicModel + + +class FilterableItemSerializer(serializers.ModelSerializer): + class Meta: + model = FilterableItem + + class RootView(generics.ListCreateAPIView): """ Example description for OPTIONS. """ - model = BasicModel + queryset = BasicModel.objects.all() + serializer_class = BasicSerializer paginate_by = 10 @@ -34,14 +45,16 @@ class DefaultPageSizeKwargView(generics.ListAPIView): """ View for testing default paginate_by_param usage """ - model = BasicModel + queryset = BasicModel.objects.all() + serializer_class = BasicSerializer class PaginateByParamView(generics.ListAPIView): """ View for testing custom paginate_by_param usage """ - model = BasicModel + queryset = BasicModel.objects.all() + serializer_class = BasicSerializer paginate_by_param = 'page_size' @@ -49,7 +62,8 @@ class MaxPaginateByView(generics.ListAPIView): """ View for testing custom max_paginate_by usage """ - model = BasicModel + queryset = BasicModel.objects.all() + serializer_class = BasicSerializer paginate_by = 3 max_paginate_by = 5 paginate_by_param = 'page_size' @@ -121,7 +135,7 @@ class IntegrationTestPaginationAndFiltering(TestCase): self.objects = FilterableItem.objects self.data = [ - {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} + {'id': obj.id, 'text': obj.text, 'decimal': str(obj.decimal), 'date': obj.date.isoformat()} for obj in self.objects.all() ] @@ -140,7 +154,8 @@ class IntegrationTestPaginationAndFiltering(TestCase): fields = ['text', 'decimal', 'date'] class FilterFieldsRootView(generics.ListCreateAPIView): - model = FilterableItem + queryset = FilterableItem.objects.all() + serializer_class = FilterableItemSerializer paginate_by = 10 filter_class = DecimalFilter filter_backends = (filters.DjangoFilterBackend,) @@ -188,7 +203,8 @@ class IntegrationTestPaginationAndFiltering(TestCase): return queryset.filter(decimal__lt=Decimal(request.GET['decimal'])) class BasicFilterFieldsRootView(generics.ListCreateAPIView): - model = FilterableItem + queryset = FilterableItem.objects.all() + serializer_class = FilterableItemSerializer paginate_by = 10 filter_backends = (DecimalFilterBackend,) @@ -365,7 +381,7 @@ class TestMaxPaginateByParam(TestCase): # Tests for context in pagination serializers -class CustomField(serializers.Field): +class CustomField(serializers.ReadOnlyField): def to_native(self, value): if 'view' not in self.context: raise RuntimeError("context isn't getting passed into custom field") @@ -375,10 +391,10 @@ class CustomField(serializers.Field): class BasicModelSerializer(serializers.Serializer): text = CustomField() - def __init__(self, *args, **kwargs): - super(BasicModelSerializer, self).__init__(*args, **kwargs) + def to_native(self, value): if 'view' not in self.context: - raise RuntimeError("context isn't getting passed into serializer init") + raise RuntimeError("context isn't getting passed into serializer") + return super(BasicSerializer, self).to_native(value) class TestContextPassedToCustomField(TestCase): @@ -387,7 +403,7 @@ class TestContextPassedToCustomField(TestCase): def test_with_pagination(self): class ListView(generics.ListCreateAPIView): - model = BasicModel + queryset = BasicModel.objects.all() serializer_class = BasicModelSerializer paginate_by = 1 @@ -407,7 +423,7 @@ class LinksSerializer(serializers.Serializer): class CustomPaginationSerializer(pagination.BasePaginationSerializer): links = LinksSerializer(source='*') # Takes the page object as the source - total_results = serializers.Field(source='paginator.count') + total_results = serializers.ReadOnlyField(source='paginator.count') results_field = 'objects' diff --git a/tests/test_parsers.py b/tests/test_parsers.py index 3f2672df..d28d8bd4 100644 --- a/tests/test_parsers.py +++ b/tests/test_parsers.py @@ -1,12 +1,13 @@ # -*- coding: utf-8 -*- from __future__ import unicode_literals -from rest_framework.compat import StringIO from django import forms from django.core.files.uploadhandler import MemoryFileUploadHandler from django.test import TestCase from django.utils import unittest +from django.utils.six.moves import StringIO from rest_framework.compat import etree +from rest_framework.exceptions import ParseError from rest_framework.parsers import FormParser, FileUploadParser from rest_framework.parsers import XMLParser import datetime @@ -104,13 +105,40 @@ class TestFileUploadParser(TestCase): self.parser_context = {'request': request, 'kwargs': {}} def test_parse(self): - """ Make sure the `QueryDict` works OK """ + """ + Parse raw file upload. + """ parser = FileUploadParser() self.stream.seek(0) data_and_files = parser.parse(self.stream, None, self.parser_context) file_obj = data_and_files.files['file'] self.assertEqual(file_obj._size, 14) + def test_parse_missing_filename(self): + """ + Parse raw file upload when filename is missing. + """ + parser = FileUploadParser() + self.stream.seek(0) + self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = '' + with self.assertRaises(ParseError): + parser.parse(self.stream, None, self.parser_context) + + def test_parse_missing_filename_multiple_upload_handlers(self): + """ + Parse raw file upload with multiple handlers when filename is missing. + Regression test for #2109. + """ + parser = FileUploadParser() + self.stream.seek(0) + self.parser_context['request'].upload_handlers = ( + MemoryFileUploadHandler(), + MemoryFileUploadHandler() + ) + self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = '' + with self.assertRaises(ParseError): + parser.parse(self.stream, None, self.parser_context) + def test_get_filename(self): parser = FileUploadParser() filename = parser.get_filename(self.stream, None, self.parser_context) diff --git a/tests/test_permissions.py b/tests/test_permissions.py index 93f8020f..97bac33d 100644 --- a/tests/test_permissions.py +++ b/tests/test_permissions.py @@ -3,7 +3,7 @@ from django.contrib.auth.models import User, Permission, Group from django.db import models from django.test import TestCase from django.utils import unittest -from rest_framework import generics, status, permissions, authentication, HTTP_HEADER_ENCODING +from rest_framework import generics, serializers, status, permissions, authentication, HTTP_HEADER_ENCODING from rest_framework.compat import guardian, get_model_name from rest_framework.filters import DjangoObjectPermissionsFilter from rest_framework.test import APIRequestFactory @@ -13,14 +13,21 @@ import base64 factory = APIRequestFactory() +class BasicSerializer(serializers.ModelSerializer): + class Meta: + model = BasicModel + + class RootView(generics.ListCreateAPIView): - model = BasicModel + queryset = BasicModel.objects.all() + serializer_class = BasicSerializer authentication_classes = [authentication.BasicAuthentication] permission_classes = [permissions.DjangoModelPermissions] class InstanceView(generics.RetrieveUpdateDestroyAPIView): - model = BasicModel + queryset = BasicModel.objects.all() + serializer_class = BasicSerializer authentication_classes = [authentication.BasicAuthentication] permission_classes = [permissions.DjangoModelPermissions] @@ -88,19 +95,6 @@ class ModelPermissionsIntegrationTests(TestCase): response = instance_view(request, pk=1) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - def test_has_put_as_create_permissions(self): - # User only has update permissions - should be able to update an entity. - request = factory.put('/1', {'text': 'foobar'}, format='json', - HTTP_AUTHORIZATION=self.updateonly_credentials) - response = instance_view(request, pk='1') - self.assertEqual(response.status_code, status.HTTP_200_OK) - - # But if PUTing to a new entity, permission should be denied. - request = factory.put('/2', {'text': 'foobar'}, format='json', - HTTP_AUTHORIZATION=self.updateonly_credentials) - response = instance_view(request, pk='2') - self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) - def test_options_permitted(self): request = factory.options( '/', @@ -167,6 +161,11 @@ class BasicPermModel(models.Model): ) +class BasicPermSerializer(serializers.ModelSerializer): + class Meta: + model = BasicPermModel + + # Custom object-level permission, that includes 'view' permissions class ViewObjectPermissions(permissions.DjangoObjectPermissions): perms_map = { @@ -181,7 +180,8 @@ class ViewObjectPermissions(permissions.DjangoObjectPermissions): class ObjectPermissionInstanceView(generics.RetrieveUpdateDestroyAPIView): - model = BasicPermModel + queryset = BasicPermModel.objects.all() + serializer_class = BasicPermSerializer authentication_classes = [authentication.BasicAuthentication] permission_classes = [ViewObjectPermissions] @@ -189,7 +189,8 @@ object_permissions_view = ObjectPermissionInstanceView.as_view() class ObjectPermissionListView(generics.ListAPIView): - model = BasicPermModel + queryset = BasicPermModel.objects.all() + serializer_class = BasicPermSerializer authentication_classes = [authentication.BasicAuthentication] permission_classes = [ViewObjectPermissions] diff --git a/tests/test_relations.py b/tests/test_relations.py index 501a9208..62353dc2 100644 --- a/tests/test_relations.py +++ b/tests/test_relations.py @@ -1,150 +1,136 @@ -""" -General tests for relational fields. -""" -from __future__ import unicode_literals -from django import get_version -from django.db import models -from django.test import TestCase -from django.utils import unittest +from .utils import mock_reverse, fail_reverse, BadType, MockObject, MockQueryset +from django.core.exceptions import ImproperlyConfigured from rest_framework import serializers -from tests.models import BlogPost - - -class NullModel(models.Model): - pass - - -class FieldTests(TestCase): - def test_pk_related_field_with_empty_string(self): - """ - Regression test for #446 - - https://github.com/tomchristie/django-rest-framework/issues/446 - """ - field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all()) - self.assertRaises(serializers.ValidationError, field.from_native, '') - self.assertRaises(serializers.ValidationError, field.from_native, []) - - def test_hyperlinked_related_field_with_empty_string(self): - field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='') - self.assertRaises(serializers.ValidationError, field.from_native, '') - self.assertRaises(serializers.ValidationError, field.from_native, []) - - def test_slug_related_field_with_empty_string(self): - field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk') - self.assertRaises(serializers.ValidationError, field.from_native, '') - self.assertRaises(serializers.ValidationError, field.from_native, []) - - -class TestManyRelatedMixin(TestCase): - def test_missing_many_to_many_related_field(self): - ''' - Regression test for #632 - - https://github.com/tomchristie/django-rest-framework/pull/632 - ''' - field = serializers.RelatedField(many=True, read_only=False) - - into = {} - field.field_from_native({}, None, 'field_name', into) - self.assertEqual(into['field_name'], []) - - -# Regression tests for #694 (`source` attribute on related fields) - -class RelatedFieldSourceTests(TestCase): - def test_related_manager_source(self): - """ - Relational fields should be able to use manager-returning methods as their source. - """ - BlogPost.objects.create(title='blah') - field = serializers.RelatedField(many=True, source='get_blogposts_manager') - - class ClassWithManagerMethod(object): - def get_blogposts_manager(self): - return BlogPost.objects - - obj = ClassWithManagerMethod() - value = field.field_to_native(obj, 'field_name') - self.assertEqual(value, ['BlogPost object']) - - def test_related_queryset_source(self): - """ - Relational fields should be able to use queryset-returning methods as their source. - """ - BlogPost.objects.create(title='blah') - field = serializers.RelatedField(many=True, source='get_blogposts_queryset') - - class ClassWithQuerysetMethod(object): - def get_blogposts_queryset(self): - return BlogPost.objects.all() - - obj = ClassWithQuerysetMethod() - value = field.field_to_native(obj, 'field_name') - self.assertEqual(value, ['BlogPost object']) - - def test_dotted_source(self): - """ - Source argument should support dotted.source notation. +from rest_framework.test import APISimpleTestCase +import pytest + + +class TestStringRelatedField(APISimpleTestCase): + def setUp(self): + self.instance = MockObject(pk=1, name='foo') + self.field = serializers.StringRelatedField() + + def test_string_related_representation(self): + representation = self.field.to_representation(self.instance) + assert representation == '<MockObject name=foo, pk=1>' + + +class TestPrimaryKeyRelatedField(APISimpleTestCase): + def setUp(self): + self.queryset = MockQueryset([ + MockObject(pk=1, name='foo'), + MockObject(pk=2, name='bar'), + MockObject(pk=3, name='baz') + ]) + self.instance = self.queryset.items[2] + self.field = serializers.PrimaryKeyRelatedField(queryset=self.queryset) + + def test_pk_related_lookup_exists(self): + instance = self.field.to_internal_value(self.instance.pk) + assert instance is self.instance + + def test_pk_related_lookup_does_not_exist(self): + with pytest.raises(serializers.ValidationError) as excinfo: + self.field.to_internal_value(4) + msg = excinfo.value.detail[0] + assert msg == "Invalid pk '4' - object does not exist." + + def test_pk_related_lookup_invalid_type(self): + with pytest.raises(serializers.ValidationError) as excinfo: + self.field.to_internal_value(BadType()) + msg = excinfo.value.detail[0] + assert msg == 'Incorrect type. Expected pk value, received BadType.' + + def test_pk_representation(self): + representation = self.field.to_representation(self.instance) + assert representation == self.instance.pk + + +class TestHyperlinkedIdentityField(APISimpleTestCase): + def setUp(self): + self.instance = MockObject(pk=1, name='foo') + self.field = serializers.HyperlinkedIdentityField(view_name='example') + self.field.reverse = mock_reverse + self.field._context = {'request': True} + + def test_representation(self): + representation = self.field.to_representation(self.instance) + assert representation == 'http://example.org/example/1/' + + def test_representation_unsaved_object(self): + representation = self.field.to_representation(MockObject(pk=None)) + assert representation is None + + def test_representation_with_format(self): + self.field._context['format'] = 'xml' + representation = self.field.to_representation(self.instance) + assert representation == 'http://example.org/example/1.xml/' + + def test_improperly_configured(self): """ - BlogPost.objects.create(title='blah') - field = serializers.RelatedField(many=True, source='a.b.c') - - class ClassWithQuerysetMethod(object): - a = { - 'b': { - 'c': BlogPost.objects.all() - } - } - - obj = ClassWithQuerysetMethod() - value = field.field_to_native(obj, 'field_name') - self.assertEqual(value, ['BlogPost object']) - - # Regression for #1129 - def test_exception_for_incorrect_fk(self): + If a matching view cannot be reversed with the given instance, + the the user has misconfigured something, as the URL conf and the + hyperlinked field do not match. """ - Check that the exception message are correct if the source field - doesn't exist. - """ - from tests.models import ManyToManySource - - class Meta: - model = ManyToManySource + self.field.reverse = fail_reverse + with pytest.raises(ImproperlyConfigured): + self.field.to_representation(self.instance) - attrs = { - 'name': serializers.SlugRelatedField( - slug_field='name', source='banzai'), - 'Meta': Meta, - } - TestSerializer = type( - str('TestSerializer'), - (serializers.ModelSerializer,), - attrs - ) - serializer = TestSerializer(data={'name': 'foo'}) - with self.assertRaises(AttributeError): - serializer.fields - - -@unittest.skipIf(get_version() < '1.6.0', 'Upstream behaviour changed in v1.6') -class RelatedFieldChoicesTests(TestCase): +class TestHyperlinkedIdentityFieldWithFormat(APISimpleTestCase): """ - Tests for #1408 "Web browseable API doesn't have blank option on drop down list box" - https://github.com/tomchristie/django-rest-framework/issues/1408 - """ - def test_blank_option_is_added_to_choice_if_required_equals_false(self): - """ + Tests for a hyperlinked identity field that has a `format` set, + which enforces that alternate formats are never linked too. - """ - post = BlogPost(title="Checking blank option is added") - post.save() - - queryset = BlogPost.objects.all() - field = serializers.RelatedField(required=False, queryset=queryset) + Eg. If your API includes some endpoints that accept both `.xml` and `.json`, + but other endpoints that only accept `.json`, we allow for hyperlinked + relationships that enforce only a single suffix type. + """ - choice_count = BlogPost.objects.count() - widget_count = len(field.widget.choices) + def setUp(self): + self.instance = MockObject(pk=1, name='foo') + self.field = serializers.HyperlinkedIdentityField(view_name='example', format='json') + self.field.reverse = mock_reverse + self.field._context = {'request': True} + + def test_representation(self): + representation = self.field.to_representation(self.instance) + assert representation == 'http://example.org/example/1/' + + def test_representation_with_format(self): + self.field._context['format'] = 'xml' + representation = self.field.to_representation(self.instance) + assert representation == 'http://example.org/example/1.json/' + + +class TestSlugRelatedField(APISimpleTestCase): + def setUp(self): + self.queryset = MockQueryset([ + MockObject(pk=1, name='foo'), + MockObject(pk=2, name='bar'), + MockObject(pk=3, name='baz') + ]) + self.instance = self.queryset.items[2] + self.field = serializers.SlugRelatedField( + slug_field='name', queryset=self.queryset + ) - self.assertEqual(widget_count, choice_count + 1, 'BLANK_CHOICE_DASH option should have been added') + def test_slug_related_lookup_exists(self): + instance = self.field.to_internal_value(self.instance.name) + assert instance is self.instance + + def test_slug_related_lookup_does_not_exist(self): + with pytest.raises(serializers.ValidationError) as excinfo: + self.field.to_internal_value('doesnotexist') + msg = excinfo.value.detail[0] + assert msg == 'Object with name=doesnotexist does not exist.' + + def test_slug_related_lookup_invalid_type(self): + with pytest.raises(serializers.ValidationError) as excinfo: + self.field.to_internal_value(BadType()) + msg = excinfo.value.detail[0] + assert msg == 'Invalid value.' + + def test_representation(self): + representation = self.field.to_representation(self.instance) + assert representation == self.instance.name diff --git a/tests/test_genericrelations.py b/tests/test_relations_generic.py index 95295eaa..b600b333 100644 --- a/tests/test_genericrelations.py +++ b/tests/test_relations_generic.py @@ -3,8 +3,8 @@ from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.generic import GenericRelation, GenericForeignKey from django.db import models from django.test import TestCase +from django.utils.encoding import python_2_unicode_compatible from rest_framework import serializers -from rest_framework.compat import python_2_unicode_compatible @python_2_unicode_compatible @@ -60,11 +60,11 @@ class TestGenericRelations(TestCase): """ class BookmarkSerializer(serializers.ModelSerializer): - tags = serializers.RelatedField(many=True) + tags = serializers.StringRelatedField(many=True) class Meta: model = Bookmark - exclude = ('id',) + fields = ('tags', 'url') serializer = BookmarkSerializer(self.bookmark) expected = { @@ -73,35 +73,6 @@ class TestGenericRelations(TestCase): } self.assertEqual(serializer.data, expected) - def test_generic_nested_relation(self): - """ - Test saving a GenericRelation field via a nested serializer. - """ - - class TagSerializer(serializers.ModelSerializer): - class Meta: - model = Tag - exclude = ('content_type', 'object_id') - - class BookmarkSerializer(serializers.ModelSerializer): - tags = TagSerializer(many=True) - - class Meta: - model = Bookmark - exclude = ('id',) - - data = { - 'url': 'https://docs.djangoproject.com/', - 'tags': [ - {'tag': 'contenttypes'}, - {'tag': 'genericrelations'}, - ] - } - serializer = BookmarkSerializer(data=data) - self.assertTrue(serializer.is_valid()) - serializer.save() - self.assertEqual(serializer.object.tags.count(), 2) - def test_generic_fk(self): """ Test a relationship that spans a GenericForeignKey field. @@ -109,11 +80,11 @@ class TestGenericRelations(TestCase): """ class TagSerializer(serializers.ModelSerializer): - tagged_item = serializers.RelatedField() + tagged_item = serializers.StringRelatedField() class Meta: model = Tag - exclude = ('id', 'content_type', 'object_id') + fields = ('tag', 'tagged_item') serializer = TagSerializer(Tag.objects.all(), many=True) expected = [ @@ -131,21 +102,3 @@ class TestGenericRelations(TestCase): } ] self.assertEqual(serializer.data, expected) - - def test_restore_object_generic_fk(self): - """ - Ensure an object with a generic foreign key can be restored. - """ - - class TagSerializer(serializers.ModelSerializer): - class Meta: - model = Tag - exclude = ('content_type', 'object_id') - - serializer = TagSerializer() - - bookmark = Bookmark(url='http://example.com') - attrs = {'tagged_item': bookmark, 'tag': 'example'} - - tag = serializer.restore_object(attrs) - self.assertEqual(tag.tagged_item, bookmark) diff --git a/tests/test_relations_hyperlink.py b/tests/test_relations_hyperlink.py index 0c8eb254..f1b882ed 100644 --- a/tests/test_relations_hyperlink.py +++ b/tests/test_relations_hyperlink.py @@ -4,7 +4,6 @@ from django.test import TestCase from rest_framework import serializers from rest_framework.test import APIRequestFactory from tests.models import ( - BlogPost, ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource ) @@ -13,8 +12,7 @@ factory = APIRequestFactory() request = factory.get('/') # Just to ensure we have a request in the serializer context -def dummy_view(request, pk): - pass +dummy_view = lambda request, pk: None urlpatterns = patterns( '', @@ -91,7 +89,14 @@ class HyperlinkedManyToManyTests(TestCase): {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} ] - self.assertEqual(serializer.data, expected) + with self.assertNumQueries(4): + self.assertEqual(serializer.data, expected) + + def test_many_to_many_retrieve_prefetch_related(self): + queryset = ManyToManySource.objects.all().prefetch_related('targets') + serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) + with self.assertNumQueries(2): + serializer.data def test_reverse_many_to_many_retrieve(self): queryset = ManyToManyTarget.objects.all() @@ -101,7 +106,8 @@ class HyperlinkedManyToManyTests(TestCase): {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']} ] - self.assertEqual(serializer.data, expected) + with self.assertNumQueries(4): + self.assertEqual(serializer.data, expected) def test_many_to_many_update(self): data = {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} @@ -199,7 +205,8 @@ class HyperlinkedForeignKeyTests(TestCase): {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'} ] - self.assertEqual(serializer.data, expected) + with self.assertNumQueries(1): + self.assertEqual(serializer.data, expected) def test_reverse_foreign_key_retrieve(self): queryset = ForeignKeyTarget.objects.all() @@ -208,15 +215,16 @@ class HyperlinkedForeignKeyTests(TestCase): {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']}, {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, ] - self.assertEqual(serializer.data, expected) + with self.assertNumQueries(3): + self.assertEqual(serializer.data, expected) def test_foreign_key_update(self): data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'} instance = ForeignKeySource.objects.get(pk=1) serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) self.assertTrue(serializer.is_valid()) - self.assertEqual(serializer.data, data) serializer.save() + self.assertEqual(serializer.data, data) # Ensure source 1 is updated, and everything else is as expected queryset = ForeignKeySource.objects.all() @@ -233,7 +241,7 @@ class HyperlinkedForeignKeyTests(TestCase): instance = ForeignKeySource.objects.get(pk=1) serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) self.assertFalse(serializer.is_valid()) - self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected url string, received int.']}) + self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected URL string, received int.']}) def test_reverse_foreign_key_update(self): data = {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']} @@ -304,7 +312,7 @@ class HyperlinkedForeignKeyTests(TestCase): instance = ForeignKeySource.objects.get(pk=1) serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) self.assertFalse(serializer.is_valid()) - self.assertEqual(serializer.errors, {'target': ['This field is required.']}) + self.assertEqual(serializer.errors, {'target': ['This field may not be null.']}) class HyperlinkedNullableForeignKeyTests(TestCase): @@ -377,8 +385,8 @@ class HyperlinkedNullableForeignKeyTests(TestCase): instance = NullableForeignKeySource.objects.get(pk=1) serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request}) self.assertTrue(serializer.is_valid()) - self.assertEqual(serializer.data, data) serializer.save() + self.assertEqual(serializer.data, data) # Ensure source 1 is updated, and everything else is as expected queryset = NullableForeignKeySource.objects.all() @@ -400,8 +408,8 @@ class HyperlinkedNullableForeignKeyTests(TestCase): instance = NullableForeignKeySource.objects.get(pk=1) serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request}) self.assertTrue(serializer.is_valid()) - self.assertEqual(serializer.data, expected_data) serializer.save() + self.assertEqual(serializer.data, expected_data) # Ensure source 1 is updated, and everything else is as expected queryset = NullableForeignKeySource.objects.all() @@ -413,27 +421,6 @@ class HyperlinkedNullableForeignKeyTests(TestCase): ] self.assertEqual(serializer.data, expected) - # reverse foreign keys MUST be read_only - # In the general case they do not provide .remove() or .clear() - # and cannot be arbitrarily set. - - # def test_reverse_foreign_key_update(self): - # data = {'id': 1, 'name': 'target-1', 'sources': [1]} - # instance = ForeignKeyTarget.objects.get(pk=1) - # serializer = ForeignKeyTargetSerializer(instance, data=data) - # self.assertTrue(serializer.is_valid()) - # self.assertEqual(serializer.data, data) - # serializer.save() - - # # Ensure target 1 is updated, and everything else is as expected - # queryset = ForeignKeyTarget.objects.all() - # serializer = ForeignKeyTargetSerializer(queryset, many=True) - # expected = [ - # {'id': 1, 'name': 'target-1', 'sources': [1]}, - # {'id': 2, 'name': 'target-2', 'sources': []}, - # ] - # self.assertEqual(serializer.data, expected) - class HyperlinkedNullableOneToOneTests(TestCase): urls = 'tests.test_relations_hyperlink' @@ -454,72 +441,3 @@ class HyperlinkedNullableOneToOneTests(TestCase): {'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None}, ] self.assertEqual(serializer.data, expected) - - -# Regression tests for #694 (`source` attribute on related fields) - -class HyperlinkedRelatedFieldSourceTests(TestCase): - urls = 'tests.test_relations_hyperlink' - - def test_related_manager_source(self): - """ - Relational fields should be able to use manager-returning methods as their source. - """ - BlogPost.objects.create(title='blah') - field = serializers.HyperlinkedRelatedField( - many=True, - source='get_blogposts_manager', - view_name='dummy-url', - ) - field.context = {'request': request} - - class ClassWithManagerMethod(object): - def get_blogposts_manager(self): - return BlogPost.objects - - obj = ClassWithManagerMethod() - value = field.field_to_native(obj, 'field_name') - self.assertEqual(value, ['http://testserver/dummyurl/1/']) - - def test_related_queryset_source(self): - """ - Relational fields should be able to use queryset-returning methods as their source. - """ - BlogPost.objects.create(title='blah') - field = serializers.HyperlinkedRelatedField( - many=True, - source='get_blogposts_queryset', - view_name='dummy-url', - ) - field.context = {'request': request} - - class ClassWithQuerysetMethod(object): - def get_blogposts_queryset(self): - return BlogPost.objects.all() - - obj = ClassWithQuerysetMethod() - value = field.field_to_native(obj, 'field_name') - self.assertEqual(value, ['http://testserver/dummyurl/1/']) - - def test_dotted_source(self): - """ - Source argument should support dotted.source notation. - """ - BlogPost.objects.create(title='blah') - field = serializers.HyperlinkedRelatedField( - many=True, - source='a.b.c', - view_name='dummy-url', - ) - field.context = {'request': request} - - class ClassWithQuerysetMethod(object): - a = { - 'b': { - 'c': BlogPost.objects.all() - } - } - - obj = ClassWithQuerysetMethod() - value = field.field_to_native(obj, 'field_name') - self.assertEqual(value, ['http://testserver/dummyurl/1/']) diff --git a/tests/test_relations_nested.py b/tests/test_relations_nested.py deleted file mode 100644 index 4d9da489..00000000 --- a/tests/test_relations_nested.py +++ /dev/null @@ -1,326 +0,0 @@ -from __future__ import unicode_literals -from django.db import models -from django.test import TestCase -from rest_framework import serializers - -from .models import OneToOneTarget - - -class OneToOneSource(models.Model): - name = models.CharField(max_length=100) - target = models.OneToOneField(OneToOneTarget, related_name='source', - null=True, blank=True) - - -class OneToManyTarget(models.Model): - name = models.CharField(max_length=100) - - -class OneToManySource(models.Model): - name = models.CharField(max_length=100) - target = models.ForeignKey(OneToManyTarget, related_name='sources') - - -class ReverseNestedOneToOneTests(TestCase): - def setUp(self): - class OneToOneSourceSerializer(serializers.ModelSerializer): - class Meta: - model = OneToOneSource - fields = ('id', 'name') - - class OneToOneTargetSerializer(serializers.ModelSerializer): - source = OneToOneSourceSerializer() - - class Meta: - model = OneToOneTarget - fields = ('id', 'name', 'source') - - self.Serializer = OneToOneTargetSerializer - - for idx in range(1, 4): - target = OneToOneTarget(name='target-%d' % idx) - target.save() - source = OneToOneSource(name='source-%d' % idx, target=target) - source.save() - - def test_one_to_one_retrieve(self): - queryset = OneToOneTarget.objects.all() - serializer = self.Serializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, - {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, - {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}} - ] - self.assertEqual(serializer.data, expected) - - def test_one_to_one_create(self): - data = {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}} - serializer = self.Serializer(data=data) - self.assertTrue(serializer.is_valid()) - obj = serializer.save() - self.assertEqual(serializer.data, data) - self.assertEqual(obj.name, 'target-4') - - # Ensure (target 4, target_source 4, source 4) are added, and - # everything else is as expected. - queryset = OneToOneTarget.objects.all() - serializer = self.Serializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, - {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, - {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}}, - {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}} - ] - self.assertEqual(serializer.data, expected) - - def test_one_to_one_create_with_invalid_data(self): - data = {'id': 4, 'name': 'target-4', 'source': {'id': 4}} - serializer = self.Serializer(data=data) - self.assertFalse(serializer.is_valid()) - self.assertEqual(serializer.errors, {'source': [{'name': ['This field is required.']}]}) - - def test_one_to_one_update(self): - data = {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}} - instance = OneToOneTarget.objects.get(pk=3) - serializer = self.Serializer(instance, data=data) - self.assertTrue(serializer.is_valid()) - obj = serializer.save() - self.assertEqual(serializer.data, data) - self.assertEqual(obj.name, 'target-3-updated') - - # Ensure (target 3, target_source 3, source 3) are updated, - # and everything else is as expected. - queryset = OneToOneTarget.objects.all() - serializer = self.Serializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, - {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, - {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}} - ] - self.assertEqual(serializer.data, expected) - - -class ForwardNestedOneToOneTests(TestCase): - def setUp(self): - class OneToOneTargetSerializer(serializers.ModelSerializer): - class Meta: - model = OneToOneTarget - fields = ('id', 'name') - - class OneToOneSourceSerializer(serializers.ModelSerializer): - target = OneToOneTargetSerializer() - - class Meta: - model = OneToOneSource - fields = ('id', 'name', 'target') - - self.Serializer = OneToOneSourceSerializer - - for idx in range(1, 4): - target = OneToOneTarget(name='target-%d' % idx) - target.save() - source = OneToOneSource(name='source-%d' % idx, target=target) - source.save() - - def test_one_to_one_retrieve(self): - queryset = OneToOneSource.objects.all() - serializer = self.Serializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, - {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, - {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}} - ] - self.assertEqual(serializer.data, expected) - - def test_one_to_one_create(self): - data = {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}} - serializer = self.Serializer(data=data) - self.assertTrue(serializer.is_valid()) - obj = serializer.save() - self.assertEqual(serializer.data, data) - self.assertEqual(obj.name, 'source-4') - - # Ensure (target 4, target_source 4, source 4) are added, and - # everything else is as expected. - queryset = OneToOneSource.objects.all() - serializer = self.Serializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, - {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, - {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}}, - {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}} - ] - self.assertEqual(serializer.data, expected) - - def test_one_to_one_create_with_invalid_data(self): - data = {'id': 4, 'name': 'source-4', 'target': {'id': 4}} - serializer = self.Serializer(data=data) - self.assertFalse(serializer.is_valid()) - self.assertEqual(serializer.errors, {'target': [{'name': ['This field is required.']}]}) - - def test_one_to_one_update(self): - data = {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}} - instance = OneToOneSource.objects.get(pk=3) - serializer = self.Serializer(instance, data=data) - self.assertTrue(serializer.is_valid()) - obj = serializer.save() - self.assertEqual(serializer.data, data) - self.assertEqual(obj.name, 'source-3-updated') - - # Ensure (target 3, target_source 3, source 3) are updated, - # and everything else is as expected. - queryset = OneToOneSource.objects.all() - serializer = self.Serializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, - {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, - {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}} - ] - self.assertEqual(serializer.data, expected) - - def test_one_to_one_update_to_null(self): - data = {'id': 3, 'name': 'source-3-updated', 'target': None} - instance = OneToOneSource.objects.get(pk=3) - serializer = self.Serializer(instance, data=data) - self.assertTrue(serializer.is_valid()) - obj = serializer.save() - - self.assertEqual(serializer.data, data) - self.assertEqual(obj.name, 'source-3-updated') - self.assertEqual(obj.target, None) - - queryset = OneToOneSource.objects.all() - serializer = self.Serializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, - {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, - {'id': 3, 'name': 'source-3-updated', 'target': None} - ] - self.assertEqual(serializer.data, expected) - - # TODO: Nullable 1-1 tests - # def test_one_to_one_delete(self): - # data = {'id': 3, 'name': 'target-3', 'target_source': None} - # instance = OneToOneTarget.objects.get(pk=3) - # serializer = self.Serializer(instance, data=data) - # self.assertTrue(serializer.is_valid()) - # serializer.save() - - # # Ensure (target_source 3, source 3) are deleted, - # # and everything else is as expected. - # queryset = OneToOneTarget.objects.all() - # serializer = self.Serializer(queryset) - # expected = [ - # {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, - # {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, - # {'id': 3, 'name': 'target-3', 'source': None} - # ] - # self.assertEqual(serializer.data, expected) - - -class ReverseNestedOneToManyTests(TestCase): - def setUp(self): - class OneToManySourceSerializer(serializers.ModelSerializer): - class Meta: - model = OneToManySource - fields = ('id', 'name') - - class OneToManyTargetSerializer(serializers.ModelSerializer): - sources = OneToManySourceSerializer(many=True, allow_add_remove=True) - - class Meta: - model = OneToManyTarget - fields = ('id', 'name', 'sources') - - self.Serializer = OneToManyTargetSerializer - - target = OneToManyTarget(name='target-1') - target.save() - for idx in range(1, 4): - source = OneToManySource(name='source-%d' % idx, target=target) - source.save() - - def test_one_to_many_retrieve(self): - queryset = OneToManyTarget.objects.all() - serializer = self.Serializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, - {'id': 2, 'name': 'source-2'}, - {'id': 3, 'name': 'source-3'}]}, - ] - self.assertEqual(serializer.data, expected) - - def test_one_to_many_create(self): - data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, - {'id': 2, 'name': 'source-2'}, - {'id': 3, 'name': 'source-3'}, - {'id': 4, 'name': 'source-4'}]} - instance = OneToManyTarget.objects.get(pk=1) - serializer = self.Serializer(instance, data=data) - self.assertTrue(serializer.is_valid()) - obj = serializer.save() - self.assertEqual(serializer.data, data) - self.assertEqual(obj.name, 'target-1') - - # Ensure source 4 is added, and everything else is as - # expected. - queryset = OneToManyTarget.objects.all() - serializer = self.Serializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, - {'id': 2, 'name': 'source-2'}, - {'id': 3, 'name': 'source-3'}, - {'id': 4, 'name': 'source-4'}]} - ] - self.assertEqual(serializer.data, expected) - - def test_one_to_many_create_with_invalid_data(self): - data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, - {'id': 2, 'name': 'source-2'}, - {'id': 3, 'name': 'source-3'}, - {'id': 4}]} - serializer = self.Serializer(data=data) - self.assertFalse(serializer.is_valid()) - self.assertEqual(serializer.errors, {'sources': [{}, {}, {}, {'name': ['This field is required.']}]}) - - def test_one_to_many_update(self): - data = {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'}, - {'id': 2, 'name': 'source-2'}, - {'id': 3, 'name': 'source-3'}]} - instance = OneToManyTarget.objects.get(pk=1) - serializer = self.Serializer(instance, data=data) - self.assertTrue(serializer.is_valid()) - obj = serializer.save() - self.assertEqual(serializer.data, data) - self.assertEqual(obj.name, 'target-1-updated') - - # Ensure (target 1, source 1) are updated, - # and everything else is as expected. - queryset = OneToManyTarget.objects.all() - serializer = self.Serializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'}, - {'id': 2, 'name': 'source-2'}, - {'id': 3, 'name': 'source-3'}]} - - ] - self.assertEqual(serializer.data, expected) - - def test_one_to_many_delete(self): - data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, - {'id': 3, 'name': 'source-3'}]} - instance = OneToManyTarget.objects.get(pk=1) - serializer = self.Serializer(instance, data=data) - self.assertTrue(serializer.is_valid()) - serializer.save() - - # Ensure source 2 is deleted, and everything else is as - # expected. - queryset = OneToManyTarget.objects.all() - serializer = self.Serializer(queryset, many=True) - expected = [ - {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, - {'id': 3, 'name': 'source-3'}]} - - ] - self.assertEqual(serializer.data, expected) diff --git a/tests/test_relations_pk.py b/tests/test_relations_pk.py index e3f836ed..f872a8dc 100644 --- a/tests/test_relations_pk.py +++ b/tests/test_relations_pk.py @@ -1,10 +1,9 @@ from __future__ import unicode_literals -from django.db import models from django.test import TestCase from django.utils import six from rest_framework import serializers from tests.models import ( - BlogPost, ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, + ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource, ) @@ -69,7 +68,14 @@ class PKManyToManyTests(TestCase): {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]} ] - self.assertEqual(serializer.data, expected) + with self.assertNumQueries(4): + self.assertEqual(serializer.data, expected) + + def test_many_to_many_retrieve_prefetch_related(self): + queryset = ManyToManySource.objects.all().prefetch_related('targets') + serializer = ManyToManySourceSerializer(queryset, many=True) + with self.assertNumQueries(2): + serializer.data def test_reverse_many_to_many_retrieve(self): queryset = ManyToManyTarget.objects.all() @@ -79,7 +85,8 @@ class PKManyToManyTests(TestCase): {'id': 2, 'name': 'target-2', 'sources': [2, 3]}, {'id': 3, 'name': 'target-3', 'sources': [3]} ] - self.assertEqual(serializer.data, expected) + with self.assertNumQueries(4): + self.assertEqual(serializer.data, expected) def test_many_to_many_update(self): data = {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]} @@ -128,7 +135,6 @@ class PKManyToManyTests(TestCase): # Ensure source 4 is added, and everything else is as expected queryset = ManyToManySource.objects.all() serializer = ManyToManySourceSerializer(queryset, many=True) - self.assertFalse(serializer.fields['targets'].read_only) expected = [ {'id': 1, 'name': 'source-1', 'targets': [1]}, {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, @@ -140,7 +146,6 @@ class PKManyToManyTests(TestCase): def test_reverse_many_to_many_create(self): data = {'id': 4, 'name': 'target-4', 'sources': [1, 3]} serializer = ManyToManyTargetSerializer(data=data) - self.assertFalse(serializer.fields['sources'].read_only) self.assertTrue(serializer.is_valid()) obj = serializer.save() self.assertEqual(serializer.data, data) @@ -176,7 +181,8 @@ class PKForeignKeyTests(TestCase): {'id': 2, 'name': 'source-2', 'target': 1}, {'id': 3, 'name': 'source-3', 'target': 1} ] - self.assertEqual(serializer.data, expected) + with self.assertNumQueries(1): + self.assertEqual(serializer.data, expected) def test_reverse_foreign_key_retrieve(self): queryset = ForeignKeyTarget.objects.all() @@ -185,15 +191,22 @@ class PKForeignKeyTests(TestCase): {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]}, {'id': 2, 'name': 'target-2', 'sources': []}, ] - self.assertEqual(serializer.data, expected) + with self.assertNumQueries(3): + self.assertEqual(serializer.data, expected) + + def test_reverse_foreign_key_retrieve_prefetch_related(self): + queryset = ForeignKeyTarget.objects.all().prefetch_related('sources') + serializer = ForeignKeyTargetSerializer(queryset, many=True) + with self.assertNumQueries(2): + serializer.data def test_foreign_key_update(self): data = {'id': 1, 'name': 'source-1', 'target': 2} instance = ForeignKeySource.objects.get(pk=1) serializer = ForeignKeySourceSerializer(instance, data=data) self.assertTrue(serializer.is_valid()) - self.assertEqual(serializer.data, data) serializer.save() + self.assertEqual(serializer.data, data) # Ensure source 1 is updated, and everything else is as expected queryset = ForeignKeySource.objects.all() @@ -210,7 +223,7 @@ class PKForeignKeyTests(TestCase): instance = ForeignKeySource.objects.get(pk=1) serializer = ForeignKeySourceSerializer(instance, data=data) self.assertFalse(serializer.is_valid()) - self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected pk value, received %s.' % six.text_type.__name__]}) + self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected pk value, received %s.' % six.text_type.__name__]}) def test_reverse_foreign_key_update(self): data = {'id': 2, 'name': 'target-2', 'sources': [1, 3]} @@ -281,7 +294,7 @@ class PKForeignKeyTests(TestCase): instance = ForeignKeySource.objects.get(pk=1) serializer = ForeignKeySourceSerializer(instance, data=data) self.assertFalse(serializer.is_valid()) - self.assertEqual(serializer.errors, {'target': ['This field is required.']}) + self.assertEqual(serializer.errors, {'target': ['This field may not be null.']}) def test_foreign_key_with_empty(self): """ @@ -361,8 +374,8 @@ class PKNullableForeignKeyTests(TestCase): instance = NullableForeignKeySource.objects.get(pk=1) serializer = NullableForeignKeySourceSerializer(instance, data=data) self.assertTrue(serializer.is_valid()) - self.assertEqual(serializer.data, data) serializer.save() + self.assertEqual(serializer.data, data) # Ensure source 1 is updated, and everything else is as expected queryset = NullableForeignKeySource.objects.all() @@ -384,8 +397,8 @@ class PKNullableForeignKeyTests(TestCase): instance = NullableForeignKeySource.objects.get(pk=1) serializer = NullableForeignKeySourceSerializer(instance, data=data) self.assertTrue(serializer.is_valid()) - self.assertEqual(serializer.data, expected_data) serializer.save() + self.assertEqual(serializer.data, expected_data) # Ensure source 1 is updated, and everything else is as expected queryset = NullableForeignKeySource.objects.all() @@ -397,27 +410,6 @@ class PKNullableForeignKeyTests(TestCase): ] self.assertEqual(serializer.data, expected) - # reverse foreign keys MUST be read_only - # In the general case they do not provide .remove() or .clear() - # and cannot be arbitrarily set. - - # def test_reverse_foreign_key_update(self): - # data = {'id': 1, 'name': 'target-1', 'sources': [1]} - # instance = ForeignKeyTarget.objects.get(pk=1) - # serializer = ForeignKeyTargetSerializer(instance, data=data) - # self.assertTrue(serializer.is_valid()) - # self.assertEqual(serializer.data, data) - # serializer.save() - - # # Ensure target 1 is updated, and everything else is as expected - # queryset = ForeignKeyTarget.objects.all() - # serializer = ForeignKeyTargetSerializer(queryset, many=True) - # expected = [ - # {'id': 1, 'name': 'target-1', 'sources': [1]}, - # {'id': 2, 'name': 'target-2', 'sources': []}, - # ] - # self.assertEqual(serializer.data, expected) - class PKNullableOneToOneTests(TestCase): def setUp(self): @@ -436,116 +428,3 @@ class PKNullableOneToOneTests(TestCase): {'id': 2, 'name': 'target-2', 'nullable_source': 1}, ] self.assertEqual(serializer.data, expected) - - -# The below models and tests ensure that serializer fields corresponding -# to a ManyToManyField field with a user-specified ``through`` model are -# set to read only - - -class ManyToManyThroughTarget(models.Model): - name = models.CharField(max_length=100) - - -class ManyToManyThrough(models.Model): - source = models.ForeignKey('ManyToManyThroughSource') - target = models.ForeignKey(ManyToManyThroughTarget) - - -class ManyToManyThroughSource(models.Model): - name = models.CharField(max_length=100) - targets = models.ManyToManyField(ManyToManyThroughTarget, - related_name='sources', - through='ManyToManyThrough') - - -class ManyToManyThroughTargetSerializer(serializers.ModelSerializer): - class Meta: - model = ManyToManyThroughTarget - fields = ('id', 'name', 'sources') - - -class ManyToManyThroughSourceSerializer(serializers.ModelSerializer): - class Meta: - model = ManyToManyThroughSource - fields = ('id', 'name', 'targets') - - -class PKManyToManyThroughTests(TestCase): - def setUp(self): - self.source = ManyToManyThroughSource.objects.create( - name='through-source-1') - self.target = ManyToManyThroughTarget.objects.create( - name='through-target-1') - - def test_many_to_many_create(self): - data = {'id': 2, 'name': 'source-2', 'targets': [self.target.pk]} - serializer = ManyToManyThroughSourceSerializer(data=data) - self.assertTrue(serializer.fields['targets'].read_only) - self.assertTrue(serializer.is_valid()) - obj = serializer.save() - self.assertEqual(obj.name, 'source-2') - self.assertEqual(obj.targets.count(), 0) - - def test_many_to_many_reverse_create(self): - data = {'id': 2, 'name': 'target-2', 'sources': [self.source.pk]} - serializer = ManyToManyThroughTargetSerializer(data=data) - self.assertTrue(serializer.fields['sources'].read_only) - self.assertTrue(serializer.is_valid()) - serializer.save() - obj = serializer.save() - self.assertEqual(obj.name, 'target-2') - self.assertEqual(obj.sources.count(), 0) - - -# Regression tests for #694 (`source` attribute on related fields) - - -class PrimaryKeyRelatedFieldSourceTests(TestCase): - def test_related_manager_source(self): - """ - Relational fields should be able to use manager-returning methods as their source. - """ - BlogPost.objects.create(title='blah') - field = serializers.PrimaryKeyRelatedField(many=True, source='get_blogposts_manager') - - class ClassWithManagerMethod(object): - def get_blogposts_manager(self): - return BlogPost.objects - - obj = ClassWithManagerMethod() - value = field.field_to_native(obj, 'field_name') - self.assertEqual(value, [1]) - - def test_related_queryset_source(self): - """ - Relational fields should be able to use queryset-returning methods as their source. - """ - BlogPost.objects.create(title='blah') - field = serializers.PrimaryKeyRelatedField(many=True, source='get_blogposts_queryset') - - class ClassWithQuerysetMethod(object): - def get_blogposts_queryset(self): - return BlogPost.objects.all() - - obj = ClassWithQuerysetMethod() - value = field.field_to_native(obj, 'field_name') - self.assertEqual(value, [1]) - - def test_dotted_source(self): - """ - Source argument should support dotted.source notation. - """ - BlogPost.objects.create(title='blah') - field = serializers.PrimaryKeyRelatedField(many=True, source='a.b.c') - - class ClassWithQuerysetMethod(object): - a = { - 'b': { - 'c': BlogPost.objects.all() - } - } - - obj = ClassWithQuerysetMethod() - value = field.field_to_native(obj, 'field_name') - self.assertEqual(value, [1]) diff --git a/tests/test_relations_slug.py b/tests/test_relations_slug.py index 97ebf23a..cd2cb1ed 100644 --- a/tests/test_relations_slug.py +++ b/tests/test_relations_slug.py @@ -4,21 +4,32 @@ from tests.models import NullableForeignKeySource, ForeignKeySource, ForeignKeyT class ForeignKeyTargetSerializer(serializers.ModelSerializer): - sources = serializers.SlugRelatedField(many=True, slug_field='name') + sources = serializers.SlugRelatedField( + slug_field='name', + queryset=ForeignKeySource.objects.all(), + many=True + ) class Meta: model = ForeignKeyTarget class ForeignKeySourceSerializer(serializers.ModelSerializer): - target = serializers.SlugRelatedField(slug_field='name') + target = serializers.SlugRelatedField( + slug_field='name', + queryset=ForeignKeyTarget.objects.all() + ) class Meta: model = ForeignKeySource class NullableForeignKeySourceSerializer(serializers.ModelSerializer): - target = serializers.SlugRelatedField(slug_field='name', required=False) + target = serializers.SlugRelatedField( + slug_field='name', + queryset=ForeignKeyTarget.objects.all(), + allow_null=True + ) class Meta: model = NullableForeignKeySource @@ -43,7 +54,14 @@ class SlugForeignKeyTests(TestCase): {'id': 2, 'name': 'source-2', 'target': 'target-1'}, {'id': 3, 'name': 'source-3', 'target': 'target-1'} ] - self.assertEqual(serializer.data, expected) + with self.assertNumQueries(4): + self.assertEqual(serializer.data, expected) + + def test_foreign_key_retrieve_select_related(self): + queryset = ForeignKeySource.objects.all().select_related('target') + serializer = ForeignKeySourceSerializer(queryset, many=True) + with self.assertNumQueries(1): + serializer.data def test_reverse_foreign_key_retrieve(self): queryset = ForeignKeyTarget.objects.all() @@ -54,13 +72,19 @@ class SlugForeignKeyTests(TestCase): ] self.assertEqual(serializer.data, expected) + def test_reverse_foreign_key_retrieve_prefetch_related(self): + queryset = ForeignKeyTarget.objects.all().prefetch_related('sources') + serializer = ForeignKeyTargetSerializer(queryset, many=True) + with self.assertNumQueries(2): + serializer.data + def test_foreign_key_update(self): data = {'id': 1, 'name': 'source-1', 'target': 'target-2'} instance = ForeignKeySource.objects.get(pk=1) serializer = ForeignKeySourceSerializer(instance, data=data) self.assertTrue(serializer.is_valid()) - self.assertEqual(serializer.data, data) serializer.save() + self.assertEqual(serializer.data, data) # Ensure source 1 is updated, and everything else is as expected queryset = ForeignKeySource.objects.all() @@ -149,7 +173,7 @@ class SlugForeignKeyTests(TestCase): instance = ForeignKeySource.objects.get(pk=1) serializer = ForeignKeySourceSerializer(instance, data=data) self.assertFalse(serializer.is_valid()) - self.assertEqual(serializer.errors, {'target': ['This field is required.']}) + self.assertEqual(serializer.errors, {'target': ['This field may not be null.']}) class SlugNullableForeignKeyTests(TestCase): @@ -220,8 +244,8 @@ class SlugNullableForeignKeyTests(TestCase): instance = NullableForeignKeySource.objects.get(pk=1) serializer = NullableForeignKeySourceSerializer(instance, data=data) self.assertTrue(serializer.is_valid()) - self.assertEqual(serializer.data, data) serializer.save() + self.assertEqual(serializer.data, data) # Ensure source 1 is updated, and everything else is as expected queryset = NullableForeignKeySource.objects.all() @@ -243,8 +267,8 @@ class SlugNullableForeignKeyTests(TestCase): instance = NullableForeignKeySource.objects.get(pk=1) serializer = NullableForeignKeySourceSerializer(instance, data=data) self.assertTrue(serializer.is_valid()) - self.assertEqual(serializer.data, expected_data) serializer.save() + self.assertEqual(serializer.data, expected_data) # Ensure source 1 is updated, and everything else is as expected queryset = NullableForeignKeySource.objects.all() diff --git a/tests/test_renderers.py b/tests/test_renderers.py index 91244e26..00a24fb1 100644 --- a/tests/test_renderers.py +++ b/tests/test_renderers.py @@ -7,13 +7,15 @@ from django.core.cache import cache from django.db import models from django.test import TestCase from django.utils import six, unittest +from django.utils.six import BytesIO +from django.utils.six.moves import StringIO from django.utils.translation import ugettext_lazy as _ from rest_framework import status, permissions -from rest_framework.compat import yaml, etree, StringIO +from rest_framework.compat import yaml, etree from rest_framework.response import Response from rest_framework.views import APIView from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \ - XMLRenderer, JSONPRenderer, BrowsableAPIRenderer, UnicodeJSONRenderer, UnicodeYAMLRenderer + XMLRenderer, JSONPRenderer, BrowsableAPIRenderer from rest_framework.parsers import YAMLParser, XMLParser from rest_framework.settings import api_settings from rest_framework.test import APIRequestFactory @@ -32,7 +34,7 @@ RENDERER_B_SERIALIZER = lambda x: ('Renderer B: %s' % x).encode('ascii') expected_results = [ - ((elem for elem in [1, 2, 3]), JSONRenderer, b'[1, 2, 3]') # Generator + ((elem for elem in [1, 2, 3]), JSONRenderer, b'[1,2,3]') # Generator ] @@ -270,7 +272,7 @@ class RendererEndToEndTests(TestCase): self.assertNotContains(resp, '>text/html; charset=utf-8<') -_flat_repr = '{"foo": ["bar", "baz"]}' +_flat_repr = '{"foo":["bar","baz"]}' _indented_repr = '{\n "foo": [\n "bar",\n "baz"\n ]\n}' @@ -373,22 +375,38 @@ class JSONRendererTests(TestCase): content = renderer.render(obj, 'application/json; indent=2') self.assertEqual(strip_trailing_whitespace(content.decode('utf-8')), _indented_repr) - def test_check_ascii(self): + +class UnicodeJSONRendererTests(TestCase): + """ + Tests specific for the Unicode JSON Renderer + """ + def test_proper_encoding(self): obj = {'countries': ['United Kingdom', 'France', 'España']} renderer = JSONRenderer() content = renderer.render(obj, 'application/json') - self.assertEqual(content, '{"countries": ["United Kingdom", "France", "Espa\\u00f1a"]}'.encode('utf-8')) + self.assertEqual(content, '{"countries":["United Kingdom","France","España"]}'.encode('utf-8')) + def test_u2028_u2029(self): + # The \u2028 and \u2029 characters should be escaped, + # even when the non-escaping unicode representation is used. + # Regression test for #2169 + obj = {'should_escape': '\u2028\u2029'} + renderer = JSONRenderer() + content = renderer.render(obj, 'application/json') + self.assertEqual(content, '{"should_escape":"\\u2028\\u2029"}'.encode('utf-8')) -class UnicodeJSONRendererTests(TestCase): + +class AsciiJSONRendererTests(TestCase): """ Tests specific for the Unicode JSON Renderer """ def test_proper_encoding(self): + class AsciiJSONRenderer(JSONRenderer): + ensure_ascii = True obj = {'countries': ['United Kingdom', 'France', 'España']} - renderer = UnicodeJSONRenderer() + renderer = AsciiJSONRenderer() content = renderer.render(obj, 'application/json') - self.assertEqual(content, '{"countries": ["United Kingdom", "France", "España"]}'.encode('utf-8')) + self.assertEqual(content, '{"countries":["United Kingdom","France","Espa\\u00f1a"]}'.encode('utf-8')) class JSONPRendererTests(TestCase): @@ -460,7 +478,7 @@ if yaml: obj = {'foo': ['bar', 'baz']} renderer = YAMLRenderer() content = renderer.render(obj, 'application/yaml') - self.assertEqual(content, _yaml_repr) + self.assertEqual(content.decode('utf-8'), _yaml_repr) def test_render_and_parse(self): """ @@ -473,7 +491,7 @@ if yaml: parser = YAMLParser() content = renderer.render(obj, 'application/yaml') - data = parser.parse(StringIO(content)) + data = parser.parse(BytesIO(content)) self.assertEqual(obj, data) def test_render_decimal(self): @@ -482,18 +500,14 @@ if yaml: """ renderer = YAMLRenderer() content = renderer.render({'field': Decimal('111.2')}, 'application/yaml') - self.assertYAMLContains(content, "field: '111.2'") + self.assertYAMLContains(content.decode('utf-8'), "field: '111.2'") def assertYAMLContains(self, content, string): self.assertTrue(string in content, '%r not in %r' % (string, content)) - class UnicodeYAMLRendererTests(TestCase): - """ - Tests specific for the Unicode YAML Renderer - """ def test_proper_encoding(self): obj = {'countries': ['United Kingdom', 'France', 'España']} - renderer = UnicodeYAMLRenderer() + renderer = YAMLRenderer() content = renderer.render(obj, 'application/yaml') self.assertEqual(content.strip(), 'countries: [United Kingdom, France, España]'.encode('utf-8')) @@ -643,6 +657,7 @@ class CacheRenderTest(TestCase): """ method = getattr(self.client, http_method) resp = method(url) + resp._closable_objects = [] del resp.client, resp.request try: del resp.wsgi_request diff --git a/tests/test_request.py b/tests/test_request.py index 8ddaf0a7..02a9b1e2 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -179,89 +179,6 @@ class TestContentParsing(TestCase): self.assertEqual(request._data, Empty) self.assertEqual(request._files, Empty) - # def test_accessing_post_after_data_form(self): - # """ - # Ensures request.POST can be accessed after request.DATA in - # form request. - # """ - # data = {'qwerty': 'uiop'} - # request = factory.post('/', data=data) - # self.assertEqual(request.DATA.items(), data.items()) - # self.assertEqual(request.POST.items(), data.items()) - - # def test_accessing_post_after_data_for_json(self): - # """ - # Ensures request.POST can be accessed after request.DATA in - # json request. - # """ - # data = {'qwerty': 'uiop'} - # content = json.dumps(data) - # content_type = 'application/json' - # parsers = (JSONParser, ) - - # request = factory.post('/', content, content_type=content_type, - # parsers=parsers) - # self.assertEqual(request.DATA.items(), data.items()) - # self.assertEqual(request.POST.items(), []) - - # def test_accessing_post_after_data_for_overloaded_json(self): - # """ - # Ensures request.POST can be accessed after request.DATA in overloaded - # json request. - # """ - # data = {'qwerty': 'uiop'} - # content = json.dumps(data) - # content_type = 'application/json' - # parsers = (JSONParser, ) - # form_data = {Request._CONTENT_PARAM: content, - # Request._CONTENTTYPE_PARAM: content_type} - - # request = factory.post('/', form_data, parsers=parsers) - # self.assertEqual(request.DATA.items(), data.items()) - # self.assertEqual(request.POST.items(), form_data.items()) - - # def test_accessing_data_after_post_form(self): - # """ - # Ensures request.DATA can be accessed after request.POST in - # form request. - # """ - # data = {'qwerty': 'uiop'} - # parsers = (FormParser, MultiPartParser) - # request = factory.post('/', data, parsers=parsers) - - # self.assertEqual(request.POST.items(), data.items()) - # self.assertEqual(request.DATA.items(), data.items()) - - # def test_accessing_data_after_post_for_json(self): - # """ - # Ensures request.DATA can be accessed after request.POST in - # json request. - # """ - # data = {'qwerty': 'uiop'} - # content = json.dumps(data) - # content_type = 'application/json' - # parsers = (JSONParser, ) - # request = factory.post('/', content, content_type=content_type, - # parsers=parsers) - # self.assertEqual(request.POST.items(), []) - # self.assertEqual(request.DATA.items(), data.items()) - - # def test_accessing_data_after_post_for_overloaded_json(self): - # """ - # Ensures request.DATA can be accessed after request.POST in overloaded - # json request - # """ - # data = {'qwerty': 'uiop'} - # content = json.dumps(data) - # content_type = 'application/json' - # parsers = (JSONParser, ) - # form_data = {Request._CONTENT_PARAM: content, - # Request._CONTENTTYPE_PARAM: content_type} - - # request = factory.post('/', form_data, parsers=parsers) - # self.assertEqual(request.POST.items(), form_data.items()) - # self.assertEqual(request.DATA.items(), data.items()) - class MockView(APIView): authentication_classes = (SessionAuthentication,) @@ -270,7 +187,7 @@ class MockView(APIView): if request.POST.get('example') is not None: return Response(status=status.HTTP_200_OK) - return Response(status=status.INTERNAL_SERVER_ERROR) + return Response(status=status.HTTP_500_INTERNAL_SERVER_ERROR) urlpatterns = patterns( '', @@ -301,25 +218,14 @@ class TestContentParsingWithAuthentication(TestCase): response = self.csrf_client.post('/', content) self.assertEqual(status.HTTP_200_OK, response.status_code) - # def test_user_logged_in_authentication_has_post_when_logged_in(self): - # """Ensures request.POST exists after UserLoggedInAuthentication when user does log in""" - # self.client.login(username='john', password='password') - # self.csrf_client.login(username='john', password='password') - # content = {'example': 'example'} - - # response = self.client.post('/', content) - # self.assertEqual(status.OK, response.status_code, "POST data is malformed") - - # response = self.csrf_client.post('/', content) - # self.assertEqual(status.OK, response.status_code, "POST data is malformed") - class TestUserSetter(TestCase): def setUp(self): # Pass request object through session middleware so session is # available to login and logout functions - self.request = Request(factory.get('/')) + self.wrapped_request = factory.get('/') + self.request = Request(self.wrapped_request) SessionMiddleware().process_request(self.request) User.objects.create_user('ringo', 'starr@thebeatles.com', 'yellow') @@ -339,6 +245,10 @@ class TestUserSetter(TestCase): logout(self.request) self.assertTrue(self.request.user.is_anonymous()) + def test_logged_in_user_is_set_on_wrapped_request(self): + login(self.request, self.user) + self.assertEqual(self.wrapped_request.user, self.user) + class TestAuthSetter(TestCase): diff --git a/tests/test_response.py b/tests/test_response.py index 2eff83d3..f233ae33 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -2,11 +2,12 @@ from __future__ import unicode_literals from django.conf.urls import patterns, url, include from django.test import TestCase from django.utils import six -from tests.models import BasicModel, BasicModelSerializer +from tests.models import BasicModel from rest_framework.response import Response from rest_framework.views import APIView from rest_framework import generics from rest_framework import routers +from rest_framework import serializers from rest_framework import status from rest_framework.renderers import ( BaseRenderer, @@ -17,6 +18,12 @@ from rest_framework import viewsets from rest_framework.settings import api_settings +# Serializer used to test BasicModel +class BasicModelSerializer(serializers.ModelSerializer): + class Meta: + model = BasicModel + + class MockPickleRenderer(BaseRenderer): media_type = 'application/pickle' @@ -86,14 +93,15 @@ class HTMLView1(APIView): class HTMLNewModelViewSet(viewsets.ModelViewSet): - model = BasicModel + serializer_class = BasicModelSerializer + queryset = BasicModel.objects.all() class HTMLNewModelView(generics.ListCreateAPIView): renderer_classes = (BrowsableAPIRenderer,) permission_classes = [] serializer_class = BasicModelSerializer - model = BasicModel + queryset = BasicModel.objects.all() new_model_viewset_router = routers.DefaultRouter() @@ -224,8 +232,8 @@ class Issue467Tests(TestCase): def test_form_has_label_and_help_text(self): resp = self.client.get('/html_new_model') self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') - self.assertContains(resp, 'Text comes here') - self.assertContains(resp, 'Text description.') + # self.assertContains(resp, 'Text comes here') + # self.assertContains(resp, 'Text description.') class Issue807Tests(TestCase): @@ -254,9 +262,9 @@ class Issue807Tests(TestCase): expected = "{0}; charset={1}".format(RendererC.media_type, RendererC.charset) self.assertEqual(expected, resp['Content-Type']) - def test_content_type_set_explictly_on_response(self): + def test_content_type_set_explicitly_on_response(self): """ - The content type may be set explictly on the response. + The content type may be set explicitly on the response. """ headers = {"HTTP_ACCEPT": RendererC.media_type} resp = self.client.get('/setbyview', **headers) @@ -269,11 +277,11 @@ class Issue807Tests(TestCase): ) resp = self.client.get('/html_new_model_viewset/' + param) self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') - self.assertContains(resp, 'Text comes here') - self.assertContains(resp, 'Text description.') + # self.assertContains(resp, 'Text comes here') + # self.assertContains(resp, 'Text description.') def test_form_has_label_and_help_text(self): resp = self.client.get('/html_new_model') self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') - self.assertContains(resp, 'Text comes here') - self.assertContains(resp, 'Text description.') + # self.assertContains(resp, 'Text comes here') + # self.assertContains(resp, 'Text description.') diff --git a/tests/test_routers.py b/tests/test_routers.py index 73d10822..06ab8103 100644 --- a/tests/test_routers.py +++ b/tests/test_routers.py @@ -77,9 +77,10 @@ class TestCustomLookupFields(TestCase): def setUp(self): class NoteSerializer(serializers.HyperlinkedModelSerializer): + url = serializers.HyperlinkedIdentityField(view_name='routertestmodel-detail', lookup_field='uuid') + class Meta: model = RouterTestModel - lookup_field = 'uuid' fields = ('url', 'uuid', 'text') class NoteViewSet(viewsets.ModelViewSet): @@ -87,8 +88,6 @@ class TestCustomLookupFields(TestCase): serializer_class = NoteSerializer lookup_field = 'uuid' - RouterTestModel.objects.create(uuid='123', text='foo bar') - self.router = SimpleRouter() self.router.register(r'notes', NoteViewSet) @@ -99,6 +98,8 @@ class TestCustomLookupFields(TestCase): url(r'^', include(self.router.urls)), ) + RouterTestModel.objects.create(uuid='123', text='foo bar') + def test_custom_lookup_field_route(self): detail_route = self.router.urls[-1] detail_url_pattern = detail_route.regex.pattern diff --git a/tests/test_serializer.py b/tests/test_serializer.py index e72b723f..c17b6d8c 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -1,2006 +1,218 @@ -# -*- coding: utf-8 -*- +# coding: utf-8 from __future__ import unicode_literals -from django.db import models -from django.db.models.fields import BLANK_CHOICE_DASH -from django.test import TestCase -from django.utils import unittest -from django.utils.datastructures import MultiValueDict -from django.utils.translation import ugettext_lazy as _ -from rest_framework import serializers, fields, relations -from tests.models import ( - HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel, - BlankFieldModel, BlogPost, BlogPostComment, Book, CallableDefaultValueModel, - DefaultValueModel, ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo, - RESTFrameworkModel, ForeignKeySource -) -from tests.models import BasicModelSerializer -import datetime -import pickle -try: - import PIL -except: - PIL = None - - -if PIL is not None: - class AMOAFModel(RESTFrameworkModel): - char_field = models.CharField(max_length=1024, blank=True) - comma_separated_integer_field = models.CommaSeparatedIntegerField(max_length=1024, blank=True) - decimal_field = models.DecimalField(max_digits=64, decimal_places=32, blank=True) - email_field = models.EmailField(max_length=1024, blank=True) - file_field = models.FileField(upload_to='test', max_length=1024, blank=True) - image_field = models.ImageField(upload_to='test', max_length=1024, blank=True) - slug_field = models.SlugField(max_length=1024, blank=True) - url_field = models.URLField(max_length=1024, blank=True) - nullable_char_field = models.CharField(max_length=1024, blank=True, null=True) - - class DVOAFModel(RESTFrameworkModel): - positive_integer_field = models.PositiveIntegerField(blank=True) - positive_small_integer_field = models.PositiveSmallIntegerField(blank=True) - email_field = models.EmailField(blank=True) - file_field = models.FileField(upload_to='test', blank=True) - image_field = models.ImageField(upload_to='test', blank=True) - slug_field = models.SlugField(blank=True) - url_field = models.URLField(blank=True) - - -class SubComment(object): - def __init__(self, sub_comment): - self.sub_comment = sub_comment - - -class Comment(object): - def __init__(self, email, content, created): - self.email = email - self.content = content - self.created = created or datetime.datetime.now() - - def __eq__(self, other): - return all([getattr(self, attr) == getattr(other, attr) - for attr in ('email', 'content', 'created')]) - - def get_sub_comment(self): - sub_comment = SubComment('And Merry Christmas!') - return sub_comment - - -class CommentSerializer(serializers.Serializer): - email = serializers.EmailField() - content = serializers.CharField(max_length=1000) - created = serializers.DateTimeField() - sub_comment = serializers.Field(source='get_sub_comment.sub_comment') - - def restore_object(self, data, instance=None): - if instance is None: - return Comment(**data) - for key, val in data.items(): - setattr(instance, key, val) - return instance - - -class NamesSerializer(serializers.Serializer): - first = serializers.CharField() - last = serializers.CharField(required=False, default='') - initials = serializers.CharField(required=False, default='') - - -class PersonIdentifierSerializer(serializers.Serializer): - ssn = serializers.CharField() - names = NamesSerializer(source='names', required=False) - - -class BookSerializer(serializers.ModelSerializer): - isbn = serializers.RegexField(regex=r'^[0-9]{13}$', error_messages={'invalid': 'isbn has to be exact 13 numbers'}) - - class Meta: - model = Book - - -class ActionItemSerializer(serializers.ModelSerializer): - - class Meta: - model = ActionItem - - -class ActionItemSerializerOptionalFields(serializers.ModelSerializer): - """ - Intended to test that fields with `required=False` are excluded from validation. - """ - title = serializers.CharField(required=False) - - class Meta: - model = ActionItem - fields = ('title',) - - -class ActionItemSerializerCustomRestore(serializers.ModelSerializer): - - class Meta: - model = ActionItem - - def restore_object(self, data, instance=None): - if instance is None: - return ActionItem(**data) - for key, val in data.items(): - setattr(instance, key, val) - return instance - - -class PersonSerializer(serializers.ModelSerializer): - info = serializers.Field(source='info') - - class Meta: - model = Person - fields = ('name', 'age', 'info') - read_only_fields = ('age',) - - -class NestedSerializer(serializers.Serializer): - info = serializers.Field() - - -class ModelSerializerWithNestedSerializer(serializers.ModelSerializer): - nested = NestedSerializer(source='*') - - class Meta: - model = Person - - -class NestedSerializerWithRenamedField(serializers.Serializer): - renamed_info = serializers.Field(source='info') - - -class ModelSerializerWithNestedSerializerWithRenamedField(serializers.ModelSerializer): - nested = NestedSerializerWithRenamedField(source='*') - - class Meta: - model = Person - - -class PersonSerializerInvalidReadOnly(serializers.ModelSerializer): - """ - Testing for #652. - """ - info = serializers.Field(source='info') - - class Meta: - model = Person - fields = ('name', 'age', 'info') - read_only_fields = ('age', 'info') - - -class AlbumsSerializer(serializers.ModelSerializer): - - class Meta: - model = Album - fields = ['title', 'ref'] # lists are also valid options - - -class PositiveIntegerAsChoiceSerializer(serializers.ModelSerializer): - class Meta: - model = HasPositiveIntegerAsChoice - fields = ['some_integer'] - - -class ForeignKeySourceSerializer(serializers.ModelSerializer): - class Meta: - model = ForeignKeySource - - -class HyperlinkedForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): - class Meta: - model = ForeignKeySource - - -class BasicTests(TestCase): - def setUp(self): - self.comment = Comment( - 'tom@example.com', - 'Happy new year!', - datetime.datetime(2012, 1, 1) - ) - self.actionitem = ActionItem(title='Some to do item',) - self.data = { - 'email': 'tom@example.com', - 'content': 'Happy new year!', - 'created': datetime.datetime(2012, 1, 1), - 'sub_comment': 'This wont change' - } - self.expected = { - 'email': 'tom@example.com', - 'content': 'Happy new year!', - 'created': datetime.datetime(2012, 1, 1), - 'sub_comment': 'And Merry Christmas!' - } - self.person_data = {'name': 'dwight', 'age': 35} - self.person = Person(**self.person_data) - self.person.save() - - def test_empty(self): - serializer = CommentSerializer() - expected = { - 'email': '', - 'content': '', - 'created': None - } - self.assertEqual(serializer.data, expected) - - def test_retrieve(self): - serializer = CommentSerializer(self.comment) - self.assertEqual(serializer.data, self.expected) - - def test_create(self): - serializer = CommentSerializer(data=self.data) - expected = self.comment - self.assertEqual(serializer.is_valid(), True) - self.assertEqual(serializer.object, expected) - self.assertFalse(serializer.object is expected) - self.assertEqual(serializer.data['sub_comment'], 'And Merry Christmas!') - - def test_create_nested(self): - """Test a serializer with nested data.""" - names = {'first': 'John', 'last': 'Doe', 'initials': 'jd'} - data = {'ssn': '1234567890', 'names': names} - serializer = PersonIdentifierSerializer(data=data) - - self.assertEqual(serializer.is_valid(), True) - self.assertEqual(serializer.object, data) - self.assertFalse(serializer.object is data) - self.assertEqual(serializer.data['names'], names) - - def test_create_partial_nested(self): - """Test a serializer with nested data which has missing fields.""" - names = {'first': 'John'} - data = {'ssn': '1234567890', 'names': names} - serializer = PersonIdentifierSerializer(data=data) - - expected_names = {'first': 'John', 'last': '', 'initials': ''} - data['names'] = expected_names - - self.assertEqual(serializer.is_valid(), True) - self.assertEqual(serializer.object, data) - self.assertFalse(serializer.object is expected_names) - self.assertEqual(serializer.data['names'], expected_names) - - def test_null_nested(self): - """Test a serializer with a nonexistent nested field""" - data = {'ssn': '1234567890'} - serializer = PersonIdentifierSerializer(data=data) - - self.assertEqual(serializer.is_valid(), True) - self.assertEqual(serializer.object, data) - self.assertFalse(serializer.object is data) - expected = {'ssn': '1234567890', 'names': None} - self.assertEqual(serializer.data, expected) - - def test_update(self): - serializer = CommentSerializer(self.comment, data=self.data) - expected = self.comment - self.assertEqual(serializer.is_valid(), True) - self.assertEqual(serializer.object, expected) - self.assertTrue(serializer.object is expected) - self.assertEqual(serializer.data['sub_comment'], 'And Merry Christmas!') - - def test_partial_update(self): - msg = 'Merry New Year!' - partial_data = {'content': msg} - serializer = CommentSerializer(self.comment, data=partial_data) - self.assertEqual(serializer.is_valid(), False) - serializer = CommentSerializer(self.comment, data=partial_data, partial=True) - expected = self.comment - self.assertEqual(serializer.is_valid(), True) - self.assertEqual(serializer.object, expected) - self.assertTrue(serializer.object is expected) - self.assertEqual(serializer.data['content'], msg) - - def test_model_fields_as_expected(self): - """ - Make sure that the fields returned are the same as defined - in the Meta data - """ - serializer = PersonSerializer(self.person) - self.assertEqual( - set(serializer.data.keys()), - set(['name', 'age', 'info']) - ) - - def test_field_with_dictionary(self): - """ - Make sure that dictionaries from fields are left intact - """ - serializer = PersonSerializer(self.person) - expected = self.person_data - self.assertEqual(serializer.data['info'], expected) - - def test_read_only_fields(self): - """ - Attempting to update fields set as read_only should have no effect. - """ - serializer = PersonSerializer(self.person, data={'name': 'dwight', 'age': 99}) - self.assertEqual(serializer.is_valid(), True) - instance = serializer.save() - self.assertEqual(serializer.errors, {}) - # Assert age is unchanged (35) - self.assertEqual(instance.age, self.person_data['age']) - - def test_invalid_read_only_fields(self): - """ - Regression test for #652. - """ - serializer = PersonSerializerInvalidReadOnly() - with self.assertRaises(AssertionError): - serializer.fields - - def test_serializer_data_is_cleared_on_save(self): - """ - Check _data attribute is cleared on `save()` - - Regression test for #1116 - — id field is not populated if `data` is accessed prior to `save()` - """ - serializer = ActionItemSerializer(self.actionitem) - self.assertIsNone(serializer.data.get('id', None), 'New instance. `id` should not be set.') - serializer.save() - self.assertIsNotNone(serializer.data.get('id', None), 'Model is saved. `id` should be set.') - - def test_fields_marked_as_not_required_are_excluded_from_validation(self): - """ - Check that fields with `required=False` are included in list of exclusions. - """ - serializer = ActionItemSerializerOptionalFields(self.actionitem) - exclusions = serializer.get_validation_exclusions() - self.assertTrue('title' in exclusions, '`title` field was marked `required=False` and should be excluded') - - -class DictStyleSerializer(serializers.Serializer): - """ - Note that we don't have any `restore_object` method, so the default - case of simply returning a dict will apply. - """ - email = serializers.EmailField() - - -class DictStyleSerializerTests(TestCase): - def test_dict_style_deserialize(self): - """ - Ensure serializers can deserialize into a dict. - """ - data = {'email': 'foo@example.com'} - serializer = DictStyleSerializer(data=data) - self.assertTrue(serializer.is_valid()) - self.assertEqual(serializer.data, data) - - def test_dict_style_serialize(self): - """ - Ensure serializers can serialize dict objects. - """ - data = {'email': 'foo@example.com'} - serializer = DictStyleSerializer(data) - self.assertEqual(serializer.data, data) - - -class ValidationTests(TestCase): - def setUp(self): - self.comment = Comment( - 'tom@example.com', - 'Happy new year!', - datetime.datetime(2012, 1, 1) - ) - self.data = { - 'email': 'tom@example.com', - 'content': 'x' * 1001, - 'created': datetime.datetime(2012, 1, 1) - } - self.actionitem = ActionItem(title='Some to do item',) - - def test_create(self): - serializer = CommentSerializer(data=self.data) - self.assertEqual(serializer.is_valid(), False) - self.assertEqual(serializer.errors, {'content': ['Ensure this value has at most 1000 characters (it has 1001).']}) - - def test_update(self): - serializer = CommentSerializer(self.comment, data=self.data) - self.assertEqual(serializer.is_valid(), False) - self.assertEqual(serializer.errors, {'content': ['Ensure this value has at most 1000 characters (it has 1001).']}) - - def test_update_missing_field(self): - data = { - 'content': 'xxx', - 'created': datetime.datetime(2012, 1, 1) - } - serializer = CommentSerializer(self.comment, data=data) - self.assertEqual(serializer.is_valid(), False) - self.assertEqual(serializer.errors, {'email': ['This field is required.']}) - - def test_missing_bool_with_default(self): - """Make sure that a boolean value with a 'False' value is not - mistaken for not having a default.""" - data = { - 'title': 'Some action item', - # No 'done' value. - } - serializer = ActionItemSerializer(self.actionitem, data=data) - self.assertEqual(serializer.is_valid(), True) - self.assertEqual(serializer.errors, {}) - - def test_cross_field_validation(self): - - class CommentSerializerWithCrossFieldValidator(CommentSerializer): +from rest_framework import serializers +from rest_framework.compat import unicode_repr +import pytest + + +# Tests for core functionality. +# ----------------------------- + +class TestSerializer: + def setup(self): + class ExampleSerializer(serializers.Serializer): + char = serializers.CharField() + integer = serializers.IntegerField() + self.Serializer = ExampleSerializer + + def test_valid_serializer(self): + serializer = self.Serializer(data={'char': 'abc', 'integer': 123}) + assert serializer.is_valid() + assert serializer.validated_data == {'char': 'abc', 'integer': 123} + assert serializer.errors == {} + + def test_invalid_serializer(self): + serializer = self.Serializer(data={'char': 'abc'}) + assert not serializer.is_valid() + assert serializer.validated_data == {} + assert serializer.errors == {'integer': ['This field is required.']} + + def test_partial_validation(self): + serializer = self.Serializer(data={'char': 'abc'}, partial=True) + assert serializer.is_valid() + assert serializer.validated_data == {'char': 'abc'} + assert serializer.errors == {} + + def test_empty_serializer(self): + serializer = self.Serializer() + assert serializer.data == {'char': '', 'integer': None} + + def test_missing_attribute_during_serialization(self): + class MissingAttributes: + pass + instance = MissingAttributes() + serializer = self.Serializer(instance) + with pytest.raises(AttributeError): + serializer.data + + +class TestValidateMethod: + def test_non_field_error_validate_method(self): + class ExampleSerializer(serializers.Serializer): + char = serializers.CharField() + integer = serializers.IntegerField() def validate(self, attrs): - if attrs["email"] not in attrs["content"]: - raise serializers.ValidationError("Email address not in content") - return attrs - - data = { - 'email': 'tom@example.com', - 'content': 'A comment from tom@example.com', - 'created': datetime.datetime(2012, 1, 1) - } - - serializer = CommentSerializerWithCrossFieldValidator(data=data) - self.assertTrue(serializer.is_valid()) - - data['content'] = 'A comment from foo@bar.com' - - serializer = CommentSerializerWithCrossFieldValidator(data=data) - self.assertFalse(serializer.is_valid()) - self.assertEqual(serializer.errors, {'non_field_errors': ['Email address not in content']}) - - def test_null_is_true_fields(self): - """ - Omitting a value for null-field should validate. - """ - serializer = PersonSerializer(data={'name': 'marko'}) - self.assertEqual(serializer.is_valid(), True) - self.assertEqual(serializer.errors, {}) - - def test_modelserializer_max_length_exceeded(self): - data = { - 'title': 'x' * 201, - } - serializer = ActionItemSerializer(data=data) - self.assertEqual(serializer.is_valid(), False) - self.assertEqual(serializer.errors, {'title': ['Ensure this value has at most 200 characters (it has 201).']}) - - def test_modelserializer_max_length_exceeded_with_custom_restore(self): - """ - When overriding ModelSerializer.restore_object, validation tests should still apply. - Regression test for #623. - - https://github.com/tomchristie/django-rest-framework/pull/623 - """ - data = { - 'title': 'x' * 201, - } - serializer = ActionItemSerializerCustomRestore(data=data) - self.assertEqual(serializer.is_valid(), False) - self.assertEqual(serializer.errors, {'title': ['Ensure this value has at most 200 characters (it has 201).']}) - - def test_default_modelfield_max_length_exceeded(self): - data = { - 'title': 'Testing "info" field...', - 'info': 'x' * 13, - } - serializer = ActionItemSerializer(data=data) - self.assertEqual(serializer.is_valid(), False) - self.assertEqual(serializer.errors, {'info': ['Ensure this value has at most 12 characters (it has 13).']}) - - def test_datetime_validation_failure(self): - """ - Test DateTimeField validation errors on non-str values. - Regression test for #669. - - https://github.com/tomchristie/django-rest-framework/issues/669 - """ - data = self.data - data['created'] = 0 - - serializer = CommentSerializer(data=data) - self.assertEqual(serializer.is_valid(), False) - - self.assertIn('created', serializer.errors) - - def test_missing_model_field_exception_msg(self): - """ - Assert that a meaningful exception message is outputted when the model - field is missing (e.g. when mistyping ``model``). - """ - class BrokenModelSerializer(serializers.ModelSerializer): - class Meta: - fields = ['some_field'] - - try: - BrokenModelSerializer() - except AssertionError as e: - self.assertEqual(e.args[0], "Serializer class 'BrokenModelSerializer' is missing 'model' Meta option") - except: - self.fail('Wrong exception type thrown.') - - def test_writable_star_source_on_nested_serializer(self): - """ - Assert that a nested serializer instantiated with source='*' correctly - expands the data into the outer serializer. - """ - serializer = ModelSerializerWithNestedSerializer(data={ - 'name': 'marko', - 'nested': {'info': 'hi'}}, - ) - self.assertEqual(serializer.is_valid(), True) - - def test_writable_star_source_on_nested_serializer_with_parent_object(self): - class TitleSerializer(serializers.Serializer): - title = serializers.WritableField(source='title') - - class AlbumSerializer(serializers.ModelSerializer): - nested = TitleSerializer(source='*') - - class Meta: - model = Album - fields = ('nested',) - - class PhotoSerializer(serializers.ModelSerializer): - album = AlbumSerializer(source='album') - - class Meta: - model = Photo - fields = ('album', ) - - photo = Photo(album=Album()) - - data = {'album': {'nested': {'title': 'test'}}} - - serializer = PhotoSerializer(photo, data=data) - self.assertEqual(serializer.is_valid(), True) - self.assertEqual(serializer.data, data) - - def test_writable_star_source_with_inner_source_fields(self): - """ - Tests that a serializer with source="*" correctly expands the - it's fields into the outer serializer even if they have their - own 'source' parameters. - """ - - serializer = ModelSerializerWithNestedSerializerWithRenamedField(data={ - 'name': 'marko', - 'nested': {'renamed_info': 'hi'}}, - ) - self.assertEqual(serializer.is_valid(), True) - self.assertEqual(serializer.errors, {}) - - -class CustomValidationTests(TestCase): - class CommentSerializerWithFieldValidator(CommentSerializer): - - def validate_email(self, attrs, source): - attrs[source] - return attrs - - def validate_content(self, attrs, source): - value = attrs[source] - if "test" not in value: - raise serializers.ValidationError("Test not in value") - return attrs - - def test_field_validation(self): - data = { - 'email': 'tom@example.com', - 'content': 'A test comment', - 'created': datetime.datetime(2012, 1, 1) - } - - serializer = self.CommentSerializerWithFieldValidator(data=data) - self.assertTrue(serializer.is_valid()) - - data['content'] = 'This should not validate' - - serializer = self.CommentSerializerWithFieldValidator(data=data) - self.assertFalse(serializer.is_valid()) - self.assertEqual(serializer.errors, {'content': ['Test not in value']}) - - def test_missing_data(self): - """ - Make sure that validate_content isn't called if the field is missing - """ - incomplete_data = { - 'email': 'tom@example.com', - 'created': datetime.datetime(2012, 1, 1) - } - serializer = self.CommentSerializerWithFieldValidator(data=incomplete_data) - self.assertFalse(serializer.is_valid()) - self.assertEqual(serializer.errors, {'content': ['This field is required.']}) - - def test_wrong_data(self): - """ - Make sure that validate_content isn't called if the field input is wrong - """ - wrong_data = { - 'email': 'not an email', - 'content': 'A test comment', - 'created': datetime.datetime(2012, 1, 1) - } - serializer = self.CommentSerializerWithFieldValidator(data=wrong_data) - self.assertFalse(serializer.is_valid()) - self.assertEqual(serializer.errors, {'email': ['Enter a valid email address.']}) - - def test_partial_update(self): - """ - Make sure that validate_email isn't called when partial=True and email - isn't found in data. - """ - initial_data = { - 'email': 'tom@example.com', - 'content': 'A test comment', - 'created': datetime.datetime(2012, 1, 1) - } - - serializer = self.CommentSerializerWithFieldValidator(data=initial_data) - self.assertEqual(serializer.is_valid(), True) - instance = serializer.object - - new_content = 'An *updated* test comment' - partial_data = { - 'content': new_content - } - - serializer = self.CommentSerializerWithFieldValidator(instance=instance, - data=partial_data, - partial=True) - self.assertEqual(serializer.is_valid(), True) - instance = serializer.object - self.assertEqual(instance.content, new_content) - - -class PositiveIntegerAsChoiceTests(TestCase): - def test_positive_integer_in_json_is_correctly_parsed(self): - data = {'some_integer': 1} - serializer = PositiveIntegerAsChoiceSerializer(data=data) - self.assertEqual(serializer.is_valid(), True) - - -class ModelValidationTests(TestCase): - def test_validate_unique(self): - """ - Just check if serializers.ModelSerializer handles unique checks via .full_clean() - """ - serializer = AlbumsSerializer(data={'title': 'a', 'ref': '1'}) - serializer.is_valid() - serializer.save() - second_serializer = AlbumsSerializer(data={'title': 'a'}) - self.assertFalse(second_serializer.is_valid()) - self.assertEqual(second_serializer.errors, {'title': ['Album with this Title already exists.']}) - third_serializer = AlbumsSerializer(data=[{'title': 'b', 'ref': '1'}, {'title': 'c'}], many=True) - self.assertFalse(third_serializer.is_valid()) - self.assertEqual(third_serializer.errors, [{'ref': ['Album with this Ref already exists.']}, {}]) - - def test_foreign_key_is_null_with_partial(self): - """ - Test ModelSerializer validation with partial=True - - Specifically test that a null foreign key does not pass validation - """ - album = Album(title='test') - album.save() - - class PhotoSerializer(serializers.ModelSerializer): - class Meta: - model = Photo - - photo_serializer = PhotoSerializer(data={'description': 'test', 'album': album.pk}) - self.assertTrue(photo_serializer.is_valid()) - photo = photo_serializer.save() - - # Updating only the album (foreign key) - photo_serializer = PhotoSerializer(instance=photo, data={'album': ''}, partial=True) - self.assertFalse(photo_serializer.is_valid()) - self.assertTrue('album' in photo_serializer.errors) - self.assertEqual(photo_serializer.errors['album'], [photo_serializer.error_messages['required']]) - - def test_foreign_key_with_partial(self): - """ - Test ModelSerializer validation with partial=True - - Specifically test foreign key validation. - """ - - album = Album(title='test') - album.save() - - class PhotoSerializer(serializers.ModelSerializer): - class Meta: - model = Photo - - photo_serializer = PhotoSerializer(data={'description': 'test', 'album': album.pk}) - self.assertTrue(photo_serializer.is_valid()) - photo = photo_serializer.save() - - # Updating only the album (foreign key) - photo_serializer = PhotoSerializer(instance=photo, data={'album': album.pk}, partial=True) - self.assertTrue(photo_serializer.is_valid()) - self.assertTrue(photo_serializer.save()) - - # Updating only the description - photo_serializer = PhotoSerializer(instance=photo, - data={'description': 'new'}, - partial=True) - - self.assertTrue(photo_serializer.is_valid()) - self.assertTrue(photo_serializer.save()) - - -class RegexValidationTest(TestCase): - def test_create_failed(self): - serializer = BookSerializer(data={'isbn': '1234567890'}) - self.assertFalse(serializer.is_valid()) - self.assertEqual(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']}) - - serializer = BookSerializer(data={'isbn': '12345678901234'}) - self.assertFalse(serializer.is_valid()) - self.assertEqual(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']}) - - serializer = BookSerializer(data={'isbn': 'abcdefghijklm'}) - self.assertFalse(serializer.is_valid()) - self.assertEqual(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']}) - - def test_create_success(self): - serializer = BookSerializer(data={'isbn': '1234567890123'}) - self.assertTrue(serializer.is_valid()) - - -class MetadataTests(TestCase): - def test_empty(self): - serializer = CommentSerializer() - expected = { - 'email': serializers.CharField, - 'content': serializers.CharField, - 'created': serializers.DateTimeField - } - for field_name, field in expected.items(): - self.assertTrue(isinstance(serializer.data.fields[field_name], field)) - - -class ManyToManyTests(TestCase): - def setUp(self): - class ManyToManySerializer(serializers.ModelSerializer): - class Meta: - model = ManyToManyModel - - self.serializer_class = ManyToManySerializer - - # An anchor instance to use for the relationship - self.anchor = Anchor() - self.anchor.save() - - # A model instance with a many to many relationship to the anchor - self.instance = ManyToManyModel() - self.instance.save() - self.instance.rel.add(self.anchor) - - # A serialized representation of the model instance - self.data = {'id': 1, 'rel': [self.anchor.id]} - - def test_retrieve(self): - """ - Serialize an instance of a model with a ManyToMany relationship. - """ - serializer = self.serializer_class(instance=self.instance) - expected = self.data - self.assertEqual(serializer.data, expected) - - def test_create(self): - """ - Create an instance of a model with a ManyToMany relationship. - """ - data = {'rel': [self.anchor.id]} - serializer = self.serializer_class(data=data) - self.assertEqual(serializer.is_valid(), True) - instance = serializer.save() - self.assertEqual(len(ManyToManyModel.objects.all()), 2) - self.assertEqual(instance.pk, 2) - self.assertEqual(list(instance.rel.all()), [self.anchor]) - - def test_update(self): - """ - Update an instance of a model with a ManyToMany relationship. - """ - new_anchor = Anchor() - new_anchor.save() - data = {'rel': [self.anchor.id, new_anchor.id]} - serializer = self.serializer_class(self.instance, data=data) - self.assertEqual(serializer.is_valid(), True) - instance = serializer.save() - self.assertEqual(len(ManyToManyModel.objects.all()), 1) - self.assertEqual(instance.pk, 1) - self.assertEqual(list(instance.rel.all()), [self.anchor, new_anchor]) - - def test_create_empty_relationship(self): - """ - Create an instance of a model with a ManyToMany relationship, - containing no items. - """ - data = {'rel': []} - serializer = self.serializer_class(data=data) - self.assertEqual(serializer.is_valid(), True) - instance = serializer.save() - self.assertEqual(len(ManyToManyModel.objects.all()), 2) - self.assertEqual(instance.pk, 2) - self.assertEqual(list(instance.rel.all()), []) - - def test_update_empty_relationship(self): - """ - Update an instance of a model with a ManyToMany relationship, - containing no items. - """ - new_anchor = Anchor() - new_anchor.save() - data = {'rel': []} - serializer = self.serializer_class(self.instance, data=data) - self.assertEqual(serializer.is_valid(), True) - instance = serializer.save() - self.assertEqual(len(ManyToManyModel.objects.all()), 1) - self.assertEqual(instance.pk, 1) - self.assertEqual(list(instance.rel.all()), []) - - def test_create_empty_relationship_flat_data(self): - """ - Create an instance of a model with a ManyToMany relationship, - containing no items, using a representation that does not support - lists (eg form data). - """ - data = MultiValueDict() - data.setlist('rel', ['']) - serializer = self.serializer_class(data=data) - self.assertEqual(serializer.is_valid(), True) - instance = serializer.save() - self.assertEqual(len(ManyToManyModel.objects.all()), 2) - self.assertEqual(instance.pk, 2) - self.assertEqual(list(instance.rel.all()), []) - - -class ReadOnlyManyToManyTests(TestCase): - def setUp(self): - class ReadOnlyManyToManySerializer(serializers.ModelSerializer): - rel = serializers.RelatedField(many=True, read_only=True) - - class Meta: - model = ReadOnlyManyToManyModel - - self.serializer_class = ReadOnlyManyToManySerializer - - # An anchor instance to use for the relationship - self.anchor = Anchor() - self.anchor.save() - - # A model instance with a many to many relationship to the anchor - self.instance = ReadOnlyManyToManyModel() - self.instance.save() - self.instance.rel.add(self.anchor) - - # A serialized representation of the model instance - self.data = {'rel': [self.anchor.id], 'id': 1, 'text': 'anchor'} - - def test_update(self): - """ - Attempt to update an instance of a model with a ManyToMany - relationship. Not updated due to read_only=True - """ - new_anchor = Anchor() - new_anchor.save() - data = {'rel': [self.anchor.id, new_anchor.id]} - serializer = self.serializer_class(self.instance, data=data) - self.assertEqual(serializer.is_valid(), True) - instance = serializer.save() - self.assertEqual(len(ReadOnlyManyToManyModel.objects.all()), 1) - self.assertEqual(instance.pk, 1) - # rel is still as original (1 entry) - self.assertEqual(list(instance.rel.all()), [self.anchor]) - - def test_update_without_relationship(self): - """ - Attempt to update an instance of a model where many to ManyToMany - relationship is not supplied. Not updated due to read_only=True - """ - new_anchor = Anchor() - new_anchor.save() - data = {} - serializer = self.serializer_class(self.instance, data=data) - self.assertEqual(serializer.is_valid(), True) - instance = serializer.save() - self.assertEqual(len(ReadOnlyManyToManyModel.objects.all()), 1) - self.assertEqual(instance.pk, 1) - # rel is still as original (1 entry) - self.assertEqual(list(instance.rel.all()), [self.anchor]) - - -class DefaultValueTests(TestCase): - def setUp(self): - class DefaultValueSerializer(serializers.ModelSerializer): - class Meta: - model = DefaultValueModel - - self.serializer_class = DefaultValueSerializer - self.objects = DefaultValueModel.objects - - def test_create_using_default(self): - data = {} - serializer = self.serializer_class(data=data) - self.assertEqual(serializer.is_valid(), True) - instance = serializer.save() - self.assertEqual(len(self.objects.all()), 1) - self.assertEqual(instance.pk, 1) - self.assertEqual(instance.text, 'foobar') - - def test_create_overriding_default(self): - data = {'text': 'overridden'} - serializer = self.serializer_class(data=data) - self.assertEqual(serializer.is_valid(), True) - instance = serializer.save() - self.assertEqual(len(self.objects.all()), 1) - self.assertEqual(instance.pk, 1) - self.assertEqual(instance.text, 'overridden') - - def test_partial_update_default(self): - """ Regression test for issue #532 """ - data = {'text': 'overridden'} - serializer = self.serializer_class(data=data, partial=True) - self.assertEqual(serializer.is_valid(), True) - instance = serializer.save() - - data = {'extra': 'extra_value'} - serializer = self.serializer_class(instance=instance, data=data, partial=True) - self.assertEqual(serializer.is_valid(), True) - instance = serializer.save() - - self.assertEqual(instance.extra, 'extra_value') - self.assertEqual(instance.text, 'overridden') - - -class WritableFieldDefaultValueTests(TestCase): - - def setUp(self): - self.expected = {'default': 'value'} - self.create_field = fields.WritableField + raise serializers.ValidationError('Non field error') - def test_get_default_value_with_noncallable(self): - field = self.create_field(default=self.expected) - got = field.get_default_value() - self.assertEqual(got, self.expected) + serializer = ExampleSerializer(data={'char': 'abc', 'integer': 123}) + assert not serializer.is_valid() + assert serializer.errors == {'non_field_errors': ['Non field error']} - def test_get_default_value_with_callable(self): - field = self.create_field(default=lambda: self.expected) - got = field.get_default_value() - self.assertEqual(got, self.expected) + def test_field_error_validate_method(self): + class ExampleSerializer(serializers.Serializer): + char = serializers.CharField() + integer = serializers.IntegerField() - def test_get_default_value_when_not_required(self): - field = self.create_field(default=self.expected, required=False) - got = field.get_default_value() - self.assertEqual(got, self.expected) - - def test_get_default_value_returns_None(self): - field = self.create_field() - got = field.get_default_value() - self.assertIsNone(got) - - def test_get_default_value_returns_non_True_values(self): - values = [None, '', False, 0, [], (), {}] # values that assumed as 'False' in the 'if' clause - for expected in values: - field = self.create_field(default=expected) - got = field.get_default_value() - self.assertEqual(got, expected) - - -class RelatedFieldDefaultValueTests(WritableFieldDefaultValueTests): - - def setUp(self): - self.expected = {'foo': 'bar'} - self.create_field = relations.RelatedField - - def test_get_default_value_returns_empty_list(self): - field = self.create_field(many=True) - got = field.get_default_value() - self.assertListEqual(got, []) - - def test_get_default_value_returns_expected(self): - expected = [1, 2, 3] - field = self.create_field(many=True, default=expected) - got = field.get_default_value() - self.assertListEqual(got, expected) - - -class CallableDefaultValueTests(TestCase): - def setUp(self): - class CallableDefaultValueSerializer(serializers.ModelSerializer): - class Meta: - model = CallableDefaultValueModel - - self.serializer_class = CallableDefaultValueSerializer - self.objects = CallableDefaultValueModel.objects - - def test_create_using_default(self): - data = {} - serializer = self.serializer_class(data=data) - self.assertEqual(serializer.is_valid(), True) - instance = serializer.save() - self.assertEqual(len(self.objects.all()), 1) - self.assertEqual(instance.pk, 1) - self.assertEqual(instance.text, 'foobar') - - def test_create_overriding_default(self): - data = {'text': 'overridden'} - serializer = self.serializer_class(data=data) - self.assertEqual(serializer.is_valid(), True) - instance = serializer.save() - self.assertEqual(len(self.objects.all()), 1) - self.assertEqual(instance.pk, 1) - self.assertEqual(instance.text, 'overridden') - - -class ManyRelatedTests(TestCase): - def test_reverse_relations(self): - post = BlogPost.objects.create(title="Test blog post") - post.blogpostcomment_set.create(text="I hate this blog post") - post.blogpostcomment_set.create(text="I love this blog post") - - class BlogPostCommentSerializer(serializers.Serializer): - text = serializers.CharField() - - class BlogPostSerializer(serializers.Serializer): - title = serializers.CharField() - comments = BlogPostCommentSerializer(source='blogpostcomment_set') - - serializer = BlogPostSerializer(instance=post) - expected = { - 'title': 'Test blog post', - 'comments': [ - {'text': 'I hate this blog post'}, - {'text': 'I love this blog post'} - ] - } - - self.assertEqual(serializer.data, expected) - - def test_include_reverse_relations(self): - post = BlogPost.objects.create(title="Test blog post") - post.blogpostcomment_set.create(text="I hate this blog post") - post.blogpostcomment_set.create(text="I love this blog post") - - class BlogPostSerializer(serializers.ModelSerializer): - class Meta: - model = BlogPost - fields = ('id', 'title', 'blogpostcomment_set') - - serializer = BlogPostSerializer(instance=post) - expected = { - 'id': 1, 'title': 'Test blog post', 'blogpostcomment_set': [1, 2] - } - self.assertEqual(serializer.data, expected) - - def test_depth_include_reverse_relations(self): - post = BlogPost.objects.create(title="Test blog post") - post.blogpostcomment_set.create(text="I hate this blog post") - post.blogpostcomment_set.create(text="I love this blog post") - - class BlogPostSerializer(serializers.ModelSerializer): - class Meta: - model = BlogPost - fields = ('id', 'title', 'blogpostcomment_set') - depth = 1 - - serializer = BlogPostSerializer(instance=post) - expected = { - 'id': 1, 'title': 'Test blog post', - 'blogpostcomment_set': [ - {'id': 1, 'text': 'I hate this blog post', 'blog_post': 1}, - {'id': 2, 'text': 'I love this blog post', 'blog_post': 1} - ] - } - self.assertEqual(serializer.data, expected) - - def test_callable_source(self): - post = BlogPost.objects.create(title="Test blog post") - post.blogpostcomment_set.create(text="I love this blog post") - - class BlogPostCommentSerializer(serializers.Serializer): - text = serializers.CharField() - - class BlogPostSerializer(serializers.Serializer): - title = serializers.CharField() - first_comment = BlogPostCommentSerializer(source='get_first_comment') - - serializer = BlogPostSerializer(post) - - expected = { - 'title': 'Test blog post', - 'first_comment': {'text': 'I love this blog post'} - } - self.assertEqual(serializer.data, expected) - - -class RelatedTraversalTest(TestCase): - def test_nested_traversal(self): - """ - Source argument should support dotted.source notation. - """ - user = Person.objects.create(name="django") - post = BlogPost.objects.create(title="Test blog post", writer=user) - post.blogpostcomment_set.create(text="I love this blog post") - - class PersonSerializer(serializers.ModelSerializer): - class Meta: - model = Person - fields = ("name", "age") - - class BlogPostCommentSerializer(serializers.ModelSerializer): - class Meta: - model = BlogPostComment - fields = ("text", "post_owner") - - text = serializers.CharField() - post_owner = PersonSerializer(source='blog_post.writer') + def validate(self, attrs): + raise serializers.ValidationError({'char': 'Field error'}) - class BlogPostSerializer(serializers.Serializer): - title = serializers.CharField() - comments = BlogPostCommentSerializer(source='blogpostcomment_set') + serializer = ExampleSerializer(data={'char': 'abc', 'integer': 123}) + assert not serializer.is_valid() + assert serializer.errors == {'char': ['Field error']} - serializer = BlogPostSerializer(instance=post) - expected = { - 'title': 'Test blog post', - 'comments': [{ - 'text': 'I love this blog post', - 'post_owner': { - "name": "django", - "age": None +class TestBaseSerializer: + def setup(self): + class ExampleSerializer(serializers.BaseSerializer): + def to_representation(self, obj): + return { + 'id': obj['id'], + 'email': obj['name'] + '@' + obj['domain'] } - }] - } - - self.assertEqual(serializer.data, expected) - - def test_nested_traversal_with_none(self): - """ - If a component of the dotted.source is None, return None for the field. - """ - from tests.models import NullableForeignKeySource - instance = NullableForeignKeySource.objects.create(name='Source with null FK') - - class NullableSourceSerializer(serializers.Serializer): - target_name = serializers.Field(source='target.name') - - serializer = NullableSourceSerializer(instance=instance) - - expected = { - 'target_name': None, - } - - self.assertEqual(serializer.data, expected) - - -class SerializerMethodFieldTests(TestCase): - def setUp(self): - - class BoopSerializer(serializers.Serializer): - beep = serializers.SerializerMethodField('get_beep') - boop = serializers.Field() - boop_count = serializers.SerializerMethodField('get_boop_count') - - def get_beep(self, obj): - return 'hello!' - - def get_boop_count(self, obj): - return len(obj.boop) - - self.serializer_class = BoopSerializer - - def test_serializer_method_field(self): - - class MyModel(object): - boop = ['a', 'b', 'c'] - - source_data = MyModel() - - serializer = self.serializer_class(source_data) - - expected = { - 'beep': 'hello!', - 'boop': ['a', 'b', 'c'], - 'boop_count': 3, - } - - self.assertEqual(serializer.data, expected) - - -# Test for issue #324 -class BlankFieldTests(TestCase): - def setUp(self): - - class BlankFieldModelSerializer(serializers.ModelSerializer): - class Meta: - model = BlankFieldModel - - class BlankFieldSerializer(serializers.Serializer): - title = serializers.CharField(required=False) - - class NotBlankFieldModelSerializer(serializers.ModelSerializer): - class Meta: - model = BasicModel - - class NotBlankFieldSerializer(serializers.Serializer): - title = serializers.CharField() - - self.model_serializer_class = BlankFieldModelSerializer - self.serializer_class = BlankFieldSerializer - self.not_blank_model_serializer_class = NotBlankFieldModelSerializer - self.not_blank_serializer_class = NotBlankFieldSerializer - self.data = {'title': ''} - - def test_create_blank_field(self): - serializer = self.serializer_class(data=self.data) - self.assertEqual(serializer.is_valid(), True) - - def test_create_model_blank_field(self): - serializer = self.model_serializer_class(data=self.data) - self.assertEqual(serializer.is_valid(), True) - - def test_create_model_null_field(self): - serializer = self.model_serializer_class(data={'title': None}) - self.assertEqual(serializer.is_valid(), True) - serializer.save() - self.assertIsNot(serializer.object.pk, None) - self.assertEqual(serializer.object.title, '') - - def test_create_not_blank_field(self): - """ - Test to ensure blank data in a field not marked as blank=True - is considered invalid in a non-model serializer - """ - serializer = self.not_blank_serializer_class(data=self.data) - self.assertEqual(serializer.is_valid(), False) - - def test_create_model_not_blank_field(self): - """ - Test to ensure blank data in a field not marked as blank=True - is considered invalid in a model serializer - """ - serializer = self.not_blank_model_serializer_class(data=self.data) - self.assertEqual(serializer.is_valid(), False) - - def test_create_model_empty_field(self): - serializer = self.model_serializer_class(data={}) - self.assertEqual(serializer.is_valid(), True) - - def test_create_model_null_field_save(self): - """ - Regression test for #1330. - - https://github.com/tomchristie/django-rest-framework/pull/1330 - """ - serializer = self.model_serializer_class(data={'title': None}) - self.assertEqual(serializer.is_valid(), True) - - try: - serializer.save() - except Exception: - self.fail('Exception raised on save() after validation passes') + def to_internal_value(self, data): + name, domain = str(data['email']).split('@') + return { + 'id': int(data['id']), + 'name': name, + 'domain': domain, + } -# Test for issue #460 -class SerializerPickleTests(TestCase): + self.Serializer = ExampleSerializer + + def test_serialize_instance(self): + instance = {'id': 1, 'name': 'tom', 'domain': 'example.com'} + serializer = self.Serializer(instance) + assert serializer.data == {'id': 1, 'email': 'tom@example.com'} + + def test_serialize_list(self): + instances = [ + {'id': 1, 'name': 'tom', 'domain': 'example.com'}, + {'id': 2, 'name': 'ann', 'domain': 'example.com'}, + ] + serializer = self.Serializer(instances, many=True) + assert serializer.data == [ + {'id': 1, 'email': 'tom@example.com'}, + {'id': 2, 'email': 'ann@example.com'} + ] + + def test_validate_data(self): + data = {'id': 1, 'email': 'tom@example.com'} + serializer = self.Serializer(data=data) + assert serializer.is_valid() + assert serializer.validated_data == { + 'id': 1, + 'name': 'tom', + 'domain': 'example.com' + } + + def test_validate_list(self): + data = [ + {'id': 1, 'email': 'tom@example.com'}, + {'id': 2, 'email': 'ann@example.com'}, + ] + serializer = self.Serializer(data=data, many=True) + assert serializer.is_valid() + assert serializer.validated_data == [ + {'id': 1, 'name': 'tom', 'domain': 'example.com'}, + {'id': 2, 'name': 'ann', 'domain': 'example.com'} + ] + + +class TestStarredSource: """ - Test pickleability of the output of Serializers - """ - def test_pickle_simple_model_serializer_data(self): - """ - Test simple serializer - """ - pickle.dumps(PersonSerializer(Person(name="Methusela", age=969)).data) - - def test_pickle_inner_serializer(self): - """ - Test pickling a serializer whose resulting .data (a SortedDictWithMetadata) will - have unpickleable meta data--in order to make sure metadata doesn't get pulled into the pickle. - See DictWithMetadata.__getstate__ - """ - class InnerPersonSerializer(serializers.ModelSerializer): - class Meta: - model = Person - fields = ('name', 'age') - pickle.dumps(InnerPersonSerializer(Person(name="Noah", age=950)).data, 0) - - def test_getstate_method_should_not_return_none(self): - """ - Regression test for #645. - """ - data = serializers.DictWithMetadata({1: 1}) - self.assertEqual(data.__getstate__(), serializers.SortedDict({1: 1})) - - def test_serializer_data_is_pickleable(self): - """ - Another regression test for #645. - """ - data = serializers.SortedDictWithMetadata({1: 1}) - repr(pickle.loads(pickle.dumps(data, 0))) - - -# test for issue #725 -class SeveralChoicesModel(models.Model): - color = models.CharField( - max_length=10, - choices=[('red', 'Red'), ('green', 'Green'), ('blue', 'Blue')], - blank=False - ) - drink = models.CharField( - max_length=10, - choices=[('beer', 'Beer'), ('wine', 'Wine'), ('cider', 'Cider')], - blank=False, - default='beer' - ) - os = models.CharField( - max_length=10, - choices=[('linux', 'Linux'), ('osx', 'OSX'), ('windows', 'Windows')], - blank=True - ) - music_genre = models.CharField( - max_length=10, - choices=[('rock', 'Rock'), ('metal', 'Metal'), ('grunge', 'Grunge')], - blank=True, - default='metal' - ) - - -class SerializerChoiceFields(TestCase): - - def setUp(self): - super(SerializerChoiceFields, self).setUp() - - class SeveralChoicesSerializer(serializers.ModelSerializer): - class Meta: - model = SeveralChoicesModel - fields = ('color', 'drink', 'os', 'music_genre') - - self.several_choices_serializer = SeveralChoicesSerializer - - def test_choices_blank_false_not_default(self): - serializer = self.several_choices_serializer() - self.assertEqual( - serializer.fields['color'].choices, - [('red', 'Red'), ('green', 'Green'), ('blue', 'Blue')] - ) - - def test_choices_blank_false_with_default(self): - serializer = self.several_choices_serializer() - self.assertEqual( - serializer.fields['drink'].choices, - [('beer', 'Beer'), ('wine', 'Wine'), ('cider', 'Cider')] - ) - - def test_choices_blank_true_not_default(self): - serializer = self.several_choices_serializer() - self.assertEqual( - serializer.fields['os'].choices, - BLANK_CHOICE_DASH + [('linux', 'Linux'), ('osx', 'OSX'), ('windows', 'Windows')] - ) - - def test_choices_blank_true_with_default(self): - serializer = self.several_choices_serializer() - self.assertEqual( - serializer.fields['music_genre'].choices, - BLANK_CHOICE_DASH + [('rock', 'Rock'), ('metal', 'Metal'), ('grunge', 'Grunge')] - ) - - -# Regression tests for #675 -class Ticket(models.Model): - assigned = models.ForeignKey( - Person, related_name='assigned_tickets') - reviewer = models.ForeignKey( - Person, blank=True, null=True, related_name='reviewed_tickets') - - -class SerializerRelatedChoicesTest(TestCase): - - def setUp(self): - super(SerializerRelatedChoicesTest, self).setUp() - - class RelatedChoicesSerializer(serializers.ModelSerializer): - class Meta: - model = Ticket - fields = ('assigned', 'reviewer') - - self.related_fields_serializer = RelatedChoicesSerializer - - def test_empty_queryset_required(self): - serializer = self.related_fields_serializer() - self.assertEqual(serializer.fields['assigned'].queryset.count(), 0) - self.assertEqual( - [x for x in serializer.fields['assigned'].widget.choices], - [] - ) - - def test_empty_queryset_not_required(self): - serializer = self.related_fields_serializer() - self.assertEqual(serializer.fields['reviewer'].queryset.count(), 0) - self.assertEqual( - [x for x in serializer.fields['reviewer'].widget.choices], - [('', '---------')] - ) - - def test_with_some_persons_required(self): - Person.objects.create(name="Lionel Messi") - Person.objects.create(name="Xavi Hernandez") - serializer = self.related_fields_serializer() - self.assertEqual(serializer.fields['assigned'].queryset.count(), 2) - self.assertEqual( - [x for x in serializer.fields['assigned'].widget.choices], - [(1, 'Person object - 1'), (2, 'Person object - 2')] - ) - - def test_with_some_persons_not_required(self): - Person.objects.create(name="Lionel Messi") - Person.objects.create(name="Xavi Hernandez") - serializer = self.related_fields_serializer() - self.assertEqual(serializer.fields['reviewer'].queryset.count(), 2) - self.assertEqual( - [x for x in serializer.fields['reviewer'].widget.choices], - [('', '---------'), (1, 'Person object - 1'), (2, 'Person object - 2')] - ) - - -class DepthTest(TestCase): - def test_implicit_nesting(self): - - writer = Person.objects.create(name="django", age=1) - post = BlogPost.objects.create(title="Test blog post", writer=writer) - comment = BlogPostComment.objects.create(text="Test blog post comment", blog_post=post) - - class BlogPostCommentSerializer(serializers.ModelSerializer): - class Meta: - model = BlogPostComment - depth = 2 - - serializer = BlogPostCommentSerializer(instance=comment) - expected = {'id': 1, 'text': 'Test blog post comment', 'blog_post': {'id': 1, 'title': 'Test blog post', - 'writer': {'id': 1, 'name': 'django', 'age': 1}}} + Tests for `source='*'` argument, which is used for nested representations. - self.assertEqual(serializer.data, expected) + For example: - def test_explicit_nesting(self): - writer = Person.objects.create(name="django", age=1) - post = BlogPost.objects.create(title="Test blog post", writer=writer) - comment = BlogPostComment.objects.create(text="Test blog post comment", blog_post=post) - - class PersonSerializer(serializers.ModelSerializer): - class Meta: - model = Person - - class BlogPostSerializer(serializers.ModelSerializer): - writer = PersonSerializer() - - class Meta: - model = BlogPost - - class BlogPostCommentSerializer(serializers.ModelSerializer): - blog_post = BlogPostSerializer() - - class Meta: - model = BlogPostComment + nested_field = NestedField(source='*') + """ + data = { + 'nested1': {'a': 1, 'b': 2}, + 'nested2': {'c': 3, 'd': 4} + } - serializer = BlogPostCommentSerializer(instance=comment) - expected = {'id': 1, 'text': 'Test blog post comment', 'blog_post': {'id': 1, 'title': 'Test blog post', - 'writer': {'id': 1, 'name': 'django', 'age': 1}}} + def setup(self): + class NestedSerializer1(serializers.Serializer): + a = serializers.IntegerField() + b = serializers.IntegerField() - self.assertEqual(serializer.data, expected) + class NestedSerializer2(serializers.Serializer): + c = serializers.IntegerField() + d = serializers.IntegerField() + class TestSerializer(serializers.Serializer): + nested1 = NestedSerializer1(source='*') + nested2 = NestedSerializer2(source='*') -class NestedSerializerContextTests(TestCase): + self.Serializer = TestSerializer - def test_nested_serializer_context(self): + def test_nested_validate(self): """ - Regression for #497 - - https://github.com/tomchristie/django-rest-framework/issues/497 + A nested representation is validated into a flat internal object. """ - class PhotoSerializer(serializers.ModelSerializer): - class Meta: - model = Photo - fields = ("description", "callable") - - callable = serializers.SerializerMethodField('_callable') - - def _callable(self, instance): - if 'context_item' not in self.context: - raise RuntimeError("context isn't getting passed into 2nd level nested serializer") - return "success" - - class AlbumSerializer(serializers.ModelSerializer): - class Meta: - model = Album - fields = ("photo_set", "callable") - - photo_set = PhotoSerializer(source="photo_set", many=True) - callable = serializers.SerializerMethodField("_callable") - - def _callable(self, instance): - if 'context_item' not in self.context: - raise RuntimeError("context isn't getting passed into 1st level nested serializer") - return "success" - - class AlbumCollection(object): - albums = None - - class AlbumCollectionSerializer(serializers.Serializer): - albums = AlbumSerializer(source="albums", many=True) - - album1 = Album.objects.create(title="album 1") - album2 = Album.objects.create(title="album 2") - Photo.objects.create(description="Bigfoot", album=album1) - Photo.objects.create(description="Unicorn", album=album1) - Photo.objects.create(description="Yeti", album=album2) - Photo.objects.create(description="Sasquatch", album=album2) - album_collection = AlbumCollection() - album_collection.albums = [album1, album2] - - # This will raise RuntimeError if context doesn't get passed correctly to the nested Serializers - AlbumCollectionSerializer(album_collection, context={'context_item': 'album context'}).data - - -class DeserializeListTestCase(TestCase): - - def setUp(self): - self.data = { - 'email': 'nobody@nowhere.com', - 'content': 'This is some test content', - 'created': datetime.datetime(2013, 3, 7), + serializer = self.Serializer(data=self.data) + assert serializer.is_valid() + assert serializer.validated_data == { + 'a': 1, + 'b': 2, + 'c': 3, + 'd': 4 } - def test_no_errors(self): - data = [self.data.copy() for x in range(0, 3)] - serializer = CommentSerializer(data=data, many=True) - self.assertTrue(serializer.is_valid()) - self.assertTrue(isinstance(serializer.object, list)) - self.assertTrue( - all((isinstance(item, Comment) for item in serializer.object)) - ) - - def test_errors_return_as_list(self): - invalid_item = self.data.copy() - invalid_item['email'] = '' - data = [self.data.copy(), invalid_item, self.data.copy()] - - serializer = CommentSerializer(data=data, many=True) - self.assertFalse(serializer.is_valid()) - expected = [{}, {'email': ['This field is required.']}, {}] - self.assertEqual(serializer.errors, expected) - - -# Test for issue 747 - -class LazyStringModel(object): - def __init__(self, lazystring): - self.lazystring = lazystring - - -class LazyStringSerializer(serializers.Serializer): - lazystring = serializers.Field() - - def restore_object(self, attrs, instance=None): - if instance is not None: - instance.lazystring = attrs.get('lazystring', instance.lazystring) - return instance - return LazyStringModel(**attrs) - - -class LazyStringsTestCase(TestCase): - def setUp(self): - self.model = LazyStringModel(lazystring=_('lazystring')) - - def test_lazy_strings_are_translated(self): - serializer = LazyStringSerializer(self.model) - self.assertEqual(type(serializer.data['lazystring']), - type('lazystring')) - - -# Test for issue #467 - -class FieldLabelTest(TestCase): - def setUp(self): - self.serializer_class = BasicModelSerializer - - def test_label_from_model(self): - """ - Validates that label and help_text are correctly copied from the model class. - """ - serializer = self.serializer_class() - text_field = serializer.fields['text'] - - self.assertEqual('Text comes here', text_field.label) - self.assertEqual('Text description.', text_field.help_text) - - def test_field_ctor(self): + def test_nested_serialize(self): """ - This is check that ctor supports both label and help_text. + An object can be serialized into a nested representation. """ - self.assertEqual('Label', fields.Field(label='Label', help_text='Help').label) - self.assertEqual('Help', fields.CharField(label='Label', help_text='Help').help_text) - self.assertEqual('Label', relations.HyperlinkedRelatedField(view_name='fake', label='Label', help_text='Help', many=True).label) + instance = {'a': 1, 'b': 2, 'c': 3, 'd': 4} + serializer = self.Serializer(instance) + assert serializer.data == self.data -# Test for issue #961 - -class ManyFieldHelpTextTest(TestCase): - def test_help_text_no_hold_down_control_msg(self): - """ - Validate that help_text doesn't contain the 'Hold down "Control" ...' - message that Django appends to choice fields. - """ - rel_field = fields.Field(help_text=ManyToManyModel._meta.get_field('rel').help_text) - self.assertEqual('Some help text.', rel_field.help_text) - +class TestIncorrectlyConfigured: + def test_incorrect_field_name(self): + class ExampleSerializer(serializers.Serializer): + incorrect_name = serializers.IntegerField() -class AttributeMappingOnAutogeneratedRelatedFields(TestCase): - - def test_primary_key_related_field(self): - serializer = ForeignKeySourceSerializer() - self.assertEqual(serializer.fields['target'].help_text, 'Target') - self.assertEqual(serializer.fields['target'].label, 'Target') - - def test_hyperlinked_related_field(self): - serializer = HyperlinkedForeignKeySourceSerializer() - self.assertEqual(serializer.fields['target'].help_text, 'Target') - self.assertEqual(serializer.fields['target'].label, 'Target') - - -@unittest.skipUnless(PIL is not None, 'PIL is not installed') -class AttributeMappingOnAutogeneratedFieldsTests(TestCase): - - def setUp(self): - - class AMOAFSerializer(serializers.ModelSerializer): - class Meta: - model = AMOAFModel - - self.serializer_class = AMOAFSerializer - self.fields_attributes = { - 'char_field': [ - ('max_length', 1024), - ], - 'comma_separated_integer_field': [ - ('max_length', 1024), - ], - 'decimal_field': [ - ('max_digits', 64), - ('decimal_places', 32), - ], - 'email_field': [ - ('max_length', 1024), - ], - 'file_field': [ - ('max_length', 1024), - ], - 'image_field': [ - ('max_length', 1024), - ], - 'slug_field': [ - ('max_length', 1024), - ], - 'url_field': [ - ('max_length', 1024), - ], - 'nullable_char_field': [ - ('max_length', 1024), - ('allow_none', True), - ], - } + class ExampleObject: + def __init__(self): + self.correct_name = 123 - def field_test(self, field): - serializer = self.serializer_class(data={}) - self.assertEqual(serializer.is_valid(), True) - - for attribute in self.fields_attributes[field]: - self.assertEqual( - getattr(serializer.fields[field], attribute[0]), - attribute[1] - ) - - def test_char_field(self): - self.field_test('char_field') - - def test_comma_separated_integer_field(self): - self.field_test('comma_separated_integer_field') - - def test_decimal_field(self): - self.field_test('decimal_field') - - def test_email_field(self): - self.field_test('email_field') - - def test_file_field(self): - self.field_test('file_field') - - def test_image_field(self): - self.field_test('image_field') - - def test_slug_field(self): - self.field_test('slug_field') - - def test_url_field(self): - self.field_test('url_field') - - def test_nullable_char_field(self): - self.field_test('nullable_char_field') - - -@unittest.skipUnless(PIL is not None, 'PIL is not installed') -class DefaultValuesOnAutogeneratedFieldsTests(TestCase): - - def setUp(self): - - class DVOAFSerializer(serializers.ModelSerializer): - class Meta: - model = DVOAFModel - - self.serializer_class = DVOAFSerializer - self.fields_attributes = { - 'positive_integer_field': [ - ('min_value', 0), - ], - 'positive_small_integer_field': [ - ('min_value', 0), - ], - 'email_field': [ - ('max_length', 75), - ], - 'file_field': [ - ('max_length', 100), - ], - 'image_field': [ - ('max_length', 100), - ], - 'slug_field': [ - ('max_length', 50), - ], - 'url_field': [ - ('max_length', 200), - ], - } - - def field_test(self, field): - serializer = self.serializer_class(data={}) - self.assertEqual(serializer.is_valid(), True) - - for attribute in self.fields_attributes[field]: - self.assertEqual( - getattr(serializer.fields[field], attribute[0]), - attribute[1] - ) - - def test_positive_integer_field(self): - self.field_test('positive_integer_field') - - def test_positive_small_integer_field(self): - self.field_test('positive_small_integer_field') - - def test_email_field(self): - self.field_test('email_field') - - def test_file_field(self): - self.field_test('file_field') - - def test_image_field(self): - self.field_test('image_field') - - def test_slug_field(self): - self.field_test('slug_field') - - def test_url_field(self): - self.field_test('url_field') - - -class MetadataSerializer(serializers.Serializer): - field1 = serializers.CharField(3, required=True) - field2 = serializers.CharField(10, required=False) - - -class MetadataSerializerTestCase(TestCase): - def setUp(self): - self.serializer = MetadataSerializer() - - def test_serializer_metadata(self): - metadata = self.serializer.metadata() - expected = { - 'field1': { - 'required': True, - 'max_length': 3, - 'type': 'string', - 'read_only': False - }, - 'field2': { - 'required': False, - 'max_length': 10, - 'type': 'string', - 'read_only': False - } - } - self.assertEqual(expected, metadata) - - -# Regression test for #840 - -class SimpleModel(models.Model): - text = models.CharField(max_length=100) - - -class SimpleModelSerializer(serializers.ModelSerializer): - text = serializers.CharField() - other = serializers.CharField() - - class Meta: - model = SimpleModel - - def validate_other(self, attrs, source): - del attrs['other'] - return attrs - - -class FieldValidationRemovingAttr(TestCase): - def test_removing_non_model_field_in_validation(self): - """ - Removing an attr during field valiation should ensure that it is not - passed through when restoring the object. - - This allows additional non-model fields to be supported. - - Regression test for #840. - """ - serializer = SimpleModelSerializer(data={'text': 'foo', 'other': 'bar'}) - self.assertTrue(serializer.is_valid()) - serializer.save() - self.assertEqual(serializer.object.text, 'foo') - - -# Regression test for #878 - -class SimpleTargetModel(models.Model): - text = models.CharField(max_length=100) - - -class SimplePKSourceModelSerializer(serializers.Serializer): - targets = serializers.PrimaryKeyRelatedField(queryset=SimpleTargetModel.objects.all(), many=True) - text = serializers.CharField() - - -class SimpleSlugSourceModelSerializer(serializers.Serializer): - targets = serializers.SlugRelatedField(queryset=SimpleTargetModel.objects.all(), many=True, slug_field='pk') - text = serializers.CharField() - - -class SerializerSupportsManyRelationships(TestCase): - def setUp(self): - SimpleTargetModel.objects.create(text='foo') - SimpleTargetModel.objects.create(text='bar') - - def test_serializer_supports_pk_many_relationships(self): - """ - Regression test for #878. - - Note that pk behavior has a different code path to usual cases, - for performance reasons. - """ - serializer = SimplePKSourceModelSerializer(data={'text': 'foo', 'targets': [1, 2]}) - self.assertTrue(serializer.is_valid()) - self.assertEqual(serializer.data, {'text': 'foo', 'targets': [1, 2]}) - - def test_serializer_supports_slug_many_relationships(self): - """ - Regression test for #878. - """ - serializer = SimpleSlugSourceModelSerializer(data={'text': 'foo', 'targets': [1, 2]}) - self.assertTrue(serializer.is_valid()) - self.assertEqual(serializer.data, {'text': 'foo', 'targets': [1, 2]}) - - -class TransformMethodsSerializer(serializers.Serializer): - a = serializers.CharField() - b_renamed = serializers.CharField(source='b') - - def transform_a(self, obj, value): - return value.lower() - - def transform_b_renamed(self, obj, value): - if value is not None: - return 'and ' + value - - -class TestSerializerTransformMethods(TestCase): - def setUp(self): - self.s = TransformMethodsSerializer() - - def test_transform_methods(self): - self.assertEqual( - self.s.to_native({'a': 'GREEN EGGS', 'b': 'HAM'}), - { - 'a': 'green eggs', - 'b_renamed': 'and HAM', - } - ) - - def test_missing_fields(self): - self.assertEqual( - self.s.to_native({'a': 'GREEN EGGS'}), - { - 'a': 'green eggs', - 'b_renamed': None, - } + instance = ExampleObject() + serializer = ExampleSerializer(instance) + with pytest.raises(AttributeError) as exc_info: + serializer.data + msg = str(exc_info.value) + assert msg.startswith( + "Got AttributeError when attempting to get a value for field `incorrect_name` on serializer `ExampleSerializer`.\n" + "The serializer field might be named incorrectly and not match any attribute or key on the `ExampleObject` instance.\n" + "Original exception text was:" ) -class DefaultTrueBooleanModel(models.Model): - cat = models.BooleanField(default=True) - dog = models.BooleanField(default=False) - - -class SerializerDefaultTrueBoolean(TestCase): - - def setUp(self): - super(SerializerDefaultTrueBoolean, self).setUp() - - class DefaultTrueBooleanSerializer(serializers.ModelSerializer): - class Meta: - model = DefaultTrueBooleanModel - fields = ('cat', 'dog') - - self.default_true_boolean_serializer = DefaultTrueBooleanSerializer - - def test_enabled_as_false(self): - serializer = self.default_true_boolean_serializer(data={'cat': False, - 'dog': False}) - self.assertEqual(serializer.is_valid(), True) - self.assertEqual(serializer.data['cat'], False) - self.assertEqual(serializer.data['dog'], False) - - def test_enabled_as_true(self): - serializer = self.default_true_boolean_serializer(data={'cat': True, - 'dog': True}) - self.assertEqual(serializer.is_valid(), True) - self.assertEqual(serializer.data['cat'], True) - self.assertEqual(serializer.data['dog'], True) - - def test_enabled_partial(self): - serializer = self.default_true_boolean_serializer(data={'cat': False}, - partial=True) - self.assertEqual(serializer.is_valid(), True) - self.assertEqual(serializer.data['cat'], False) - self.assertEqual(serializer.data['dog'], False) - - -class BoolenFieldTypeTest(TestCase): - ''' - Ensure the various Boolean based model fields are rendered as the proper - field type - - ''' - - def setUp(self): - ''' - Setup an ActionItemSerializer for BooleanTesting - ''' - data = { - 'title': 'b' * 201, - } - self.serializer = ActionItemSerializer(data=data) +class TestUnicodeRepr: + def test_unicode_repr(self): + class ExampleSerializer(serializers.Serializer): + example = serializers.CharField() - def test_booleanfield_type(self): - ''' - Test that BooleanField is infered from models.BooleanField - ''' - bfield = self.serializer.get_fields()['done'] - self.assertEqual(type(bfield), fields.BooleanField) + class ExampleObject: + def __init__(self): + self.example = '한êµ' - def test_nullbooleanfield_type(self): - ''' - Test that BooleanField is infered from models.NullBooleanField + def __repr__(self): + return unicode_repr(self.example) - https://groups.google.com/forum/#!topic/django-rest-framework/D9mXEftpuQ8 - ''' - bfield = self.serializer.get_fields()['started'] - self.assertEqual(type(bfield), fields.BooleanField) + instance = ExampleObject() + serializer = ExampleSerializer(instance) + repr(serializer) # Should not error. diff --git a/tests/test_serializer_bulk_update.py b/tests/test_serializer_bulk_update.py index 67a8ed0d..fb881a75 100644 --- a/tests/test_serializer_bulk_update.py +++ b/tests/test_serializer_bulk_update.py @@ -3,6 +3,7 @@ Tests to cover bulk create and update using serializers. """ from __future__ import unicode_literals from django.test import TestCase +from django.utils import six from rest_framework import serializers @@ -42,11 +43,11 @@ class BulkCreateSerializerTests(TestCase): serializer = self.BookSerializer(data=data, many=True) self.assertEqual(serializer.is_valid(), True) - self.assertEqual(serializer.object, data) + self.assertEqual(serializer.validated_data, data) def test_bulk_create_errors(self): """ - Correct bulk update serialization should return the input data. + Incorrect bulk create serialization should return errors. """ data = [ @@ -67,7 +68,7 @@ class BulkCreateSerializerTests(TestCase): expected_errors = [ {}, {}, - {'id': ['Enter a whole number.']} + {'id': ['A valid integer is required.']} ] serializer = self.BookSerializer(data=data, many=True) @@ -82,10 +83,12 @@ class BulkCreateSerializerTests(TestCase): serializer = self.BookSerializer(data=data, many=True) self.assertEqual(serializer.is_valid(), False) + text_type_string = six.text_type.__name__ + message = 'Invalid data. Expected a dictionary, but got %s.' % text_type_string expected_errors = [ - {'non_field_errors': ['Invalid data']}, - {'non_field_errors': ['Invalid data']}, - {'non_field_errors': ['Invalid data']} + {'non_field_errors': [message]}, + {'non_field_errors': [message]}, + {'non_field_errors': [message]} ] self.assertEqual(serializer.errors, expected_errors) @@ -98,7 +101,7 @@ class BulkCreateSerializerTests(TestCase): serializer = self.BookSerializer(data=data, many=True) self.assertEqual(serializer.is_valid(), False) - expected_errors = {'non_field_errors': ['Expected a list of items.']} + expected_errors = {'non_field_errors': ['Expected a list of items but got type `int`.']} self.assertEqual(serializer.errors, expected_errors) @@ -115,164 +118,6 @@ class BulkCreateSerializerTests(TestCase): serializer = self.BookSerializer(data=data, many=True) self.assertEqual(serializer.is_valid(), False) - expected_errors = {'non_field_errors': ['Expected a list of items.']} - - self.assertEqual(serializer.errors, expected_errors) - - -class BulkUpdateSerializerTests(TestCase): - """ - Updating multiple instances using serializers. - """ - - def setUp(self): - class Book(object): - """ - A data type that can be persisted to a mock storage backend - with `.save()` and `.delete()`. - """ - object_map = {} - - def __init__(self, id, title, author): - self.id = id - self.title = title - self.author = author - - def save(self): - Book.object_map[self.id] = self - - def delete(self): - del Book.object_map[self.id] - - class BookSerializer(serializers.Serializer): - id = serializers.IntegerField() - title = serializers.CharField(max_length=100) - author = serializers.CharField(max_length=100) - - def restore_object(self, attrs, instance=None): - if instance: - instance.id = attrs['id'] - instance.title = attrs['title'] - instance.author = attrs['author'] - return instance - return Book(**attrs) - - self.Book = Book - self.BookSerializer = BookSerializer - - data = [ - { - 'id': 0, - 'title': 'The electric kool-aid acid test', - 'author': 'Tom Wolfe' - }, { - 'id': 1, - 'title': 'If this is a man', - 'author': 'Primo Levi' - }, { - 'id': 2, - 'title': 'The wind-up bird chronicle', - 'author': 'Haruki Murakami' - } - ] - - for item in data: - book = Book(item['id'], item['title'], item['author']) - book.save() - - def books(self): - """ - Return all the objects in the mock storage backend. - """ - return self.Book.object_map.values() - - def test_bulk_update_success(self): - """ - Correct bulk update serialization should return the input data. - """ - data = [ - { - 'id': 0, - 'title': 'The electric kool-aid acid test', - 'author': 'Tom Wolfe' - }, { - 'id': 2, - 'title': 'Kafka on the shore', - 'author': 'Haruki Murakami' - } - ] - serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) - self.assertEqual(serializer.is_valid(), True) - self.assertEqual(serializer.data, data) - serializer.save() - new_data = self.BookSerializer(self.books(), many=True).data - - self.assertEqual(data, new_data) - - def test_bulk_update_and_create(self): - """ - Bulk update serialization may also include created items. - """ - data = [ - { - 'id': 0, - 'title': 'The electric kool-aid acid test', - 'author': 'Tom Wolfe' - }, { - 'id': 3, - 'title': 'Kafka on the shore', - 'author': 'Haruki Murakami' - } - ] - serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) - self.assertEqual(serializer.is_valid(), True) - self.assertEqual(serializer.data, data) - serializer.save() - new_data = self.BookSerializer(self.books(), many=True).data - self.assertEqual(data, new_data) + expected_errors = {'non_field_errors': ['Expected a list of items but got type `dict`.']} - def test_bulk_update_invalid_create(self): - """ - Bulk update serialization without allow_add_remove may not create items. - """ - data = [ - { - 'id': 0, - 'title': 'The electric kool-aid acid test', - 'author': 'Tom Wolfe' - }, { - 'id': 3, - 'title': 'Kafka on the shore', - 'author': 'Haruki Murakami' - } - ] - expected_errors = [ - {}, - {'non_field_errors': ['Cannot create a new item, only existing items may be updated.']} - ] - serializer = self.BookSerializer(self.books(), data=data, many=True) - self.assertEqual(serializer.is_valid(), False) - self.assertEqual(serializer.errors, expected_errors) - - def test_bulk_update_error(self): - """ - Incorrect bulk update serialization should return error data. - """ - data = [ - { - 'id': 0, - 'title': 'The electric kool-aid acid test', - 'author': 'Tom Wolfe' - }, { - 'id': 'foo', - 'title': 'Kafka on the shore', - 'author': 'Haruki Murakami' - } - ] - expected_errors = [ - {}, - {'id': ['Enter a whole number.']} - ] - serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) - self.assertEqual(serializer.is_valid(), False) self.assertEqual(serializer.errors, expected_errors) diff --git a/tests/test_serializer_empty.py b/tests/test_serializer_empty.py deleted file mode 100644 index 30cff361..00000000 --- a/tests/test_serializer_empty.py +++ /dev/null @@ -1,15 +0,0 @@ -from django.test import TestCase -from rest_framework import serializers - - -class EmptySerializerTestCase(TestCase): - def test_empty_serializer(self): - class FooBarSerializer(serializers.Serializer): - foo = serializers.IntegerField() - bar = serializers.SerializerMethodField('get_bar') - - def get_bar(self, obj): - return 'bar' - - serializer = FooBarSerializer() - self.assertEquals(serializer.data, {'foo': 0}) diff --git a/tests/test_serializer_import.py b/tests/test_serializer_import.py deleted file mode 100644 index 3b8ff4b3..00000000 --- a/tests/test_serializer_import.py +++ /dev/null @@ -1,19 +0,0 @@ -from django.test import TestCase - -from rest_framework import serializers -from tests.accounts.serializers import AccountSerializer - - -class ImportingModelSerializerTests(TestCase): - """ - In some situations like, GH #1225, it is possible, especially in - testing, to import a serializer who's related models have not yet - been resolved by Django. `AccountSerializer` is an example of such - a serializer (imported at the top of this file). - """ - def test_import_model_serializer(self): - """ - The serializer at the top of this file should have been - imported successfully, and we should be able to instantiate it. - """ - self.assertIsInstance(AccountSerializer(), serializers.ModelSerializer) diff --git a/tests/test_serializer_lists.py b/tests/test_serializer_lists.py new file mode 100644 index 00000000..35b68ae7 --- /dev/null +++ b/tests/test_serializer_lists.py @@ -0,0 +1,290 @@ +from rest_framework import serializers +from django.utils.datastructures import MultiValueDict + + +class BasicObject: + """ + A mock object for testing serializer save behavior. + """ + def __init__(self, **kwargs): + self._data = kwargs + for key, value in kwargs.items(): + setattr(self, key, value) + + def __eq__(self, other): + if self._data.keys() != other._data.keys(): + return False + for key in self._data.keys(): + if self._data[key] != other._data[key]: + return False + return True + + +class TestListSerializer: + """ + Tests for using a ListSerializer as a top-level serializer. + Note that this is in contrast to using ListSerializer as a field. + """ + + def setup(self): + class IntegerListSerializer(serializers.ListSerializer): + child = serializers.IntegerField() + self.Serializer = IntegerListSerializer + + def test_validate(self): + """ + Validating a list of items should return a list of validated items. + """ + input_data = ["123", "456"] + expected_output = [123, 456] + serializer = self.Serializer(data=input_data) + assert serializer.is_valid() + assert serializer.validated_data == expected_output + + def test_validate_html_input(self): + """ + HTML input should be able to mock list structures using [x] style ids. + """ + input_data = MultiValueDict({"[0]": ["123"], "[1]": ["456"]}) + expected_output = [123, 456] + serializer = self.Serializer(data=input_data) + assert serializer.is_valid() + assert serializer.validated_data == expected_output + + +class TestListSerializerContainingNestedSerializer: + """ + Tests for using a ListSerializer containing another serializer. + """ + + def setup(self): + class TestSerializer(serializers.Serializer): + integer = serializers.IntegerField() + boolean = serializers.BooleanField() + + def create(self, validated_data): + return BasicObject(**validated_data) + + class ObjectListSerializer(serializers.ListSerializer): + child = TestSerializer() + + self.Serializer = ObjectListSerializer + + def test_validate(self): + """ + Validating a list of dictionaries should return a list of + validated dictionaries. + """ + input_data = [ + {"integer": "123", "boolean": "true"}, + {"integer": "456", "boolean": "false"} + ] + expected_output = [ + {"integer": 123, "boolean": True}, + {"integer": 456, "boolean": False} + ] + serializer = self.Serializer(data=input_data) + assert serializer.is_valid() + assert serializer.validated_data == expected_output + + def test_create(self): + """ + Creating from a list of dictionaries should return a list of objects. + """ + input_data = [ + {"integer": "123", "boolean": "true"}, + {"integer": "456", "boolean": "false"} + ] + expected_output = [ + BasicObject(integer=123, boolean=True), + BasicObject(integer=456, boolean=False), + ] + serializer = self.Serializer(data=input_data) + assert serializer.is_valid() + assert serializer.save() == expected_output + + def test_serialize(self): + """ + Serialization of a list of objects should return a list of dictionaries. + """ + input_objects = [ + BasicObject(integer=123, boolean=True), + BasicObject(integer=456, boolean=False) + ] + expected_output = [ + {"integer": 123, "boolean": True}, + {"integer": 456, "boolean": False} + ] + serializer = self.Serializer(input_objects) + assert serializer.data == expected_output + + def test_validate_html_input(self): + """ + HTML input should be able to mock list structures using [x] + style prefixes. + """ + input_data = MultiValueDict({ + "[0]integer": ["123"], + "[0]boolean": ["true"], + "[1]integer": ["456"], + "[1]boolean": ["false"] + }) + expected_output = [ + {"integer": 123, "boolean": True}, + {"integer": 456, "boolean": False} + ] + serializer = self.Serializer(data=input_data) + assert serializer.is_valid() + assert serializer.validated_data == expected_output + + +class TestNestedListSerializer: + """ + Tests for using a ListSerializer as a field. + """ + + def setup(self): + class TestSerializer(serializers.Serializer): + integers = serializers.ListSerializer(child=serializers.IntegerField()) + booleans = serializers.ListSerializer(child=serializers.BooleanField()) + + def create(self, validated_data): + return BasicObject(**validated_data) + + self.Serializer = TestSerializer + + def test_validate(self): + """ + Validating a list of items should return a list of validated items. + """ + input_data = { + "integers": ["123", "456"], + "booleans": ["true", "false"] + } + expected_output = { + "integers": [123, 456], + "booleans": [True, False] + } + serializer = self.Serializer(data=input_data) + assert serializer.is_valid() + assert serializer.validated_data == expected_output + + def test_create(self): + """ + Creation with a list of items return an object with an attribute that + is a list of items. + """ + input_data = { + "integers": ["123", "456"], + "booleans": ["true", "false"] + } + expected_output = BasicObject( + integers=[123, 456], + booleans=[True, False] + ) + serializer = self.Serializer(data=input_data) + assert serializer.is_valid() + assert serializer.save() == expected_output + + def test_serialize(self): + """ + Serialization of a list of items should return a list of items. + """ + input_object = BasicObject( + integers=[123, 456], + booleans=[True, False] + ) + expected_output = { + "integers": [123, 456], + "booleans": [True, False] + } + serializer = self.Serializer(input_object) + assert serializer.data == expected_output + + def test_validate_html_input(self): + """ + HTML input should be able to mock list structures using [x] + style prefixes. + """ + input_data = MultiValueDict({ + "integers[0]": ["123"], + "integers[1]": ["456"], + "booleans[0]": ["true"], + "booleans[1]": ["false"] + }) + expected_output = { + "integers": [123, 456], + "booleans": [True, False] + } + serializer = self.Serializer(data=input_data) + assert serializer.is_valid() + assert serializer.validated_data == expected_output + + +class TestNestedListOfListsSerializer: + def setup(self): + class TestSerializer(serializers.Serializer): + integers = serializers.ListSerializer( + child=serializers.ListSerializer( + child=serializers.IntegerField() + ) + ) + booleans = serializers.ListSerializer( + child=serializers.ListSerializer( + child=serializers.BooleanField() + ) + ) + + self.Serializer = TestSerializer + + def test_validate(self): + input_data = { + 'integers': [['123', '456'], ['789', '0']], + 'booleans': [['true', 'true'], ['false', 'true']] + } + expected_output = { + "integers": [[123, 456], [789, 0]], + "booleans": [[True, True], [False, True]] + } + serializer = self.Serializer(data=input_data) + assert serializer.is_valid() + assert serializer.validated_data == expected_output + + def test_validate_html_input(self): + """ + HTML input should be able to mock lists of lists using [x][y] + style prefixes. + """ + input_data = MultiValueDict({ + "integers[0][0]": ["123"], + "integers[0][1]": ["456"], + "integers[1][0]": ["789"], + "integers[1][1]": ["000"], + "booleans[0][0]": ["true"], + "booleans[0][1]": ["true"], + "booleans[1][0]": ["false"], + "booleans[1][1]": ["true"] + }) + expected_output = { + "integers": [[123, 456], [789, 0]], + "booleans": [[True, True], [False, True]] + } + serializer = self.Serializer(data=input_data) + assert serializer.is_valid() + assert serializer.validated_data == expected_output + + +class TestListSerializerClass: + """Tests for a custom list_serializer_class.""" + def test_list_serializer_class_validate(self): + class CustomListSerializer(serializers.ListSerializer): + def validate(self, attrs): + raise serializers.ValidationError('Non field error') + + class TestSerializer(serializers.Serializer): + class Meta: + list_serializer_class = CustomListSerializer + + serializer = TestSerializer(data=[], many=True) + assert not serializer.is_valid() + assert serializer.errors == {'non_field_errors': ['Non field error']} diff --git a/tests/test_serializer_nested.py b/tests/test_serializer_nested.py index c09c24db..f5e4b26a 100644 --- a/tests/test_serializer_nested.py +++ b/tests/test_serializer_nested.py @@ -1,349 +1,40 @@ -""" -Tests to cover nested serializers. - -Doesn't cover model serializers. -""" -from __future__ import unicode_literals -from django.test import TestCase from rest_framework import serializers -from . import models - - -class WritableNestedSerializerBasicTests(TestCase): - """ - Tests for deserializing nested entities. - Basic tests that use serializers that simply restore to dicts. - """ - - def setUp(self): - class TrackSerializer(serializers.Serializer): - order = serializers.IntegerField() - title = serializers.CharField(max_length=100) - duration = serializers.IntegerField() - - class AlbumSerializer(serializers.Serializer): - album_name = serializers.CharField(max_length=100) - artist = serializers.CharField(max_length=100) - tracks = TrackSerializer(many=True) - - self.AlbumSerializer = AlbumSerializer - - def test_nested_validation_success(self): - """ - Correct nested serialization should return the input data. - """ - data = { - 'album_name': 'Discovery', - 'artist': 'Daft Punk', - 'tracks': [ - {'order': 1, 'title': 'One More Time', 'duration': 235}, - {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, - {'order': 3, 'title': 'Digital Love', 'duration': 239} - ] - } - - serializer = self.AlbumSerializer(data=data) - self.assertEqual(serializer.is_valid(), True) - self.assertEqual(serializer.object, data) - def test_nested_validation_error(self): - """ - Incorrect nested serialization should return appropriate error data. - """ +class TestNestedSerializer: + def setup(self): + class NestedSerializer(serializers.Serializer): + one = serializers.IntegerField(max_value=10) + two = serializers.IntegerField(max_value=10) - data = { - 'album_name': 'Discovery', - 'artist': 'Daft Punk', - 'tracks': [ - {'order': 1, 'title': 'One More Time', 'duration': 235}, - {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, - {'order': 3, 'title': 'Digital Love', 'duration': 'foobar'} - ] - } - expected_errors = { - 'tracks': [ - {}, - {}, - {'duration': ['Enter a whole number.']} - ] - } + class TestSerializer(serializers.Serializer): + nested = NestedSerializer() - serializer = self.AlbumSerializer(data=data) - self.assertEqual(serializer.is_valid(), False) - self.assertEqual(serializer.errors, expected_errors) + self.Serializer = TestSerializer - def test_many_nested_validation_error(self): - """ - Incorrect nested serialization should return appropriate error data - when multiple entities are being deserialized. - """ - - data = [ - { - 'album_name': 'Russian Red', - 'artist': 'I Love Your Glasses', - 'tracks': [ - {'order': 1, 'title': 'Cigarettes', 'duration': 121}, - {'order': 2, 'title': 'No Past Land', 'duration': 198}, - {'order': 3, 'title': 'They Don\'t Believe', 'duration': 191} - ] - }, - { - 'album_name': 'Discovery', - 'artist': 'Daft Punk', - 'tracks': [ - {'order': 1, 'title': 'One More Time', 'duration': 235}, - {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, - {'order': 3, 'title': 'Digital Love', 'duration': 'foobar'} - ] + def test_nested_validate(self): + input_data = { + 'nested': { + 'one': '1', + 'two': '2', } - ] - expected_errors = [ - {}, - { - 'tracks': [ - {}, - {}, - {'duration': ['Enter a whole number.']} - ] + } + expected_data = { + 'nested': { + 'one': 1, + 'two': 2, } - ] - - serializer = self.AlbumSerializer(data=data, many=True) - self.assertEqual(serializer.is_valid(), False) - self.assertEqual(serializer.errors, expected_errors) - - -class WritableNestedSerializerObjectTests(TestCase): - """ - Tests for deserializing nested entities. - These tests use serializers that restore to concrete objects. - """ - - def setUp(self): - # Couple of concrete objects that we're going to deserialize into - class Track(object): - def __init__(self, order, title, duration): - self.order, self.title, self.duration = order, title, duration - - def __eq__(self, other): - return ( - self.order == other.order and - self.title == other.title and - self.duration == other.duration - ) - - class Album(object): - def __init__(self, album_name, artist, tracks): - self.album_name, self.artist, self.tracks = album_name, artist, tracks - - def __eq__(self, other): - return ( - self.album_name == other.album_name and - self.artist == other.artist and - self.tracks == other.tracks - ) - - # And their corresponding serializers - class TrackSerializer(serializers.Serializer): - order = serializers.IntegerField() - title = serializers.CharField(max_length=100) - duration = serializers.IntegerField() - - def restore_object(self, attrs, instance=None): - return Track(attrs['order'], attrs['title'], attrs['duration']) - - class AlbumSerializer(serializers.Serializer): - album_name = serializers.CharField(max_length=100) - artist = serializers.CharField(max_length=100) - tracks = TrackSerializer(many=True) - - def restore_object(self, attrs, instance=None): - return Album(attrs['album_name'], attrs['artist'], attrs['tracks']) - - self.Album, self.Track = Album, Track - self.AlbumSerializer = AlbumSerializer - - def test_nested_validation_success(self): - """ - Correct nested serialization should return a restored object - that corresponds to the input data. - """ - - data = { - 'album_name': 'Discovery', - 'artist': 'Daft Punk', - 'tracks': [ - {'order': 1, 'title': 'One More Time', 'duration': 235}, - {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, - {'order': 3, 'title': 'Digital Love', 'duration': 239} - ] } - expected_object = self.Album( - album_name='Discovery', - artist='Daft Punk', - tracks=[ - self.Track(order=1, title='One More Time', duration=235), - self.Track(order=2, title='Aerodynamic', duration=184), - self.Track(order=3, title='Digital Love', duration=239), - ] - ) - - serializer = self.AlbumSerializer(data=data) - self.assertEqual(serializer.is_valid(), True) - self.assertEqual(serializer.object, expected_object) - - def test_many_nested_validation_success(self): - """ - Correct nested serialization should return multiple restored objects - that corresponds to the input data when multiple objects are - being deserialized. - """ - - data = [ - { - 'album_name': 'Russian Red', - 'artist': 'I Love Your Glasses', - 'tracks': [ - {'order': 1, 'title': 'Cigarettes', 'duration': 121}, - {'order': 2, 'title': 'No Past Land', 'duration': 198}, - {'order': 3, 'title': 'They Don\'t Believe', 'duration': 191} - ] - }, - { - 'album_name': 'Discovery', - 'artist': 'Daft Punk', - 'tracks': [ - {'order': 1, 'title': 'One More Time', 'duration': 235}, - {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, - {'order': 3, 'title': 'Digital Love', 'duration': 239} - ] + serializer = self.Serializer(data=input_data) + assert serializer.is_valid() + assert serializer.validated_data == expected_data + + def test_nested_serialize_empty(self): + expected_data = { + 'nested': { + 'one': None, + 'two': None } - ] - expected_object = [ - self.Album( - album_name='Russian Red', - artist='I Love Your Glasses', - tracks=[ - self.Track(order=1, title='Cigarettes', duration=121), - self.Track(order=2, title='No Past Land', duration=198), - self.Track(order=3, title='They Don\'t Believe', duration=191), - ] - ), - self.Album( - album_name='Discovery', - artist='Daft Punk', - tracks=[ - self.Track(order=1, title='One More Time', duration=235), - self.Track(order=2, title='Aerodynamic', duration=184), - self.Track(order=3, title='Digital Love', duration=239), - ] - ) - ] - - serializer = self.AlbumSerializer(data=data, many=True) - self.assertEqual(serializer.is_valid(), True) - self.assertEqual(serializer.object, expected_object) - - -class ForeignKeyNestedSerializerUpdateTests(TestCase): - def setUp(self): - class Artist(object): - def __init__(self, name): - self.name = name - - def __eq__(self, other): - return self.name == other.name - - class Album(object): - def __init__(self, name, artist): - self.name, self.artist = name, artist - - def __eq__(self, other): - return self.name == other.name and self.artist == other.artist - - class ArtistSerializer(serializers.Serializer): - name = serializers.CharField() - - def restore_object(self, attrs, instance=None): - if instance: - instance.name = attrs['name'] - else: - instance = Artist(attrs['name']) - return instance - - class AlbumSerializer(serializers.Serializer): - name = serializers.CharField() - by = ArtistSerializer(source='artist') - - def restore_object(self, attrs, instance=None): - if instance: - instance.name = attrs['name'] - instance.artist = attrs['artist'] - else: - instance = Album(attrs['name'], attrs['artist']) - return instance - - self.Artist = Artist - self.Album = Album - self.AlbumSerializer = AlbumSerializer - - def test_create_via_foreign_key_with_source(self): - """ - Check that we can both *create* and *update* into objects across - ForeignKeys that have a `source` specified. - Regression test for #1170 - """ - data = { - 'name': 'Discovery', - 'by': {'name': 'Daft Punk'}, } - - expected = self.Album(artist=self.Artist('Daft Punk'), name='Discovery') - - # create - serializer = self.AlbumSerializer(data=data) - self.assertEqual(serializer.is_valid(), True) - self.assertEqual(serializer.object, expected) - - # update - original = self.Album(artist=self.Artist('The Bats'), name='Free All the Monsters') - serializer = self.AlbumSerializer(instance=original, data=data) - self.assertEqual(serializer.is_valid(), True) - self.assertEqual(serializer.object, expected) - - -class NestedModelSerializerUpdateTests(TestCase): - def test_second_nested_level(self): - john = models.Person.objects.create(name="john") - - post = john.blogpost_set.create(title="Test blog post") - post.blogpostcomment_set.create(text="I hate this blog post") - post.blogpostcomment_set.create(text="I love this blog post") - - class BlogPostCommentSerializer(serializers.ModelSerializer): - class Meta: - model = models.BlogPostComment - - class BlogPostSerializer(serializers.ModelSerializer): - comments = BlogPostCommentSerializer(many=True, source='blogpostcomment_set') - - class Meta: - model = models.BlogPost - fields = ('id', 'title', 'comments') - - class PersonSerializer(serializers.ModelSerializer): - posts = BlogPostSerializer(many=True, source='blogpost_set') - - class Meta: - model = models.Person - fields = ('id', 'name', 'age', 'posts') - - serialize = PersonSerializer(instance=john) - deserialize = PersonSerializer(data=serialize.data, instance=john) - self.assertTrue(deserialize.is_valid()) - - result = deserialize.object - result.save() - self.assertEqual(result.id, john.id) + serializer = self.Serializer() + assert serializer.data == expected_data diff --git a/tests/test_serializers.py b/tests/test_serializers.py deleted file mode 100644 index 31c41730..00000000 --- a/tests/test_serializers.py +++ /dev/null @@ -1,31 +0,0 @@ -from django.test import TestCase -from django.utils import six -from rest_framework.serializers import _resolve_model -from tests.models import BasicModel - - -class ResolveModelTests(TestCase): - """ - `_resolve_model` should return a Django model class given the - provided argument is a Django model class itself, or a properly - formatted string representation of one. - """ - def test_resolve_django_model(self): - resolved_model = _resolve_model(BasicModel) - self.assertEqual(resolved_model, BasicModel) - - def test_resolve_string_representation(self): - resolved_model = _resolve_model('tests.BasicModel') - self.assertEqual(resolved_model, BasicModel) - - def test_resolve_unicode_representation(self): - resolved_model = _resolve_model(six.text_type('tests.BasicModel')) - self.assertEqual(resolved_model, BasicModel) - - def test_resolve_non_django_model(self): - with self.assertRaises(ValueError): - _resolve_model(TestCase) - - def test_resolve_improper_string_representation(self): - with self.assertRaises(ValueError): - _resolve_model('BasicModel') diff --git a/tests/test_settings.py b/tests/test_settings.py index e29fc34a..f2ff4ca1 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -1,22 +1,17 @@ -"""Tests for the settings module""" from __future__ import unicode_literals from django.test import TestCase - -from rest_framework.settings import APISettings, DEFAULTS, IMPORT_STRINGS +from rest_framework.settings import APISettings class TestSettings(TestCase): - """Tests relating to the api settings""" - - def test_non_import_errors(self): - """Make sure other errors aren't suppressed.""" - settings = APISettings({'DEFAULT_MODEL_SERIALIZER_CLASS': 'tests.extras.bad_import.ModelSerializer'}, DEFAULTS, IMPORT_STRINGS) - with self.assertRaises(ValueError): - settings.DEFAULT_MODEL_SERIALIZER_CLASS - def test_import_error_message_maintained(self): - """Make sure real import errors are captured and raised sensibly.""" - settings = APISettings({'DEFAULT_MODEL_SERIALIZER_CLASS': 'tests.extras.not_here.ModelSerializer'}, DEFAULTS, IMPORT_STRINGS) - with self.assertRaises(ImportError) as cm: - settings.DEFAULT_MODEL_SERIALIZER_CLASS - self.assertTrue('ImportError' in str(cm.exception)) + """ + Make sure import errors are captured and raised sensibly. + """ + settings = APISettings({ + 'DEFAULT_RENDERER_CLASSES': [ + 'tests.invalid_module.InvalidClassName' + ] + }) + with self.assertRaises(ImportError): + settings.DEFAULT_RENDERER_CLASSES diff --git a/tests/test_templatetags.py b/tests/test_templatetags.py index 75ee0eaa..b04a937e 100644 --- a/tests/test_templatetags.py +++ b/tests/test_templatetags.py @@ -4,6 +4,7 @@ from django.test import TestCase from rest_framework.test import APIRequestFactory from rest_framework.templatetags.rest_framework import add_query_param, urlize_quoted_links + factory = APIRequestFactory() @@ -49,3 +50,37 @@ class Issue1386Tests(TestCase): # example from issue #1386, this shouldn't raise an exception urlize_quoted_links("asdf:[/p]zxcv.com") + + +class URLizerTests(TestCase): + """ + Test if both JSON and YAML URLs are transformed into links well + """ + def _urlize_dict_check(self, data): + """ + For all items in dict test assert that the value is urlized key + """ + for original, urlized in data.items(): + assert urlize_quoted_links(original, nofollow=False) == urlized + + def test_json_with_url(self): + """ + Test if JSON URLs are transformed into links well + """ + data = {} + data['"url": "http://api/users/1/", '] = \ + '"url": "<a href="http://api/users/1/">http://api/users/1/</a>", ' + data['"foo_set": [\n "http://api/foos/1/"\n], '] = \ + '"foo_set": [\n "<a href="http://api/foos/1/">http://api/foos/1/</a>"\n], ' + self._urlize_dict_check(data) + + def test_yaml_with_url(self): + """ + Test if YAML URLs are transformed into links well + """ + data = {} + data['''{users: 'http://api/users/'}'''] = \ + '''{users: '<a href="http://api/users/">http://api/users/</a>'}''' + data['''foo_set: ['http://api/foos/1/']'''] = \ + '''foo_set: ['<a href="http://api/foos/1/">http://api/foos/1/</a>']''' + self._urlize_dict_check(data) diff --git a/tests/test_testing.py b/tests/test_testing.py index 9fd5966e..87d2b61f 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -1,15 +1,13 @@ -# -- coding: utf-8 -- - +# encoding: utf-8 from __future__ import unicode_literals from django.conf.urls import patterns, url -from io import BytesIO - from django.contrib.auth.models import User from django.shortcuts import redirect from django.test import TestCase from rest_framework.decorators import api_view from rest_framework.response import Response from rest_framework.test import APIClient, APIRequestFactory, force_authenticate +from io import BytesIO @api_view(['GET', 'POST']) @@ -109,7 +107,7 @@ class TestAPITestClient(TestCase): def test_can_logout(self): """ - `logout()` reset stored credentials + `logout()` resets stored credentials """ self.client.credentials(HTTP_AUTHORIZATION='example') response = self.client.get('/view/') @@ -118,6 +116,18 @@ class TestAPITestClient(TestCase): response = self.client.get('/view/') self.assertEqual(response.data['auth'], b'') + def test_logout_resets_force_authenticate(self): + """ + `logout()` resets any `force_authenticate` + """ + user = User.objects.create_user('example', 'example@example.com', 'password') + self.client.force_authenticate(user) + response = self.client.get('/view/') + self.assertEqual(response.data['user'], 'example') + self.client.logout() + response = self.client.get('/view/') + self.assertEqual(response.data['user'], '') + def test_follow_redirect(self): """ Follow redirect by setting follow argument. diff --git a/tests/test_throttling.py b/tests/test_throttling.py index 7b696f07..cc36a004 100644 --- a/tests/test_throttling.py +++ b/tests/test_throttling.py @@ -109,7 +109,7 @@ class ThrottlingTests(TestCase): def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers): """ - Ensure the response returns an X-Throttle field with status and next attributes + Ensure the response returns an Retry-After field with status and next attributes set properly. """ request = self.factory.get('/') @@ -117,10 +117,8 @@ class ThrottlingTests(TestCase): self.set_throttle_timer(view, timer) response = view.as_view()(request) if expect is not None: - self.assertEqual(response['X-Throttle-Wait-Seconds'], expect) self.assertEqual(response['Retry-After'], expect) else: - self.assertFalse('X-Throttle-Wait-Seconds' in response) self.assertFalse('Retry-After' in response) def test_seconds_fields(self): @@ -173,13 +171,11 @@ class ThrottlingTests(TestCase): self.assertFalse(hasattr(MockView_NonTimeThrottling.throttle_classes[0], 'called')) response = MockView_NonTimeThrottling.as_view()(request) - self.assertFalse('X-Throttle-Wait-Seconds' in response) self.assertFalse('Retry-After' in response) self.assertTrue(MockView_NonTimeThrottling.throttle_classes[0].called) response = MockView_NonTimeThrottling.as_view()(request) - self.assertFalse('X-Throttle-Wait-Seconds' in response) self.assertFalse('Retry-After' in response) diff --git a/tests/test_urlizer.py b/tests/test_urlizer.py deleted file mode 100644 index a77aa22a..00000000 --- a/tests/test_urlizer.py +++ /dev/null @@ -1,37 +0,0 @@ -from __future__ import unicode_literals -from django.test import TestCase -from rest_framework.templatetags.rest_framework import urlize_quoted_links - - -class URLizerTests(TestCase): - """ - Test if both JSON and YAML URLs are transformed into links well - """ - def _urlize_dict_check(self, data): - """ - For all items in dict test assert that the value is urlized key - """ - for original, urlized in data.items(): - assert urlize_quoted_links(original, nofollow=False) == urlized - - def test_json_with_url(self): - """ - Test if JSON URLs are transformed into links well - """ - data = {} - data['"url": "http://api/users/1/", '] = \ - '"url": "<a href="http://api/users/1/">http://api/users/1/</a>", ' - data['"foo_set": [\n "http://api/foos/1/"\n], '] = \ - '"foo_set": [\n "<a href="http://api/foos/1/">http://api/foos/1/</a>"\n], ' - self._urlize_dict_check(data) - - def test_yaml_with_url(self): - """ - Test if YAML URLs are transformed into links well - """ - data = {} - data['''{users: 'http://api/users/'}'''] = \ - '''{users: '<a href="http://api/users/">http://api/users/</a>'}''' - data['''foo_set: ['http://api/foos/1/']'''] = \ - '''foo_set: ['<a href="http://api/foos/1/">http://api/foos/1/</a>']''' - self._urlize_dict_check(data) diff --git a/tests/test_breadcrumbs.py b/tests/test_utils.py index 780fd5c4..8c286ea4 100644 --- a/tests/test_breadcrumbs.py +++ b/tests/test_utils.py @@ -1,8 +1,14 @@ from __future__ import unicode_literals +from django.core.exceptions import ImproperlyConfigured from django.conf.urls import patterns, url from django.test import TestCase +from django.utils import six +from rest_framework.utils.model_meta import _resolve_model from rest_framework.utils.breadcrumbs import get_breadcrumbs from rest_framework.views import APIView +from tests.models import BasicModel + +import rest_framework.utils.model_meta class Root(APIView): @@ -24,6 +30,7 @@ class NestedResourceRoot(APIView): class NestedResourceInstance(APIView): pass + urlpatterns = patterns( '', url(r'^$', Root.as_view()), @@ -35,9 +42,10 @@ urlpatterns = patterns( class BreadcrumbTests(TestCase): - """Tests the breadcrumb functionality used by the HTML renderer.""" - - urls = 'tests.test_breadcrumbs' + """ + Tests the breadcrumb functionality used by the HTML renderer. + """ + urls = 'tests.test_utils' def test_root_breadcrumbs(self): url = '/' @@ -98,3 +106,61 @@ class BreadcrumbTests(TestCase): get_breadcrumbs(url), [('Root', '/')] ) + + +class ResolveModelTests(TestCase): + """ + `_resolve_model` should return a Django model class given the + provided argument is a Django model class itself, or a properly + formatted string representation of one. + """ + def test_resolve_django_model(self): + resolved_model = _resolve_model(BasicModel) + self.assertEqual(resolved_model, BasicModel) + + def test_resolve_string_representation(self): + resolved_model = _resolve_model('tests.BasicModel') + self.assertEqual(resolved_model, BasicModel) + + def test_resolve_unicode_representation(self): + resolved_model = _resolve_model(six.text_type('tests.BasicModel')) + self.assertEqual(resolved_model, BasicModel) + + def test_resolve_non_django_model(self): + with self.assertRaises(ValueError): + _resolve_model(TestCase) + + def test_resolve_improper_string_representation(self): + with self.assertRaises(ValueError): + _resolve_model('BasicModel') + + +class ResolveModelWithPatchedDjangoTests(TestCase): + """ + Test coverage for when Django's `get_model` returns `None`. + + Under certain circumstances Django may return `None` with `get_model`: + http://git.io/get-model-source + + It usually happens with circular imports so it is important that DRF + excepts early, otherwise fault happens downstream and is much more + difficult to debug. + + """ + + def setUp(self): + """Monkeypatch get_model.""" + self.get_model = rest_framework.utils.model_meta.models.get_model + + def get_model(app_label, model_name): + return None + + rest_framework.utils.model_meta.models.get_model = get_model + + def tearDown(self): + """Revert monkeypatching.""" + rest_framework.utils.model_meta.models.get_model = self.get_model + + def test_blows_up_if_model_does_not_resolve(self): + with self.assertRaises(ImproperlyConfigured): + _resolve_model('tests.BasicModel') diff --git a/tests/test_validation.py b/tests/test_validation.py index a46e38ac..4234efd3 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1,10 +1,10 @@ from __future__ import unicode_literals -from django.core.validators import MaxValueValidator -from django.core.exceptions import ValidationError +from django.core.validators import RegexValidator, MaxValueValidator from django.db import models from django.test import TestCase from rest_framework import generics, serializers, status from rest_framework.test import APIRequestFactory +import re factory = APIRequestFactory() @@ -23,23 +23,10 @@ class ValidationModelSerializer(serializers.ModelSerializer): class UpdateValidationModel(generics.RetrieveUpdateDestroyAPIView): - model = ValidationModel + queryset = ValidationModel.objects.all() serializer_class = ValidationModelSerializer -class TestPreSaveValidationExclusions(TestCase): - def test_pre_save_validation_exclusions(self): - """ - Somewhat weird test case to ensure that we don't perform model - validation on read only fields. - """ - obj = ValidationModel.objects.create(blank_validated_field='') - request = factory.put('/', {}, format='json') - view = UpdateValidationModel().as_view() - response = view(request, pk=obj.pk).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - - # Regression for #653 class ShouldValidateModel(models.Model): @@ -49,11 +36,10 @@ class ShouldValidateModel(models.Model): class ShouldValidateModelSerializer(serializers.ModelSerializer): renamed = serializers.CharField(source='should_validate_field', required=False) - def validate_renamed(self, attrs, source): - value = attrs[source] + def validate_renamed(self, value): if len(value) < 3: raise serializers.ValidationError('Minimum 3 characters.') - return attrs + return value class Meta: model = ShouldValidateModel @@ -102,8 +88,11 @@ class TestAvoidValidation(TestCase): def test_serializer_errors_has_only_invalid_data_error(self): serializer = ValidationSerializer(data='invalid data') self.assertFalse(serializer.is_valid()) - self.assertDictEqual(serializer.errors, - {'non_field_errors': ['Invalid data']}) + self.assertDictEqual(serializer.errors, { + 'non_field_errors': [ + 'Invalid data. Expected a dictionary, but got %s.' % type('').__name__ + ] + }) # regression tests for issue: 1493 @@ -118,7 +107,7 @@ class ValidationMaxValueValidatorModelSerializer(serializers.ModelSerializer): class UpdateMaxValueValidationModel(generics.RetrieveUpdateDestroyAPIView): - model = ValidationMaxValueValidatorModel + queryset = ValidationMaxValueValidatorModel.objects.all() serializer_class = ValidationMaxValueValidatorModelSerializer @@ -145,7 +134,7 @@ class TestMaxValueValidatorValidation(TestCase): request = factory.patch('/{0}'.format(obj.pk), {'number_value': 101}, format='json') view = UpdateMaxValueValidationModel().as_view() response = view(request, pk=obj.pk).render() - self.assertEqual(response.content, b'{"number_value": ["Ensure this value is less than or equal to 100."]}') + self.assertEqual(response.content, b'{"number_value":["Ensure this value is less than or equal to 100."]}') self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) @@ -172,17 +161,23 @@ class TestChoiceFieldChoicesValidate(TestCase): f = serializers.ChoiceField(choices=self.CHOICES) value = self.CHOICES[0][0] try: - f.validate(value) - except ValidationError: + f.to_internal_value(value) + except serializers.ValidationError: self.fail("Value %s does not validate" % str(value)) - def test_nested_choices(self): - """ - Make sure a nested value for choices works as expected. - """ - f = serializers.ChoiceField(choices=self.CHOICES_NESTED) - value = self.CHOICES_NESTED[0][1][0][0] - try: - f.validate(value) - except ValidationError: - self.fail("Value %s does not validate" % str(value)) + +class RegexSerializer(serializers.Serializer): + pin = serializers.CharField( + validators=[RegexValidator(regex=re.compile('^[0-9]{4,6}$'), + message='A PIN is 4-6 digits')]) + +expected_repr = """ +RegexSerializer(): + pin = CharField(validators=[<django.core.validators.RegexValidator object>]) +""".strip() + + +class TestRegexSerializer(TestCase): + def test_regex_repr(self): + serializer_repr = repr(RegexSerializer()) + assert serializer_repr == expected_repr diff --git a/tests/test_validators.py b/tests/test_validators.py new file mode 100644 index 00000000..072cec36 --- /dev/null +++ b/tests/test_validators.py @@ -0,0 +1,293 @@ +from django.db import models +from django.test import TestCase +from rest_framework import serializers +import datetime + + +def dedent(blocktext): + return '\n'.join([line[12:] for line in blocktext.splitlines()[1:-1]]) + + +# Tests for `UniqueValidator` +# --------------------------- + +class UniquenessModel(models.Model): + username = models.CharField(unique=True, max_length=100) + + +class UniquenessSerializer(serializers.ModelSerializer): + class Meta: + model = UniquenessModel + + +class AnotherUniquenessModel(models.Model): + code = models.IntegerField(unique=True) + + +class AnotherUniquenessSerializer(serializers.ModelSerializer): + class Meta: + model = AnotherUniquenessModel + + +class TestUniquenessValidation(TestCase): + def setUp(self): + self.instance = UniquenessModel.objects.create(username='existing') + + def test_repr(self): + serializer = UniquenessSerializer() + expected = dedent(""" + UniquenessSerializer(): + id = IntegerField(label='ID', read_only=True) + username = CharField(max_length=100, validators=[<UniqueValidator(queryset=UniquenessModel.objects.all())>]) + """) + assert repr(serializer) == expected + + def test_is_not_unique(self): + data = {'username': 'existing'} + serializer = UniquenessSerializer(data=data) + assert not serializer.is_valid() + assert serializer.errors == {'username': ['This field must be unique.']} + + def test_is_unique(self): + data = {'username': 'other'} + serializer = UniquenessSerializer(data=data) + assert serializer.is_valid() + assert serializer.validated_data == {'username': 'other'} + + def test_updated_instance_excluded(self): + data = {'username': 'existing'} + serializer = UniquenessSerializer(self.instance, data=data) + assert serializer.is_valid() + assert serializer.validated_data == {'username': 'existing'} + + def test_doesnt_pollute_model(self): + instance = AnotherUniquenessModel.objects.create(code='100') + serializer = AnotherUniquenessSerializer(instance) + self.assertEqual( + AnotherUniquenessModel._meta.get_field('code').validators, []) + + # Accessing data shouldn't effect validators on the model + serializer.data + self.assertEqual( + AnotherUniquenessModel._meta.get_field('code').validators, []) + + +# Tests for `UniqueTogetherValidator` +# ----------------------------------- + +class UniquenessTogetherModel(models.Model): + race_name = models.CharField(max_length=100) + position = models.IntegerField() + + class Meta: + unique_together = ('race_name', 'position') + + +class UniquenessTogetherSerializer(serializers.ModelSerializer): + class Meta: + model = UniquenessTogetherModel + + +class TestUniquenessTogetherValidation(TestCase): + def setUp(self): + self.instance = UniquenessTogetherModel.objects.create( + race_name='example', + position=1 + ) + UniquenessTogetherModel.objects.create( + race_name='example', + position=2 + ) + UniquenessTogetherModel.objects.create( + race_name='other', + position=1 + ) + + def test_repr(self): + serializer = UniquenessTogetherSerializer() + expected = dedent(""" + UniquenessTogetherSerializer(): + id = IntegerField(label='ID', read_only=True) + race_name = CharField(max_length=100, required=True) + position = IntegerField(required=True) + class Meta: + validators = [<UniqueTogetherValidator(queryset=UniquenessTogetherModel.objects.all(), fields=('race_name', 'position'))>] + """) + assert repr(serializer) == expected + + def test_is_not_unique_together(self): + """ + Failing unique together validation should result in non field errors. + """ + data = {'race_name': 'example', 'position': 2} + serializer = UniquenessTogetherSerializer(data=data) + assert not serializer.is_valid() + assert serializer.errors == { + 'non_field_errors': [ + 'The fields race_name, position must make a unique set.' + ] + } + + def test_is_unique_together(self): + """ + In a unique together validation, one field may be non-unique + so long as the set as a whole is unique. + """ + data = {'race_name': 'other', 'position': 2} + serializer = UniquenessTogetherSerializer(data=data) + assert serializer.is_valid() + assert serializer.validated_data == { + 'race_name': 'other', + 'position': 2 + } + + def test_updated_instance_excluded_from_unique_together(self): + """ + When performing an update, the existing instance does not count + as a match against uniqueness. + """ + data = {'race_name': 'example', 'position': 1} + serializer = UniquenessTogetherSerializer(self.instance, data=data) + assert serializer.is_valid() + assert serializer.validated_data == { + 'race_name': 'example', + 'position': 1 + } + + def test_unique_together_is_required(self): + """ + In a unique together validation, all fields are required. + """ + data = {'position': 2} + serializer = UniquenessTogetherSerializer(data=data, partial=True) + assert not serializer.is_valid() + assert serializer.errors == { + 'race_name': ['This field is required.'] + } + + def test_ignore_excluded_fields(self): + """ + When model fields are not included in a serializer, then uniqueness + validators should not be added for that field. + """ + class ExcludedFieldSerializer(serializers.ModelSerializer): + class Meta: + model = UniquenessTogetherModel + fields = ('id', 'race_name',) + serializer = ExcludedFieldSerializer() + expected = dedent(""" + ExcludedFieldSerializer(): + id = IntegerField(label='ID', read_only=True) + race_name = CharField(max_length=100) + """) + assert repr(serializer) == expected + + +# Tests for `UniqueForDateValidator` +# ---------------------------------- + +class UniqueForDateModel(models.Model): + slug = models.CharField(max_length=100, unique_for_date='published') + published = models.DateField() + + +class UniqueForDateSerializer(serializers.ModelSerializer): + class Meta: + model = UniqueForDateModel + + +class TestUniquenessForDateValidation(TestCase): + def setUp(self): + self.instance = UniqueForDateModel.objects.create( + slug='existing', + published='2000-01-01' + ) + + def test_repr(self): + serializer = UniqueForDateSerializer() + expected = dedent(""" + UniqueForDateSerializer(): + id = IntegerField(label='ID', read_only=True) + slug = CharField(max_length=100) + published = DateField(required=True) + class Meta: + validators = [<UniqueForDateValidator(queryset=UniqueForDateModel.objects.all(), field='slug', date_field='published')>] + """) + assert repr(serializer) == expected + + def test_is_not_unique_for_date(self): + """ + Failing unique for date validation should result in field error. + """ + data = {'slug': 'existing', 'published': '2000-01-01'} + serializer = UniqueForDateSerializer(data=data) + assert not serializer.is_valid() + assert serializer.errors == { + 'slug': ['This field must be unique for the "published" date.'] + } + + def test_is_unique_for_date(self): + """ + Passing unique for date validation. + """ + data = {'slug': 'existing', 'published': '2000-01-02'} + serializer = UniqueForDateSerializer(data=data) + assert serializer.is_valid() + assert serializer.validated_data == { + 'slug': 'existing', + 'published': datetime.date(2000, 1, 2) + } + + def test_updated_instance_excluded_from_unique_for_date(self): + """ + When performing an update, the existing instance does not count + as a match against unique_for_date. + """ + data = {'slug': 'existing', 'published': '2000-01-01'} + serializer = UniqueForDateSerializer(instance=self.instance, data=data) + assert serializer.is_valid() + assert serializer.validated_data == { + 'slug': 'existing', + 'published': datetime.date(2000, 1, 1) + } + + +class HiddenFieldUniqueForDateModel(models.Model): + slug = models.CharField(max_length=100, unique_for_date='published') + published = models.DateTimeField(auto_now_add=True) + + +class TestHiddenFieldUniquenessForDateValidation(TestCase): + def test_repr_date_field_not_included(self): + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = HiddenFieldUniqueForDateModel + fields = ('id', 'slug') + + serializer = TestSerializer() + expected = dedent(""" + TestSerializer(): + id = IntegerField(label='ID', read_only=True) + slug = CharField(max_length=100) + published = HiddenField(default=CreateOnlyDefault(<function now>)) + class Meta: + validators = [<UniqueForDateValidator(queryset=HiddenFieldUniqueForDateModel.objects.all(), field='slug', date_field='published')>] + """) + assert repr(serializer) == expected + + def test_repr_date_field_included(self): + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = HiddenFieldUniqueForDateModel + fields = ('id', 'slug', 'published') + + serializer = TestSerializer() + expected = dedent(""" + TestSerializer(): + id = IntegerField(label='ID', read_only=True) + slug = CharField(max_length=100) + published = DateTimeField(default=CreateOnlyDefault(<function now>), read_only=True) + class Meta: + validators = [<UniqueForDateValidator(queryset=HiddenFieldUniqueForDateModel.objects.all(), field='slug', date_field='published')>] + """) + assert repr(serializer) == expected diff --git a/tests/test_viewsets.py b/tests/test_viewsets.py new file mode 100644 index 00000000..4d18a955 --- /dev/null +++ b/tests/test_viewsets.py @@ -0,0 +1,35 @@ +from django.test import TestCase +from rest_framework import status +from rest_framework.response import Response +from rest_framework.test import APIRequestFactory +from rest_framework.viewsets import GenericViewSet + + +factory = APIRequestFactory() + + +class BasicViewSet(GenericViewSet): + def list(self, request, *args, **kwargs): + return Response({'ACTION': 'LIST'}) + + +class InitializeViewSetsTestCase(TestCase): + def test_initialize_view_set_with_actions(self): + request = factory.get('/', '', content_type='application/json') + my_view = BasicViewSet.as_view(actions={ + 'get': 'list', + }) + + response = my_view(request) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, {'ACTION': 'LIST'}) + + def test_initialize_view_set_with_empty_actions(self): + try: + BasicViewSet.as_view() + except TypeError as e: + self.assertEqual(str(e), "The `actions` argument must be provided " + "when calling `.as_view()` on a ViewSet. " + "For example `.as_view({'get': 'list'})`") + else: + self.fail("actions must not be empty.") diff --git a/tests/test_write_only_fields.py b/tests/test_write_only_fields.py index aabb18d6..dd3bbd6e 100644 --- a/tests/test_write_only_fields.py +++ b/tests/test_write_only_fields.py @@ -1,42 +1,31 @@ -from django.db import models from django.test import TestCase from rest_framework import serializers -class ExampleModel(models.Model): - email = models.EmailField(max_length=100) - password = models.CharField(max_length=100) - - class WriteOnlyFieldTests(TestCase): - def test_write_only_fields(self): + def setUp(self): class ExampleSerializer(serializers.Serializer): email = serializers.EmailField() password = serializers.CharField(write_only=True) + def create(self, attrs): + return attrs + + self.Serializer = ExampleSerializer + + def write_only_fields_are_present_on_input(self): data = { 'email': 'foo@example.com', 'password': '123' } - serializer = ExampleSerializer(data=data) + serializer = self.Serializer(data=data) self.assertTrue(serializer.is_valid()) - self.assertEquals(serializer.object, data) - self.assertEquals(serializer.data, {'email': 'foo@example.com'}) - - def test_write_only_fields_meta(self): - class ExampleSerializer(serializers.ModelSerializer): - class Meta: - model = ExampleModel - fields = ('email', 'password') - write_only_fields = ('password',) + self.assertEquals(serializer.validated_data, data) - data = { + def write_only_fields_are_not_present_on_output(self): + instance = { 'email': 'foo@example.com', 'password': '123' } - serializer = ExampleSerializer(data=data) - self.assertTrue(serializer.is_valid()) - self.assertTrue(isinstance(serializer.object, ExampleModel)) - self.assertEquals(serializer.object.email, data['email']) - self.assertEquals(serializer.object.password, data['password']) + serializer = self.Serializer(instance) self.assertEquals(serializer.data, {'email': 'foo@example.com'}) diff --git a/tests/users/__init__.py b/tests/users/__init__.py deleted file mode 100644 index e69de29b..00000000 --- a/tests/users/__init__.py +++ /dev/null diff --git a/tests/users/models.py b/tests/users/models.py deleted file mode 100644 index 128bac90..00000000 --- a/tests/users/models.py +++ /dev/null @@ -1,6 +0,0 @@ -from django.db import models - - -class User(models.Model): - account = models.ForeignKey('accounts.Account', blank=True, null=True, related_name='users') - active_record = models.ForeignKey('records.Record', blank=True, null=True) diff --git a/tests/users/serializers.py b/tests/users/serializers.py deleted file mode 100644 index 4893ddb3..00000000 --- a/tests/users/serializers.py +++ /dev/null @@ -1,8 +0,0 @@ -from rest_framework import serializers - -from tests.users.models import User - - -class UserSerializer(serializers.ModelSerializer): - class Meta: - model = User diff --git a/tests/utils.py b/tests/utils.py index 28be81bd..5e902ba9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,4 +1,6 @@ from contextlib import contextmanager +from django.core.exceptions import ObjectDoesNotExist +from django.core.urlresolvers import NoReverseMatch from django.utils import six from rest_framework.settings import api_settings @@ -23,3 +25,54 @@ def temporary_setting(setting, value, module=None): if module is not None: six.moves.reload_module(module) + + +class MockObject(object): + def __init__(self, **kwargs): + self._kwargs = kwargs + for key, val in kwargs.items(): + setattr(self, key, val) + + def __str__(self): + kwargs_str = ', '.join([ + '%s=%s' % (key, value) + for key, value in sorted(self._kwargs.items()) + ]) + return '<MockObject %s>' % kwargs_str + + +class MockQueryset(object): + def __init__(self, iterable): + self.items = iterable + + def get(self, **lookup): + for item in self.items: + if all([ + getattr(item, key, None) == value + for key, value in lookup.items() + ]): + return item + raise ObjectDoesNotExist() + + +class BadType(object): + """ + When used as a lookup with a `MockQueryset`, these objects + will raise a `TypeError`, as occurs in Django when making + queryset lookups with an incorrect type for the lookup value. + """ + def __eq__(self): + raise TypeError() + + +def mock_reverse(view_name, args=None, kwargs=None, request=None, format=None): + args = args or [] + kwargs = kwargs or {} + value = (args + list(kwargs.values()) + ['-'])[0] + prefix = 'http://example.org' if request else '' + suffix = ('.' + format) if (format is not None) else '' + return '%s/%s/%s%s/' % (prefix, view_name, value, suffix) + + +def fail_reverse(view_name, args=None, kwargs=None, request=None, format=None): + raise NoReverseMatch() diff --git a/tests/views.py b/tests/views.py deleted file mode 100644 index 55935e92..00000000 --- a/tests/views.py +++ /dev/null @@ -1,8 +0,0 @@ -from rest_framework import generics -from .models import NullableForeignKeySource -from .serializers import NullableFKSourceSerializer - - -class NullableFKSourceDetail(generics.RetrieveUpdateDestroyAPIView): - model = NullableForeignKeySource - model_serializer_class = NullableFKSourceSerializer |
