aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/tests
diff options
context:
space:
mode:
authorTom Christie2013-06-05 13:33:19 +0100
committerTom Christie2013-06-05 13:33:19 +0100
commitde00ec95c3007dd90b5b01f7486b430699ea63c1 (patch)
treed2ce8037d446fd9133b3d6a77ebcc49350d7ebc3 /rest_framework/tests
parent9428d6ddb5ebc2d5d9c8557a52be09f0def69cca (diff)
parent2ca243a1144bb2a5461767a21ed14dec1d2b8dc2 (diff)
downloaddjango-rest-framework-de00ec95c3007dd90b5b01f7486b430699ea63c1.tar.bz2
Merge master
Diffstat (limited to 'rest_framework/tests')
-rw-r--r--rest_framework/tests/models.py17
-rw-r--r--rest_framework/tests/relations.py47
-rw-r--r--rest_framework/tests/test_authentication.py (renamed from rest_framework/tests/authentication.py)53
-rw-r--r--rest_framework/tests/test_breadcrumbs.py (renamed from rest_framework/tests/breadcrumbs.py)2
-rw-r--r--rest_framework/tests/test_decorators.py (renamed from rest_framework/tests/decorators.py)0
-rw-r--r--rest_framework/tests/test_description.py (renamed from rest_framework/tests/description.py)0
-rw-r--r--rest_framework/tests/test_fields.py (renamed from rest_framework/tests/fields.py)242
-rw-r--r--rest_framework/tests/test_files.py (renamed from rest_framework/tests/files.py)0
-rw-r--r--rest_framework/tests/test_filters.py (renamed from rest_framework/tests/filterset.py)230
-rw-r--r--rest_framework/tests/test_genericrelations.py (renamed from rest_framework/tests/genericrelations.py)0
-rw-r--r--rest_framework/tests/test_generics.py (renamed from rest_framework/tests/generics.py)115
-rw-r--r--rest_framework/tests/test_htmlrenderer.py (renamed from rest_framework/tests/htmlrenderer.py)14
-rw-r--r--rest_framework/tests/test_hyperlinkedserializers.py (renamed from rest_framework/tests/hyperlinkedserializers.py)50
-rw-r--r--rest_framework/tests/test_multitable_inheritance.py (renamed from rest_framework/tests/multitable_inheritance.py)0
-rw-r--r--rest_framework/tests/test_negotiation.py (renamed from rest_framework/tests/negotiation.py)9
-rw-r--r--rest_framework/tests/test_pagination.py (renamed from rest_framework/tests/pagination.py)14
-rw-r--r--rest_framework/tests/test_parsers.py (renamed from rest_framework/tests/parsers.py)0
-rw-r--r--rest_framework/tests/test_permissions.py (renamed from rest_framework/tests/permissions.py)42
-rw-r--r--rest_framework/tests/test_relations.py100
-rw-r--r--rest_framework/tests/test_relations_hyperlink.py (renamed from rest_framework/tests/relations_hyperlink.py)79
-rw-r--r--rest_framework/tests/test_relations_nested.py (renamed from rest_framework/tests/relations_nested.py)0
-rw-r--r--rest_framework/tests/test_relations_pk.py (renamed from rest_framework/tests/relations_pk.py)121
-rw-r--r--rest_framework/tests/test_relations_slug.py (renamed from rest_framework/tests/relations_slug.py)0
-rw-r--r--rest_framework/tests/test_renderers.py (renamed from rest_framework/tests/renderers.py)66
-rw-r--r--rest_framework/tests/test_request.py (renamed from rest_framework/tests/request.py)2
-rw-r--r--rest_framework/tests/test_response.py (renamed from rest_framework/tests/response.py)131
-rw-r--r--rest_framework/tests/test_reverse.py (renamed from rest_framework/tests/reverse.py)2
-rw-r--r--rest_framework/tests/test_routers.py150
-rw-r--r--rest_framework/tests/test_serializer.py (renamed from rest_framework/tests/serializer.py)560
-rw-r--r--rest_framework/tests/test_serializer_bulk_update.py (renamed from rest_framework/tests/serializer_bulk_update.py)0
-rw-r--r--rest_framework/tests/test_serializer_nested.py (renamed from rest_framework/tests/serializer_nested.py)0
-rw-r--r--rest_framework/tests/test_settings.py (renamed from rest_framework/tests/settings.py)0
-rw-r--r--rest_framework/tests/test_throttling.py (renamed from rest_framework/tests/throttling.py)2
-rw-r--r--rest_framework/tests/test_urlpatterns.py (renamed from rest_framework/tests/urlpatterns.py)0
-rw-r--r--rest_framework/tests/test_validation.py (renamed from rest_framework/tests/validation.py)22
-rw-r--r--rest_framework/tests/test_views.py (renamed from rest_framework/tests/views.py)5
-rw-r--r--rest_framework/tests/testcases.py66
-rw-r--r--rest_framework/tests/tests.py6
38 files changed, 1889 insertions, 258 deletions
diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py
index f2117538..e2d4eacd 100644
--- a/rest_framework/tests/models.py
+++ b/rest_framework/tests/models.py
@@ -1,5 +1,7 @@
from __future__ import unicode_literals
from django.db import models
+from django.utils.translation import ugettext_lazy as _
+from rest_framework import serializers
def foobar():
@@ -32,7 +34,7 @@ class Anchor(RESTFrameworkModel):
class BasicModel(RESTFrameworkModel):
- text = models.CharField(max_length=100)
+ text = models.CharField(max_length=100, verbose_name=_("Text comes here"), help_text=_("Text description."))
class SlugBasedModel(RESTFrameworkModel):
@@ -58,13 +60,6 @@ class ReadOnlyManyToManyModel(RESTFrameworkModel):
rel = models.ManyToManyField(Anchor)
-# Model to test filtering.
-class FilterableItem(RESTFrameworkModel):
- text = models.CharField(max_length=100)
- decimal = models.DecimalField(max_digits=4, decimal_places=2)
- date = models.DateField()
-
-
# Model for regression test for #285
class Comment(RESTFrameworkModel):
@@ -166,3 +161,9 @@ class NullableOneToOneSource(RESTFrameworkModel):
name = models.CharField(max_length=100)
target = models.OneToOneField(OneToOneTarget, null=True, blank=True,
related_name='nullable_source')
+
+
+# Serializer used to test BasicModel
+class BasicModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BasicModel
diff --git a/rest_framework/tests/relations.py b/rest_framework/tests/relations.py
deleted file mode 100644
index cbf93c65..00000000
--- a/rest_framework/tests/relations.py
+++ /dev/null
@@ -1,47 +0,0 @@
-"""
-General tests for relational fields.
-"""
-from __future__ import unicode_literals
-from django.db import models
-from django.test import TestCase
-from rest_framework import serializers
-
-
-class NullModel(models.Model):
- pass
-
-
-class FieldTests(TestCase):
- def test_pk_related_field_with_empty_string(self):
- """
- Regression test for #446
-
- https://github.com/tomchristie/django-rest-framework/issues/446
- """
- field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all())
- self.assertRaises(serializers.ValidationError, field.from_native, '')
- self.assertRaises(serializers.ValidationError, field.from_native, [])
-
- def test_hyperlinked_related_field_with_empty_string(self):
- field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='')
- self.assertRaises(serializers.ValidationError, field.from_native, '')
- self.assertRaises(serializers.ValidationError, field.from_native, [])
-
- def test_slug_related_field_with_empty_string(self):
- field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk')
- self.assertRaises(serializers.ValidationError, field.from_native, '')
- self.assertRaises(serializers.ValidationError, field.from_native, [])
-
-
-class TestManyRelateMixin(TestCase):
- def test_missing_many_to_many_related_field(self):
- '''
- Regression test for #632
-
- https://github.com/tomchristie/django-rest-framework/pull/632
- '''
- field = serializers.RelatedField(many=True, read_only=False)
-
- into = {}
- field.field_from_native({}, None, 'field_name', into)
- self.assertEqual(into['field_name'], [])
diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/test_authentication.py
index 8e6d3e51..d46ac079 100644
--- a/rest_framework/tests/authentication.py
+++ b/rest_framework/tests/test_authentication.py
@@ -6,6 +6,8 @@ from django.utils import unittest
from rest_framework import HTTP_HEADER_ENCODING
from rest_framework import exceptions
from rest_framework import permissions
+from rest_framework import renderers
+from rest_framework.response import Response
from rest_framework import status
from rest_framework.authentication import (
BaseAuthentication,
@@ -48,7 +50,7 @@ urlpatterns = patterns('',
(r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),
(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),
(r'^oauth/$', MockView.as_view(authentication_classes=[OAuthAuthentication])),
- (r'^oauth-with-scope/$', MockView.as_view(authentication_classes=[OAuthAuthentication],
+ (r'^oauth-with-scope/$', MockView.as_view(authentication_classes=[OAuthAuthentication],
permission_classes=[permissions.TokenHasReadWriteScope]))
)
@@ -56,14 +58,14 @@ if oauth2_provider is not None:
urlpatterns += patterns('',
url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')),
url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])),
- url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication],
+ url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication],
permission_classes=[permissions.TokenHasReadWriteScope])),
)
class BasicAuthTests(TestCase):
"""Basic authentication"""
- urls = 'rest_framework.tests.authentication'
+ urls = 'rest_framework.tests.test_authentication'
def setUp(self):
self.csrf_client = Client(enforce_csrf_checks=True)
@@ -102,7 +104,7 @@ class BasicAuthTests(TestCase):
class SessionAuthTests(TestCase):
"""User session authentication"""
- urls = 'rest_framework.tests.authentication'
+ urls = 'rest_framework.tests.test_authentication'
def setUp(self):
self.csrf_client = Client(enforce_csrf_checks=True)
@@ -149,7 +151,7 @@ class SessionAuthTests(TestCase):
class TokenAuthTests(TestCase):
"""Token authentication"""
- urls = 'rest_framework.tests.authentication'
+ urls = 'rest_framework.tests.test_authentication'
def setUp(self):
self.csrf_client = Client(enforce_csrf_checks=True)
@@ -243,7 +245,7 @@ class IncorrectCredentialsTests(TestCase):
class OAuthTests(TestCase):
"""OAuth 1.0a authentication"""
- urls = 'rest_framework.tests.authentication'
+ urls = 'rest_framework.tests.test_authentication'
def setUp(self):
# these imports are here because oauth is optional and hiding them in try..except block or compat
@@ -429,7 +431,7 @@ class OAuthTests(TestCase):
class OAuth2Tests(TestCase):
"""OAuth 2.0 authentication"""
- urls = 'rest_framework.tests.authentication'
+ urls = 'rest_framework.tests.test_authentication'
def setUp(self):
self.csrf_client = Client(enforce_csrf_checks=True)
@@ -553,3 +555,40 @@ class OAuth2Tests(TestCase):
auth = self._create_authorization_header(token=read_write_access_token.token)
response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
+
+
+class FailingAuthAccessedInRenderer(TestCase):
+ def setUp(self):
+ class AuthAccessingRenderer(renderers.BaseRenderer):
+ media_type = 'text/plain'
+ format = 'txt'
+
+ def render(self, data, media_type=None, renderer_context=None):
+ request = renderer_context['request']
+ if request.user.is_authenticated():
+ return b'authenticated'
+ return b'not authenticated'
+
+ class FailingAuth(BaseAuthentication):
+ def authenticate(self, request):
+ raise exceptions.AuthenticationFailed('authentication failed')
+
+ class ExampleView(APIView):
+ authentication_classes = (FailingAuth,)
+ renderer_classes = (AuthAccessingRenderer,)
+
+ def get(self, request):
+ return Response({'foo': 'bar'})
+
+ self.view = ExampleView.as_view()
+
+ def test_failing_auth_accessed_in_renderer(self):
+ """
+ When authentication fails the renderer should still be able to access
+ `request.user` without raising an exception. Particularly relevant
+ to HTML responses that might reasonably access `request.user`.
+ """
+ request = factory.get('/')
+ response = self.view(request)
+ content = response.render().content
+ self.assertEqual(content, b'not authenticated')
diff --git a/rest_framework/tests/breadcrumbs.py b/rest_framework/tests/test_breadcrumbs.py
index d9ed647e..41ddf2ce 100644
--- a/rest_framework/tests/breadcrumbs.py
+++ b/rest_framework/tests/test_breadcrumbs.py
@@ -36,7 +36,7 @@ urlpatterns = patterns('',
class BreadcrumbTests(TestCase):
"""Tests the breadcrumb functionality used by the HTML renderer."""
- urls = 'rest_framework.tests.breadcrumbs'
+ urls = 'rest_framework.tests.test_breadcrumbs'
def test_root_breadcrumbs(self):
url = '/'
diff --git a/rest_framework/tests/decorators.py b/rest_framework/tests/test_decorators.py
index 1016fed3..1016fed3 100644
--- a/rest_framework/tests/decorators.py
+++ b/rest_framework/tests/test_decorators.py
diff --git a/rest_framework/tests/description.py b/rest_framework/tests/test_description.py
index 52c1a34c..52c1a34c 100644
--- a/rest_framework/tests/description.py
+++ b/rest_framework/tests/test_description.py
diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/test_fields.py
index 3cdfa0f6..69a0468e 100644
--- a/rest_framework/tests/fields.py
+++ b/rest_framework/tests/test_fields.py
@@ -2,15 +2,16 @@
General serializer field tests.
"""
from __future__ import unicode_literals
+
import datetime
from decimal import Decimal
-
+from uuid import uuid4
+from django.core import validators
from django.db import models
from django.test import TestCase
-from django.core import validators
-
+from django.utils.datastructures import SortedDict
from rest_framework import serializers
-from rest_framework.serializers import Serializer
+from rest_framework.tests.models import RESTFrameworkModel
class TimestampedModel(models.Model):
@@ -63,6 +64,20 @@ class BasicFieldTests(TestCase):
serializer = CharPrimaryKeyModelSerializer()
self.assertEqual(serializer.fields['id'].read_only, False)
+ def test_dict_field_ordering(self):
+ """
+ Field should preserve dictionary ordering, if it exists.
+ See: https://github.com/tomchristie/django-rest-framework/issues/832
+ """
+ ret = SortedDict()
+ ret['c'] = 1
+ ret['b'] = 1
+ ret['a'] = 1
+ ret['z'] = 1
+ field = serializers.Field()
+ keys = list(field.to_native(ret).keys())
+ self.assertEqual(keys, ['c', 'b', 'a', 'z'])
+
class DateFieldTest(TestCase):
"""
@@ -573,7 +588,7 @@ class DecimalFieldTest(TestCase):
"""
Make sure the serializer works correctly
"""
- class DecimalSerializer(Serializer):
+ class DecimalSerializer(serializers.Serializer):
decimal_field = serializers.DecimalField(max_value=9010,
min_value=9000,
max_digits=6,
@@ -591,7 +606,7 @@ class DecimalFieldTest(TestCase):
"""
Make sure max_value violations raises ValidationError
"""
- class DecimalSerializer(Serializer):
+ class DecimalSerializer(serializers.Serializer):
decimal_field = serializers.DecimalField(max_value=100)
s = DecimalSerializer(data={'decimal_field': '123'})
@@ -603,7 +618,7 @@ class DecimalFieldTest(TestCase):
"""
Make sure min_value violations raises ValidationError
"""
- class DecimalSerializer(Serializer):
+ class DecimalSerializer(serializers.Serializer):
decimal_field = serializers.DecimalField(min_value=100)
s = DecimalSerializer(data={'decimal_field': '99'})
@@ -615,7 +630,7 @@ class DecimalFieldTest(TestCase):
"""
Make sure max_digits violations raises ValidationError
"""
- class DecimalSerializer(Serializer):
+ class DecimalSerializer(serializers.Serializer):
decimal_field = serializers.DecimalField(max_digits=5)
s = DecimalSerializer(data={'decimal_field': '123.456'})
@@ -627,7 +642,7 @@ class DecimalFieldTest(TestCase):
"""
Make sure max_decimal_places violations raises ValidationError
"""
- class DecimalSerializer(Serializer):
+ class DecimalSerializer(serializers.Serializer):
decimal_field = serializers.DecimalField(decimal_places=3)
s = DecimalSerializer(data={'decimal_field': '123.4567'})
@@ -639,10 +654,215 @@ class DecimalFieldTest(TestCase):
"""
Make sure max_whole_digits violations raises ValidationError
"""
- class DecimalSerializer(Serializer):
+ class DecimalSerializer(serializers.Serializer):
decimal_field = serializers.DecimalField(max_digits=4, decimal_places=3)
s = DecimalSerializer(data={'decimal_field': '12345.6'})
self.assertFalse(s.is_valid())
- self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']}) \ No newline at end of file
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']})
+
+
+class ChoiceFieldTests(TestCase):
+ """
+ Tests for the ChoiceField options generator
+ """
+
+ SAMPLE_CHOICES = [
+ ('red', 'Red'),
+ ('green', 'Green'),
+ ('blue', 'Blue'),
+ ]
+
+ def test_choices_required(self):
+ """
+ Make sure proper choices are rendered if field is required
+ """
+ f = serializers.ChoiceField(required=True, choices=self.SAMPLE_CHOICES)
+ self.assertEqual(f.choices, self.SAMPLE_CHOICES)
+
+ def test_choices_not_required(self):
+ """
+ Make sure proper choices (plus blank) are rendered if the field isn't required
+ """
+ f = serializers.ChoiceField(required=False, choices=self.SAMPLE_CHOICES)
+ self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + self.SAMPLE_CHOICES)
+
+
+class EmailFieldTests(TestCase):
+ """
+ Tests for EmailField attribute values
+ """
+
+ class EmailFieldModel(RESTFrameworkModel):
+ email_field = models.EmailField(blank=True)
+
+ class EmailFieldWithGivenMaxLengthModel(RESTFrameworkModel):
+ email_field = models.EmailField(max_length=150, blank=True)
+
+ def test_default_model_value(self):
+ class EmailFieldSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = self.EmailFieldModel
+
+ serializer = EmailFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 75)
+
+ def test_given_model_value(self):
+ class EmailFieldSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = self.EmailFieldWithGivenMaxLengthModel
+
+ serializer = EmailFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 150)
+
+ def test_given_serializer_value(self):
+ class EmailFieldSerializer(serializers.ModelSerializer):
+ email_field = serializers.EmailField(source='email_field', max_length=20, required=False)
+
+ class Meta:
+ model = self.EmailFieldModel
+
+ serializer = EmailFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 20)
+
+
+class SlugFieldTests(TestCase):
+ """
+ Tests for SlugField attribute values
+ """
+
+ class SlugFieldModel(RESTFrameworkModel):
+ slug_field = models.SlugField(blank=True)
+
+ class SlugFieldWithGivenMaxLengthModel(RESTFrameworkModel):
+ slug_field = models.SlugField(max_length=84, blank=True)
+
+ def test_default_model_value(self):
+ class SlugFieldSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = self.SlugFieldModel
+
+ serializer = SlugFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['slug_field'], 'max_length'), 50)
+
+ def test_given_model_value(self):
+ class SlugFieldSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = self.SlugFieldWithGivenMaxLengthModel
+
+ serializer = SlugFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['slug_field'], 'max_length'), 84)
+
+ def test_given_serializer_value(self):
+ class SlugFieldSerializer(serializers.ModelSerializer):
+ slug_field = serializers.SlugField(source='slug_field',
+ max_length=20, required=False)
+
+ class Meta:
+ model = self.SlugFieldModel
+
+ serializer = SlugFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['slug_field'],
+ 'max_length'), 20)
+
+ def test_invalid_slug(self):
+ """
+ Make sure an invalid slug raises ValidationError
+ """
+ class SlugFieldSerializer(serializers.ModelSerializer):
+ slug_field = serializers.SlugField(source='slug_field', max_length=20, required=True)
+
+ class Meta:
+ model = self.SlugFieldModel
+
+ s = SlugFieldSerializer(data={'slug_field': 'a b'})
+
+ self.assertEqual(s.is_valid(), False)
+ self.assertEqual(s.errors, {'slug_field': ["Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens."]})
+
+
+class URLFieldTests(TestCase):
+ """
+ Tests for URLField attribute values
+ """
+
+ class URLFieldModel(RESTFrameworkModel):
+ url_field = models.URLField(blank=True)
+
+ class URLFieldWithGivenMaxLengthModel(RESTFrameworkModel):
+ url_field = models.URLField(max_length=128, blank=True)
+
+ def test_default_model_value(self):
+ class URLFieldSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = self.URLFieldModel
+
+ serializer = URLFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['url_field'],
+ 'max_length'), 200)
+
+ def test_given_model_value(self):
+ class URLFieldSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = self.URLFieldWithGivenMaxLengthModel
+
+ serializer = URLFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['url_field'],
+ 'max_length'), 128)
+
+ def test_given_serializer_value(self):
+ class URLFieldSerializer(serializers.ModelSerializer):
+ url_field = serializers.URLField(source='url_field',
+ max_length=20, required=False)
+
+ class Meta:
+ model = self.URLFieldWithGivenMaxLengthModel
+
+ serializer = URLFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['url_field'],
+ 'max_length'), 20)
+
+
+class FieldMetadata(TestCase):
+ def setUp(self):
+ self.required_field = serializers.Field()
+ self.required_field.label = uuid4().hex
+ self.required_field.required = True
+
+ self.optional_field = serializers.Field()
+ self.optional_field.label = uuid4().hex
+ self.optional_field.required = False
+
+ def test_required(self):
+ self.assertEqual(self.required_field.metadata()['required'], True)
+
+ def test_optional(self):
+ self.assertEqual(self.optional_field.metadata()['required'], False)
+
+ def test_label(self):
+ for field in (self.required_field, self.optional_field):
+ self.assertEqual(field.metadata()['label'], field.label)
+
+
+class FieldCallableDefault(TestCase):
+ def setUp(self):
+ self.simple_callable = lambda: 'foo bar'
+
+ def test_default_can_be_simple_callable(self):
+ """
+ Ensure that the 'default' argument can also be a simple callable.
+ """
+ field = serializers.WritableField(default=self.simple_callable)
+ into = {}
+ field.field_from_native({}, {}, 'field', into)
+ self.assertEqual(into, {'field': 'foo bar'})
diff --git a/rest_framework/tests/files.py b/rest_framework/tests/test_files.py
index 487046ac..487046ac 100644
--- a/rest_framework/tests/files.py
+++ b/rest_framework/tests/test_files.py
diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/test_filters.py
index 023bd016..aaed6247 100644
--- a/rest_framework/tests/filterset.py
+++ b/rest_framework/tests/test_filters.py
@@ -1,23 +1,30 @@
from __future__ import unicode_literals
import datetime
from decimal import Decimal
+from django.db import models
from django.core.urlresolvers import reverse
from django.test import TestCase
from django.test.client import RequestFactory
from django.utils import unittest
from rest_framework import generics, serializers, status, filters
from rest_framework.compat import django_filters, patterns, url
-from rest_framework.tests.models import FilterableItem, BasicModel
+from rest_framework.tests.models import BasicModel
factory = RequestFactory()
+class FilterableItem(models.Model):
+ text = models.CharField(max_length=100)
+ decimal = models.DecimalField(max_digits=4, decimal_places=2)
+ date = models.DateField()
+
+
if django_filters:
# Basic filter on a list view.
class FilterFieldsRootView(generics.ListCreateAPIView):
model = FilterableItem
filter_fields = ['decimal', 'date']
- filter_backend = filters.DjangoFilterBackend
+ filter_backends = (filters.DjangoFilterBackend,)
# These class are used to test a filter class.
class SeveralFieldsFilter(django_filters.FilterSet):
@@ -32,7 +39,7 @@ if django_filters:
class FilterClassRootView(generics.ListCreateAPIView):
model = FilterableItem
filter_class = SeveralFieldsFilter
- filter_backend = filters.DjangoFilterBackend
+ filter_backends = (filters.DjangoFilterBackend,)
# These classes are used to test a misconfigured filter class.
class MisconfiguredFilter(django_filters.FilterSet):
@@ -45,12 +52,12 @@ if django_filters:
class IncorrectlyConfiguredRootView(generics.ListCreateAPIView):
model = FilterableItem
filter_class = MisconfiguredFilter
- filter_backend = filters.DjangoFilterBackend
+ filter_backends = (filters.DjangoFilterBackend,)
class FilterClassDetailView(generics.RetrieveAPIView):
model = FilterableItem
filter_class = SeveralFieldsFilter
- filter_backend = filters.DjangoFilterBackend
+ filter_backends = (filters.DjangoFilterBackend,)
# Regression test for #814
class FilterableItemSerializer(serializers.ModelSerializer):
@@ -61,11 +68,21 @@ if django_filters:
queryset = FilterableItem.objects.all()
serializer_class = FilterableItemSerializer
filter_fields = ['decimal', 'date']
- filter_backend = filters.DjangoFilterBackend
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ class GetQuerysetView(generics.ListCreateAPIView):
+ serializer_class = FilterableItemSerializer
+ filter_class = SeveralFieldsFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ def get_queryset(self):
+ return FilterableItem.objects.all()
urlpatterns = patterns('',
url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'),
url(r'^$', FilterClassRootView.as_view(), name='root-view'),
+ url(r'^get-queryset/$', GetQuerysetView.as_view(),
+ name='get-queryset-view'),
)
@@ -141,6 +158,17 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
self.assertEqual(response.data, expected_data)
@unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_filter_with_get_queryset_only(self):
+ """
+ Regression test for #834.
+ """
+ view = GetQuerysetView.as_view()
+ request = factory.get('/get-queryset/')
+ view(request).render()
+ # Used to raise "issubclass() arg 2 must be a class or tuple of classes"
+ # here when neither `model' nor `queryset' was specified.
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
def test_get_filtered_class_root_view(self):
"""
GET requests to filtered ListCreateAPIView that have a filter_class set
@@ -215,7 +243,7 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase):
"""
Integration tests for filtered detail views.
"""
- urls = 'rest_framework.tests.filterset'
+ urls = 'rest_framework.tests.test_filters'
def _get_url(self, item):
return reverse('detail-view', kwargs=dict(pk=item.pk))
@@ -256,3 +284,191 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase):
response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, valid_item_data)
+
+
+class SearchFilterModel(models.Model):
+ title = models.CharField(max_length=20)
+ text = models.CharField(max_length=100)
+
+
+class SearchFilterTests(TestCase):
+ def setUp(self):
+ # Sequence of title/text is:
+ #
+ # z abc
+ # zz bcd
+ # zzz cde
+ # ...
+ for idx in range(10):
+ title = 'z' * (idx + 1)
+ text = (
+ chr(idx + ord('a')) +
+ chr(idx + ord('b')) +
+ chr(idx + ord('c'))
+ )
+ SearchFilterModel(title=title, text=text).save()
+
+ def test_search(self):
+ class SearchListView(generics.ListAPIView):
+ model = SearchFilterModel
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('title', 'text')
+
+ view = SearchListView.as_view()
+ request = factory.get('?search=b')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 1, 'title': 'z', 'text': 'abc'},
+ {'id': 2, 'title': 'zz', 'text': 'bcd'}
+ ]
+ )
+
+ def test_exact_search(self):
+ class SearchListView(generics.ListAPIView):
+ model = SearchFilterModel
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('=title', 'text')
+
+ view = SearchListView.as_view()
+ request = factory.get('?search=zzz')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'zzz', 'text': 'cde'}
+ ]
+ )
+
+ def test_startswith_search(self):
+ class SearchListView(generics.ListAPIView):
+ model = SearchFilterModel
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('title', '^text')
+
+ view = SearchListView.as_view()
+ request = factory.get('?search=b')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 2, 'title': 'zz', 'text': 'bcd'}
+ ]
+ )
+
+
+class OrdringFilterModel(models.Model):
+ title = models.CharField(max_length=20)
+ text = models.CharField(max_length=100)
+
+
+class OrderingFilterTests(TestCase):
+ def setUp(self):
+ # Sequence of title/text is:
+ #
+ # zyx abc
+ # yxw bcd
+ # xwv cde
+ for idx in range(3):
+ title = (
+ chr(ord('z') - idx) +
+ chr(ord('y') - idx) +
+ chr(ord('x') - idx)
+ )
+ text = (
+ chr(idx + ord('a')) +
+ chr(idx + ord('b')) +
+ chr(idx + ord('c'))
+ )
+ OrdringFilterModel(title=title, text=text).save()
+
+ def test_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=text')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ ]
+ )
+
+ def test_reverse_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=-text')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
+
+ def test_incorrectfield_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=foobar')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
+
+ def test_default_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
+
+ def test_default_ordering_using_string(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = 'title'
+
+ view = OrderingListView.as_view()
+ request = factory.get('')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
diff --git a/rest_framework/tests/genericrelations.py b/rest_framework/tests/test_genericrelations.py
index c38bfb9f..c38bfb9f 100644
--- a/rest_framework/tests/genericrelations.py
+++ b/rest_framework/tests/test_genericrelations.py
diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/test_generics.py
index eca50d82..37734195 100644
--- a/rest_framework/tests/generics.py
+++ b/rest_framework/tests/test_generics.py
@@ -2,7 +2,7 @@ from __future__ import unicode_literals
from django.db import models
from django.shortcuts import get_object_or_404
from django.test import TestCase
-from rest_framework import generics, serializers, status
+from rest_framework import generics, renderers, serializers, status
from rest_framework.tests.utils import RequestFactory
from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel
from rest_framework.compat import six
@@ -39,6 +39,7 @@ class SlugBasedInstanceView(InstanceView):
"""
model = SlugBasedModel
serializer_class = SlugSerializer
+ lookup_field = 'slug'
class TestRootView(TestCase):
@@ -120,7 +121,25 @@ class TestRootView(TestCase):
'text/html'
],
'name': 'Root',
- 'description': 'Example description for OPTIONS.'
+ 'description': 'Example description for OPTIONS.',
+ 'actions': {
+ 'POST': {
+ 'text': {
+ 'max_length': 100,
+ 'read_only': False,
+ 'required': True,
+ 'type': 'string',
+ "label": "Text comes here",
+ "help_text": "Text description."
+ },
+ 'id': {
+ 'read_only': True,
+ 'required': False,
+ 'type': 'integer',
+ 'label': 'ID',
+ },
+ }
+ }
}
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, expected)
@@ -223,9 +242,9 @@ class TestInstanceView(TestCase):
"""
OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata
"""
- request = factory.options('/')
- with self.assertNumQueries(0):
- response = self.view(request).render()
+ request = factory.options('/1')
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=1).render()
expected = {
'parses': [
'application/json',
@@ -237,11 +256,39 @@ class TestInstanceView(TestCase):
'text/html'
],
'name': 'Instance',
- 'description': 'Example description for OPTIONS.'
+ 'description': 'Example description for OPTIONS.',
+ 'actions': {
+ 'PUT': {
+ 'text': {
+ 'max_length': 100,
+ 'read_only': False,
+ 'required': True,
+ 'type': 'string',
+ 'label': 'Text comes here',
+ 'help_text': 'Text description.'
+ },
+ 'id': {
+ 'read_only': True,
+ 'required': False,
+ 'type': 'integer',
+ 'label': 'ID',
+ },
+ }
+ }
}
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, expected)
+ def test_get_instance_view_incorrect_arg(self):
+ """
+ GET requests with an incorrect pk type, should raise 404, not 500.
+ Regression test for #890.
+ """
+ request = factory.get('/a')
+ with self.assertNumQueries(0):
+ response = self.view(request, pk='a').render()
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
def test_put_cannot_set_id(self):
"""
PUT requests to create a new object should not be able to set the id.
@@ -434,22 +481,14 @@ class TestFilterBackendAppliedToViews(TestCase):
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
- self.root_view = RootView.as_view()
- self.instance_view = InstanceView.as_view()
- self.original_root_backend = getattr(RootView, 'filter_backend')
- self.original_instance_backend = getattr(InstanceView, 'filter_backend')
-
- def tearDown(self):
- setattr(RootView, 'filter_backend', self.original_root_backend)
- setattr(InstanceView, 'filter_backend', self.original_instance_backend)
def test_get_root_view_filters_by_name_with_filter_backend(self):
"""
GET requests to ListCreateAPIView should return filtered list.
"""
- setattr(RootView, 'filter_backend', InclusiveFilterBackend)
+ root_view = RootView.as_view(filter_backends=(InclusiveFilterBackend,))
request = factory.get('/')
- response = self.root_view(request).render()
+ response = root_view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data), 1)
self.assertEqual(response.data, [{'id': 1, 'text': 'foo'}])
@@ -458,9 +497,9 @@ class TestFilterBackendAppliedToViews(TestCase):
"""
GET requests to ListCreateAPIView should return empty list when all models are filtered out.
"""
- setattr(RootView, 'filter_backend', ExclusiveFilterBackend)
+ root_view = RootView.as_view(filter_backends=(ExclusiveFilterBackend,))
request = factory.get('/')
- response = self.root_view(request).render()
+ response = root_view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, [])
@@ -468,9 +507,9 @@ class TestFilterBackendAppliedToViews(TestCase):
"""
GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out.
"""
- setattr(InstanceView, 'filter_backend', ExclusiveFilterBackend)
+ instance_view = InstanceView.as_view(filter_backends=(ExclusiveFilterBackend,))
request = factory.get('/1')
- response = self.instance_view(request, pk=1).render()
+ response = instance_view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(response.data, {'detail': 'Not found'})
@@ -478,8 +517,40 @@ class TestFilterBackendAppliedToViews(TestCase):
"""
GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded
"""
- setattr(InstanceView, 'filter_backend', InclusiveFilterBackend)
+ instance_view = InstanceView.as_view(filter_backends=(InclusiveFilterBackend,))
request = factory.get('/1')
- response = self.instance_view(request, pk=1).render()
+ response = instance_view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, {'id': 1, 'text': 'foo'})
+
+
+class TwoFieldModel(models.Model):
+ field_a = models.CharField(max_length=100)
+ field_b = models.CharField(max_length=100)
+
+
+class DynamicSerializerView(generics.ListCreateAPIView):
+ model = TwoFieldModel
+ renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer)
+
+ def get_serializer_class(self):
+ if self.request.method == 'POST':
+ class DynamicSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = TwoFieldModel
+ fields = ('field_b',)
+ return DynamicSerializer
+ return super(DynamicSerializerView, self).get_serializer_class()
+
+
+class TestFilterBackendAppliedToViews(TestCase):
+
+ def test_dynamic_serializer_form_in_browsable_api(self):
+ """
+ GET requests to ListCreateAPIView should return filtered list.
+ """
+ view = DynamicSerializerView.as_view()
+ request = factory.get('/')
+ response = view(request).render()
+ self.assertContains(response, 'field_b')
+ self.assertNotContains(response, 'field_a')
diff --git a/rest_framework/tests/htmlrenderer.py b/rest_framework/tests/test_htmlrenderer.py
index 8f2e2b5a..8957a43c 100644
--- a/rest_framework/tests/htmlrenderer.py
+++ b/rest_framework/tests/test_htmlrenderer.py
@@ -42,7 +42,7 @@ urlpatterns = patterns('',
class TemplateHTMLRendererTests(TestCase):
- urls = 'rest_framework.tests.htmlrenderer'
+ urls = 'rest_framework.tests.test_htmlrenderer'
def setUp(self):
"""
@@ -66,23 +66,23 @@ class TemplateHTMLRendererTests(TestCase):
def test_simple_html_view(self):
response = self.client.get('/')
self.assertContains(response, "example: foobar")
- self.assertEqual(response['Content-Type'], 'text/html')
+ self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
def test_not_found_html_view(self):
response = self.client.get('/not_found')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(response.content, six.b("404 Not Found"))
- self.assertEqual(response['Content-Type'], 'text/html')
+ self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
def test_permission_denied_html_view(self):
response = self.client.get('/permission_denied')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(response.content, six.b("403 Forbidden"))
- self.assertEqual(response['Content-Type'], 'text/html')
+ self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
class TemplateHTMLRendererExceptionTests(TestCase):
- urls = 'rest_framework.tests.htmlrenderer'
+ urls = 'rest_framework.tests.test_htmlrenderer'
def setUp(self):
"""
@@ -109,10 +109,10 @@ class TemplateHTMLRendererExceptionTests(TestCase):
response = self.client.get('/not_found')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(response.content, six.b("404: Not found"))
- self.assertEqual(response['Content-Type'], 'text/html')
+ self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
def test_permission_denied_html_view_with_template(self):
response = self.client.get('/permission_denied')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(response.content, six.b("403: Permission denied"))
- self.assertEqual(response['Content-Type'], 'text/html')
+ self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/test_hyperlinkedserializers.py
index 9a61f299..1894ddb2 100644
--- a/rest_framework/tests/hyperlinkedserializers.py
+++ b/rest_framework/tests/test_hyperlinkedserializers.py
@@ -27,6 +27,14 @@ class PhotoSerializer(serializers.Serializer):
return Photo(**attrs)
+class AlbumSerializer(serializers.ModelSerializer):
+ url = serializers.HyperlinkedIdentityField(view_name='album-detail', lookup_field='title')
+
+ class Meta:
+ model = Album
+ fields = ('title', 'url')
+
+
class BasicList(generics.ListCreateAPIView):
model = BasicModel
model_serializer_class = serializers.HyperlinkedModelSerializer
@@ -73,6 +81,8 @@ class PhotoListCreate(generics.ListCreateAPIView):
class AlbumDetail(generics.RetrieveAPIView):
model = Album
+ serializer_class = AlbumSerializer
+ lookup_field = 'title'
class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView):
@@ -96,7 +106,7 @@ urlpatterns = patterns('',
class TestBasicHyperlinkedView(TestCase):
- urls = 'rest_framework.tests.hyperlinkedserializers'
+ urls = 'rest_framework.tests.test_hyperlinkedserializers'
def setUp(self):
"""
@@ -133,7 +143,7 @@ class TestBasicHyperlinkedView(TestCase):
class TestManyToManyHyperlinkedView(TestCase):
- urls = 'rest_framework.tests.hyperlinkedserializers'
+ urls = 'rest_framework.tests.test_hyperlinkedserializers'
def setUp(self):
"""
@@ -180,8 +190,38 @@ class TestManyToManyHyperlinkedView(TestCase):
self.assertEqual(response.data, self.data[0])
+class TestHyperlinkedIdentityFieldLookup(TestCase):
+ urls = 'rest_framework.tests.test_hyperlinkedserializers'
+
+ def setUp(self):
+ """
+ Create 3 Album instances.
+ """
+ titles = ['foo', 'bar', 'baz']
+ for title in titles:
+ album = Album(title=title)
+ album.save()
+ self.detail_view = AlbumDetail.as_view()
+ self.data = {
+ 'foo': {'title': 'foo', 'url': 'http://testserver/albums/foo/'},
+ 'bar': {'title': 'bar', 'url': 'http://testserver/albums/bar/'},
+ 'baz': {'title': 'baz', 'url': 'http://testserver/albums/baz/'}
+ }
+
+ def test_lookup_field(self):
+ """
+ GET requests to AlbumDetail view should return serialized Albums
+ with a url field keyed by `title`.
+ """
+ for album in Album.objects.all():
+ request = factory.get('/albums/{0}/'.format(album.title))
+ response = self.detail_view(request, title=album.title)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data[album.title])
+
+
class TestCreateWithForeignKeys(TestCase):
- urls = 'rest_framework.tests.hyperlinkedserializers'
+ urls = 'rest_framework.tests.test_hyperlinkedserializers'
def setUp(self):
"""
@@ -206,7 +246,7 @@ class TestCreateWithForeignKeys(TestCase):
class TestCreateWithForeignKeysAndCustomSlug(TestCase):
- urls = 'rest_framework.tests.hyperlinkedserializers'
+ urls = 'rest_framework.tests.test_hyperlinkedserializers'
def setUp(self):
"""
@@ -231,7 +271,7 @@ class TestCreateWithForeignKeysAndCustomSlug(TestCase):
class TestOptionalRelationHyperlinkedView(TestCase):
- urls = 'rest_framework.tests.hyperlinkedserializers'
+ urls = 'rest_framework.tests.test_hyperlinkedserializers'
def setUp(self):
"""
diff --git a/rest_framework/tests/multitable_inheritance.py b/rest_framework/tests/test_multitable_inheritance.py
index 00c15327..00c15327 100644
--- a/rest_framework/tests/multitable_inheritance.py
+++ b/rest_framework/tests/test_multitable_inheritance.py
diff --git a/rest_framework/tests/negotiation.py b/rest_framework/tests/test_negotiation.py
index 43721b84..7f84827f 100644
--- a/rest_framework/tests/negotiation.py
+++ b/rest_framework/tests/test_negotiation.py
@@ -3,19 +3,24 @@ from django.test import TestCase
from django.test.client import RequestFactory
from rest_framework.negotiation import DefaultContentNegotiation
from rest_framework.request import Request
+from rest_framework.renderers import BaseRenderer
factory = RequestFactory()
-class MockJSONRenderer(object):
+class MockJSONRenderer(BaseRenderer):
media_type = 'application/json'
-class MockHTMLRenderer(object):
+class MockHTMLRenderer(BaseRenderer):
media_type = 'text/html'
+class NoCharsetSpecifiedRenderer(BaseRenderer):
+ media_type = 'my/media'
+
+
class TestAcceptedMediaType(TestCase):
def setUp(self):
self.renderers = [MockJSONRenderer(), MockHTMLRenderer()]
diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/test_pagination.py
index 6b8ef02f..e538a78e 100644
--- a/rest_framework/tests/pagination.py
+++ b/rest_framework/tests/test_pagination.py
@@ -1,18 +1,24 @@
from __future__ import unicode_literals
import datetime
from decimal import Decimal
-import django
+from django.db import models
from django.core.paginator import Paginator
from django.test import TestCase
from django.test.client import RequestFactory
from django.utils import unittest
from rest_framework import generics, status, pagination, filters, serializers
from rest_framework.compat import django_filters
-from rest_framework.tests.models import BasicModel, FilterableItem
+from rest_framework.tests.models import BasicModel
factory = RequestFactory()
+class FilterableItem(models.Model):
+ text = models.CharField(max_length=100)
+ decimal = models.DecimalField(max_digits=4, decimal_places=2)
+ date = models.DateField()
+
+
class RootView(generics.ListCreateAPIView):
"""
Example description for OPTIONS.
@@ -124,7 +130,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
model = FilterableItem
paginate_by = 10
filter_class = DecimalFilter
- filter_backend = filters.DjangoFilterBackend
+ filter_backends = (filters.DjangoFilterBackend,)
view = FilterFieldsRootView.as_view()
@@ -171,7 +177,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
class BasicFilterFieldsRootView(generics.ListCreateAPIView):
model = FilterableItem
paginate_by = 10
- filter_backend = DecimalFilterBackend
+ filter_backends = (DecimalFilterBackend,)
view = BasicFilterFieldsRootView.as_view()
diff --git a/rest_framework/tests/parsers.py b/rest_framework/tests/test_parsers.py
index 7699e10c..7699e10c 100644
--- a/rest_framework/tests/parsers.py
+++ b/rest_framework/tests/test_parsers.py
diff --git a/rest_framework/tests/permissions.py b/rest_framework/tests/test_permissions.py
index b3993be5..6caaf65b 100644
--- a/rest_framework/tests/permissions.py
+++ b/rest_framework/tests/test_permissions.py
@@ -108,6 +108,48 @@ class ModelPermissionsIntegrationTests(TestCase):
response = instance_view(request, pk='2')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+ def test_options_permitted(self):
+ request = factory.options('/', content_type='application/json',
+ HTTP_AUTHORIZATION=self.permitted_credentials)
+ response = root_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertIn('actions', response.data)
+ self.assertEqual(list(response.data['actions'].keys()), ['POST'])
+
+ request = factory.options('/1', content_type='application/json',
+ HTTP_AUTHORIZATION=self.permitted_credentials)
+ response = instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertIn('actions', response.data)
+ self.assertEqual(list(response.data['actions'].keys()), ['PUT'])
+
+ def test_options_disallowed(self):
+ request = factory.options('/', content_type='application/json',
+ HTTP_AUTHORIZATION=self.disallowed_credentials)
+ response = root_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertNotIn('actions', response.data)
+
+ request = factory.options('/1', content_type='application/json',
+ HTTP_AUTHORIZATION=self.disallowed_credentials)
+ response = instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertNotIn('actions', response.data)
+
+ def test_options_updateonly(self):
+ request = factory.options('/', content_type='application/json',
+ HTTP_AUTHORIZATION=self.updateonly_credentials)
+ response = root_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertNotIn('actions', response.data)
+
+ request = factory.options('/1', content_type='application/json',
+ HTTP_AUTHORIZATION=self.updateonly_credentials)
+ response = instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertIn('actions', response.data)
+ self.assertEqual(list(response.data['actions'].keys()), ['PUT'])
+
class OwnerModel(models.Model):
text = models.CharField(max_length=100)
diff --git a/rest_framework/tests/test_relations.py b/rest_framework/tests/test_relations.py
new file mode 100644
index 00000000..d19219c9
--- /dev/null
+++ b/rest_framework/tests/test_relations.py
@@ -0,0 +1,100 @@
+"""
+General tests for relational fields.
+"""
+from __future__ import unicode_literals
+from django.db import models
+from django.test import TestCase
+from rest_framework import serializers
+from rest_framework.tests.models import BlogPost
+
+
+class NullModel(models.Model):
+ pass
+
+
+class FieldTests(TestCase):
+ def test_pk_related_field_with_empty_string(self):
+ """
+ Regression test for #446
+
+ https://github.com/tomchristie/django-rest-framework/issues/446
+ """
+ field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all())
+ self.assertRaises(serializers.ValidationError, field.from_native, '')
+ self.assertRaises(serializers.ValidationError, field.from_native, [])
+
+ def test_hyperlinked_related_field_with_empty_string(self):
+ field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='')
+ self.assertRaises(serializers.ValidationError, field.from_native, '')
+ self.assertRaises(serializers.ValidationError, field.from_native, [])
+
+ def test_slug_related_field_with_empty_string(self):
+ field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk')
+ self.assertRaises(serializers.ValidationError, field.from_native, '')
+ self.assertRaises(serializers.ValidationError, field.from_native, [])
+
+
+class TestManyRelatedMixin(TestCase):
+ def test_missing_many_to_many_related_field(self):
+ '''
+ Regression test for #632
+
+ https://github.com/tomchristie/django-rest-framework/pull/632
+ '''
+ field = serializers.RelatedField(many=True, read_only=False)
+
+ into = {}
+ field.field_from_native({}, None, 'field_name', into)
+ self.assertEqual(into['field_name'], [])
+
+
+# Regression tests for #694 (`source` attribute on related fields)
+
+class RelatedFieldSourceTests(TestCase):
+ def test_related_manager_source(self):
+ """
+ Relational fields should be able to use manager-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.RelatedField(many=True, source='get_blogposts_manager')
+
+ class ClassWithManagerMethod(object):
+ def get_blogposts_manager(self):
+ return BlogPost.objects
+
+ obj = ClassWithManagerMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['BlogPost object'])
+
+ def test_related_queryset_source(self):
+ """
+ Relational fields should be able to use queryset-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.RelatedField(many=True, source='get_blogposts_queryset')
+
+ class ClassWithQuerysetMethod(object):
+ def get_blogposts_queryset(self):
+ return BlogPost.objects.all()
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['BlogPost object'])
+
+ def test_dotted_source(self):
+ """
+ Source argument should support dotted.source notation.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.RelatedField(many=True, source='a.b.c')
+
+ class ClassWithQuerysetMethod(object):
+ a = {
+ 'b': {
+ 'c': BlogPost.objects.all()
+ }
+ }
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['BlogPost object'])
diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/test_relations_hyperlink.py
index b1eed9a7..2ca7f4f2 100644
--- a/rest_framework/tests/relations_hyperlink.py
+++ b/rest_framework/tests/test_relations_hyperlink.py
@@ -4,6 +4,7 @@ from django.test.client import RequestFactory
from rest_framework import serializers
from rest_framework.compat import patterns, url
from rest_framework.tests.models import (
+ BlogPost,
ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource,
NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
)
@@ -16,6 +17,7 @@ def dummy_view(request, pk):
pass
urlpatterns = patterns('',
+ url(r'^dummyurl/(?P<pk>[0-9]+)/$', dummy_view, name='dummy-url'),
url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'),
url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'),
url(r'^foreignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeysource-detail'),
@@ -69,7 +71,7 @@ class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer):
# TODO: Add test that .data cannot be accessed prior to .is_valid
class HyperlinkedManyToManyTests(TestCase):
- urls = 'rest_framework.tests.relations_hyperlink'
+ urls = 'rest_framework.tests.test_relations_hyperlink'
def setUp(self):
for idx in range(1, 4):
@@ -177,7 +179,7 @@ class HyperlinkedManyToManyTests(TestCase):
class HyperlinkedForeignKeyTests(TestCase):
- urls = 'rest_framework.tests.relations_hyperlink'
+ urls = 'rest_framework.tests.test_relations_hyperlink'
def setUp(self):
target = ForeignKeyTarget(name='target-1')
@@ -305,7 +307,7 @@ class HyperlinkedForeignKeyTests(TestCase):
class HyperlinkedNullableForeignKeyTests(TestCase):
- urls = 'rest_framework.tests.relations_hyperlink'
+ urls = 'rest_framework.tests.test_relations_hyperlink'
def setUp(self):
target = ForeignKeyTarget(name='target-1')
@@ -433,7 +435,7 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
class HyperlinkedNullableOneToOneTests(TestCase):
- urls = 'rest_framework.tests.relations_hyperlink'
+ urls = 'rest_framework.tests.test_relations_hyperlink'
def setUp(self):
target = OneToOneTarget(name='target-1')
@@ -451,3 +453,72 @@ class HyperlinkedNullableOneToOneTests(TestCase):
{'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None},
]
self.assertEqual(serializer.data, expected)
+
+
+# Regression tests for #694 (`source` attribute on related fields)
+
+class HyperlinkedRelatedFieldSourceTests(TestCase):
+ urls = 'rest_framework.tests.test_relations_hyperlink'
+
+ def test_related_manager_source(self):
+ """
+ Relational fields should be able to use manager-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.HyperlinkedRelatedField(
+ many=True,
+ source='get_blogposts_manager',
+ view_name='dummy-url',
+ )
+ field.context = {'request': request}
+
+ class ClassWithManagerMethod(object):
+ def get_blogposts_manager(self):
+ return BlogPost.objects
+
+ obj = ClassWithManagerMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['http://testserver/dummyurl/1/'])
+
+ def test_related_queryset_source(self):
+ """
+ Relational fields should be able to use queryset-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.HyperlinkedRelatedField(
+ many=True,
+ source='get_blogposts_queryset',
+ view_name='dummy-url',
+ )
+ field.context = {'request': request}
+
+ class ClassWithQuerysetMethod(object):
+ def get_blogposts_queryset(self):
+ return BlogPost.objects.all()
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['http://testserver/dummyurl/1/'])
+
+ def test_dotted_source(self):
+ """
+ Source argument should support dotted.source notation.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.HyperlinkedRelatedField(
+ many=True,
+ source='a.b.c',
+ view_name='dummy-url',
+ )
+ field.context = {'request': request}
+
+ class ClassWithQuerysetMethod(object):
+ a = {
+ 'b': {
+ 'c': BlogPost.objects.all()
+ }
+ }
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['http://testserver/dummyurl/1/'])
diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/test_relations_nested.py
index 8325580f..8325580f 100644
--- a/rest_framework/tests/relations_nested.py
+++ b/rest_framework/tests/test_relations_nested.py
diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/test_relations_pk.py
index 5ce8b567..e2a1b815 100644
--- a/rest_framework/tests/relations_pk.py
+++ b/rest_framework/tests/test_relations_pk.py
@@ -1,7 +1,11 @@
from __future__ import unicode_literals
+from django.db import models
from django.test import TestCase
from rest_framework import serializers
-from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
+from rest_framework.tests.models import (
+ BlogPost, ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource,
+ NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource,
+)
from rest_framework.compat import six
@@ -124,6 +128,7 @@ class PKManyToManyTests(TestCase):
# Ensure source 4 is added, and everything else is as expected
queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True)
+ self.assertFalse(serializer.fields['targets'].read_only)
expected = [
{'id': 1, 'name': 'source-1', 'targets': [1]},
{'id': 2, 'name': 'source-2', 'targets': [1, 2]},
@@ -135,6 +140,7 @@ class PKManyToManyTests(TestCase):
def test_reverse_many_to_many_create(self):
data = {'id': 4, 'name': 'target-4', 'sources': [1, 3]}
serializer = ManyToManyTargetSerializer(data=data)
+ self.assertFalse(serializer.fields['sources'].read_only)
self.assertTrue(serializer.is_valid())
obj = serializer.save()
self.assertEqual(serializer.data, data)
@@ -421,3 +427,116 @@ class PKNullableOneToOneTests(TestCase):
{'id': 2, 'name': 'target-2', 'nullable_source': 1},
]
self.assertEqual(serializer.data, expected)
+
+
+# The below models and tests ensure that serializer fields corresponding
+# to a ManyToManyField field with a user-specified ``through`` model are
+# set to read only
+
+
+class ManyToManyThroughTarget(models.Model):
+ name = models.CharField(max_length=100)
+
+
+class ManyToManyThrough(models.Model):
+ source = models.ForeignKey('ManyToManyThroughSource')
+ target = models.ForeignKey(ManyToManyThroughTarget)
+
+
+class ManyToManyThroughSource(models.Model):
+ name = models.CharField(max_length=100)
+ targets = models.ManyToManyField(ManyToManyThroughTarget,
+ related_name='sources',
+ through='ManyToManyThrough')
+
+
+class ManyToManyThroughTargetSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ManyToManyThroughTarget
+ fields = ('id', 'name', 'sources')
+
+
+class ManyToManyThroughSourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ManyToManyThroughSource
+ fields = ('id', 'name', 'targets')
+
+
+class PKManyToManyThroughTests(TestCase):
+ def setUp(self):
+ self.source = ManyToManyThroughSource.objects.create(
+ name='through-source-1')
+ self.target = ManyToManyThroughTarget.objects.create(
+ name='through-target-1')
+
+ def test_many_to_many_create(self):
+ data = {'id': 2, 'name': 'source-2', 'targets': [self.target.pk]}
+ serializer = ManyToManyThroughSourceSerializer(data=data)
+ self.assertTrue(serializer.fields['targets'].read_only)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(obj.name, 'source-2')
+ self.assertEqual(obj.targets.count(), 0)
+
+ def test_many_to_many_reverse_create(self):
+ data = {'id': 2, 'name': 'target-2', 'sources': [self.source.pk]}
+ serializer = ManyToManyThroughTargetSerializer(data=data)
+ self.assertTrue(serializer.fields['sources'].read_only)
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ obj = serializer.save()
+ self.assertEqual(obj.name, 'target-2')
+ self.assertEqual(obj.sources.count(), 0)
+
+
+# Regression tests for #694 (`source` attribute on related fields)
+
+
+class PrimaryKeyRelatedFieldSourceTests(TestCase):
+ def test_related_manager_source(self):
+ """
+ Relational fields should be able to use manager-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.PrimaryKeyRelatedField(many=True, source='get_blogposts_manager')
+
+ class ClassWithManagerMethod(object):
+ def get_blogposts_manager(self):
+ return BlogPost.objects
+
+ obj = ClassWithManagerMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, [1])
+
+ def test_related_queryset_source(self):
+ """
+ Relational fields should be able to use queryset-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.PrimaryKeyRelatedField(many=True, source='get_blogposts_queryset')
+
+ class ClassWithQuerysetMethod(object):
+ def get_blogposts_queryset(self):
+ return BlogPost.objects.all()
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, [1])
+
+ def test_dotted_source(self):
+ """
+ Source argument should support dotted.source notation.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.PrimaryKeyRelatedField(many=True, source='a.b.c')
+
+ class ClassWithQuerysetMethod(object):
+ a = {
+ 'b': {
+ 'c': BlogPost.objects.all()
+ }
+ }
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, [1])
diff --git a/rest_framework/tests/relations_slug.py b/rest_framework/tests/test_relations_slug.py
index 435c821c..435c821c 100644
--- a/rest_framework/tests/relations_slug.py
+++ b/rest_framework/tests/test_relations_slug.py
diff --git a/rest_framework/tests/renderers.py b/rest_framework/tests/test_renderers.py
index 40bac9cb..95b59741 100644
--- a/rest_framework/tests/renderers.py
+++ b/rest_framework/tests/test_renderers.py
@@ -1,14 +1,18 @@
+# -*- coding: utf-8 -*-
+from __future__ import unicode_literals
+
from decimal import Decimal
from django.core.cache import cache
from django.test import TestCase
from django.test.client import RequestFactory
from django.utils import unittest
+from django.utils.translation import ugettext_lazy as _
from rest_framework import status, permissions
from rest_framework.compat import yaml, etree, patterns, url, include
from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \
- XMLRenderer, JSONPRenderer, BrowsableAPIRenderer
+ XMLRenderer, JSONPRenderer, BrowsableAPIRenderer, UnicodeJSONRenderer
from rest_framework.parsers import YAMLParser, XMLParser
from rest_framework.settings import api_settings
from rest_framework.compat import StringIO
@@ -26,7 +30,7 @@ RENDERER_B_SERIALIZER = lambda x: ('Renderer B: %s' % x).encode('ascii')
expected_results = [
- ((elem for elem in [1, 2, 3]), JSONRenderer, '[1, 2, 3]') # Generator
+ ((elem for elem in [1, 2, 3]), JSONRenderer, b'[1, 2, 3]') # Generator
]
@@ -129,12 +133,12 @@ class RendererEndToEndTests(TestCase):
End-to-end testing of renderers using an RendererMixin on a generic view.
"""
- urls = 'rest_framework.tests.renderers'
+ urls = 'rest_framework.tests.test_renderers'
def test_default_renderer_serializes_content(self):
"""If the Accept header is not set the default renderer should serialize the response."""
resp = self.client.get('/')
- self.assertEqual(resp['Content-Type'], RendererA.media_type)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@@ -142,13 +146,13 @@ class RendererEndToEndTests(TestCase):
"""No response must be included in HEAD requests."""
resp = self.client.head('/')
self.assertEqual(resp.status_code, DUMMYSTATUS)
- self.assertEqual(resp['Content-Type'], RendererA.media_type)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, six.b(''))
def test_default_renderer_serializes_content_on_accept_any(self):
"""If the Accept header is set to */* the default renderer should serialize the response."""
resp = self.client.get('/', HTTP_ACCEPT='*/*')
- self.assertEqual(resp['Content-Type'], RendererA.media_type)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@@ -156,7 +160,7 @@ class RendererEndToEndTests(TestCase):
"""If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for the default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
- self.assertEqual(resp['Content-Type'], RendererA.media_type)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@@ -164,7 +168,7 @@ class RendererEndToEndTests(TestCase):
"""If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for a non-default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
- self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@@ -175,7 +179,7 @@ class RendererEndToEndTests(TestCase):
RendererB.media_type
)
resp = self.client.get('/' + param)
- self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@@ -192,7 +196,7 @@ class RendererEndToEndTests(TestCase):
RendererB.format
)
resp = self.client.get('/' + param)
- self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@@ -200,7 +204,7 @@ class RendererEndToEndTests(TestCase):
"""If a 'format' keyword arg is specified, the renderer with the matching
format attribute should serialize the response."""
resp = self.client.get('/something.formatb')
- self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@@ -213,7 +217,7 @@ class RendererEndToEndTests(TestCase):
)
resp = self.client.get('/' + param,
HTTP_ACCEPT=RendererB.media_type)
- self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@@ -235,6 +239,13 @@ class JSONRendererTests(TestCase):
Tests specific to the JSON Renderer
"""
+ def test_render_lazy_strings(self):
+ """
+ JSONRenderer should deal with lazy translated strings.
+ """
+ ret = JSONRenderer().render(_('test'))
+ self.assertEqual(ret, b'"test"')
+
def test_without_content_type_args(self):
"""
Test basic JSON rendering.
@@ -243,7 +254,7 @@ class JSONRendererTests(TestCase):
renderer = JSONRenderer()
content = renderer.render(obj, 'application/json')
# Fix failing test case which depends on version of JSON library.
- self.assertEqual(content, _flat_repr)
+ self.assertEqual(content.decode('utf-8'), _flat_repr)
def test_with_content_type_args(self):
"""
@@ -252,7 +263,24 @@ class JSONRendererTests(TestCase):
obj = {'foo': ['bar', 'baz']}
renderer = JSONRenderer()
content = renderer.render(obj, 'application/json; indent=2')
- self.assertEqual(strip_trailing_whitespace(content), _indented_repr)
+ self.assertEqual(strip_trailing_whitespace(content.decode('utf-8')), _indented_repr)
+
+ def test_check_ascii(self):
+ obj = {'countries': ['United Kingdom', 'France', 'España']}
+ renderer = JSONRenderer()
+ content = renderer.render(obj, 'application/json')
+ self.assertEqual(content, '{"countries": ["United Kingdom", "France", "Espa\\u00f1a"]}'.encode('utf-8'))
+
+
+class UnicodeJSONRendererTests(TestCase):
+ """
+ Tests specific for the Unicode JSON Renderer
+ """
+ def test_proper_encoding(self):
+ obj = {'countries': ['United Kingdom', 'France', 'España']}
+ renderer = UnicodeJSONRenderer()
+ content = renderer.render(obj, 'application/json')
+ self.assertEqual(content, '{"countries": ["United Kingdom", "France", "España"]}'.encode('utf-8'))
class JSONPRendererTests(TestCase):
@@ -260,7 +288,7 @@ class JSONPRendererTests(TestCase):
Tests specific to the JSONP Renderer
"""
- urls = 'rest_framework.tests.renderers'
+ urls = 'rest_framework.tests.test_renderers'
def test_without_callback_with_json_renderer(self):
"""
@@ -269,7 +297,7 @@ class JSONPRendererTests(TestCase):
resp = self.client.get('/jsonp/jsonrenderer',
HTTP_ACCEPT='application/javascript')
self.assertEqual(resp.status_code, status.HTTP_200_OK)
- self.assertEqual(resp['Content-Type'], 'application/javascript')
+ self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8')
self.assertEqual(resp.content,
('callback(%s);' % _flat_repr).encode('ascii'))
@@ -280,7 +308,7 @@ class JSONPRendererTests(TestCase):
resp = self.client.get('/jsonp/nojsonrenderer',
HTTP_ACCEPT='application/javascript')
self.assertEqual(resp.status_code, status.HTTP_200_OK)
- self.assertEqual(resp['Content-Type'], 'application/javascript')
+ self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8')
self.assertEqual(resp.content,
('callback(%s);' % _flat_repr).encode('ascii'))
@@ -292,7 +320,7 @@ class JSONPRendererTests(TestCase):
resp = self.client.get('/jsonp/nojsonrenderer?callback=' + callback_func,
HTTP_ACCEPT='application/javascript')
self.assertEqual(resp.status_code, status.HTTP_200_OK)
- self.assertEqual(resp['Content-Type'], 'application/javascript')
+ self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8')
self.assertEqual(resp.content,
('%s(%s);' % (callback_func, _flat_repr)).encode('ascii'))
@@ -433,7 +461,7 @@ class CacheRenderTest(TestCase):
Tests specific to caching responses
"""
- urls = 'rest_framework.tests.renderers'
+ urls = 'rest_framework.tests.test_renderers'
cache_key = 'just_a_cache_key'
diff --git a/rest_framework/tests/request.py b/rest_framework/tests/test_request.py
index 97e5af20..a5c5e84c 100644
--- a/rest_framework/tests/request.py
+++ b/rest_framework/tests/test_request.py
@@ -254,7 +254,7 @@ urlpatterns = patterns('',
class TestContentParsingWithAuthentication(TestCase):
- urls = 'rest_framework.tests.request'
+ urls = 'rest_framework.tests.test_request'
def setUp(self):
self.csrf_client = Client(enforce_csrf_checks=True)
diff --git a/rest_framework/tests/response.py b/rest_framework/tests/test_response.py
index aecf83f4..eea3c641 100644
--- a/rest_framework/tests/response.py
+++ b/rest_framework/tests/test_response.py
@@ -1,14 +1,18 @@
from __future__ import unicode_literals
from django.test import TestCase
+from rest_framework.tests.models import BasicModel, BasicModelSerializer
from rest_framework.compat import patterns, url, include
from rest_framework.response import Response
from rest_framework.views import APIView
+from rest_framework import generics
+from rest_framework import routers
from rest_framework import status
from rest_framework.renderers import (
BaseRenderer,
JSONRenderer,
BrowsableAPIRenderer
)
+from rest_framework import viewsets
from rest_framework.settings import api_settings
from rest_framework.compat import six
@@ -21,6 +25,9 @@ class MockJsonRenderer(BaseRenderer):
media_type = 'application/json'
+class MockTextMediaRenderer(BaseRenderer):
+ media_type = 'text/html'
+
DUMMYSTATUS = status.HTTP_200_OK
DUMMYCONTENT = 'dummycontent'
@@ -44,13 +51,26 @@ class RendererB(BaseRenderer):
return RENDERER_B_SERIALIZER(data)
+class RendererC(RendererB):
+ media_type = 'mock/rendererc'
+ format = 'formatc'
+ charset = "rendererc"
+
+
class MockView(APIView):
- renderer_classes = (RendererA, RendererB)
+ renderer_classes = (RendererA, RendererB, RendererC)
def get(self, request, **kwargs):
return Response(DUMMYCONTENT, status=DUMMYSTATUS)
+class MockViewSettingContentType(APIView):
+ renderer_classes = (RendererA, RendererB, RendererC)
+
+ def get(self, request, **kwargs):
+ return Response(DUMMYCONTENT, status=DUMMYSTATUS, content_type='setbyview')
+
+
class HTMLView(APIView):
renderer_classes = (BrowsableAPIRenderer, )
@@ -65,11 +85,29 @@ class HTMLView1(APIView):
return Response('text')
+class HTMLNewModelViewSet(viewsets.ModelViewSet):
+ model = BasicModel
+
+
+class HTMLNewModelView(generics.ListCreateAPIView):
+ renderer_classes = (BrowsableAPIRenderer,)
+ permission_classes = []
+ serializer_class = BasicModelSerializer
+ model = BasicModel
+
+
+new_model_viewset_router = routers.DefaultRouter()
+new_model_viewset_router.register(r'', HTMLNewModelViewSet)
+
+
urlpatterns = patterns('',
- url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
- url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
+ url(r'^setbyview$', MockViewSettingContentType.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
+ url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
+ url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
url(r'^html$', HTMLView.as_view()),
url(r'^html1$', HTMLView1.as_view()),
+ url(r'^html_new_model$', HTMLNewModelView.as_view()),
+ url(r'^html_new_model_viewset', include(new_model_viewset_router.urls)),
url(r'^restframework', include('rest_framework.urls', namespace='rest_framework'))
)
@@ -80,12 +118,12 @@ class RendererIntegrationTests(TestCase):
End-to-end testing of renderers using an ResponseMixin on a generic view.
"""
- urls = 'rest_framework.tests.response'
+ urls = 'rest_framework.tests.test_response'
def test_default_renderer_serializes_content(self):
"""If the Accept header is not set the default renderer should serialize the response."""
resp = self.client.get('/')
- self.assertEqual(resp['Content-Type'], RendererA.media_type)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@@ -93,13 +131,13 @@ class RendererIntegrationTests(TestCase):
"""No response must be included in HEAD requests."""
resp = self.client.head('/')
self.assertEqual(resp.status_code, DUMMYSTATUS)
- self.assertEqual(resp['Content-Type'], RendererA.media_type)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, six.b(''))
def test_default_renderer_serializes_content_on_accept_any(self):
"""If the Accept header is set to */* the default renderer should serialize the response."""
resp = self.client.get('/', HTTP_ACCEPT='*/*')
- self.assertEqual(resp['Content-Type'], RendererA.media_type)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@@ -107,7 +145,7 @@ class RendererIntegrationTests(TestCase):
"""If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for the default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
- self.assertEqual(resp['Content-Type'], RendererA.media_type)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@@ -115,7 +153,7 @@ class RendererIntegrationTests(TestCase):
"""If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for a non-default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
- self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@@ -126,7 +164,7 @@ class RendererIntegrationTests(TestCase):
RendererB.media_type
)
resp = self.client.get('/' + param)
- self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@@ -134,7 +172,7 @@ class RendererIntegrationTests(TestCase):
"""If a 'format' query is specified, the renderer with the matching
format attribute should serialize the response."""
resp = self.client.get('/?format=%s' % RendererB.format)
- self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@@ -142,7 +180,7 @@ class RendererIntegrationTests(TestCase):
"""If a 'format' keyword arg is specified, the renderer with the matching
format attribute should serialize the response."""
resp = self.client.get('/something.formatb')
- self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@@ -151,7 +189,7 @@ class RendererIntegrationTests(TestCase):
the renderer with the matching format attribute should serialize the response."""
resp = self.client.get('/?format=%s' % RendererB.format,
HTTP_ACCEPT=RendererB.media_type)
- self.assertEqual(resp['Content-Type'], RendererB.media_type)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@@ -160,7 +198,7 @@ class Issue122Tests(TestCase):
"""
Tests that covers #122.
"""
- urls = 'rest_framework.tests.response'
+ urls = 'rest_framework.tests.test_response'
def test_only_html_renderer(self):
"""
@@ -173,3 +211,68 @@ class Issue122Tests(TestCase):
Test if no infinite recursion occurs.
"""
self.client.get('/html1')
+
+
+class Issue467Tests(TestCase):
+ """
+ Tests for #467
+ """
+
+ urls = 'rest_framework.tests.test_response'
+
+ def test_form_has_label_and_help_text(self):
+ resp = self.client.get('/html_new_model')
+ self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
+ self.assertContains(resp, 'Text comes here')
+ self.assertContains(resp, 'Text description.')
+
+
+class Issue807Tests(TestCase):
+ """
+ Covers #807
+ """
+
+ urls = 'rest_framework.tests.test_response'
+
+ def test_does_not_append_charset_by_default(self):
+ """
+ Renderers don't include a charset unless set explicitly.
+ """
+ headers = {"HTTP_ACCEPT": RendererA.media_type}
+ resp = self.client.get('/', **headers)
+ expected = "{0}; charset={1}".format(RendererA.media_type, 'utf-8')
+ self.assertEqual(expected, resp['Content-Type'])
+
+ def test_if_there_is_charset_specified_on_renderer_it_gets_appended(self):
+ """
+ If renderer class has charset attribute declared, it gets appended
+ to Response's Content-Type
+ """
+ headers = {"HTTP_ACCEPT": RendererC.media_type}
+ resp = self.client.get('/', **headers)
+ expected = "{0}; charset={1}".format(RendererC.media_type, RendererC.charset)
+ self.assertEqual(expected, resp['Content-Type'])
+
+ def test_content_type_set_explictly_on_response(self):
+ """
+ The content type may be set explictly on the response.
+ """
+ headers = {"HTTP_ACCEPT": RendererC.media_type}
+ resp = self.client.get('/setbyview', **headers)
+ self.assertEqual('setbyview', resp['Content-Type'])
+
+ def test_viewset_label_help_text(self):
+ param = '?%s=%s' % (
+ api_settings.URL_ACCEPT_OVERRIDE,
+ 'text/html'
+ )
+ resp = self.client.get('/html_new_model_viewset/' + param)
+ self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
+ self.assertContains(resp, 'Text comes here')
+ self.assertContains(resp, 'Text description.')
+
+ def test_form_has_label_and_help_text(self):
+ resp = self.client.get('/html_new_model')
+ self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
+ self.assertContains(resp, 'Text comes here')
+ self.assertContains(resp, 'Text description.')
diff --git a/rest_framework/tests/reverse.py b/rest_framework/tests/test_reverse.py
index cb8d8132..93ef5637 100644
--- a/rest_framework/tests/reverse.py
+++ b/rest_framework/tests/test_reverse.py
@@ -19,7 +19,7 @@ class ReverseTests(TestCase):
"""
Tests for fully qualified URLs when using `reverse`.
"""
- urls = 'rest_framework.tests.reverse'
+ urls = 'rest_framework.tests.test_reverse'
def test_reversed_urls_are_fully_qualified(self):
request = factory.get('/view')
diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py
new file mode 100644
index 00000000..a7534f70
--- /dev/null
+++ b/rest_framework/tests/test_routers.py
@@ -0,0 +1,150 @@
+from __future__ import unicode_literals
+from django.db import models
+from django.test import TestCase
+from django.test.client import RequestFactory
+from rest_framework import serializers, viewsets
+from rest_framework.compat import include, patterns, url
+from rest_framework.decorators import link, action
+from rest_framework.response import Response
+from rest_framework.routers import SimpleRouter
+
+factory = RequestFactory()
+
+urlpatterns = patterns('',)
+
+
+class BasicViewSet(viewsets.ViewSet):
+ def list(self, request, *args, **kwargs):
+ return Response({'method': 'list'})
+
+ @action()
+ def action1(self, request, *args, **kwargs):
+ return Response({'method': 'action1'})
+
+ @action()
+ def action2(self, request, *args, **kwargs):
+ return Response({'method': 'action2'})
+
+ @action(methods=['post', 'delete'])
+ def action3(self, request, *args, **kwargs):
+ return Response({'method': 'action2'})
+
+ @link()
+ def link1(self, request, *args, **kwargs):
+ return Response({'method': 'link1'})
+
+ @link()
+ def link2(self, request, *args, **kwargs):
+ return Response({'method': 'link2'})
+
+
+class TestSimpleRouter(TestCase):
+ def setUp(self):
+ self.router = SimpleRouter()
+
+ def test_link_and_action_decorator(self):
+ routes = self.router.get_routes(BasicViewSet)
+ decorator_routes = routes[2:]
+ # Make sure all these endpoints exist and none have been clobbered
+ for i, endpoint in enumerate(['action1', 'action2', 'action3', 'link1', 'link2']):
+ route = decorator_routes[i]
+ # check url listing
+ self.assertEqual(route.url,
+ '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(endpoint))
+ # check method to function mapping
+ if endpoint == 'action3':
+ methods_map = ['post', 'delete']
+ elif endpoint.startswith('action'):
+ methods_map = ['post']
+ else:
+ methods_map = ['get']
+ for method in methods_map:
+ self.assertEqual(route.mapping[method], endpoint)
+
+
+class RouterTestModel(models.Model):
+ uuid = models.CharField(max_length=20)
+ text = models.CharField(max_length=200)
+
+
+class TestCustomLookupFields(TestCase):
+ """
+ Ensure that custom lookup fields are correctly routed.
+ """
+ urls = 'rest_framework.tests.test_routers'
+
+ def setUp(self):
+ class NoteSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = RouterTestModel
+ lookup_field = 'uuid'
+ fields = ('url', 'uuid', 'text')
+
+ class NoteViewSet(viewsets.ModelViewSet):
+ queryset = RouterTestModel.objects.all()
+ serializer_class = NoteSerializer
+ lookup_field = 'uuid'
+
+ RouterTestModel.objects.create(uuid='123', text='foo bar')
+
+ self.router = SimpleRouter()
+ self.router.register(r'notes', NoteViewSet)
+
+ from rest_framework.tests import test_routers
+ urls = getattr(test_routers, 'urlpatterns')
+ urls += patterns('',
+ url(r'^', include(self.router.urls)),
+ )
+
+ def test_custom_lookup_field_route(self):
+ detail_route = self.router.urls[-1]
+ detail_url_pattern = detail_route.regex.pattern
+ self.assertIn('<uuid>', detail_url_pattern)
+
+ def test_retrieve_lookup_field_list_view(self):
+ response = self.client.get('/notes/')
+ self.assertEqual(response.data,
+ [{
+ "url": "http://testserver/notes/123/",
+ "uuid": "123", "text": "foo bar"
+ }]
+ )
+
+ def test_retrieve_lookup_field_detail_view(self):
+ response = self.client.get('/notes/123/')
+ self.assertEqual(response.data,
+ {
+ "url": "http://testserver/notes/123/",
+ "uuid": "123", "text": "foo bar"
+ }
+ )
+
+
+class TestTrailingSlash(TestCase):
+ def setUp(self):
+ class NoteViewSet(viewsets.ModelViewSet):
+ model = RouterTestModel
+
+ self.router = SimpleRouter()
+ self.router.register(r'notes', NoteViewSet)
+ self.urls = self.router.urls
+
+ def test_urls_have_trailing_slash_by_default(self):
+ expected = ['^notes/$', '^notes/(?P<pk>[^/]+)/$']
+ for idx in range(len(expected)):
+ self.assertEqual(expected[idx], self.urls[idx].regex.pattern)
+
+
+class TestTrailingSlash(TestCase):
+ def setUp(self):
+ class NoteViewSet(viewsets.ModelViewSet):
+ model = RouterTestModel
+
+ self.router = SimpleRouter(trailing_slash=False)
+ self.router.register(r'notes', NoteViewSet)
+ self.urls = self.router.urls
+
+ def test_urls_can_have_trailing_slash_removed(self):
+ expected = ['^notes$', '^notes/(?P<pk>[^/]+)$']
+ for idx in range(len(expected)):
+ self.assertEqual(expected[idx], self.urls[idx].regex.pattern)
diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/test_serializer.py
index 84e1ee4e..8b87a084 100644
--- a/rest_framework/tests/serializer.py
+++ b/rest_framework/tests/test_serializer.py
@@ -1,10 +1,14 @@
from __future__ import unicode_literals
-from django.utils.datastructures import MultiValueDict
+from django.db import models
+from django.db.models.fields import BLANK_CHOICE_DASH
from django.test import TestCase
-from rest_framework import serializers
+from django.utils.datastructures import MultiValueDict
+from django.utils.translation import ugettext_lazy as _
+from rest_framework import serializers, fields, relations
from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel,
BlankFieldModel, BlogPost, BlogPostComment, Book, CallableDefaultValueModel, DefaultValueModel,
- ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo)
+ ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo, RESTFrameworkModel)
+from rest_framework.tests.models import BasicModelSerializer
import datetime
import pickle
@@ -43,6 +47,17 @@ class CommentSerializer(serializers.Serializer):
return instance
+class NamesSerializer(serializers.Serializer):
+ first = serializers.CharField()
+ last = serializers.CharField(required=False, default='')
+ initials = serializers.CharField(required=False, default='')
+
+
+class PersonIdentifierSerializer(serializers.Serializer):
+ ssn = serializers.CharField()
+ names = NamesSerializer(source='names', required=False)
+
+
class BookSerializer(serializers.ModelSerializer):
isbn = serializers.RegexField(regex=r'^[0-9]{13}$', error_messages={'invalid': 'isbn has to be exact 13 numbers'})
@@ -78,6 +93,29 @@ class PersonSerializer(serializers.ModelSerializer):
read_only_fields = ('age',)
+class NestedSerializer(serializers.Serializer):
+ info = serializers.Field()
+
+
+class ModelSerializerWithNestedSerializer(serializers.ModelSerializer):
+ nested = NestedSerializer(source='*')
+
+ class Meta:
+ model = Person
+
+
+class PersonSerializerInvalidReadOnly(serializers.ModelSerializer):
+ """
+ Testing for #652.
+ """
+ info = serializers.Field(source='info')
+
+ class Meta:
+ model = Person
+ fields = ('name', 'age', 'info')
+ read_only_fields = ('age', 'info')
+
+
class AlbumsSerializer(serializers.ModelSerializer):
class Meta:
@@ -91,11 +129,6 @@ class PositiveIntegerAsChoiceSerializer(serializers.ModelSerializer):
fields = ['some_integer']
-class BrokenModelSerializer(serializers.ModelSerializer):
- class Meta:
- fields = ['some_field']
-
-
class BasicTests(TestCase):
def setUp(self):
self.comment = Comment(
@@ -141,6 +174,42 @@ class BasicTests(TestCase):
self.assertFalse(serializer.object is expected)
self.assertEqual(serializer.data['sub_comment'], 'And Merry Christmas!')
+ def test_create_nested(self):
+ """Test a serializer with nested data."""
+ names = {'first': 'John', 'last': 'Doe', 'initials': 'jd'}
+ data = {'ssn': '1234567890', 'names': names}
+ serializer = PersonIdentifierSerializer(data=data)
+
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, data)
+ self.assertFalse(serializer.object is data)
+ self.assertEqual(serializer.data['names'], names)
+
+ def test_create_partial_nested(self):
+ """Test a serializer with nested data which has missing fields."""
+ names = {'first': 'John'}
+ data = {'ssn': '1234567890', 'names': names}
+ serializer = PersonIdentifierSerializer(data=data)
+
+ expected_names = {'first': 'John', 'last': '', 'initials': ''}
+ data['names'] = expected_names
+
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, data)
+ self.assertFalse(serializer.object is expected_names)
+ self.assertEqual(serializer.data['names'], expected_names)
+
+ def test_null_nested(self):
+ """Test a serializer with a nonexistent nested field"""
+ data = {'ssn': '1234567890'}
+ serializer = PersonIdentifierSerializer(data=data)
+
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, data)
+ self.assertFalse(serializer.object is data)
+ expected = {'ssn': '1234567890', 'names': None}
+ self.assertEqual(serializer.data, expected)
+
def test_update(self):
serializer = CommentSerializer(self.comment, data=self.data)
expected = self.comment
@@ -189,6 +258,12 @@ class BasicTests(TestCase):
# Assert age is unchanged (35)
self.assertEqual(instance.age, self.person_data['age'])
+ def test_invalid_read_only_fields(self):
+ """
+ Regression test for #652.
+ """
+ self.assertRaises(AssertionError, PersonSerializerInvalidReadOnly, [])
+
class DictStyleSerializer(serializers.Serializer):
"""
@@ -344,19 +419,34 @@ class ValidationTests(TestCase):
Assert that a meaningful exception message is outputted when the model
field is missing (e.g. when mistyping ``model``).
"""
+ class BrokenModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ fields = ['some_field']
+
try:
- serializer = BrokenModelSerializer()
+ BrokenModelSerializer()
except AssertionError as e:
self.assertEqual(e.args[0], "Serializer class 'BrokenModelSerializer' is missing 'model' Meta option")
except:
self.fail('Wrong exception type thrown.')
+ def test_writable_star_source_on_nested_serializer(self):
+ """
+ Assert that a nested serializer instantiated with source='*' correctly
+ expands the data into the outer serializer.
+ """
+ serializer = ModelSerializerWithNestedSerializer(data={
+ 'name': 'marko',
+ 'nested': {'info': 'hi'}},
+ )
+ self.assertEqual(serializer.is_valid(), True)
+
class CustomValidationTests(TestCase):
class CommentSerializerWithFieldValidator(CommentSerializer):
def validate_email(self, attrs, source):
- value = attrs[source]
+ attrs[source]
return attrs
def validate_content(self, attrs, source):
@@ -853,23 +943,6 @@ class RelatedTraversalTest(TestCase):
self.assertEqual(serializer.data, expected)
- def test_queryset_nested_traversal(self):
- """
- Relational fields should be able to use methods as their source.
- """
- BlogPost.objects.create(title='blah')
-
- class QuerysetMethodSerializer(serializers.Serializer):
- blogposts = serializers.RelatedField(many=True, source='get_all_blogposts')
-
- class ClassWithQuerysetMethod(object):
- def get_all_blogposts(self):
- return BlogPost.objects
-
- obj = ClassWithQuerysetMethod()
- serializer = QuerysetMethodSerializer(obj)
- self.assertEqual(serializer.data, {'blogposts': ['BlogPost object']})
-
class SerializerMethodFieldTests(TestCase):
def setUp(self):
@@ -1000,6 +1073,130 @@ class SerializerPickleTests(TestCase):
repr(pickle.loads(pickle.dumps(data, 0)))
+# test for issue #725
+class SeveralChoicesModel(models.Model):
+ color = models.CharField(
+ max_length=10,
+ choices=[('red', 'Red'), ('green', 'Green'), ('blue', 'Blue')],
+ blank=False
+ )
+ drink = models.CharField(
+ max_length=10,
+ choices=[('beer', 'Beer'), ('wine', 'Wine'), ('cider', 'Cider')],
+ blank=False,
+ default='beer'
+ )
+ os = models.CharField(
+ max_length=10,
+ choices=[('linux', 'Linux'), ('osx', 'OSX'), ('windows', 'Windows')],
+ blank=True
+ )
+ music_genre = models.CharField(
+ max_length=10,
+ choices=[('rock', 'Rock'), ('metal', 'Metal'), ('grunge', 'Grunge')],
+ blank=True,
+ default='metal'
+ )
+
+
+class SerializerChoiceFields(TestCase):
+
+ def setUp(self):
+ super(SerializerChoiceFields, self).setUp()
+
+ class SeveralChoicesSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = SeveralChoicesModel
+ fields = ('color', 'drink', 'os', 'music_genre')
+
+ self.several_choices_serializer = SeveralChoicesSerializer
+
+ def test_choices_blank_false_not_default(self):
+ serializer = self.several_choices_serializer()
+ self.assertEqual(
+ serializer.fields['color'].choices,
+ [('red', 'Red'), ('green', 'Green'), ('blue', 'Blue')]
+ )
+
+ def test_choices_blank_false_with_default(self):
+ serializer = self.several_choices_serializer()
+ self.assertEqual(
+ serializer.fields['drink'].choices,
+ [('beer', 'Beer'), ('wine', 'Wine'), ('cider', 'Cider')]
+ )
+
+ def test_choices_blank_true_not_default(self):
+ serializer = self.several_choices_serializer()
+ self.assertEqual(
+ serializer.fields['os'].choices,
+ BLANK_CHOICE_DASH + [('linux', 'Linux'), ('osx', 'OSX'), ('windows', 'Windows')]
+ )
+
+ def test_choices_blank_true_with_default(self):
+ serializer = self.several_choices_serializer()
+ self.assertEqual(
+ serializer.fields['music_genre'].choices,
+ BLANK_CHOICE_DASH + [('rock', 'Rock'), ('metal', 'Metal'), ('grunge', 'Grunge')]
+ )
+
+
+# Regression tests for #675
+class Ticket(models.Model):
+ assigned = models.ForeignKey(
+ Person, related_name='assigned_tickets')
+ reviewer = models.ForeignKey(
+ Person, blank=True, null=True, related_name='reviewed_tickets')
+
+
+class SerializerRelatedChoicesTest(TestCase):
+
+ def setUp(self):
+ super(SerializerRelatedChoicesTest, self).setUp()
+
+ class RelatedChoicesSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Ticket
+ fields = ('assigned', 'reviewer')
+
+ self.related_fields_serializer = RelatedChoicesSerializer
+
+ def test_empty_queryset_required(self):
+ serializer = self.related_fields_serializer()
+ self.assertEqual(serializer.fields['assigned'].queryset.count(), 0)
+ self.assertEqual(
+ [x for x in serializer.fields['assigned'].widget.choices],
+ []
+ )
+
+ def test_empty_queryset_not_required(self):
+ serializer = self.related_fields_serializer()
+ self.assertEqual(serializer.fields['reviewer'].queryset.count(), 0)
+ self.assertEqual(
+ [x for x in serializer.fields['reviewer'].widget.choices],
+ [('', '---------')]
+ )
+
+ def test_with_some_persons_required(self):
+ Person.objects.create(name="Lionel Messi")
+ Person.objects.create(name="Xavi Hernandez")
+ serializer = self.related_fields_serializer()
+ self.assertEqual(serializer.fields['assigned'].queryset.count(), 2)
+ self.assertEqual(
+ [x for x in serializer.fields['assigned'].widget.choices],
+ [(1, 'Person object - 1'), (2, 'Person object - 2')]
+ )
+
+ def test_with_some_persons_not_required(self):
+ Person.objects.create(name="Lionel Messi")
+ Person.objects.create(name="Xavi Hernandez")
+ serializer = self.related_fields_serializer()
+ self.assertEqual(serializer.fields['reviewer'].queryset.count(), 2)
+ self.assertEqual(
+ [x for x in serializer.fields['reviewer'].widget.choices],
+ [('', '---------'), (1, 'Person object - 1'), (2, 'Person object - 2')]
+ )
+
+
class DepthTest(TestCase):
def test_implicit_nesting(self):
@@ -1125,3 +1322,312 @@ class DeserializeListTestCase(TestCase):
self.assertFalse(serializer.is_valid())
expected = [{}, {'email': ['This field is required.']}, {}]
self.assertEqual(serializer.errors, expected)
+
+
+# Test for issue 747
+
+class LazyStringModel(object):
+ def __init__(self, lazystring):
+ self.lazystring = lazystring
+
+
+class LazyStringSerializer(serializers.Serializer):
+ lazystring = serializers.Field()
+
+ def restore_object(self, attrs, instance=None):
+ if instance is not None:
+ instance.lazystring = attrs.get('lazystring', instance.lazystring)
+ return instance
+ return LazyStringModel(**attrs)
+
+
+class LazyStringsTestCase(TestCase):
+ def setUp(self):
+ self.model = LazyStringModel(lazystring=_('lazystring'))
+
+ def test_lazy_strings_are_translated(self):
+ serializer = LazyStringSerializer(self.model)
+ self.assertEqual(type(serializer.data['lazystring']),
+ type('lazystring'))
+
+
+# Test for issue #467
+
+class FieldLabelTest(TestCase):
+ def setUp(self):
+ self.serializer_class = BasicModelSerializer
+
+ def test_label_from_model(self):
+ """
+ Validates that label and help_text are correctly copied from the model class.
+ """
+ serializer = self.serializer_class()
+ text_field = serializer.fields['text']
+
+ self.assertEqual('Text comes here', text_field.label)
+ self.assertEqual('Text description.', text_field.help_text)
+
+ def test_field_ctor(self):
+ """
+ This is check that ctor supports both label and help_text.
+ """
+ self.assertEqual('Label', fields.Field(label='Label', help_text='Help').label)
+ self.assertEqual('Help', fields.CharField(label='Label', help_text='Help').help_text)
+ self.assertEqual('Label', relations.HyperlinkedRelatedField(view_name='fake', label='Label', help_text='Help', many=True).label)
+
+
+class AttributeMappingOnAutogeneratedFieldsTests(TestCase):
+
+ def setUp(self):
+ class AMOAFModel(RESTFrameworkModel):
+ char_field = models.CharField(max_length=1024, blank=True)
+ comma_separated_integer_field = models.CommaSeparatedIntegerField(max_length=1024, blank=True)
+ decimal_field = models.DecimalField(max_digits=64, decimal_places=32, blank=True)
+ email_field = models.EmailField(max_length=1024, blank=True)
+ file_field = models.FileField(max_length=1024, blank=True)
+ image_field = models.ImageField(max_length=1024, blank=True)
+ slug_field = models.SlugField(max_length=1024, blank=True)
+ url_field = models.URLField(max_length=1024, blank=True)
+
+ class AMOAFSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = AMOAFModel
+
+ self.serializer_class = AMOAFSerializer
+ self.fields_attributes = {
+ 'char_field': [
+ ('max_length', 1024),
+ ],
+ 'comma_separated_integer_field': [
+ ('max_length', 1024),
+ ],
+ 'decimal_field': [
+ ('max_digits', 64),
+ ('decimal_places', 32),
+ ],
+ 'email_field': [
+ ('max_length', 1024),
+ ],
+ 'file_field': [
+ ('max_length', 1024),
+ ],
+ 'image_field': [
+ ('max_length', 1024),
+ ],
+ 'slug_field': [
+ ('max_length', 1024),
+ ],
+ 'url_field': [
+ ('max_length', 1024),
+ ],
+ }
+
+ def field_test(self, field):
+ serializer = self.serializer_class(data={})
+ self.assertEqual(serializer.is_valid(), True)
+
+ for attribute in self.fields_attributes[field]:
+ self.assertEqual(
+ getattr(serializer.fields[field], attribute[0]),
+ attribute[1]
+ )
+
+ def test_char_field(self):
+ self.field_test('char_field')
+
+ def test_comma_separated_integer_field(self):
+ self.field_test('comma_separated_integer_field')
+
+ def test_decimal_field(self):
+ self.field_test('decimal_field')
+
+ def test_email_field(self):
+ self.field_test('email_field')
+
+ def test_file_field(self):
+ self.field_test('file_field')
+
+ def test_image_field(self):
+ self.field_test('image_field')
+
+ def test_slug_field(self):
+ self.field_test('slug_field')
+
+ def test_url_field(self):
+ self.field_test('url_field')
+
+
+class DefaultValuesOnAutogeneratedFieldsTests(TestCase):
+
+ def setUp(self):
+ class DVOAFModel(RESTFrameworkModel):
+ positive_integer_field = models.PositiveIntegerField(blank=True)
+ positive_small_integer_field = models.PositiveSmallIntegerField(blank=True)
+ email_field = models.EmailField(blank=True)
+ file_field = models.FileField(blank=True)
+ image_field = models.ImageField(blank=True)
+ slug_field = models.SlugField(blank=True)
+ url_field = models.URLField(blank=True)
+
+ class DVOAFSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = DVOAFModel
+
+ self.serializer_class = DVOAFSerializer
+ self.fields_attributes = {
+ 'positive_integer_field': [
+ ('min_value', 0),
+ ],
+ 'positive_small_integer_field': [
+ ('min_value', 0),
+ ],
+ 'email_field': [
+ ('max_length', 75),
+ ],
+ 'file_field': [
+ ('max_length', 100),
+ ],
+ 'image_field': [
+ ('max_length', 100),
+ ],
+ 'slug_field': [
+ ('max_length', 50),
+ ],
+ 'url_field': [
+ ('max_length', 200),
+ ],
+ }
+
+ def field_test(self, field):
+ serializer = self.serializer_class(data={})
+ self.assertEqual(serializer.is_valid(), True)
+
+ for attribute in self.fields_attributes[field]:
+ self.assertEqual(
+ getattr(serializer.fields[field], attribute[0]),
+ attribute[1]
+ )
+
+ def test_positive_integer_field(self):
+ self.field_test('positive_integer_field')
+
+ def test_positive_small_integer_field(self):
+ self.field_test('positive_small_integer_field')
+
+ def test_email_field(self):
+ self.field_test('email_field')
+
+ def test_file_field(self):
+ self.field_test('file_field')
+
+ def test_image_field(self):
+ self.field_test('image_field')
+
+ def test_slug_field(self):
+ self.field_test('slug_field')
+
+ def test_url_field(self):
+ self.field_test('url_field')
+
+
+class MetadataSerializer(serializers.Serializer):
+ field1 = serializers.CharField(3, required=True)
+ field2 = serializers.CharField(10, required=False)
+
+
+class MetadataSerializerTestCase(TestCase):
+ def setUp(self):
+ self.serializer = MetadataSerializer()
+
+ def test_serializer_metadata(self):
+ metadata = self.serializer.metadata()
+ expected = {
+ 'field1': {
+ 'required': True,
+ 'max_length': 3,
+ 'type': 'string',
+ 'read_only': False
+ },
+ 'field2': {
+ 'required': False,
+ 'max_length': 10,
+ 'type': 'string',
+ 'read_only': False
+ }
+ }
+ self.assertEqual(expected, metadata)
+
+
+### Regression test for #840
+
+class SimpleModel(models.Model):
+ text = models.CharField(max_length=100)
+
+
+class SimpleModelSerializer(serializers.ModelSerializer):
+ text = serializers.CharField()
+ other = serializers.CharField()
+
+ class Meta:
+ model = SimpleModel
+
+ def validate_other(self, attrs, source):
+ del attrs['other']
+ return attrs
+
+
+class FieldValidationRemovingAttr(TestCase):
+ def test_removing_non_model_field_in_validation(self):
+ """
+ Removing an attr during field valiation should ensure that it is not
+ passed through when restoring the object.
+
+ This allows additional non-model fields to be supported.
+
+ Regression test for #840.
+ """
+ serializer = SimpleModelSerializer(data={'text': 'foo', 'other': 'bar'})
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEqual(serializer.object.text, 'foo')
+
+
+### Regression test for #878
+
+class SimpleTargetModel(models.Model):
+ text = models.CharField(max_length=100)
+
+
+class SimplePKSourceModelSerializer(serializers.Serializer):
+ targets = serializers.PrimaryKeyRelatedField(queryset=SimpleTargetModel.objects.all(), many=True)
+ text = serializers.CharField()
+
+
+class SimpleSlugSourceModelSerializer(serializers.Serializer):
+ targets = serializers.SlugRelatedField(queryset=SimpleTargetModel.objects.all(), many=True, slug_field='pk')
+ text = serializers.CharField()
+
+
+class SerializerSupportsManyRelationships(TestCase):
+ def setUp(self):
+ SimpleTargetModel.objects.create(text='foo')
+ SimpleTargetModel.objects.create(text='bar')
+
+ def test_serializer_supports_pk_many_relationships(self):
+ """
+ Regression test for #878.
+
+ Note that pk behavior has a different code path to usual cases,
+ for performance reasons.
+ """
+ serializer = SimplePKSourceModelSerializer(data={'text': 'foo', 'targets': [1, 2]})
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, {'text': 'foo', 'targets': [1, 2]})
+
+ def test_serializer_supports_slug_many_relationships(self):
+ """
+ Regression test for #878.
+ """
+ serializer = SimpleSlugSourceModelSerializer(data={'text': 'foo', 'targets': [1, 2]})
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, {'text': 'foo', 'targets': [1, 2]})
diff --git a/rest_framework/tests/serializer_bulk_update.py b/rest_framework/tests/test_serializer_bulk_update.py
index 8b0ded1a..8b0ded1a 100644
--- a/rest_framework/tests/serializer_bulk_update.py
+++ b/rest_framework/tests/test_serializer_bulk_update.py
diff --git a/rest_framework/tests/serializer_nested.py b/rest_framework/tests/test_serializer_nested.py
index 71d0e24b..71d0e24b 100644
--- a/rest_framework/tests/serializer_nested.py
+++ b/rest_framework/tests/test_serializer_nested.py
diff --git a/rest_framework/tests/settings.py b/rest_framework/tests/test_settings.py
index 857375c2..857375c2 100644
--- a/rest_framework/tests/settings.py
+++ b/rest_framework/tests/test_settings.py
diff --git a/rest_framework/tests/throttling.py b/rest_framework/tests/test_throttling.py
index 11cbd8eb..da400b2f 100644
--- a/rest_framework/tests/throttling.py
+++ b/rest_framework/tests/test_throttling.py
@@ -36,7 +36,7 @@ class MockView_MinuteThrottling(APIView):
class ThrottlingTests(TestCase):
- urls = 'rest_framework.tests.throttling'
+ urls = 'rest_framework.tests.test_throttling'
def setUp(self):
"""
diff --git a/rest_framework/tests/urlpatterns.py b/rest_framework/tests/test_urlpatterns.py
index 29ed4a96..29ed4a96 100644
--- a/rest_framework/tests/urlpatterns.py
+++ b/rest_framework/tests/test_urlpatterns.py
diff --git a/rest_framework/tests/validation.py b/rest_framework/tests/test_validation.py
index cbdd6515..a6ec0e99 100644
--- a/rest_framework/tests/validation.py
+++ b/rest_framework/tests/test_validation.py
@@ -63,3 +63,25 @@ class TestPreSaveValidationExclusions(TestCase):
# does not have `blank=True`, so this serializer should not validate.
serializer = ShouldValidateModelSerializer(data={'renamed': ''})
self.assertEqual(serializer.is_valid(), False)
+
+
+class ValidationSerializer(serializers.Serializer):
+ foo = serializers.CharField()
+
+ def validate_foo(self, attrs, source):
+ raise serializers.ValidationError("foo invalid")
+
+ def validate(self, attrs):
+ raise serializers.ValidationError("serializer invalid")
+
+
+class TestAvoidValidation(TestCase):
+ """
+ If serializer was initialized with invalid data (None or non dict-like), it
+ should avoid validation layer (validate_<field> and validate methods)
+ """
+ def test_serializer_errors_has_only_invalid_data_error(self):
+ serializer = ValidationSerializer(data='invalid data')
+ self.assertFalse(serializer.is_valid())
+ self.assertDictEqual(serializer.errors,
+ {'non_field_errors': ['Invalid data']})
diff --git a/rest_framework/tests/views.py b/rest_framework/tests/test_views.py
index 994cf6dc..2767d24c 100644
--- a/rest_framework/tests/views.py
+++ b/rest_framework/tests/test_views.py
@@ -1,12 +1,15 @@
from __future__ import unicode_literals
+
+import copy
+
from django.test import TestCase
from django.test.client import RequestFactory
+
from rest_framework import status
from rest_framework.decorators import api_view
from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework.views import APIView
-import copy
factory = RequestFactory()
diff --git a/rest_framework/tests/testcases.py b/rest_framework/tests/testcases.py
deleted file mode 100644
index f8c2579e..00000000
--- a/rest_framework/tests/testcases.py
+++ /dev/null
@@ -1,66 +0,0 @@
-# http://djangosnippets.org/snippets/1011/
-from __future__ import unicode_literals
-from django.conf import settings
-from django.core.management import call_command
-from django.db.models import loading
-from django.test import TestCase
-
-NO_SETTING = ('!', None)
-
-
-class TestSettingsManager(object):
- """
- A class which can modify some Django settings temporarily for a
- test and then revert them to their original values later.
-
- Automatically handles resyncing the DB if INSTALLED_APPS is
- modified.
-
- """
- def __init__(self):
- self._original_settings = {}
-
- def set(self, **kwargs):
- for k, v in kwargs.iteritems():
- self._original_settings.setdefault(k, getattr(settings, k,
- NO_SETTING))
- setattr(settings, k, v)
- if 'INSTALLED_APPS' in kwargs:
- self.syncdb()
-
- def syncdb(self):
- loading.cache.loaded = False
- call_command('syncdb', verbosity=0)
-
- def revert(self):
- for k, v in self._original_settings.iteritems():
- if v == NO_SETTING:
- delattr(settings, k)
- else:
- setattr(settings, k, v)
- if 'INSTALLED_APPS' in self._original_settings:
- self.syncdb()
- self._original_settings = {}
-
-
-class SettingsTestCase(TestCase):
- """
- A subclass of the Django TestCase with a settings_manager
- attribute which is an instance of TestSettingsManager.
-
- Comes with a tearDown() method that calls
- self.settings_manager.revert().
-
- """
- def __init__(self, *args, **kwargs):
- super(SettingsTestCase, self).__init__(*args, **kwargs)
- self.settings_manager = TestSettingsManager()
-
- def tearDown(self):
- self.settings_manager.revert()
-
-
-class TestModelsTestCase(SettingsTestCase):
- def setUp(self, *args, **kwargs):
- installed_apps = tuple(settings.INSTALLED_APPS) + ('rest_framework.tests',)
- self.settings_manager.set(INSTALLED_APPS=installed_apps)
diff --git a/rest_framework/tests/tests.py b/rest_framework/tests/tests.py
index 08f88e11..554ebd1a 100644
--- a/rest_framework/tests/tests.py
+++ b/rest_framework/tests/tests.py
@@ -4,11 +4,13 @@ runner to pick up the tests. Yowzers.
"""
from __future__ import unicode_literals
import os
+import django
modules = [filename.rsplit('.', 1)[0]
for filename in os.listdir(os.path.dirname(__file__))
if filename.endswith('.py') and not filename.startswith('_')]
__test__ = dict()
-for module in modules:
- exec("from rest_framework.tests.%s import *" % module)
+if django.VERSION < (1, 6):
+ for module in modules:
+ exec("from rest_framework.tests.%s import *" % module)