diff options
| author | Tom Christie | 2013-06-05 13:33:19 +0100 |
|---|---|---|
| committer | Tom Christie | 2013-06-05 13:33:19 +0100 |
| commit | de00ec95c3007dd90b5b01f7486b430699ea63c1 (patch) | |
| tree | d2ce8037d446fd9133b3d6a77ebcc49350d7ebc3 /rest_framework/tests | |
| parent | 9428d6ddb5ebc2d5d9c8557a52be09f0def69cca (diff) | |
| parent | 2ca243a1144bb2a5461767a21ed14dec1d2b8dc2 (diff) | |
| download | django-rest-framework-de00ec95c3007dd90b5b01f7486b430699ea63c1.tar.bz2 | |
Merge master
Diffstat (limited to 'rest_framework/tests')
| -rw-r--r-- | rest_framework/tests/models.py | 17 | ||||
| -rw-r--r-- | rest_framework/tests/relations.py | 47 | ||||
| -rw-r--r-- | rest_framework/tests/test_authentication.py (renamed from rest_framework/tests/authentication.py) | 53 | ||||
| -rw-r--r-- | rest_framework/tests/test_breadcrumbs.py (renamed from rest_framework/tests/breadcrumbs.py) | 2 | ||||
| -rw-r--r-- | rest_framework/tests/test_decorators.py (renamed from rest_framework/tests/decorators.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_description.py (renamed from rest_framework/tests/description.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_fields.py (renamed from rest_framework/tests/fields.py) | 242 | ||||
| -rw-r--r-- | rest_framework/tests/test_files.py (renamed from rest_framework/tests/files.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_filters.py (renamed from rest_framework/tests/filterset.py) | 230 | ||||
| -rw-r--r-- | rest_framework/tests/test_genericrelations.py (renamed from rest_framework/tests/genericrelations.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_generics.py (renamed from rest_framework/tests/generics.py) | 115 | ||||
| -rw-r--r-- | rest_framework/tests/test_htmlrenderer.py (renamed from rest_framework/tests/htmlrenderer.py) | 14 | ||||
| -rw-r--r-- | rest_framework/tests/test_hyperlinkedserializers.py (renamed from rest_framework/tests/hyperlinkedserializers.py) | 50 | ||||
| -rw-r--r-- | rest_framework/tests/test_multitable_inheritance.py (renamed from rest_framework/tests/multitable_inheritance.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_negotiation.py (renamed from rest_framework/tests/negotiation.py) | 9 | ||||
| -rw-r--r-- | rest_framework/tests/test_pagination.py (renamed from rest_framework/tests/pagination.py) | 14 | ||||
| -rw-r--r-- | rest_framework/tests/test_parsers.py (renamed from rest_framework/tests/parsers.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_permissions.py (renamed from rest_framework/tests/permissions.py) | 42 | ||||
| -rw-r--r-- | rest_framework/tests/test_relations.py | 100 | ||||
| -rw-r--r-- | rest_framework/tests/test_relations_hyperlink.py (renamed from rest_framework/tests/relations_hyperlink.py) | 79 | ||||
| -rw-r--r-- | rest_framework/tests/test_relations_nested.py (renamed from rest_framework/tests/relations_nested.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_relations_pk.py (renamed from rest_framework/tests/relations_pk.py) | 121 | ||||
| -rw-r--r-- | rest_framework/tests/test_relations_slug.py (renamed from rest_framework/tests/relations_slug.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_renderers.py (renamed from rest_framework/tests/renderers.py) | 66 | ||||
| -rw-r--r-- | rest_framework/tests/test_request.py (renamed from rest_framework/tests/request.py) | 2 | ||||
| -rw-r--r-- | rest_framework/tests/test_response.py (renamed from rest_framework/tests/response.py) | 131 | ||||
| -rw-r--r-- | rest_framework/tests/test_reverse.py (renamed from rest_framework/tests/reverse.py) | 2 | ||||
| -rw-r--r-- | rest_framework/tests/test_routers.py | 150 | ||||
| -rw-r--r-- | rest_framework/tests/test_serializer.py (renamed from rest_framework/tests/serializer.py) | 560 | ||||
| -rw-r--r-- | rest_framework/tests/test_serializer_bulk_update.py (renamed from rest_framework/tests/serializer_bulk_update.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_serializer_nested.py (renamed from rest_framework/tests/serializer_nested.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_settings.py (renamed from rest_framework/tests/settings.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_throttling.py (renamed from rest_framework/tests/throttling.py) | 2 | ||||
| -rw-r--r-- | rest_framework/tests/test_urlpatterns.py (renamed from rest_framework/tests/urlpatterns.py) | 0 | ||||
| -rw-r--r-- | rest_framework/tests/test_validation.py (renamed from rest_framework/tests/validation.py) | 22 | ||||
| -rw-r--r-- | rest_framework/tests/test_views.py (renamed from rest_framework/tests/views.py) | 5 | ||||
| -rw-r--r-- | rest_framework/tests/testcases.py | 66 | ||||
| -rw-r--r-- | rest_framework/tests/tests.py | 6 |
38 files changed, 1889 insertions, 258 deletions
diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py index f2117538..e2d4eacd 100644 --- a/rest_framework/tests/models.py +++ b/rest_framework/tests/models.py @@ -1,5 +1,7 @@ from __future__ import unicode_literals from django.db import models +from django.utils.translation import ugettext_lazy as _ +from rest_framework import serializers def foobar(): @@ -32,7 +34,7 @@ class Anchor(RESTFrameworkModel): class BasicModel(RESTFrameworkModel): - text = models.CharField(max_length=100) + text = models.CharField(max_length=100, verbose_name=_("Text comes here"), help_text=_("Text description.")) class SlugBasedModel(RESTFrameworkModel): @@ -58,13 +60,6 @@ class ReadOnlyManyToManyModel(RESTFrameworkModel): rel = models.ManyToManyField(Anchor) -# Model to test filtering. -class FilterableItem(RESTFrameworkModel): - text = models.CharField(max_length=100) - decimal = models.DecimalField(max_digits=4, decimal_places=2) - date = models.DateField() - - # Model for regression test for #285 class Comment(RESTFrameworkModel): @@ -166,3 +161,9 @@ class NullableOneToOneSource(RESTFrameworkModel): name = models.CharField(max_length=100) target = models.OneToOneField(OneToOneTarget, null=True, blank=True, related_name='nullable_source') + + +# Serializer used to test BasicModel +class BasicModelSerializer(serializers.ModelSerializer): + class Meta: + model = BasicModel diff --git a/rest_framework/tests/relations.py b/rest_framework/tests/relations.py deleted file mode 100644 index cbf93c65..00000000 --- a/rest_framework/tests/relations.py +++ /dev/null @@ -1,47 +0,0 @@ -""" -General tests for relational fields. -""" -from __future__ import unicode_literals -from django.db import models -from django.test import TestCase -from rest_framework import serializers - - -class NullModel(models.Model): - pass - - -class FieldTests(TestCase): - def test_pk_related_field_with_empty_string(self): - """ - Regression test for #446 - - https://github.com/tomchristie/django-rest-framework/issues/446 - """ - field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all()) - self.assertRaises(serializers.ValidationError, field.from_native, '') - self.assertRaises(serializers.ValidationError, field.from_native, []) - - def test_hyperlinked_related_field_with_empty_string(self): - field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='') - self.assertRaises(serializers.ValidationError, field.from_native, '') - self.assertRaises(serializers.ValidationError, field.from_native, []) - - def test_slug_related_field_with_empty_string(self): - field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk') - self.assertRaises(serializers.ValidationError, field.from_native, '') - self.assertRaises(serializers.ValidationError, field.from_native, []) - - -class TestManyRelateMixin(TestCase): - def test_missing_many_to_many_related_field(self): - ''' - Regression test for #632 - - https://github.com/tomchristie/django-rest-framework/pull/632 - ''' - field = serializers.RelatedField(many=True, read_only=False) - - into = {} - field.field_from_native({}, None, 'field_name', into) - self.assertEqual(into['field_name'], []) diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/test_authentication.py index 8e6d3e51..d46ac079 100644 --- a/rest_framework/tests/authentication.py +++ b/rest_framework/tests/test_authentication.py @@ -6,6 +6,8 @@ from django.utils import unittest from rest_framework import HTTP_HEADER_ENCODING from rest_framework import exceptions from rest_framework import permissions +from rest_framework import renderers +from rest_framework.response import Response from rest_framework import status from rest_framework.authentication import ( BaseAuthentication, @@ -48,7 +50,7 @@ urlpatterns = patterns('', (r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])), (r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'), (r'^oauth/$', MockView.as_view(authentication_classes=[OAuthAuthentication])), - (r'^oauth-with-scope/$', MockView.as_view(authentication_classes=[OAuthAuthentication], + (r'^oauth-with-scope/$', MockView.as_view(authentication_classes=[OAuthAuthentication], permission_classes=[permissions.TokenHasReadWriteScope])) ) @@ -56,14 +58,14 @@ if oauth2_provider is not None: urlpatterns += patterns('', url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')), url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])), - url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication], + url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication], permission_classes=[permissions.TokenHasReadWriteScope])), ) class BasicAuthTests(TestCase): """Basic authentication""" - urls = 'rest_framework.tests.authentication' + urls = 'rest_framework.tests.test_authentication' def setUp(self): self.csrf_client = Client(enforce_csrf_checks=True) @@ -102,7 +104,7 @@ class BasicAuthTests(TestCase): class SessionAuthTests(TestCase): """User session authentication""" - urls = 'rest_framework.tests.authentication' + urls = 'rest_framework.tests.test_authentication' def setUp(self): self.csrf_client = Client(enforce_csrf_checks=True) @@ -149,7 +151,7 @@ class SessionAuthTests(TestCase): class TokenAuthTests(TestCase): """Token authentication""" - urls = 'rest_framework.tests.authentication' + urls = 'rest_framework.tests.test_authentication' def setUp(self): self.csrf_client = Client(enforce_csrf_checks=True) @@ -243,7 +245,7 @@ class IncorrectCredentialsTests(TestCase): class OAuthTests(TestCase): """OAuth 1.0a authentication""" - urls = 'rest_framework.tests.authentication' + urls = 'rest_framework.tests.test_authentication' def setUp(self): # these imports are here because oauth is optional and hiding them in try..except block or compat @@ -429,7 +431,7 @@ class OAuthTests(TestCase): class OAuth2Tests(TestCase): """OAuth 2.0 authentication""" - urls = 'rest_framework.tests.authentication' + urls = 'rest_framework.tests.test_authentication' def setUp(self): self.csrf_client = Client(enforce_csrf_checks=True) @@ -553,3 +555,40 @@ class OAuth2Tests(TestCase): auth = self._create_authorization_header(token=read_write_access_token.token) response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) + + +class FailingAuthAccessedInRenderer(TestCase): + def setUp(self): + class AuthAccessingRenderer(renderers.BaseRenderer): + media_type = 'text/plain' + format = 'txt' + + def render(self, data, media_type=None, renderer_context=None): + request = renderer_context['request'] + if request.user.is_authenticated(): + return b'authenticated' + return b'not authenticated' + + class FailingAuth(BaseAuthentication): + def authenticate(self, request): + raise exceptions.AuthenticationFailed('authentication failed') + + class ExampleView(APIView): + authentication_classes = (FailingAuth,) + renderer_classes = (AuthAccessingRenderer,) + + def get(self, request): + return Response({'foo': 'bar'}) + + self.view = ExampleView.as_view() + + def test_failing_auth_accessed_in_renderer(self): + """ + When authentication fails the renderer should still be able to access + `request.user` without raising an exception. Particularly relevant + to HTML responses that might reasonably access `request.user`. + """ + request = factory.get('/') + response = self.view(request) + content = response.render().content + self.assertEqual(content, b'not authenticated') diff --git a/rest_framework/tests/breadcrumbs.py b/rest_framework/tests/test_breadcrumbs.py index d9ed647e..41ddf2ce 100644 --- a/rest_framework/tests/breadcrumbs.py +++ b/rest_framework/tests/test_breadcrumbs.py @@ -36,7 +36,7 @@ urlpatterns = patterns('', class BreadcrumbTests(TestCase): """Tests the breadcrumb functionality used by the HTML renderer.""" - urls = 'rest_framework.tests.breadcrumbs' + urls = 'rest_framework.tests.test_breadcrumbs' def test_root_breadcrumbs(self): url = '/' diff --git a/rest_framework/tests/decorators.py b/rest_framework/tests/test_decorators.py index 1016fed3..1016fed3 100644 --- a/rest_framework/tests/decorators.py +++ b/rest_framework/tests/test_decorators.py diff --git a/rest_framework/tests/description.py b/rest_framework/tests/test_description.py index 52c1a34c..52c1a34c 100644 --- a/rest_framework/tests/description.py +++ b/rest_framework/tests/test_description.py diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/test_fields.py index 3cdfa0f6..69a0468e 100644 --- a/rest_framework/tests/fields.py +++ b/rest_framework/tests/test_fields.py @@ -2,15 +2,16 @@ General serializer field tests. """ from __future__ import unicode_literals + import datetime from decimal import Decimal - +from uuid import uuid4 +from django.core import validators from django.db import models from django.test import TestCase -from django.core import validators - +from django.utils.datastructures import SortedDict from rest_framework import serializers -from rest_framework.serializers import Serializer +from rest_framework.tests.models import RESTFrameworkModel class TimestampedModel(models.Model): @@ -63,6 +64,20 @@ class BasicFieldTests(TestCase): serializer = CharPrimaryKeyModelSerializer() self.assertEqual(serializer.fields['id'].read_only, False) + def test_dict_field_ordering(self): + """ + Field should preserve dictionary ordering, if it exists. + See: https://github.com/tomchristie/django-rest-framework/issues/832 + """ + ret = SortedDict() + ret['c'] = 1 + ret['b'] = 1 + ret['a'] = 1 + ret['z'] = 1 + field = serializers.Field() + keys = list(field.to_native(ret).keys()) + self.assertEqual(keys, ['c', 'b', 'a', 'z']) + class DateFieldTest(TestCase): """ @@ -573,7 +588,7 @@ class DecimalFieldTest(TestCase): """ Make sure the serializer works correctly """ - class DecimalSerializer(Serializer): + class DecimalSerializer(serializers.Serializer): decimal_field = serializers.DecimalField(max_value=9010, min_value=9000, max_digits=6, @@ -591,7 +606,7 @@ class DecimalFieldTest(TestCase): """ Make sure max_value violations raises ValidationError """ - class DecimalSerializer(Serializer): + class DecimalSerializer(serializers.Serializer): decimal_field = serializers.DecimalField(max_value=100) s = DecimalSerializer(data={'decimal_field': '123'}) @@ -603,7 +618,7 @@ class DecimalFieldTest(TestCase): """ Make sure min_value violations raises ValidationError """ - class DecimalSerializer(Serializer): + class DecimalSerializer(serializers.Serializer): decimal_field = serializers.DecimalField(min_value=100) s = DecimalSerializer(data={'decimal_field': '99'}) @@ -615,7 +630,7 @@ class DecimalFieldTest(TestCase): """ Make sure max_digits violations raises ValidationError """ - class DecimalSerializer(Serializer): + class DecimalSerializer(serializers.Serializer): decimal_field = serializers.DecimalField(max_digits=5) s = DecimalSerializer(data={'decimal_field': '123.456'}) @@ -627,7 +642,7 @@ class DecimalFieldTest(TestCase): """ Make sure max_decimal_places violations raises ValidationError """ - class DecimalSerializer(Serializer): + class DecimalSerializer(serializers.Serializer): decimal_field = serializers.DecimalField(decimal_places=3) s = DecimalSerializer(data={'decimal_field': '123.4567'}) @@ -639,10 +654,215 @@ class DecimalFieldTest(TestCase): """ Make sure max_whole_digits violations raises ValidationError """ - class DecimalSerializer(Serializer): + class DecimalSerializer(serializers.Serializer): decimal_field = serializers.DecimalField(max_digits=4, decimal_places=3) s = DecimalSerializer(data={'decimal_field': '12345.6'}) self.assertFalse(s.is_valid()) - self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']})
\ No newline at end of file + self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']}) + + +class ChoiceFieldTests(TestCase): + """ + Tests for the ChoiceField options generator + """ + + SAMPLE_CHOICES = [ + ('red', 'Red'), + ('green', 'Green'), + ('blue', 'Blue'), + ] + + def test_choices_required(self): + """ + Make sure proper choices are rendered if field is required + """ + f = serializers.ChoiceField(required=True, choices=self.SAMPLE_CHOICES) + self.assertEqual(f.choices, self.SAMPLE_CHOICES) + + def test_choices_not_required(self): + """ + Make sure proper choices (plus blank) are rendered if the field isn't required + """ + f = serializers.ChoiceField(required=False, choices=self.SAMPLE_CHOICES) + self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + self.SAMPLE_CHOICES) + + +class EmailFieldTests(TestCase): + """ + Tests for EmailField attribute values + """ + + class EmailFieldModel(RESTFrameworkModel): + email_field = models.EmailField(blank=True) + + class EmailFieldWithGivenMaxLengthModel(RESTFrameworkModel): + email_field = models.EmailField(max_length=150, blank=True) + + def test_default_model_value(self): + class EmailFieldSerializer(serializers.ModelSerializer): + class Meta: + model = self.EmailFieldModel + + serializer = EmailFieldSerializer(data={}) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 75) + + def test_given_model_value(self): + class EmailFieldSerializer(serializers.ModelSerializer): + class Meta: + model = self.EmailFieldWithGivenMaxLengthModel + + serializer = EmailFieldSerializer(data={}) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 150) + + def test_given_serializer_value(self): + class EmailFieldSerializer(serializers.ModelSerializer): + email_field = serializers.EmailField(source='email_field', max_length=20, required=False) + + class Meta: + model = self.EmailFieldModel + + serializer = EmailFieldSerializer(data={}) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 20) + + +class SlugFieldTests(TestCase): + """ + Tests for SlugField attribute values + """ + + class SlugFieldModel(RESTFrameworkModel): + slug_field = models.SlugField(blank=True) + + class SlugFieldWithGivenMaxLengthModel(RESTFrameworkModel): + slug_field = models.SlugField(max_length=84, blank=True) + + def test_default_model_value(self): + class SlugFieldSerializer(serializers.ModelSerializer): + class Meta: + model = self.SlugFieldModel + + serializer = SlugFieldSerializer(data={}) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(getattr(serializer.fields['slug_field'], 'max_length'), 50) + + def test_given_model_value(self): + class SlugFieldSerializer(serializers.ModelSerializer): + class Meta: + model = self.SlugFieldWithGivenMaxLengthModel + + serializer = SlugFieldSerializer(data={}) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(getattr(serializer.fields['slug_field'], 'max_length'), 84) + + def test_given_serializer_value(self): + class SlugFieldSerializer(serializers.ModelSerializer): + slug_field = serializers.SlugField(source='slug_field', + max_length=20, required=False) + + class Meta: + model = self.SlugFieldModel + + serializer = SlugFieldSerializer(data={}) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(getattr(serializer.fields['slug_field'], + 'max_length'), 20) + + def test_invalid_slug(self): + """ + Make sure an invalid slug raises ValidationError + """ + class SlugFieldSerializer(serializers.ModelSerializer): + slug_field = serializers.SlugField(source='slug_field', max_length=20, required=True) + + class Meta: + model = self.SlugFieldModel + + s = SlugFieldSerializer(data={'slug_field': 'a b'}) + + self.assertEqual(s.is_valid(), False) + self.assertEqual(s.errors, {'slug_field': ["Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens."]}) + + +class URLFieldTests(TestCase): + """ + Tests for URLField attribute values + """ + + class URLFieldModel(RESTFrameworkModel): + url_field = models.URLField(blank=True) + + class URLFieldWithGivenMaxLengthModel(RESTFrameworkModel): + url_field = models.URLField(max_length=128, blank=True) + + def test_default_model_value(self): + class URLFieldSerializer(serializers.ModelSerializer): + class Meta: + model = self.URLFieldModel + + serializer = URLFieldSerializer(data={}) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(getattr(serializer.fields['url_field'], + 'max_length'), 200) + + def test_given_model_value(self): + class URLFieldSerializer(serializers.ModelSerializer): + class Meta: + model = self.URLFieldWithGivenMaxLengthModel + + serializer = URLFieldSerializer(data={}) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(getattr(serializer.fields['url_field'], + 'max_length'), 128) + + def test_given_serializer_value(self): + class URLFieldSerializer(serializers.ModelSerializer): + url_field = serializers.URLField(source='url_field', + max_length=20, required=False) + + class Meta: + model = self.URLFieldWithGivenMaxLengthModel + + serializer = URLFieldSerializer(data={}) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(getattr(serializer.fields['url_field'], + 'max_length'), 20) + + +class FieldMetadata(TestCase): + def setUp(self): + self.required_field = serializers.Field() + self.required_field.label = uuid4().hex + self.required_field.required = True + + self.optional_field = serializers.Field() + self.optional_field.label = uuid4().hex + self.optional_field.required = False + + def test_required(self): + self.assertEqual(self.required_field.metadata()['required'], True) + + def test_optional(self): + self.assertEqual(self.optional_field.metadata()['required'], False) + + def test_label(self): + for field in (self.required_field, self.optional_field): + self.assertEqual(field.metadata()['label'], field.label) + + +class FieldCallableDefault(TestCase): + def setUp(self): + self.simple_callable = lambda: 'foo bar' + + def test_default_can_be_simple_callable(self): + """ + Ensure that the 'default' argument can also be a simple callable. + """ + field = serializers.WritableField(default=self.simple_callable) + into = {} + field.field_from_native({}, {}, 'field', into) + self.assertEqual(into, {'field': 'foo bar'}) diff --git a/rest_framework/tests/files.py b/rest_framework/tests/test_files.py index 487046ac..487046ac 100644 --- a/rest_framework/tests/files.py +++ b/rest_framework/tests/test_files.py diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/test_filters.py index 023bd016..aaed6247 100644 --- a/rest_framework/tests/filterset.py +++ b/rest_framework/tests/test_filters.py @@ -1,23 +1,30 @@ from __future__ import unicode_literals import datetime from decimal import Decimal +from django.db import models from django.core.urlresolvers import reverse from django.test import TestCase from django.test.client import RequestFactory from django.utils import unittest from rest_framework import generics, serializers, status, filters from rest_framework.compat import django_filters, patterns, url -from rest_framework.tests.models import FilterableItem, BasicModel +from rest_framework.tests.models import BasicModel factory = RequestFactory() +class FilterableItem(models.Model): + text = models.CharField(max_length=100) + decimal = models.DecimalField(max_digits=4, decimal_places=2) + date = models.DateField() + + if django_filters: # Basic filter on a list view. class FilterFieldsRootView(generics.ListCreateAPIView): model = FilterableItem filter_fields = ['decimal', 'date'] - filter_backend = filters.DjangoFilterBackend + filter_backends = (filters.DjangoFilterBackend,) # These class are used to test a filter class. class SeveralFieldsFilter(django_filters.FilterSet): @@ -32,7 +39,7 @@ if django_filters: class FilterClassRootView(generics.ListCreateAPIView): model = FilterableItem filter_class = SeveralFieldsFilter - filter_backend = filters.DjangoFilterBackend + filter_backends = (filters.DjangoFilterBackend,) # These classes are used to test a misconfigured filter class. class MisconfiguredFilter(django_filters.FilterSet): @@ -45,12 +52,12 @@ if django_filters: class IncorrectlyConfiguredRootView(generics.ListCreateAPIView): model = FilterableItem filter_class = MisconfiguredFilter - filter_backend = filters.DjangoFilterBackend + filter_backends = (filters.DjangoFilterBackend,) class FilterClassDetailView(generics.RetrieveAPIView): model = FilterableItem filter_class = SeveralFieldsFilter - filter_backend = filters.DjangoFilterBackend + filter_backends = (filters.DjangoFilterBackend,) # Regression test for #814 class FilterableItemSerializer(serializers.ModelSerializer): @@ -61,11 +68,21 @@ if django_filters: queryset = FilterableItem.objects.all() serializer_class = FilterableItemSerializer filter_fields = ['decimal', 'date'] - filter_backend = filters.DjangoFilterBackend + filter_backends = (filters.DjangoFilterBackend,) + + class GetQuerysetView(generics.ListCreateAPIView): + serializer_class = FilterableItemSerializer + filter_class = SeveralFieldsFilter + filter_backends = (filters.DjangoFilterBackend,) + + def get_queryset(self): + return FilterableItem.objects.all() urlpatterns = patterns('', url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'), url(r'^$', FilterClassRootView.as_view(), name='root-view'), + url(r'^get-queryset/$', GetQuerysetView.as_view(), + name='get-queryset-view'), ) @@ -141,6 +158,17 @@ class IntegrationTestFiltering(CommonFilteringTestCase): self.assertEqual(response.data, expected_data) @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_filter_with_get_queryset_only(self): + """ + Regression test for #834. + """ + view = GetQuerysetView.as_view() + request = factory.get('/get-queryset/') + view(request).render() + # Used to raise "issubclass() arg 2 must be a class or tuple of classes" + # here when neither `model' nor `queryset' was specified. + + @unittest.skipUnless(django_filters, 'django-filters not installed') def test_get_filtered_class_root_view(self): """ GET requests to filtered ListCreateAPIView that have a filter_class set @@ -215,7 +243,7 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase): """ Integration tests for filtered detail views. """ - urls = 'rest_framework.tests.filterset' + urls = 'rest_framework.tests.test_filters' def _get_url(self, item): return reverse('detail-view', kwargs=dict(pk=item.pk)) @@ -256,3 +284,191 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase): response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date)) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, valid_item_data) + + +class SearchFilterModel(models.Model): + title = models.CharField(max_length=20) + text = models.CharField(max_length=100) + + +class SearchFilterTests(TestCase): + def setUp(self): + # Sequence of title/text is: + # + # z abc + # zz bcd + # zzz cde + # ... + for idx in range(10): + title = 'z' * (idx + 1) + text = ( + chr(idx + ord('a')) + + chr(idx + ord('b')) + + chr(idx + ord('c')) + ) + SearchFilterModel(title=title, text=text).save() + + def test_search(self): + class SearchListView(generics.ListAPIView): + model = SearchFilterModel + filter_backends = (filters.SearchFilter,) + search_fields = ('title', 'text') + + view = SearchListView.as_view() + request = factory.get('?search=b') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 1, 'title': 'z', 'text': 'abc'}, + {'id': 2, 'title': 'zz', 'text': 'bcd'} + ] + ) + + def test_exact_search(self): + class SearchListView(generics.ListAPIView): + model = SearchFilterModel + filter_backends = (filters.SearchFilter,) + search_fields = ('=title', 'text') + + view = SearchListView.as_view() + request = factory.get('?search=zzz') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 3, 'title': 'zzz', 'text': 'cde'} + ] + ) + + def test_startswith_search(self): + class SearchListView(generics.ListAPIView): + model = SearchFilterModel + filter_backends = (filters.SearchFilter,) + search_fields = ('title', '^text') + + view = SearchListView.as_view() + request = factory.get('?search=b') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 2, 'title': 'zz', 'text': 'bcd'} + ] + ) + + +class OrdringFilterModel(models.Model): + title = models.CharField(max_length=20) + text = models.CharField(max_length=100) + + +class OrderingFilterTests(TestCase): + def setUp(self): + # Sequence of title/text is: + # + # zyx abc + # yxw bcd + # xwv cde + for idx in range(3): + title = ( + chr(ord('z') - idx) + + chr(ord('y') - idx) + + chr(ord('x') - idx) + ) + text = ( + chr(idx + ord('a')) + + chr(idx + ord('b')) + + chr(idx + ord('c')) + ) + OrdringFilterModel(title=title, text=text).save() + + def test_ordering(self): + class OrderingListView(generics.ListAPIView): + model = OrdringFilterModel + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + + view = OrderingListView.as_view() + request = factory.get('?ordering=text') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 1, 'title': 'zyx', 'text': 'abc'}, + {'id': 2, 'title': 'yxw', 'text': 'bcd'}, + {'id': 3, 'title': 'xwv', 'text': 'cde'}, + ] + ) + + def test_reverse_ordering(self): + class OrderingListView(generics.ListAPIView): + model = OrdringFilterModel + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + + view = OrderingListView.as_view() + request = factory.get('?ordering=-text') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 3, 'title': 'xwv', 'text': 'cde'}, + {'id': 2, 'title': 'yxw', 'text': 'bcd'}, + {'id': 1, 'title': 'zyx', 'text': 'abc'}, + ] + ) + + def test_incorrectfield_ordering(self): + class OrderingListView(generics.ListAPIView): + model = OrdringFilterModel + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + + view = OrderingListView.as_view() + request = factory.get('?ordering=foobar') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 3, 'title': 'xwv', 'text': 'cde'}, + {'id': 2, 'title': 'yxw', 'text': 'bcd'}, + {'id': 1, 'title': 'zyx', 'text': 'abc'}, + ] + ) + + def test_default_ordering(self): + class OrderingListView(generics.ListAPIView): + model = OrdringFilterModel + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + + view = OrderingListView.as_view() + request = factory.get('') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 3, 'title': 'xwv', 'text': 'cde'}, + {'id': 2, 'title': 'yxw', 'text': 'bcd'}, + {'id': 1, 'title': 'zyx', 'text': 'abc'}, + ] + ) + + def test_default_ordering_using_string(self): + class OrderingListView(generics.ListAPIView): + model = OrdringFilterModel + filter_backends = (filters.OrderingFilter,) + ordering = 'title' + + view = OrderingListView.as_view() + request = factory.get('') + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 3, 'title': 'xwv', 'text': 'cde'}, + {'id': 2, 'title': 'yxw', 'text': 'bcd'}, + {'id': 1, 'title': 'zyx', 'text': 'abc'}, + ] + ) diff --git a/rest_framework/tests/genericrelations.py b/rest_framework/tests/test_genericrelations.py index c38bfb9f..c38bfb9f 100644 --- a/rest_framework/tests/genericrelations.py +++ b/rest_framework/tests/test_genericrelations.py diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/test_generics.py index eca50d82..37734195 100644 --- a/rest_framework/tests/generics.py +++ b/rest_framework/tests/test_generics.py @@ -2,7 +2,7 @@ from __future__ import unicode_literals from django.db import models from django.shortcuts import get_object_or_404 from django.test import TestCase -from rest_framework import generics, serializers, status +from rest_framework import generics, renderers, serializers, status from rest_framework.tests.utils import RequestFactory from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel from rest_framework.compat import six @@ -39,6 +39,7 @@ class SlugBasedInstanceView(InstanceView): """ model = SlugBasedModel serializer_class = SlugSerializer + lookup_field = 'slug' class TestRootView(TestCase): @@ -120,7 +121,25 @@ class TestRootView(TestCase): 'text/html' ], 'name': 'Root', - 'description': 'Example description for OPTIONS.' + 'description': 'Example description for OPTIONS.', + 'actions': { + 'POST': { + 'text': { + 'max_length': 100, + 'read_only': False, + 'required': True, + 'type': 'string', + "label": "Text comes here", + "help_text": "Text description." + }, + 'id': { + 'read_only': True, + 'required': False, + 'type': 'integer', + 'label': 'ID', + }, + } + } } self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, expected) @@ -223,9 +242,9 @@ class TestInstanceView(TestCase): """ OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata """ - request = factory.options('/') - with self.assertNumQueries(0): - response = self.view(request).render() + request = factory.options('/1') + with self.assertNumQueries(1): + response = self.view(request, pk=1).render() expected = { 'parses': [ 'application/json', @@ -237,11 +256,39 @@ class TestInstanceView(TestCase): 'text/html' ], 'name': 'Instance', - 'description': 'Example description for OPTIONS.' + 'description': 'Example description for OPTIONS.', + 'actions': { + 'PUT': { + 'text': { + 'max_length': 100, + 'read_only': False, + 'required': True, + 'type': 'string', + 'label': 'Text comes here', + 'help_text': 'Text description.' + }, + 'id': { + 'read_only': True, + 'required': False, + 'type': 'integer', + 'label': 'ID', + }, + } + } } self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, expected) + def test_get_instance_view_incorrect_arg(self): + """ + GET requests with an incorrect pk type, should raise 404, not 500. + Regression test for #890. + """ + request = factory.get('/a') + with self.assertNumQueries(0): + response = self.view(request, pk='a').render() + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + def test_put_cannot_set_id(self): """ PUT requests to create a new object should not be able to set the id. @@ -434,22 +481,14 @@ class TestFilterBackendAppliedToViews(TestCase): {'id': obj.id, 'text': obj.text} for obj in self.objects.all() ] - self.root_view = RootView.as_view() - self.instance_view = InstanceView.as_view() - self.original_root_backend = getattr(RootView, 'filter_backend') - self.original_instance_backend = getattr(InstanceView, 'filter_backend') - - def tearDown(self): - setattr(RootView, 'filter_backend', self.original_root_backend) - setattr(InstanceView, 'filter_backend', self.original_instance_backend) def test_get_root_view_filters_by_name_with_filter_backend(self): """ GET requests to ListCreateAPIView should return filtered list. """ - setattr(RootView, 'filter_backend', InclusiveFilterBackend) + root_view = RootView.as_view(filter_backends=(InclusiveFilterBackend,)) request = factory.get('/') - response = self.root_view(request).render() + response = root_view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(len(response.data), 1) self.assertEqual(response.data, [{'id': 1, 'text': 'foo'}]) @@ -458,9 +497,9 @@ class TestFilterBackendAppliedToViews(TestCase): """ GET requests to ListCreateAPIView should return empty list when all models are filtered out. """ - setattr(RootView, 'filter_backend', ExclusiveFilterBackend) + root_view = RootView.as_view(filter_backends=(ExclusiveFilterBackend,)) request = factory.get('/') - response = self.root_view(request).render() + response = root_view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, []) @@ -468,9 +507,9 @@ class TestFilterBackendAppliedToViews(TestCase): """ GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out. """ - setattr(InstanceView, 'filter_backend', ExclusiveFilterBackend) + instance_view = InstanceView.as_view(filter_backends=(ExclusiveFilterBackend,)) request = factory.get('/1') - response = self.instance_view(request, pk=1).render() + response = instance_view(request, pk=1).render() self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(response.data, {'detail': 'Not found'}) @@ -478,8 +517,40 @@ class TestFilterBackendAppliedToViews(TestCase): """ GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded """ - setattr(InstanceView, 'filter_backend', InclusiveFilterBackend) + instance_view = InstanceView.as_view(filter_backends=(InclusiveFilterBackend,)) request = factory.get('/1') - response = self.instance_view(request, pk=1).render() + response = instance_view(request, pk=1).render() self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, {'id': 1, 'text': 'foo'}) + + +class TwoFieldModel(models.Model): + field_a = models.CharField(max_length=100) + field_b = models.CharField(max_length=100) + + +class DynamicSerializerView(generics.ListCreateAPIView): + model = TwoFieldModel + renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer) + + def get_serializer_class(self): + if self.request.method == 'POST': + class DynamicSerializer(serializers.ModelSerializer): + class Meta: + model = TwoFieldModel + fields = ('field_b',) + return DynamicSerializer + return super(DynamicSerializerView, self).get_serializer_class() + + +class TestFilterBackendAppliedToViews(TestCase): + + def test_dynamic_serializer_form_in_browsable_api(self): + """ + GET requests to ListCreateAPIView should return filtered list. + """ + view = DynamicSerializerView.as_view() + request = factory.get('/') + response = view(request).render() + self.assertContains(response, 'field_b') + self.assertNotContains(response, 'field_a') diff --git a/rest_framework/tests/htmlrenderer.py b/rest_framework/tests/test_htmlrenderer.py index 8f2e2b5a..8957a43c 100644 --- a/rest_framework/tests/htmlrenderer.py +++ b/rest_framework/tests/test_htmlrenderer.py @@ -42,7 +42,7 @@ urlpatterns = patterns('', class TemplateHTMLRendererTests(TestCase): - urls = 'rest_framework.tests.htmlrenderer' + urls = 'rest_framework.tests.test_htmlrenderer' def setUp(self): """ @@ -66,23 +66,23 @@ class TemplateHTMLRendererTests(TestCase): def test_simple_html_view(self): response = self.client.get('/') self.assertContains(response, "example: foobar") - self.assertEqual(response['Content-Type'], 'text/html') + self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') def test_not_found_html_view(self): response = self.client.get('/not_found') self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(response.content, six.b("404 Not Found")) - self.assertEqual(response['Content-Type'], 'text/html') + self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') def test_permission_denied_html_view(self): response = self.client.get('/permission_denied') self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.content, six.b("403 Forbidden")) - self.assertEqual(response['Content-Type'], 'text/html') + self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') class TemplateHTMLRendererExceptionTests(TestCase): - urls = 'rest_framework.tests.htmlrenderer' + urls = 'rest_framework.tests.test_htmlrenderer' def setUp(self): """ @@ -109,10 +109,10 @@ class TemplateHTMLRendererExceptionTests(TestCase): response = self.client.get('/not_found') self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(response.content, six.b("404: Not found")) - self.assertEqual(response['Content-Type'], 'text/html') + self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') def test_permission_denied_html_view_with_template(self): response = self.client.get('/permission_denied') self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.content, six.b("403: Permission denied")) - self.assertEqual(response['Content-Type'], 'text/html') + self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/test_hyperlinkedserializers.py index 9a61f299..1894ddb2 100644 --- a/rest_framework/tests/hyperlinkedserializers.py +++ b/rest_framework/tests/test_hyperlinkedserializers.py @@ -27,6 +27,14 @@ class PhotoSerializer(serializers.Serializer): return Photo(**attrs) +class AlbumSerializer(serializers.ModelSerializer): + url = serializers.HyperlinkedIdentityField(view_name='album-detail', lookup_field='title') + + class Meta: + model = Album + fields = ('title', 'url') + + class BasicList(generics.ListCreateAPIView): model = BasicModel model_serializer_class = serializers.HyperlinkedModelSerializer @@ -73,6 +81,8 @@ class PhotoListCreate(generics.ListCreateAPIView): class AlbumDetail(generics.RetrieveAPIView): model = Album + serializer_class = AlbumSerializer + lookup_field = 'title' class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView): @@ -96,7 +106,7 @@ urlpatterns = patterns('', class TestBasicHyperlinkedView(TestCase): - urls = 'rest_framework.tests.hyperlinkedserializers' + urls = 'rest_framework.tests.test_hyperlinkedserializers' def setUp(self): """ @@ -133,7 +143,7 @@ class TestBasicHyperlinkedView(TestCase): class TestManyToManyHyperlinkedView(TestCase): - urls = 'rest_framework.tests.hyperlinkedserializers' + urls = 'rest_framework.tests.test_hyperlinkedserializers' def setUp(self): """ @@ -180,8 +190,38 @@ class TestManyToManyHyperlinkedView(TestCase): self.assertEqual(response.data, self.data[0]) +class TestHyperlinkedIdentityFieldLookup(TestCase): + urls = 'rest_framework.tests.test_hyperlinkedserializers' + + def setUp(self): + """ + Create 3 Album instances. + """ + titles = ['foo', 'bar', 'baz'] + for title in titles: + album = Album(title=title) + album.save() + self.detail_view = AlbumDetail.as_view() + self.data = { + 'foo': {'title': 'foo', 'url': 'http://testserver/albums/foo/'}, + 'bar': {'title': 'bar', 'url': 'http://testserver/albums/bar/'}, + 'baz': {'title': 'baz', 'url': 'http://testserver/albums/baz/'} + } + + def test_lookup_field(self): + """ + GET requests to AlbumDetail view should return serialized Albums + with a url field keyed by `title`. + """ + for album in Album.objects.all(): + request = factory.get('/albums/{0}/'.format(album.title)) + response = self.detail_view(request, title=album.title) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, self.data[album.title]) + + class TestCreateWithForeignKeys(TestCase): - urls = 'rest_framework.tests.hyperlinkedserializers' + urls = 'rest_framework.tests.test_hyperlinkedserializers' def setUp(self): """ @@ -206,7 +246,7 @@ class TestCreateWithForeignKeys(TestCase): class TestCreateWithForeignKeysAndCustomSlug(TestCase): - urls = 'rest_framework.tests.hyperlinkedserializers' + urls = 'rest_framework.tests.test_hyperlinkedserializers' def setUp(self): """ @@ -231,7 +271,7 @@ class TestCreateWithForeignKeysAndCustomSlug(TestCase): class TestOptionalRelationHyperlinkedView(TestCase): - urls = 'rest_framework.tests.hyperlinkedserializers' + urls = 'rest_framework.tests.test_hyperlinkedserializers' def setUp(self): """ diff --git a/rest_framework/tests/multitable_inheritance.py b/rest_framework/tests/test_multitable_inheritance.py index 00c15327..00c15327 100644 --- a/rest_framework/tests/multitable_inheritance.py +++ b/rest_framework/tests/test_multitable_inheritance.py diff --git a/rest_framework/tests/negotiation.py b/rest_framework/tests/test_negotiation.py index 43721b84..7f84827f 100644 --- a/rest_framework/tests/negotiation.py +++ b/rest_framework/tests/test_negotiation.py @@ -3,19 +3,24 @@ from django.test import TestCase from django.test.client import RequestFactory from rest_framework.negotiation import DefaultContentNegotiation from rest_framework.request import Request +from rest_framework.renderers import BaseRenderer factory = RequestFactory() -class MockJSONRenderer(object): +class MockJSONRenderer(BaseRenderer): media_type = 'application/json' -class MockHTMLRenderer(object): +class MockHTMLRenderer(BaseRenderer): media_type = 'text/html' +class NoCharsetSpecifiedRenderer(BaseRenderer): + media_type = 'my/media' + + class TestAcceptedMediaType(TestCase): def setUp(self): self.renderers = [MockJSONRenderer(), MockHTMLRenderer()] diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/test_pagination.py index 6b8ef02f..e538a78e 100644 --- a/rest_framework/tests/pagination.py +++ b/rest_framework/tests/test_pagination.py @@ -1,18 +1,24 @@ from __future__ import unicode_literals import datetime from decimal import Decimal -import django +from django.db import models from django.core.paginator import Paginator from django.test import TestCase from django.test.client import RequestFactory from django.utils import unittest from rest_framework import generics, status, pagination, filters, serializers from rest_framework.compat import django_filters -from rest_framework.tests.models import BasicModel, FilterableItem +from rest_framework.tests.models import BasicModel factory = RequestFactory() +class FilterableItem(models.Model): + text = models.CharField(max_length=100) + decimal = models.DecimalField(max_digits=4, decimal_places=2) + date = models.DateField() + + class RootView(generics.ListCreateAPIView): """ Example description for OPTIONS. @@ -124,7 +130,7 @@ class IntegrationTestPaginationAndFiltering(TestCase): model = FilterableItem paginate_by = 10 filter_class = DecimalFilter - filter_backend = filters.DjangoFilterBackend + filter_backends = (filters.DjangoFilterBackend,) view = FilterFieldsRootView.as_view() @@ -171,7 +177,7 @@ class IntegrationTestPaginationAndFiltering(TestCase): class BasicFilterFieldsRootView(generics.ListCreateAPIView): model = FilterableItem paginate_by = 10 - filter_backend = DecimalFilterBackend + filter_backends = (DecimalFilterBackend,) view = BasicFilterFieldsRootView.as_view() diff --git a/rest_framework/tests/parsers.py b/rest_framework/tests/test_parsers.py index 7699e10c..7699e10c 100644 --- a/rest_framework/tests/parsers.py +++ b/rest_framework/tests/test_parsers.py diff --git a/rest_framework/tests/permissions.py b/rest_framework/tests/test_permissions.py index b3993be5..6caaf65b 100644 --- a/rest_framework/tests/permissions.py +++ b/rest_framework/tests/test_permissions.py @@ -108,6 +108,48 @@ class ModelPermissionsIntegrationTests(TestCase): response = instance_view(request, pk='2') self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + def test_options_permitted(self): + request = factory.options('/', content_type='application/json', + HTTP_AUTHORIZATION=self.permitted_credentials) + response = root_view(request, pk='1') + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIn('actions', response.data) + self.assertEqual(list(response.data['actions'].keys()), ['POST']) + + request = factory.options('/1', content_type='application/json', + HTTP_AUTHORIZATION=self.permitted_credentials) + response = instance_view(request, pk='1') + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIn('actions', response.data) + self.assertEqual(list(response.data['actions'].keys()), ['PUT']) + + def test_options_disallowed(self): + request = factory.options('/', content_type='application/json', + HTTP_AUTHORIZATION=self.disallowed_credentials) + response = root_view(request, pk='1') + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertNotIn('actions', response.data) + + request = factory.options('/1', content_type='application/json', + HTTP_AUTHORIZATION=self.disallowed_credentials) + response = instance_view(request, pk='1') + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertNotIn('actions', response.data) + + def test_options_updateonly(self): + request = factory.options('/', content_type='application/json', + HTTP_AUTHORIZATION=self.updateonly_credentials) + response = root_view(request, pk='1') + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertNotIn('actions', response.data) + + request = factory.options('/1', content_type='application/json', + HTTP_AUTHORIZATION=self.updateonly_credentials) + response = instance_view(request, pk='1') + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIn('actions', response.data) + self.assertEqual(list(response.data['actions'].keys()), ['PUT']) + class OwnerModel(models.Model): text = models.CharField(max_length=100) diff --git a/rest_framework/tests/test_relations.py b/rest_framework/tests/test_relations.py new file mode 100644 index 00000000..d19219c9 --- /dev/null +++ b/rest_framework/tests/test_relations.py @@ -0,0 +1,100 @@ +""" +General tests for relational fields. +""" +from __future__ import unicode_literals +from django.db import models +from django.test import TestCase +from rest_framework import serializers +from rest_framework.tests.models import BlogPost + + +class NullModel(models.Model): + pass + + +class FieldTests(TestCase): + def test_pk_related_field_with_empty_string(self): + """ + Regression test for #446 + + https://github.com/tomchristie/django-rest-framework/issues/446 + """ + field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all()) + self.assertRaises(serializers.ValidationError, field.from_native, '') + self.assertRaises(serializers.ValidationError, field.from_native, []) + + def test_hyperlinked_related_field_with_empty_string(self): + field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='') + self.assertRaises(serializers.ValidationError, field.from_native, '') + self.assertRaises(serializers.ValidationError, field.from_native, []) + + def test_slug_related_field_with_empty_string(self): + field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk') + self.assertRaises(serializers.ValidationError, field.from_native, '') + self.assertRaises(serializers.ValidationError, field.from_native, []) + + +class TestManyRelatedMixin(TestCase): + def test_missing_many_to_many_related_field(self): + ''' + Regression test for #632 + + https://github.com/tomchristie/django-rest-framework/pull/632 + ''' + field = serializers.RelatedField(many=True, read_only=False) + + into = {} + field.field_from_native({}, None, 'field_name', into) + self.assertEqual(into['field_name'], []) + + +# Regression tests for #694 (`source` attribute on related fields) + +class RelatedFieldSourceTests(TestCase): + def test_related_manager_source(self): + """ + Relational fields should be able to use manager-returning methods as their source. + """ + BlogPost.objects.create(title='blah') + field = serializers.RelatedField(many=True, source='get_blogposts_manager') + + class ClassWithManagerMethod(object): + def get_blogposts_manager(self): + return BlogPost.objects + + obj = ClassWithManagerMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, ['BlogPost object']) + + def test_related_queryset_source(self): + """ + Relational fields should be able to use queryset-returning methods as their source. + """ + BlogPost.objects.create(title='blah') + field = serializers.RelatedField(many=True, source='get_blogposts_queryset') + + class ClassWithQuerysetMethod(object): + def get_blogposts_queryset(self): + return BlogPost.objects.all() + + obj = ClassWithQuerysetMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, ['BlogPost object']) + + def test_dotted_source(self): + """ + Source argument should support dotted.source notation. + """ + BlogPost.objects.create(title='blah') + field = serializers.RelatedField(many=True, source='a.b.c') + + class ClassWithQuerysetMethod(object): + a = { + 'b': { + 'c': BlogPost.objects.all() + } + } + + obj = ClassWithQuerysetMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, ['BlogPost object']) diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/test_relations_hyperlink.py index b1eed9a7..2ca7f4f2 100644 --- a/rest_framework/tests/relations_hyperlink.py +++ b/rest_framework/tests/test_relations_hyperlink.py @@ -4,6 +4,7 @@ from django.test.client import RequestFactory from rest_framework import serializers from rest_framework.compat import patterns, url from rest_framework.tests.models import ( + BlogPost, ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource ) @@ -16,6 +17,7 @@ def dummy_view(request, pk): pass urlpatterns = patterns('', + url(r'^dummyurl/(?P<pk>[0-9]+)/$', dummy_view, name='dummy-url'), url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'), url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'), url(r'^foreignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeysource-detail'), @@ -69,7 +71,7 @@ class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer): # TODO: Add test that .data cannot be accessed prior to .is_valid class HyperlinkedManyToManyTests(TestCase): - urls = 'rest_framework.tests.relations_hyperlink' + urls = 'rest_framework.tests.test_relations_hyperlink' def setUp(self): for idx in range(1, 4): @@ -177,7 +179,7 @@ class HyperlinkedManyToManyTests(TestCase): class HyperlinkedForeignKeyTests(TestCase): - urls = 'rest_framework.tests.relations_hyperlink' + urls = 'rest_framework.tests.test_relations_hyperlink' def setUp(self): target = ForeignKeyTarget(name='target-1') @@ -305,7 +307,7 @@ class HyperlinkedForeignKeyTests(TestCase): class HyperlinkedNullableForeignKeyTests(TestCase): - urls = 'rest_framework.tests.relations_hyperlink' + urls = 'rest_framework.tests.test_relations_hyperlink' def setUp(self): target = ForeignKeyTarget(name='target-1') @@ -433,7 +435,7 @@ class HyperlinkedNullableForeignKeyTests(TestCase): class HyperlinkedNullableOneToOneTests(TestCase): - urls = 'rest_framework.tests.relations_hyperlink' + urls = 'rest_framework.tests.test_relations_hyperlink' def setUp(self): target = OneToOneTarget(name='target-1') @@ -451,3 +453,72 @@ class HyperlinkedNullableOneToOneTests(TestCase): {'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None}, ] self.assertEqual(serializer.data, expected) + + +# Regression tests for #694 (`source` attribute on related fields) + +class HyperlinkedRelatedFieldSourceTests(TestCase): + urls = 'rest_framework.tests.test_relations_hyperlink' + + def test_related_manager_source(self): + """ + Relational fields should be able to use manager-returning methods as their source. + """ + BlogPost.objects.create(title='blah') + field = serializers.HyperlinkedRelatedField( + many=True, + source='get_blogposts_manager', + view_name='dummy-url', + ) + field.context = {'request': request} + + class ClassWithManagerMethod(object): + def get_blogposts_manager(self): + return BlogPost.objects + + obj = ClassWithManagerMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, ['http://testserver/dummyurl/1/']) + + def test_related_queryset_source(self): + """ + Relational fields should be able to use queryset-returning methods as their source. + """ + BlogPost.objects.create(title='blah') + field = serializers.HyperlinkedRelatedField( + many=True, + source='get_blogposts_queryset', + view_name='dummy-url', + ) + field.context = {'request': request} + + class ClassWithQuerysetMethod(object): + def get_blogposts_queryset(self): + return BlogPost.objects.all() + + obj = ClassWithQuerysetMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, ['http://testserver/dummyurl/1/']) + + def test_dotted_source(self): + """ + Source argument should support dotted.source notation. + """ + BlogPost.objects.create(title='blah') + field = serializers.HyperlinkedRelatedField( + many=True, + source='a.b.c', + view_name='dummy-url', + ) + field.context = {'request': request} + + class ClassWithQuerysetMethod(object): + a = { + 'b': { + 'c': BlogPost.objects.all() + } + } + + obj = ClassWithQuerysetMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, ['http://testserver/dummyurl/1/']) diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/test_relations_nested.py index 8325580f..8325580f 100644 --- a/rest_framework/tests/relations_nested.py +++ b/rest_framework/tests/test_relations_nested.py diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/test_relations_pk.py index 5ce8b567..e2a1b815 100644 --- a/rest_framework/tests/relations_pk.py +++ b/rest_framework/tests/test_relations_pk.py @@ -1,7 +1,11 @@ from __future__ import unicode_literals +from django.db import models from django.test import TestCase from rest_framework import serializers -from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource +from rest_framework.tests.models import ( + BlogPost, ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, + NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource, +) from rest_framework.compat import six @@ -124,6 +128,7 @@ class PKManyToManyTests(TestCase): # Ensure source 4 is added, and everything else is as expected queryset = ManyToManySource.objects.all() serializer = ManyToManySourceSerializer(queryset, many=True) + self.assertFalse(serializer.fields['targets'].read_only) expected = [ {'id': 1, 'name': 'source-1', 'targets': [1]}, {'id': 2, 'name': 'source-2', 'targets': [1, 2]}, @@ -135,6 +140,7 @@ class PKManyToManyTests(TestCase): def test_reverse_many_to_many_create(self): data = {'id': 4, 'name': 'target-4', 'sources': [1, 3]} serializer = ManyToManyTargetSerializer(data=data) + self.assertFalse(serializer.fields['sources'].read_only) self.assertTrue(serializer.is_valid()) obj = serializer.save() self.assertEqual(serializer.data, data) @@ -421,3 +427,116 @@ class PKNullableOneToOneTests(TestCase): {'id': 2, 'name': 'target-2', 'nullable_source': 1}, ] self.assertEqual(serializer.data, expected) + + +# The below models and tests ensure that serializer fields corresponding +# to a ManyToManyField field with a user-specified ``through`` model are +# set to read only + + +class ManyToManyThroughTarget(models.Model): + name = models.CharField(max_length=100) + + +class ManyToManyThrough(models.Model): + source = models.ForeignKey('ManyToManyThroughSource') + target = models.ForeignKey(ManyToManyThroughTarget) + + +class ManyToManyThroughSource(models.Model): + name = models.CharField(max_length=100) + targets = models.ManyToManyField(ManyToManyThroughTarget, + related_name='sources', + through='ManyToManyThrough') + + +class ManyToManyThroughTargetSerializer(serializers.ModelSerializer): + class Meta: + model = ManyToManyThroughTarget + fields = ('id', 'name', 'sources') + + +class ManyToManyThroughSourceSerializer(serializers.ModelSerializer): + class Meta: + model = ManyToManyThroughSource + fields = ('id', 'name', 'targets') + + +class PKManyToManyThroughTests(TestCase): + def setUp(self): + self.source = ManyToManyThroughSource.objects.create( + name='through-source-1') + self.target = ManyToManyThroughTarget.objects.create( + name='through-target-1') + + def test_many_to_many_create(self): + data = {'id': 2, 'name': 'source-2', 'targets': [self.target.pk]} + serializer = ManyToManyThroughSourceSerializer(data=data) + self.assertTrue(serializer.fields['targets'].read_only) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(obj.name, 'source-2') + self.assertEqual(obj.targets.count(), 0) + + def test_many_to_many_reverse_create(self): + data = {'id': 2, 'name': 'target-2', 'sources': [self.source.pk]} + serializer = ManyToManyThroughTargetSerializer(data=data) + self.assertTrue(serializer.fields['sources'].read_only) + self.assertTrue(serializer.is_valid()) + serializer.save() + obj = serializer.save() + self.assertEqual(obj.name, 'target-2') + self.assertEqual(obj.sources.count(), 0) + + +# Regression tests for #694 (`source` attribute on related fields) + + +class PrimaryKeyRelatedFieldSourceTests(TestCase): + def test_related_manager_source(self): + """ + Relational fields should be able to use manager-returning methods as their source. + """ + BlogPost.objects.create(title='blah') + field = serializers.PrimaryKeyRelatedField(many=True, source='get_blogposts_manager') + + class ClassWithManagerMethod(object): + def get_blogposts_manager(self): + return BlogPost.objects + + obj = ClassWithManagerMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, [1]) + + def test_related_queryset_source(self): + """ + Relational fields should be able to use queryset-returning methods as their source. + """ + BlogPost.objects.create(title='blah') + field = serializers.PrimaryKeyRelatedField(many=True, source='get_blogposts_queryset') + + class ClassWithQuerysetMethod(object): + def get_blogposts_queryset(self): + return BlogPost.objects.all() + + obj = ClassWithQuerysetMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, [1]) + + def test_dotted_source(self): + """ + Source argument should support dotted.source notation. + """ + BlogPost.objects.create(title='blah') + field = serializers.PrimaryKeyRelatedField(many=True, source='a.b.c') + + class ClassWithQuerysetMethod(object): + a = { + 'b': { + 'c': BlogPost.objects.all() + } + } + + obj = ClassWithQuerysetMethod() + value = field.field_to_native(obj, 'field_name') + self.assertEqual(value, [1]) diff --git a/rest_framework/tests/relations_slug.py b/rest_framework/tests/test_relations_slug.py index 435c821c..435c821c 100644 --- a/rest_framework/tests/relations_slug.py +++ b/rest_framework/tests/test_relations_slug.py diff --git a/rest_framework/tests/renderers.py b/rest_framework/tests/test_renderers.py index 40bac9cb..95b59741 100644 --- a/rest_framework/tests/renderers.py +++ b/rest_framework/tests/test_renderers.py @@ -1,14 +1,18 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals + from decimal import Decimal from django.core.cache import cache from django.test import TestCase from django.test.client import RequestFactory from django.utils import unittest +from django.utils.translation import ugettext_lazy as _ from rest_framework import status, permissions from rest_framework.compat import yaml, etree, patterns, url, include from rest_framework.response import Response from rest_framework.views import APIView from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \ - XMLRenderer, JSONPRenderer, BrowsableAPIRenderer + XMLRenderer, JSONPRenderer, BrowsableAPIRenderer, UnicodeJSONRenderer from rest_framework.parsers import YAMLParser, XMLParser from rest_framework.settings import api_settings from rest_framework.compat import StringIO @@ -26,7 +30,7 @@ RENDERER_B_SERIALIZER = lambda x: ('Renderer B: %s' % x).encode('ascii') expected_results = [ - ((elem for elem in [1, 2, 3]), JSONRenderer, '[1, 2, 3]') # Generator + ((elem for elem in [1, 2, 3]), JSONRenderer, b'[1, 2, 3]') # Generator ] @@ -129,12 +133,12 @@ class RendererEndToEndTests(TestCase): End-to-end testing of renderers using an RendererMixin on a generic view. """ - urls = 'rest_framework.tests.renderers' + urls = 'rest_framework.tests.test_renderers' def test_default_renderer_serializes_content(self): """If the Accept header is not set the default renderer should serialize the response.""" resp = self.client.get('/') - self.assertEqual(resp['Content-Type'], RendererA.media_type) + self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.status_code, DUMMYSTATUS) @@ -142,13 +146,13 @@ class RendererEndToEndTests(TestCase): """No response must be included in HEAD requests.""" resp = self.client.head('/') self.assertEqual(resp.status_code, DUMMYSTATUS) - self.assertEqual(resp['Content-Type'], RendererA.media_type) + self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') self.assertEqual(resp.content, six.b('')) def test_default_renderer_serializes_content_on_accept_any(self): """If the Accept header is set to */* the default renderer should serialize the response.""" resp = self.client.get('/', HTTP_ACCEPT='*/*') - self.assertEqual(resp['Content-Type'], RendererA.media_type) + self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.status_code, DUMMYSTATUS) @@ -156,7 +160,7 @@ class RendererEndToEndTests(TestCase): """If the Accept header is set the specified renderer should serialize the response. (In this case we check that works for the default renderer)""" resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type) - self.assertEqual(resp['Content-Type'], RendererA.media_type) + self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.status_code, DUMMYSTATUS) @@ -164,7 +168,7 @@ class RendererEndToEndTests(TestCase): """If the Accept header is set the specified renderer should serialize the response. (In this case we check that works for a non-default renderer)""" resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type) - self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.status_code, DUMMYSTATUS) @@ -175,7 +179,7 @@ class RendererEndToEndTests(TestCase): RendererB.media_type ) resp = self.client.get('/' + param) - self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.status_code, DUMMYSTATUS) @@ -192,7 +196,7 @@ class RendererEndToEndTests(TestCase): RendererB.format ) resp = self.client.get('/' + param) - self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.status_code, DUMMYSTATUS) @@ -200,7 +204,7 @@ class RendererEndToEndTests(TestCase): """If a 'format' keyword arg is specified, the renderer with the matching format attribute should serialize the response.""" resp = self.client.get('/something.formatb') - self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.status_code, DUMMYSTATUS) @@ -213,7 +217,7 @@ class RendererEndToEndTests(TestCase): ) resp = self.client.get('/' + param, HTTP_ACCEPT=RendererB.media_type) - self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.status_code, DUMMYSTATUS) @@ -235,6 +239,13 @@ class JSONRendererTests(TestCase): Tests specific to the JSON Renderer """ + def test_render_lazy_strings(self): + """ + JSONRenderer should deal with lazy translated strings. + """ + ret = JSONRenderer().render(_('test')) + self.assertEqual(ret, b'"test"') + def test_without_content_type_args(self): """ Test basic JSON rendering. @@ -243,7 +254,7 @@ class JSONRendererTests(TestCase): renderer = JSONRenderer() content = renderer.render(obj, 'application/json') # Fix failing test case which depends on version of JSON library. - self.assertEqual(content, _flat_repr) + self.assertEqual(content.decode('utf-8'), _flat_repr) def test_with_content_type_args(self): """ @@ -252,7 +263,24 @@ class JSONRendererTests(TestCase): obj = {'foo': ['bar', 'baz']} renderer = JSONRenderer() content = renderer.render(obj, 'application/json; indent=2') - self.assertEqual(strip_trailing_whitespace(content), _indented_repr) + self.assertEqual(strip_trailing_whitespace(content.decode('utf-8')), _indented_repr) + + def test_check_ascii(self): + obj = {'countries': ['United Kingdom', 'France', 'España']} + renderer = JSONRenderer() + content = renderer.render(obj, 'application/json') + self.assertEqual(content, '{"countries": ["United Kingdom", "France", "Espa\\u00f1a"]}'.encode('utf-8')) + + +class UnicodeJSONRendererTests(TestCase): + """ + Tests specific for the Unicode JSON Renderer + """ + def test_proper_encoding(self): + obj = {'countries': ['United Kingdom', 'France', 'España']} + renderer = UnicodeJSONRenderer() + content = renderer.render(obj, 'application/json') + self.assertEqual(content, '{"countries": ["United Kingdom", "France", "España"]}'.encode('utf-8')) class JSONPRendererTests(TestCase): @@ -260,7 +288,7 @@ class JSONPRendererTests(TestCase): Tests specific to the JSONP Renderer """ - urls = 'rest_framework.tests.renderers' + urls = 'rest_framework.tests.test_renderers' def test_without_callback_with_json_renderer(self): """ @@ -269,7 +297,7 @@ class JSONPRendererTests(TestCase): resp = self.client.get('/jsonp/jsonrenderer', HTTP_ACCEPT='application/javascript') self.assertEqual(resp.status_code, status.HTTP_200_OK) - self.assertEqual(resp['Content-Type'], 'application/javascript') + self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8') self.assertEqual(resp.content, ('callback(%s);' % _flat_repr).encode('ascii')) @@ -280,7 +308,7 @@ class JSONPRendererTests(TestCase): resp = self.client.get('/jsonp/nojsonrenderer', HTTP_ACCEPT='application/javascript') self.assertEqual(resp.status_code, status.HTTP_200_OK) - self.assertEqual(resp['Content-Type'], 'application/javascript') + self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8') self.assertEqual(resp.content, ('callback(%s);' % _flat_repr).encode('ascii')) @@ -292,7 +320,7 @@ class JSONPRendererTests(TestCase): resp = self.client.get('/jsonp/nojsonrenderer?callback=' + callback_func, HTTP_ACCEPT='application/javascript') self.assertEqual(resp.status_code, status.HTTP_200_OK) - self.assertEqual(resp['Content-Type'], 'application/javascript') + self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8') self.assertEqual(resp.content, ('%s(%s);' % (callback_func, _flat_repr)).encode('ascii')) @@ -433,7 +461,7 @@ class CacheRenderTest(TestCase): Tests specific to caching responses """ - urls = 'rest_framework.tests.renderers' + urls = 'rest_framework.tests.test_renderers' cache_key = 'just_a_cache_key' diff --git a/rest_framework/tests/request.py b/rest_framework/tests/test_request.py index 97e5af20..a5c5e84c 100644 --- a/rest_framework/tests/request.py +++ b/rest_framework/tests/test_request.py @@ -254,7 +254,7 @@ urlpatterns = patterns('', class TestContentParsingWithAuthentication(TestCase): - urls = 'rest_framework.tests.request' + urls = 'rest_framework.tests.test_request' def setUp(self): self.csrf_client = Client(enforce_csrf_checks=True) diff --git a/rest_framework/tests/response.py b/rest_framework/tests/test_response.py index aecf83f4..eea3c641 100644 --- a/rest_framework/tests/response.py +++ b/rest_framework/tests/test_response.py @@ -1,14 +1,18 @@ from __future__ import unicode_literals from django.test import TestCase +from rest_framework.tests.models import BasicModel, BasicModelSerializer from rest_framework.compat import patterns, url, include from rest_framework.response import Response from rest_framework.views import APIView +from rest_framework import generics +from rest_framework import routers from rest_framework import status from rest_framework.renderers import ( BaseRenderer, JSONRenderer, BrowsableAPIRenderer ) +from rest_framework import viewsets from rest_framework.settings import api_settings from rest_framework.compat import six @@ -21,6 +25,9 @@ class MockJsonRenderer(BaseRenderer): media_type = 'application/json' +class MockTextMediaRenderer(BaseRenderer): + media_type = 'text/html' + DUMMYSTATUS = status.HTTP_200_OK DUMMYCONTENT = 'dummycontent' @@ -44,13 +51,26 @@ class RendererB(BaseRenderer): return RENDERER_B_SERIALIZER(data) +class RendererC(RendererB): + media_type = 'mock/rendererc' + format = 'formatc' + charset = "rendererc" + + class MockView(APIView): - renderer_classes = (RendererA, RendererB) + renderer_classes = (RendererA, RendererB, RendererC) def get(self, request, **kwargs): return Response(DUMMYCONTENT, status=DUMMYSTATUS) +class MockViewSettingContentType(APIView): + renderer_classes = (RendererA, RendererB, RendererC) + + def get(self, request, **kwargs): + return Response(DUMMYCONTENT, status=DUMMYSTATUS, content_type='setbyview') + + class HTMLView(APIView): renderer_classes = (BrowsableAPIRenderer, ) @@ -65,11 +85,29 @@ class HTMLView1(APIView): return Response('text') +class HTMLNewModelViewSet(viewsets.ModelViewSet): + model = BasicModel + + +class HTMLNewModelView(generics.ListCreateAPIView): + renderer_classes = (BrowsableAPIRenderer,) + permission_classes = [] + serializer_class = BasicModelSerializer + model = BasicModel + + +new_model_viewset_router = routers.DefaultRouter() +new_model_viewset_router.register(r'', HTMLNewModelViewSet) + + urlpatterns = patterns('', - url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])), - url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])), + url(r'^setbyview$', MockViewSettingContentType.as_view(renderer_classes=[RendererA, RendererB, RendererC])), + url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])), + url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])), url(r'^html$', HTMLView.as_view()), url(r'^html1$', HTMLView1.as_view()), + url(r'^html_new_model$', HTMLNewModelView.as_view()), + url(r'^html_new_model_viewset', include(new_model_viewset_router.urls)), url(r'^restframework', include('rest_framework.urls', namespace='rest_framework')) ) @@ -80,12 +118,12 @@ class RendererIntegrationTests(TestCase): End-to-end testing of renderers using an ResponseMixin on a generic view. """ - urls = 'rest_framework.tests.response' + urls = 'rest_framework.tests.test_response' def test_default_renderer_serializes_content(self): """If the Accept header is not set the default renderer should serialize the response.""" resp = self.client.get('/') - self.assertEqual(resp['Content-Type'], RendererA.media_type) + self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.status_code, DUMMYSTATUS) @@ -93,13 +131,13 @@ class RendererIntegrationTests(TestCase): """No response must be included in HEAD requests.""" resp = self.client.head('/') self.assertEqual(resp.status_code, DUMMYSTATUS) - self.assertEqual(resp['Content-Type'], RendererA.media_type) + self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') self.assertEqual(resp.content, six.b('')) def test_default_renderer_serializes_content_on_accept_any(self): """If the Accept header is set to */* the default renderer should serialize the response.""" resp = self.client.get('/', HTTP_ACCEPT='*/*') - self.assertEqual(resp['Content-Type'], RendererA.media_type) + self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.status_code, DUMMYSTATUS) @@ -107,7 +145,7 @@ class RendererIntegrationTests(TestCase): """If the Accept header is set the specified renderer should serialize the response. (In this case we check that works for the default renderer)""" resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type) - self.assertEqual(resp['Content-Type'], RendererA.media_type) + self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.status_code, DUMMYSTATUS) @@ -115,7 +153,7 @@ class RendererIntegrationTests(TestCase): """If the Accept header is set the specified renderer should serialize the response. (In this case we check that works for a non-default renderer)""" resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type) - self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.status_code, DUMMYSTATUS) @@ -126,7 +164,7 @@ class RendererIntegrationTests(TestCase): RendererB.media_type ) resp = self.client.get('/' + param) - self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.status_code, DUMMYSTATUS) @@ -134,7 +172,7 @@ class RendererIntegrationTests(TestCase): """If a 'format' query is specified, the renderer with the matching format attribute should serialize the response.""" resp = self.client.get('/?format=%s' % RendererB.format) - self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.status_code, DUMMYSTATUS) @@ -142,7 +180,7 @@ class RendererIntegrationTests(TestCase): """If a 'format' keyword arg is specified, the renderer with the matching format attribute should serialize the response.""" resp = self.client.get('/something.formatb') - self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.status_code, DUMMYSTATUS) @@ -151,7 +189,7 @@ class RendererIntegrationTests(TestCase): the renderer with the matching format attribute should serialize the response.""" resp = self.client.get('/?format=%s' % RendererB.format, HTTP_ACCEPT=RendererB.media_type) - self.assertEqual(resp['Content-Type'], RendererB.media_type) + self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.status_code, DUMMYSTATUS) @@ -160,7 +198,7 @@ class Issue122Tests(TestCase): """ Tests that covers #122. """ - urls = 'rest_framework.tests.response' + urls = 'rest_framework.tests.test_response' def test_only_html_renderer(self): """ @@ -173,3 +211,68 @@ class Issue122Tests(TestCase): Test if no infinite recursion occurs. """ self.client.get('/html1') + + +class Issue467Tests(TestCase): + """ + Tests for #467 + """ + + urls = 'rest_framework.tests.test_response' + + def test_form_has_label_and_help_text(self): + resp = self.client.get('/html_new_model') + self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') + self.assertContains(resp, 'Text comes here') + self.assertContains(resp, 'Text description.') + + +class Issue807Tests(TestCase): + """ + Covers #807 + """ + + urls = 'rest_framework.tests.test_response' + + def test_does_not_append_charset_by_default(self): + """ + Renderers don't include a charset unless set explicitly. + """ + headers = {"HTTP_ACCEPT": RendererA.media_type} + resp = self.client.get('/', **headers) + expected = "{0}; charset={1}".format(RendererA.media_type, 'utf-8') + self.assertEqual(expected, resp['Content-Type']) + + def test_if_there_is_charset_specified_on_renderer_it_gets_appended(self): + """ + If renderer class has charset attribute declared, it gets appended + to Response's Content-Type + """ + headers = {"HTTP_ACCEPT": RendererC.media_type} + resp = self.client.get('/', **headers) + expected = "{0}; charset={1}".format(RendererC.media_type, RendererC.charset) + self.assertEqual(expected, resp['Content-Type']) + + def test_content_type_set_explictly_on_response(self): + """ + The content type may be set explictly on the response. + """ + headers = {"HTTP_ACCEPT": RendererC.media_type} + resp = self.client.get('/setbyview', **headers) + self.assertEqual('setbyview', resp['Content-Type']) + + def test_viewset_label_help_text(self): + param = '?%s=%s' % ( + api_settings.URL_ACCEPT_OVERRIDE, + 'text/html' + ) + resp = self.client.get('/html_new_model_viewset/' + param) + self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') + self.assertContains(resp, 'Text comes here') + self.assertContains(resp, 'Text description.') + + def test_form_has_label_and_help_text(self): + resp = self.client.get('/html_new_model') + self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') + self.assertContains(resp, 'Text comes here') + self.assertContains(resp, 'Text description.') diff --git a/rest_framework/tests/reverse.py b/rest_framework/tests/test_reverse.py index cb8d8132..93ef5637 100644 --- a/rest_framework/tests/reverse.py +++ b/rest_framework/tests/test_reverse.py @@ -19,7 +19,7 @@ class ReverseTests(TestCase): """ Tests for fully qualified URLs when using `reverse`. """ - urls = 'rest_framework.tests.reverse' + urls = 'rest_framework.tests.test_reverse' def test_reversed_urls_are_fully_qualified(self): request = factory.get('/view') diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py new file mode 100644 index 00000000..a7534f70 --- /dev/null +++ b/rest_framework/tests/test_routers.py @@ -0,0 +1,150 @@ +from __future__ import unicode_literals +from django.db import models +from django.test import TestCase +from django.test.client import RequestFactory +from rest_framework import serializers, viewsets +from rest_framework.compat import include, patterns, url +from rest_framework.decorators import link, action +from rest_framework.response import Response +from rest_framework.routers import SimpleRouter + +factory = RequestFactory() + +urlpatterns = patterns('',) + + +class BasicViewSet(viewsets.ViewSet): + def list(self, request, *args, **kwargs): + return Response({'method': 'list'}) + + @action() + def action1(self, request, *args, **kwargs): + return Response({'method': 'action1'}) + + @action() + def action2(self, request, *args, **kwargs): + return Response({'method': 'action2'}) + + @action(methods=['post', 'delete']) + def action3(self, request, *args, **kwargs): + return Response({'method': 'action2'}) + + @link() + def link1(self, request, *args, **kwargs): + return Response({'method': 'link1'}) + + @link() + def link2(self, request, *args, **kwargs): + return Response({'method': 'link2'}) + + +class TestSimpleRouter(TestCase): + def setUp(self): + self.router = SimpleRouter() + + def test_link_and_action_decorator(self): + routes = self.router.get_routes(BasicViewSet) + decorator_routes = routes[2:] + # Make sure all these endpoints exist and none have been clobbered + for i, endpoint in enumerate(['action1', 'action2', 'action3', 'link1', 'link2']): + route = decorator_routes[i] + # check url listing + self.assertEqual(route.url, + '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(endpoint)) + # check method to function mapping + if endpoint == 'action3': + methods_map = ['post', 'delete'] + elif endpoint.startswith('action'): + methods_map = ['post'] + else: + methods_map = ['get'] + for method in methods_map: + self.assertEqual(route.mapping[method], endpoint) + + +class RouterTestModel(models.Model): + uuid = models.CharField(max_length=20) + text = models.CharField(max_length=200) + + +class TestCustomLookupFields(TestCase): + """ + Ensure that custom lookup fields are correctly routed. + """ + urls = 'rest_framework.tests.test_routers' + + def setUp(self): + class NoteSerializer(serializers.HyperlinkedModelSerializer): + class Meta: + model = RouterTestModel + lookup_field = 'uuid' + fields = ('url', 'uuid', 'text') + + class NoteViewSet(viewsets.ModelViewSet): + queryset = RouterTestModel.objects.all() + serializer_class = NoteSerializer + lookup_field = 'uuid' + + RouterTestModel.objects.create(uuid='123', text='foo bar') + + self.router = SimpleRouter() + self.router.register(r'notes', NoteViewSet) + + from rest_framework.tests import test_routers + urls = getattr(test_routers, 'urlpatterns') + urls += patterns('', + url(r'^', include(self.router.urls)), + ) + + def test_custom_lookup_field_route(self): + detail_route = self.router.urls[-1] + detail_url_pattern = detail_route.regex.pattern + self.assertIn('<uuid>', detail_url_pattern) + + def test_retrieve_lookup_field_list_view(self): + response = self.client.get('/notes/') + self.assertEqual(response.data, + [{ + "url": "http://testserver/notes/123/", + "uuid": "123", "text": "foo bar" + }] + ) + + def test_retrieve_lookup_field_detail_view(self): + response = self.client.get('/notes/123/') + self.assertEqual(response.data, + { + "url": "http://testserver/notes/123/", + "uuid": "123", "text": "foo bar" + } + ) + + +class TestTrailingSlash(TestCase): + def setUp(self): + class NoteViewSet(viewsets.ModelViewSet): + model = RouterTestModel + + self.router = SimpleRouter() + self.router.register(r'notes', NoteViewSet) + self.urls = self.router.urls + + def test_urls_have_trailing_slash_by_default(self): + expected = ['^notes/$', '^notes/(?P<pk>[^/]+)/$'] + for idx in range(len(expected)): + self.assertEqual(expected[idx], self.urls[idx].regex.pattern) + + +class TestTrailingSlash(TestCase): + def setUp(self): + class NoteViewSet(viewsets.ModelViewSet): + model = RouterTestModel + + self.router = SimpleRouter(trailing_slash=False) + self.router.register(r'notes', NoteViewSet) + self.urls = self.router.urls + + def test_urls_can_have_trailing_slash_removed(self): + expected = ['^notes$', '^notes/(?P<pk>[^/]+)$'] + for idx in range(len(expected)): + self.assertEqual(expected[idx], self.urls[idx].regex.pattern) diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/test_serializer.py index 84e1ee4e..8b87a084 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/test_serializer.py @@ -1,10 +1,14 @@ from __future__ import unicode_literals -from django.utils.datastructures import MultiValueDict +from django.db import models +from django.db.models.fields import BLANK_CHOICE_DASH from django.test import TestCase -from rest_framework import serializers +from django.utils.datastructures import MultiValueDict +from django.utils.translation import ugettext_lazy as _ +from rest_framework import serializers, fields, relations from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel, BlankFieldModel, BlogPost, BlogPostComment, Book, CallableDefaultValueModel, DefaultValueModel, - ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo) + ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo, RESTFrameworkModel) +from rest_framework.tests.models import BasicModelSerializer import datetime import pickle @@ -43,6 +47,17 @@ class CommentSerializer(serializers.Serializer): return instance +class NamesSerializer(serializers.Serializer): + first = serializers.CharField() + last = serializers.CharField(required=False, default='') + initials = serializers.CharField(required=False, default='') + + +class PersonIdentifierSerializer(serializers.Serializer): + ssn = serializers.CharField() + names = NamesSerializer(source='names', required=False) + + class BookSerializer(serializers.ModelSerializer): isbn = serializers.RegexField(regex=r'^[0-9]{13}$', error_messages={'invalid': 'isbn has to be exact 13 numbers'}) @@ -78,6 +93,29 @@ class PersonSerializer(serializers.ModelSerializer): read_only_fields = ('age',) +class NestedSerializer(serializers.Serializer): + info = serializers.Field() + + +class ModelSerializerWithNestedSerializer(serializers.ModelSerializer): + nested = NestedSerializer(source='*') + + class Meta: + model = Person + + +class PersonSerializerInvalidReadOnly(serializers.ModelSerializer): + """ + Testing for #652. + """ + info = serializers.Field(source='info') + + class Meta: + model = Person + fields = ('name', 'age', 'info') + read_only_fields = ('age', 'info') + + class AlbumsSerializer(serializers.ModelSerializer): class Meta: @@ -91,11 +129,6 @@ class PositiveIntegerAsChoiceSerializer(serializers.ModelSerializer): fields = ['some_integer'] -class BrokenModelSerializer(serializers.ModelSerializer): - class Meta: - fields = ['some_field'] - - class BasicTests(TestCase): def setUp(self): self.comment = Comment( @@ -141,6 +174,42 @@ class BasicTests(TestCase): self.assertFalse(serializer.object is expected) self.assertEqual(serializer.data['sub_comment'], 'And Merry Christmas!') + def test_create_nested(self): + """Test a serializer with nested data.""" + names = {'first': 'John', 'last': 'Doe', 'initials': 'jd'} + data = {'ssn': '1234567890', 'names': names} + serializer = PersonIdentifierSerializer(data=data) + + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.object, data) + self.assertFalse(serializer.object is data) + self.assertEqual(serializer.data['names'], names) + + def test_create_partial_nested(self): + """Test a serializer with nested data which has missing fields.""" + names = {'first': 'John'} + data = {'ssn': '1234567890', 'names': names} + serializer = PersonIdentifierSerializer(data=data) + + expected_names = {'first': 'John', 'last': '', 'initials': ''} + data['names'] = expected_names + + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.object, data) + self.assertFalse(serializer.object is expected_names) + self.assertEqual(serializer.data['names'], expected_names) + + def test_null_nested(self): + """Test a serializer with a nonexistent nested field""" + data = {'ssn': '1234567890'} + serializer = PersonIdentifierSerializer(data=data) + + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.object, data) + self.assertFalse(serializer.object is data) + expected = {'ssn': '1234567890', 'names': None} + self.assertEqual(serializer.data, expected) + def test_update(self): serializer = CommentSerializer(self.comment, data=self.data) expected = self.comment @@ -189,6 +258,12 @@ class BasicTests(TestCase): # Assert age is unchanged (35) self.assertEqual(instance.age, self.person_data['age']) + def test_invalid_read_only_fields(self): + """ + Regression test for #652. + """ + self.assertRaises(AssertionError, PersonSerializerInvalidReadOnly, []) + class DictStyleSerializer(serializers.Serializer): """ @@ -344,19 +419,34 @@ class ValidationTests(TestCase): Assert that a meaningful exception message is outputted when the model field is missing (e.g. when mistyping ``model``). """ + class BrokenModelSerializer(serializers.ModelSerializer): + class Meta: + fields = ['some_field'] + try: - serializer = BrokenModelSerializer() + BrokenModelSerializer() except AssertionError as e: self.assertEqual(e.args[0], "Serializer class 'BrokenModelSerializer' is missing 'model' Meta option") except: self.fail('Wrong exception type thrown.') + def test_writable_star_source_on_nested_serializer(self): + """ + Assert that a nested serializer instantiated with source='*' correctly + expands the data into the outer serializer. + """ + serializer = ModelSerializerWithNestedSerializer(data={ + 'name': 'marko', + 'nested': {'info': 'hi'}}, + ) + self.assertEqual(serializer.is_valid(), True) + class CustomValidationTests(TestCase): class CommentSerializerWithFieldValidator(CommentSerializer): def validate_email(self, attrs, source): - value = attrs[source] + attrs[source] return attrs def validate_content(self, attrs, source): @@ -853,23 +943,6 @@ class RelatedTraversalTest(TestCase): self.assertEqual(serializer.data, expected) - def test_queryset_nested_traversal(self): - """ - Relational fields should be able to use methods as their source. - """ - BlogPost.objects.create(title='blah') - - class QuerysetMethodSerializer(serializers.Serializer): - blogposts = serializers.RelatedField(many=True, source='get_all_blogposts') - - class ClassWithQuerysetMethod(object): - def get_all_blogposts(self): - return BlogPost.objects - - obj = ClassWithQuerysetMethod() - serializer = QuerysetMethodSerializer(obj) - self.assertEqual(serializer.data, {'blogposts': ['BlogPost object']}) - class SerializerMethodFieldTests(TestCase): def setUp(self): @@ -1000,6 +1073,130 @@ class SerializerPickleTests(TestCase): repr(pickle.loads(pickle.dumps(data, 0))) +# test for issue #725 +class SeveralChoicesModel(models.Model): + color = models.CharField( + max_length=10, + choices=[('red', 'Red'), ('green', 'Green'), ('blue', 'Blue')], + blank=False + ) + drink = models.CharField( + max_length=10, + choices=[('beer', 'Beer'), ('wine', 'Wine'), ('cider', 'Cider')], + blank=False, + default='beer' + ) + os = models.CharField( + max_length=10, + choices=[('linux', 'Linux'), ('osx', 'OSX'), ('windows', 'Windows')], + blank=True + ) + music_genre = models.CharField( + max_length=10, + choices=[('rock', 'Rock'), ('metal', 'Metal'), ('grunge', 'Grunge')], + blank=True, + default='metal' + ) + + +class SerializerChoiceFields(TestCase): + + def setUp(self): + super(SerializerChoiceFields, self).setUp() + + class SeveralChoicesSerializer(serializers.ModelSerializer): + class Meta: + model = SeveralChoicesModel + fields = ('color', 'drink', 'os', 'music_genre') + + self.several_choices_serializer = SeveralChoicesSerializer + + def test_choices_blank_false_not_default(self): + serializer = self.several_choices_serializer() + self.assertEqual( + serializer.fields['color'].choices, + [('red', 'Red'), ('green', 'Green'), ('blue', 'Blue')] + ) + + def test_choices_blank_false_with_default(self): + serializer = self.several_choices_serializer() + self.assertEqual( + serializer.fields['drink'].choices, + [('beer', 'Beer'), ('wine', 'Wine'), ('cider', 'Cider')] + ) + + def test_choices_blank_true_not_default(self): + serializer = self.several_choices_serializer() + self.assertEqual( + serializer.fields['os'].choices, + BLANK_CHOICE_DASH + [('linux', 'Linux'), ('osx', 'OSX'), ('windows', 'Windows')] + ) + + def test_choices_blank_true_with_default(self): + serializer = self.several_choices_serializer() + self.assertEqual( + serializer.fields['music_genre'].choices, + BLANK_CHOICE_DASH + [('rock', 'Rock'), ('metal', 'Metal'), ('grunge', 'Grunge')] + ) + + +# Regression tests for #675 +class Ticket(models.Model): + assigned = models.ForeignKey( + Person, related_name='assigned_tickets') + reviewer = models.ForeignKey( + Person, blank=True, null=True, related_name='reviewed_tickets') + + +class SerializerRelatedChoicesTest(TestCase): + + def setUp(self): + super(SerializerRelatedChoicesTest, self).setUp() + + class RelatedChoicesSerializer(serializers.ModelSerializer): + class Meta: + model = Ticket + fields = ('assigned', 'reviewer') + + self.related_fields_serializer = RelatedChoicesSerializer + + def test_empty_queryset_required(self): + serializer = self.related_fields_serializer() + self.assertEqual(serializer.fields['assigned'].queryset.count(), 0) + self.assertEqual( + [x for x in serializer.fields['assigned'].widget.choices], + [] + ) + + def test_empty_queryset_not_required(self): + serializer = self.related_fields_serializer() + self.assertEqual(serializer.fields['reviewer'].queryset.count(), 0) + self.assertEqual( + [x for x in serializer.fields['reviewer'].widget.choices], + [('', '---------')] + ) + + def test_with_some_persons_required(self): + Person.objects.create(name="Lionel Messi") + Person.objects.create(name="Xavi Hernandez") + serializer = self.related_fields_serializer() + self.assertEqual(serializer.fields['assigned'].queryset.count(), 2) + self.assertEqual( + [x for x in serializer.fields['assigned'].widget.choices], + [(1, 'Person object - 1'), (2, 'Person object - 2')] + ) + + def test_with_some_persons_not_required(self): + Person.objects.create(name="Lionel Messi") + Person.objects.create(name="Xavi Hernandez") + serializer = self.related_fields_serializer() + self.assertEqual(serializer.fields['reviewer'].queryset.count(), 2) + self.assertEqual( + [x for x in serializer.fields['reviewer'].widget.choices], + [('', '---------'), (1, 'Person object - 1'), (2, 'Person object - 2')] + ) + + class DepthTest(TestCase): def test_implicit_nesting(self): @@ -1125,3 +1322,312 @@ class DeserializeListTestCase(TestCase): self.assertFalse(serializer.is_valid()) expected = [{}, {'email': ['This field is required.']}, {}] self.assertEqual(serializer.errors, expected) + + +# Test for issue 747 + +class LazyStringModel(object): + def __init__(self, lazystring): + self.lazystring = lazystring + + +class LazyStringSerializer(serializers.Serializer): + lazystring = serializers.Field() + + def restore_object(self, attrs, instance=None): + if instance is not None: + instance.lazystring = attrs.get('lazystring', instance.lazystring) + return instance + return LazyStringModel(**attrs) + + +class LazyStringsTestCase(TestCase): + def setUp(self): + self.model = LazyStringModel(lazystring=_('lazystring')) + + def test_lazy_strings_are_translated(self): + serializer = LazyStringSerializer(self.model) + self.assertEqual(type(serializer.data['lazystring']), + type('lazystring')) + + +# Test for issue #467 + +class FieldLabelTest(TestCase): + def setUp(self): + self.serializer_class = BasicModelSerializer + + def test_label_from_model(self): + """ + Validates that label and help_text are correctly copied from the model class. + """ + serializer = self.serializer_class() + text_field = serializer.fields['text'] + + self.assertEqual('Text comes here', text_field.label) + self.assertEqual('Text description.', text_field.help_text) + + def test_field_ctor(self): + """ + This is check that ctor supports both label and help_text. + """ + self.assertEqual('Label', fields.Field(label='Label', help_text='Help').label) + self.assertEqual('Help', fields.CharField(label='Label', help_text='Help').help_text) + self.assertEqual('Label', relations.HyperlinkedRelatedField(view_name='fake', label='Label', help_text='Help', many=True).label) + + +class AttributeMappingOnAutogeneratedFieldsTests(TestCase): + + def setUp(self): + class AMOAFModel(RESTFrameworkModel): + char_field = models.CharField(max_length=1024, blank=True) + comma_separated_integer_field = models.CommaSeparatedIntegerField(max_length=1024, blank=True) + decimal_field = models.DecimalField(max_digits=64, decimal_places=32, blank=True) + email_field = models.EmailField(max_length=1024, blank=True) + file_field = models.FileField(max_length=1024, blank=True) + image_field = models.ImageField(max_length=1024, blank=True) + slug_field = models.SlugField(max_length=1024, blank=True) + url_field = models.URLField(max_length=1024, blank=True) + + class AMOAFSerializer(serializers.ModelSerializer): + class Meta: + model = AMOAFModel + + self.serializer_class = AMOAFSerializer + self.fields_attributes = { + 'char_field': [ + ('max_length', 1024), + ], + 'comma_separated_integer_field': [ + ('max_length', 1024), + ], + 'decimal_field': [ + ('max_digits', 64), + ('decimal_places', 32), + ], + 'email_field': [ + ('max_length', 1024), + ], + 'file_field': [ + ('max_length', 1024), + ], + 'image_field': [ + ('max_length', 1024), + ], + 'slug_field': [ + ('max_length', 1024), + ], + 'url_field': [ + ('max_length', 1024), + ], + } + + def field_test(self, field): + serializer = self.serializer_class(data={}) + self.assertEqual(serializer.is_valid(), True) + + for attribute in self.fields_attributes[field]: + self.assertEqual( + getattr(serializer.fields[field], attribute[0]), + attribute[1] + ) + + def test_char_field(self): + self.field_test('char_field') + + def test_comma_separated_integer_field(self): + self.field_test('comma_separated_integer_field') + + def test_decimal_field(self): + self.field_test('decimal_field') + + def test_email_field(self): + self.field_test('email_field') + + def test_file_field(self): + self.field_test('file_field') + + def test_image_field(self): + self.field_test('image_field') + + def test_slug_field(self): + self.field_test('slug_field') + + def test_url_field(self): + self.field_test('url_field') + + +class DefaultValuesOnAutogeneratedFieldsTests(TestCase): + + def setUp(self): + class DVOAFModel(RESTFrameworkModel): + positive_integer_field = models.PositiveIntegerField(blank=True) + positive_small_integer_field = models.PositiveSmallIntegerField(blank=True) + email_field = models.EmailField(blank=True) + file_field = models.FileField(blank=True) + image_field = models.ImageField(blank=True) + slug_field = models.SlugField(blank=True) + url_field = models.URLField(blank=True) + + class DVOAFSerializer(serializers.ModelSerializer): + class Meta: + model = DVOAFModel + + self.serializer_class = DVOAFSerializer + self.fields_attributes = { + 'positive_integer_field': [ + ('min_value', 0), + ], + 'positive_small_integer_field': [ + ('min_value', 0), + ], + 'email_field': [ + ('max_length', 75), + ], + 'file_field': [ + ('max_length', 100), + ], + 'image_field': [ + ('max_length', 100), + ], + 'slug_field': [ + ('max_length', 50), + ], + 'url_field': [ + ('max_length', 200), + ], + } + + def field_test(self, field): + serializer = self.serializer_class(data={}) + self.assertEqual(serializer.is_valid(), True) + + for attribute in self.fields_attributes[field]: + self.assertEqual( + getattr(serializer.fields[field], attribute[0]), + attribute[1] + ) + + def test_positive_integer_field(self): + self.field_test('positive_integer_field') + + def test_positive_small_integer_field(self): + self.field_test('positive_small_integer_field') + + def test_email_field(self): + self.field_test('email_field') + + def test_file_field(self): + self.field_test('file_field') + + def test_image_field(self): + self.field_test('image_field') + + def test_slug_field(self): + self.field_test('slug_field') + + def test_url_field(self): + self.field_test('url_field') + + +class MetadataSerializer(serializers.Serializer): + field1 = serializers.CharField(3, required=True) + field2 = serializers.CharField(10, required=False) + + +class MetadataSerializerTestCase(TestCase): + def setUp(self): + self.serializer = MetadataSerializer() + + def test_serializer_metadata(self): + metadata = self.serializer.metadata() + expected = { + 'field1': { + 'required': True, + 'max_length': 3, + 'type': 'string', + 'read_only': False + }, + 'field2': { + 'required': False, + 'max_length': 10, + 'type': 'string', + 'read_only': False + } + } + self.assertEqual(expected, metadata) + + +### Regression test for #840 + +class SimpleModel(models.Model): + text = models.CharField(max_length=100) + + +class SimpleModelSerializer(serializers.ModelSerializer): + text = serializers.CharField() + other = serializers.CharField() + + class Meta: + model = SimpleModel + + def validate_other(self, attrs, source): + del attrs['other'] + return attrs + + +class FieldValidationRemovingAttr(TestCase): + def test_removing_non_model_field_in_validation(self): + """ + Removing an attr during field valiation should ensure that it is not + passed through when restoring the object. + + This allows additional non-model fields to be supported. + + Regression test for #840. + """ + serializer = SimpleModelSerializer(data={'text': 'foo', 'other': 'bar'}) + self.assertTrue(serializer.is_valid()) + serializer.save() + self.assertEqual(serializer.object.text, 'foo') + + +### Regression test for #878 + +class SimpleTargetModel(models.Model): + text = models.CharField(max_length=100) + + +class SimplePKSourceModelSerializer(serializers.Serializer): + targets = serializers.PrimaryKeyRelatedField(queryset=SimpleTargetModel.objects.all(), many=True) + text = serializers.CharField() + + +class SimpleSlugSourceModelSerializer(serializers.Serializer): + targets = serializers.SlugRelatedField(queryset=SimpleTargetModel.objects.all(), many=True, slug_field='pk') + text = serializers.CharField() + + +class SerializerSupportsManyRelationships(TestCase): + def setUp(self): + SimpleTargetModel.objects.create(text='foo') + SimpleTargetModel.objects.create(text='bar') + + def test_serializer_supports_pk_many_relationships(self): + """ + Regression test for #878. + + Note that pk behavior has a different code path to usual cases, + for performance reasons. + """ + serializer = SimplePKSourceModelSerializer(data={'text': 'foo', 'targets': [1, 2]}) + self.assertTrue(serializer.is_valid()) + self.assertEqual(serializer.data, {'text': 'foo', 'targets': [1, 2]}) + + def test_serializer_supports_slug_many_relationships(self): + """ + Regression test for #878. + """ + serializer = SimpleSlugSourceModelSerializer(data={'text': 'foo', 'targets': [1, 2]}) + self.assertTrue(serializer.is_valid()) + self.assertEqual(serializer.data, {'text': 'foo', 'targets': [1, 2]}) diff --git a/rest_framework/tests/serializer_bulk_update.py b/rest_framework/tests/test_serializer_bulk_update.py index 8b0ded1a..8b0ded1a 100644 --- a/rest_framework/tests/serializer_bulk_update.py +++ b/rest_framework/tests/test_serializer_bulk_update.py diff --git a/rest_framework/tests/serializer_nested.py b/rest_framework/tests/test_serializer_nested.py index 71d0e24b..71d0e24b 100644 --- a/rest_framework/tests/serializer_nested.py +++ b/rest_framework/tests/test_serializer_nested.py diff --git a/rest_framework/tests/settings.py b/rest_framework/tests/test_settings.py index 857375c2..857375c2 100644 --- a/rest_framework/tests/settings.py +++ b/rest_framework/tests/test_settings.py diff --git a/rest_framework/tests/throttling.py b/rest_framework/tests/test_throttling.py index 11cbd8eb..da400b2f 100644 --- a/rest_framework/tests/throttling.py +++ b/rest_framework/tests/test_throttling.py @@ -36,7 +36,7 @@ class MockView_MinuteThrottling(APIView): class ThrottlingTests(TestCase): - urls = 'rest_framework.tests.throttling' + urls = 'rest_framework.tests.test_throttling' def setUp(self): """ diff --git a/rest_framework/tests/urlpatterns.py b/rest_framework/tests/test_urlpatterns.py index 29ed4a96..29ed4a96 100644 --- a/rest_framework/tests/urlpatterns.py +++ b/rest_framework/tests/test_urlpatterns.py diff --git a/rest_framework/tests/validation.py b/rest_framework/tests/test_validation.py index cbdd6515..a6ec0e99 100644 --- a/rest_framework/tests/validation.py +++ b/rest_framework/tests/test_validation.py @@ -63,3 +63,25 @@ class TestPreSaveValidationExclusions(TestCase): # does not have `blank=True`, so this serializer should not validate. serializer = ShouldValidateModelSerializer(data={'renamed': ''}) self.assertEqual(serializer.is_valid(), False) + + +class ValidationSerializer(serializers.Serializer): + foo = serializers.CharField() + + def validate_foo(self, attrs, source): + raise serializers.ValidationError("foo invalid") + + def validate(self, attrs): + raise serializers.ValidationError("serializer invalid") + + +class TestAvoidValidation(TestCase): + """ + If serializer was initialized with invalid data (None or non dict-like), it + should avoid validation layer (validate_<field> and validate methods) + """ + def test_serializer_errors_has_only_invalid_data_error(self): + serializer = ValidationSerializer(data='invalid data') + self.assertFalse(serializer.is_valid()) + self.assertDictEqual(serializer.errors, + {'non_field_errors': ['Invalid data']}) diff --git a/rest_framework/tests/views.py b/rest_framework/tests/test_views.py index 994cf6dc..2767d24c 100644 --- a/rest_framework/tests/views.py +++ b/rest_framework/tests/test_views.py @@ -1,12 +1,15 @@ from __future__ import unicode_literals + +import copy + from django.test import TestCase from django.test.client import RequestFactory + from rest_framework import status from rest_framework.decorators import api_view from rest_framework.response import Response from rest_framework.settings import api_settings from rest_framework.views import APIView -import copy factory = RequestFactory() diff --git a/rest_framework/tests/testcases.py b/rest_framework/tests/testcases.py deleted file mode 100644 index f8c2579e..00000000 --- a/rest_framework/tests/testcases.py +++ /dev/null @@ -1,66 +0,0 @@ -# http://djangosnippets.org/snippets/1011/ -from __future__ import unicode_literals -from django.conf import settings -from django.core.management import call_command -from django.db.models import loading -from django.test import TestCase - -NO_SETTING = ('!', None) - - -class TestSettingsManager(object): - """ - A class which can modify some Django settings temporarily for a - test and then revert them to their original values later. - - Automatically handles resyncing the DB if INSTALLED_APPS is - modified. - - """ - def __init__(self): - self._original_settings = {} - - def set(self, **kwargs): - for k, v in kwargs.iteritems(): - self._original_settings.setdefault(k, getattr(settings, k, - NO_SETTING)) - setattr(settings, k, v) - if 'INSTALLED_APPS' in kwargs: - self.syncdb() - - def syncdb(self): - loading.cache.loaded = False - call_command('syncdb', verbosity=0) - - def revert(self): - for k, v in self._original_settings.iteritems(): - if v == NO_SETTING: - delattr(settings, k) - else: - setattr(settings, k, v) - if 'INSTALLED_APPS' in self._original_settings: - self.syncdb() - self._original_settings = {} - - -class SettingsTestCase(TestCase): - """ - A subclass of the Django TestCase with a settings_manager - attribute which is an instance of TestSettingsManager. - - Comes with a tearDown() method that calls - self.settings_manager.revert(). - - """ - def __init__(self, *args, **kwargs): - super(SettingsTestCase, self).__init__(*args, **kwargs) - self.settings_manager = TestSettingsManager() - - def tearDown(self): - self.settings_manager.revert() - - -class TestModelsTestCase(SettingsTestCase): - def setUp(self, *args, **kwargs): - installed_apps = tuple(settings.INSTALLED_APPS) + ('rest_framework.tests',) - self.settings_manager.set(INSTALLED_APPS=installed_apps) diff --git a/rest_framework/tests/tests.py b/rest_framework/tests/tests.py index 08f88e11..554ebd1a 100644 --- a/rest_framework/tests/tests.py +++ b/rest_framework/tests/tests.py @@ -4,11 +4,13 @@ runner to pick up the tests. Yowzers. """ from __future__ import unicode_literals import os +import django modules = [filename.rsplit('.', 1)[0] for filename in os.listdir(os.path.dirname(__file__)) if filename.endswith('.py') and not filename.startswith('_')] __test__ = dict() -for module in modules: - exec("from rest_framework.tests.%s import *" % module) +if django.VERSION < (1, 6): + for module in modules: + exec("from rest_framework.tests.%s import *" % module) |
