diff options
Diffstat (limited to 'rest_framework/tests')
32 files changed, 855 insertions, 85 deletions
| diff --git a/rest_framework/tests/accounts/__init__.py b/rest_framework/tests/accounts/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/rest_framework/tests/accounts/__init__.py diff --git a/rest_framework/tests/accounts/models.py b/rest_framework/tests/accounts/models.py new file mode 100644 index 00000000..525e601b --- /dev/null +++ b/rest_framework/tests/accounts/models.py @@ -0,0 +1,8 @@ +from django.db import models + +from rest_framework.tests.users.models import User + + +class Account(models.Model): +    owner = models.ForeignKey(User, related_name='accounts_owned') +    admins = models.ManyToManyField(User, blank=True, null=True, related_name='accounts_administered') diff --git a/rest_framework/tests/accounts/serializers.py b/rest_framework/tests/accounts/serializers.py new file mode 100644 index 00000000..a27b9ca6 --- /dev/null +++ b/rest_framework/tests/accounts/serializers.py @@ -0,0 +1,11 @@ +from rest_framework import serializers + +from rest_framework.tests.accounts.models import Account +from rest_framework.tests.users.serializers import UserSerializer + + +class AccountSerializer(serializers.ModelSerializer): +    admins = UserSerializer(many=True) + +    class Meta: +        model = Account diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py index 32a726c0..6c8f2342 100644 --- a/rest_framework/tests/models.py +++ b/rest_framework/tests/models.py @@ -103,7 +103,7 @@ class BlogPostComment(RESTFrameworkModel):  class Album(RESTFrameworkModel):      title = models.CharField(max_length=100, unique=True) - +    ref = models.CharField(max_length=10, unique=True, null=True, blank=True)  class Photo(RESTFrameworkModel):      description = models.TextField() @@ -168,3 +168,10 @@ class NullableOneToOneSource(RESTFrameworkModel):  class BasicModelSerializer(serializers.ModelSerializer):      class Meta:          model = BasicModel + + +# Models to test filters +class FilterableItem(models.Model): +    text = models.CharField(max_length=100) +    decimal = models.DecimalField(max_digits=4, decimal_places=2) +    date = models.DateField() diff --git a/rest_framework/tests/records/__init__.py b/rest_framework/tests/records/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/rest_framework/tests/records/__init__.py diff --git a/rest_framework/tests/records/models.py b/rest_framework/tests/records/models.py new file mode 100644 index 00000000..76954807 --- /dev/null +++ b/rest_framework/tests/records/models.py @@ -0,0 +1,6 @@ +from django.db import models + + +class Record(models.Model): +    account = models.ForeignKey('accounts.Account', blank=True, null=True) +    owner = models.ForeignKey('users.User', blank=True, null=True) diff --git a/rest_framework/tests/serializers.py b/rest_framework/tests/serializers.py new file mode 100644 index 00000000..cc943c7d --- /dev/null +++ b/rest_framework/tests/serializers.py @@ -0,0 +1,8 @@ +from rest_framework import serializers + +from rest_framework.tests.models import NullableForeignKeySource + + +class NullableFKSourceSerializer(serializers.ModelSerializer): +    class Meta: +        model = NullableForeignKeySource diff --git a/rest_framework/tests/test_authentication.py b/rest_framework/tests/test_authentication.py index fb0bc694..6c14debb 100644 --- a/rest_framework/tests/test_authentication.py +++ b/rest_framework/tests/test_authentication.py @@ -4,6 +4,7 @@ from django.contrib.auth.models import User  from django.http import HttpResponse  from django.test import TestCase  from django.utils import unittest +from django.utils.http import urlencode  from rest_framework import HTTP_HEADER_ENCODING  from rest_framework import exceptions  from rest_framework import permissions @@ -19,7 +20,7 @@ from rest_framework.authentication import (      OAuth2Authentication  )  from rest_framework.authtoken.models import Token -from rest_framework.compat import oauth2_provider, oauth2_provider_models, oauth2_provider_scope +from rest_framework.compat import oauth2_provider, oauth2_provider_scope  from rest_framework.compat import oauth, oauth_provider  from rest_framework.test import APIRequestFactory, APIClient  from rest_framework.views import APIView @@ -53,10 +54,14 @@ urlpatterns = patterns('',          permission_classes=[permissions.TokenHasReadWriteScope]))  ) +class OAuth2AuthenticationDebug(OAuth2Authentication): +    allow_query_params_token = True +  if oauth2_provider is not None:      urlpatterns += patterns('',          url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')),          url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])), +        url(r'^oauth2-test-debug/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationDebug])),          url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication],              permission_classes=[permissions.TokenHasReadWriteScope])),      ) @@ -488,7 +493,7 @@ class OAuth2Tests(TestCase):          self.ACCESS_TOKEN = "access_token"          self.REFRESH_TOKEN = "refresh_token" -        self.oauth2_client = oauth2_provider_models.Client.objects.create( +        self.oauth2_client = oauth2_provider.oauth2.models.Client.objects.create(                  client_id=self.CLIENT_ID,                  client_secret=self.CLIENT_SECRET,                  redirect_uri='', @@ -497,12 +502,12 @@ class OAuth2Tests(TestCase):                  user=None,              ) -        self.access_token = oauth2_provider_models.AccessToken.objects.create( +        self.access_token = oauth2_provider.oauth2.models.AccessToken.objects.create(                  token=self.ACCESS_TOKEN,                  client=self.oauth2_client,                  user=self.user,              ) -        self.refresh_token = oauth2_provider_models.RefreshToken.objects.create( +        self.refresh_token = oauth2_provider.oauth2.models.RefreshToken.objects.create(                  user=self.user,                  access_token=self.access_token,                  client=self.oauth2_client @@ -546,6 +551,27 @@ class OAuth2Tests(TestCase):          self.assertEqual(response.status_code, 200)      @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') +    def test_post_form_passing_auth_url_transport(self): +        """Ensure GETing form over OAuth with correct client credentials in form data succeed""" +        response = self.csrf_client.post('/oauth2-test/', +                data={'access_token': self.access_token.token}) +        self.assertEqual(response.status_code, 200) + +    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') +    def test_get_form_passing_auth_url_transport(self): +        """Ensure GETing form over OAuth with correct client credentials in query succeed when DEBUG is True""" +        query = urlencode({'access_token': self.access_token.token}) +        response = self.csrf_client.get('/oauth2-test-debug/?%s' % query) +        self.assertEqual(response.status_code, 200) + +    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') +    def test_get_form_failing_auth_url_transport(self): +        """Ensure GETing form over OAuth with correct client credentials in query fails when DEBUG is False""" +        query = urlencode({'access_token': self.access_token.token}) +        response = self.csrf_client.get('/oauth2-test/?%s' % query) +        self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + +    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')      def test_post_form_passing_auth(self):          """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""          auth = self._create_authorization_header() diff --git a/rest_framework/tests/test_fields.py b/rest_framework/tests/test_fields.py index 5c96bce9..e127feef 100644 --- a/rest_framework/tests/test_fields.py +++ b/rest_framework/tests/test_fields.py @@ -860,7 +860,9 @@ class SlugFieldTests(TestCase):  class URLFieldTests(TestCase):      """ -    Tests for URLField attribute values +    Tests for URLField attribute values. + +    (Includes test for #1210, checking that validators can be overridden.)      """      class URLFieldModel(RESTFrameworkModel): @@ -902,6 +904,11 @@ class URLFieldTests(TestCase):          self.assertEqual(getattr(serializer.fields['url_field'],                           'max_length'), 20) +    def test_validators_can_be_overridden(self): +        url_field = serializers.URLField(validators=[]) +        validators = url_field.validators +        self.assertEqual([], validators, 'Passing `validators` kwarg should have overridden default validators') +  class FieldMetadata(TestCase):      def setUp(self): diff --git a/rest_framework/tests/test_filters.py b/rest_framework/tests/test_filters.py index 8a03a077..2aa6f81a 100644 --- a/rest_framework/tests/test_filters.py +++ b/rest_framework/tests/test_filters.py @@ -1,25 +1,21 @@  from __future__ import unicode_literals  import datetime  from decimal import Decimal -from django.conf.urls import patterns, url  from django.db import models  from django.core.urlresolvers import reverse  from django.test import TestCase  from django.utils import unittest +from django.conf.urls import patterns, url  from rest_framework import generics, serializers, status, filters  from rest_framework.compat import django_filters  from rest_framework.test import APIRequestFactory  from rest_framework.tests.models import BasicModel +from .models import FilterableItem +from .utils import temporary_setting  factory = APIRequestFactory() -class FilterableItem(models.Model): -    text = models.CharField(max_length=100) -    decimal = models.DecimalField(max_digits=4, decimal_places=2) -    date = models.DateField() - -  if django_filters:      # Basic filter on a list view.      class FilterFieldsRootView(generics.ListCreateAPIView): @@ -129,7 +125,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):          # Tests that the decimal filter works.          search_decimal = Decimal('2.25') -        request = factory.get('/?decimal=%s' % search_decimal) +        request = factory.get('/', {'decimal': '%s' % search_decimal})          response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK)          expected_data = [f for f in self.data if f['decimal'] == search_decimal] @@ -137,7 +133,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):          # Tests that the date filter works.          search_date = datetime.date(2012, 9, 22) -        request = factory.get('/?date=%s' % search_date)  # search_date str: '2012-09-22' +        request = factory.get('/', {'date': '%s' % search_date})  # search_date str: '2012-09-22'          response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK)          expected_data = [f for f in self.data if f['date'] == search_date] @@ -152,7 +148,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):          # Tests that the decimal filter works.          search_decimal = Decimal('2.25') -        request = factory.get('/?decimal=%s' % search_decimal) +        request = factory.get('/', {'decimal': '%s' % search_decimal})          response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK)          expected_data = [f for f in self.data if f['decimal'] == search_decimal] @@ -185,7 +181,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):          # Tests that the decimal filter set with 'lt' in the filter class works.          search_decimal = Decimal('4.25') -        request = factory.get('/?decimal=%s' % search_decimal) +        request = factory.get('/', {'decimal': '%s' % search_decimal})          response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK)          expected_data = [f for f in self.data if f['decimal'] < search_decimal] @@ -193,7 +189,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):          # Tests that the date filter set with 'gt' in the filter class works.          search_date = datetime.date(2012, 10, 2) -        request = factory.get('/?date=%s' % search_date)  # search_date str: '2012-10-02' +        request = factory.get('/', {'date': '%s' % search_date})  # search_date str: '2012-10-02'          response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK)          expected_data = [f for f in self.data if f['date'] > search_date] @@ -201,7 +197,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):          # Tests that the text filter set with 'icontains' in the filter class works.          search_text = 'ff' -        request = factory.get('/?text=%s' % search_text) +        request = factory.get('/', {'text': '%s' % search_text})          response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK)          expected_data = [f for f in self.data if search_text in f['text'].lower()] @@ -210,7 +206,10 @@ class IntegrationTestFiltering(CommonFilteringTestCase):          # Tests that multiple filters works.          search_decimal = Decimal('5.25')          search_date = datetime.date(2012, 10, 2) -        request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date)) +        request = factory.get('/', { +            'decimal': '%s' % (search_decimal,), +            'date': '%s' % (search_date,) +        })          response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK)          expected_data = [f for f in self.data if f['date'] > search_date and @@ -235,7 +234,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):          view = FilterFieldsRootView.as_view()          search_integer = 10 -        request = factory.get('/?integer=%s' % search_integer) +        request = factory.get('/', {'integer': '%s' % search_integer})          response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -266,14 +265,18 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase):          # Tests that the decimal filter set that should fail.          search_decimal = Decimal('4.25')          high_item = self.objects.filter(decimal__gt=search_decimal)[0] -        response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(high_item), param=search_decimal)) +        response = self.client.get( +            '{url}'.format(url=self._get_url(high_item)), +            {'decimal': '{param}'.format(param=search_decimal)})          self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)          # Tests that the decimal filter set that should succeed.          search_decimal = Decimal('4.25')          low_item = self.objects.filter(decimal__lt=search_decimal)[0]          low_item_data = self._serialize_object(low_item) -        response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal)) +        response = self.client.get( +            '{url}'.format(url=self._get_url(low_item)), +            {'decimal': '{param}'.format(param=search_decimal)})          self.assertEqual(response.status_code, status.HTTP_200_OK)          self.assertEqual(response.data, low_item_data) @@ -282,7 +285,11 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase):          search_date = datetime.date(2012, 10, 2)          valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0]          valid_item_data = self._serialize_object(valid_item) -        response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date)) +        response = self.client.get( +            '{url}'.format(url=self._get_url(valid_item)), { +                'decimal': '{decimal}'.format(decimal=search_decimal), +                'date': '{date}'.format(date=search_date) +            })          self.assertEqual(response.status_code, status.HTTP_200_OK)          self.assertEqual(response.data, valid_item_data) @@ -316,7 +323,7 @@ class SearchFilterTests(TestCase):              search_fields = ('title', 'text')          view = SearchListView.as_view() -        request = factory.get('?search=b') +        request = factory.get('/', {'search': 'b'})          response = view(request)          self.assertEqual(              response.data, @@ -333,7 +340,7 @@ class SearchFilterTests(TestCase):              search_fields = ('=title', 'text')          view = SearchListView.as_view() -        request = factory.get('?search=zzz') +        request = factory.get('/', {'search': 'zzz'})          response = view(request)          self.assertEqual(              response.data, @@ -349,7 +356,7 @@ class SearchFilterTests(TestCase):              search_fields = ('title', '^text')          view = SearchListView.as_view() -        request = factory.get('?search=b') +        request = factory.get('/', {'search': 'b'})          response = view(request)          self.assertEqual(              response.data, @@ -358,6 +365,24 @@ class SearchFilterTests(TestCase):              ]          ) +    def test_search_with_nonstandard_search_param(self): +        with temporary_setting('SEARCH_PARAM', 'query', module=filters): +            class SearchListView(generics.ListAPIView): +                model = SearchFilterModel +                filter_backends = (filters.SearchFilter,) +                search_fields = ('title', 'text') + +            view = SearchListView.as_view() +            request = factory.get('/', {'query': 'b'}) +            response = view(request) +            self.assertEqual( +                response.data, +                [ +                    {'id': 1, 'title': 'z', 'text': 'abc'}, +                    {'id': 2, 'title': 'zz', 'text': 'bcd'} +                ] +            ) +  class OrdringFilterModel(models.Model):      title = models.CharField(max_length=20) @@ -369,7 +394,6 @@ class OrderingFilterRelatedModel(models.Model):                                         related_name="relateds") -  class OrderingFilterTests(TestCase):      def setUp(self):          # Sequence of title/text is: @@ -395,9 +419,10 @@ class OrderingFilterTests(TestCase):              model = OrdringFilterModel              filter_backends = (filters.OrderingFilter,)              ordering = ('title',) +            ordering_fields = ('text',)          view = OrderingListView.as_view() -        request = factory.get('?ordering=text') +        request = factory.get('/', {'ordering': 'text'})          response = view(request)          self.assertEqual(              response.data, @@ -413,9 +438,10 @@ class OrderingFilterTests(TestCase):              model = OrdringFilterModel              filter_backends = (filters.OrderingFilter,)              ordering = ('title',) +            ordering_fields = ('text',)          view = OrderingListView.as_view() -        request = factory.get('?ordering=-text') +        request = factory.get('/', {'ordering': '-text'})          response = view(request)          self.assertEqual(              response.data, @@ -431,9 +457,10 @@ class OrderingFilterTests(TestCase):              model = OrdringFilterModel              filter_backends = (filters.OrderingFilter,)              ordering = ('title',) +            ordering_fields = ('text',)          view = OrderingListView.as_view() -        request = factory.get('?ordering=foobar') +        request = factory.get('/', {'ordering': 'foobar'})          response = view(request)          self.assertEqual(              response.data, @@ -449,6 +476,7 @@ class OrderingFilterTests(TestCase):              model = OrdringFilterModel              filter_backends = (filters.OrderingFilter,)              ordering = ('title',) +            oredering_fields = ('text',)          view = OrderingListView.as_view()          request = factory.get('') @@ -467,6 +495,7 @@ class OrderingFilterTests(TestCase):              model = OrdringFilterModel              filter_backends = (filters.OrderingFilter,)              ordering = 'title' +            ordering_fields = ('text',)          view = OrderingListView.as_view()          request = factory.get('') @@ -495,11 +524,12 @@ class OrderingFilterTests(TestCase):              model = OrdringFilterModel              filter_backends = (filters.OrderingFilter,)              ordering = 'title' +            ordering_fields = '__all__'              queryset = OrdringFilterModel.objects.all().annotate(                  models.Count("relateds"))          view = OrderingListView.as_view() -        request = factory.get('?ordering=relateds__count') +        request = factory.get('/', {'ordering': 'relateds__count'})          response = view(request)          self.assertEqual(              response.data, @@ -510,5 +540,122 @@ class OrderingFilterTests(TestCase):              ]          ) +    def test_ordering_with_nonstandard_ordering_param(self): +        with temporary_setting('ORDERING_PARAM', 'order', filters): +            class OrderingListView(generics.ListAPIView): +                model = OrdringFilterModel +                filter_backends = (filters.OrderingFilter,) +                ordering = ('title',) +                ordering_fields = ('text',) + +            view = OrderingListView.as_view() +            request = factory.get('/', {'order': 'text'}) +            response = view(request) +            self.assertEqual( +                response.data, +                [ +                    {'id': 1, 'title': 'zyx', 'text': 'abc'}, +                    {'id': 2, 'title': 'yxw', 'text': 'bcd'}, +                    {'id': 3, 'title': 'xwv', 'text': 'cde'}, +                ] +            ) + + +class SensitiveOrderingFilterModel(models.Model): +    username = models.CharField(max_length=20) +    password = models.CharField(max_length=100) + + +# Three different styles of serializer. +# All should allow ordering by username, but not by password. +class SensitiveDataSerializer1(serializers.ModelSerializer): +    username = serializers.CharField() +    class Meta: +        model = SensitiveOrderingFilterModel +        fields = ('id', 'username') + +class SensitiveDataSerializer2(serializers.ModelSerializer): +    username = serializers.CharField() +    password = serializers.CharField(write_only=True) + +    class Meta: +        model = SensitiveOrderingFilterModel +        fields = ('id', 'username', 'password') + + +class SensitiveDataSerializer3(serializers.ModelSerializer): +    user = serializers.CharField(source='username') + +    class Meta: +        model = SensitiveOrderingFilterModel +        fields = ('id', 'user') + + +class SensitiveOrderingFilterTests(TestCase): +    def setUp(self): +        for idx in range(3): +            username = {0: 'userA', 1: 'userB', 2: 'userC'}[idx] +            password = {0: 'passA', 1: 'passC', 2: 'passB'}[idx] +            SensitiveOrderingFilterModel(username=username, password=password).save() + +    def test_order_by_serializer_fields(self): +        for serializer_cls in [ +            SensitiveDataSerializer1, +            SensitiveDataSerializer2, +            SensitiveDataSerializer3 +        ]: +            class OrderingListView(generics.ListAPIView): +                queryset = SensitiveOrderingFilterModel.objects.all().order_by('username') +                filter_backends = (filters.OrderingFilter,) +                serializer_class = serializer_cls + +            view = OrderingListView.as_view() +            request = factory.get('/', {'ordering': '-username'}) +            response = view(request) + +            if serializer_cls == SensitiveDataSerializer3: +                username_field = 'user' +            else: +                username_field = 'username' + +            # Note: Inverse username ordering correctly applied. +            self.assertEqual( +                response.data, +                [ +                    {'id': 3, username_field: 'userC'}, +                    {'id': 2, username_field: 'userB'}, +                    {'id': 1, username_field: 'userA'}, +                ] +            ) + +    def test_cannot_order_by_non_serializer_fields(self): +        for serializer_cls in [ +            SensitiveDataSerializer1, +            SensitiveDataSerializer2, +            SensitiveDataSerializer3 +        ]: +            class OrderingListView(generics.ListAPIView): +                queryset = SensitiveOrderingFilterModel.objects.all().order_by('username') +                filter_backends = (filters.OrderingFilter,) +                serializer_class = serializer_cls + +            view = OrderingListView.as_view() +            request = factory.get('/', {'ordering': 'password'}) +            response = view(request) + +            if serializer_cls == SensitiveDataSerializer3: +                username_field = 'user' +            else: +                username_field = 'username' + +            # Note: The passwords are not in order.  Default ordering is used. +            self.assertEqual( +                response.data, +                [ +                    {'id': 1, username_field: 'userA'}, # PassB +                    {'id': 2, username_field: 'userB'}, # PassC +                    {'id': 3, username_field: 'userC'}, # PassA +                ] +            ) diff --git a/rest_framework/tests/test_genericrelations.py b/rest_framework/tests/test_genericrelations.py index c38bfb9f..fa09c9e6 100644 --- a/rest_framework/tests/test_genericrelations.py +++ b/rest_framework/tests/test_genericrelations.py @@ -4,8 +4,10 @@ from django.contrib.contenttypes.generic import GenericRelation, GenericForeignK  from django.db import models  from django.test import TestCase  from rest_framework import serializers +from rest_framework.compat import python_2_unicode_compatible +@python_2_unicode_compatible  class Tag(models.Model):      """      Tags have a descriptive slug, and are attached to an arbitrary object. @@ -15,10 +17,11 @@ class Tag(models.Model):      object_id = models.PositiveIntegerField()      tagged_item = GenericForeignKey('content_type', 'object_id') -    def __unicode__(self): +    def __str__(self):          return self.tag +@python_2_unicode_compatible  class Bookmark(models.Model):      """      A URL bookmark that may have multiple tags attached. @@ -26,10 +29,11 @@ class Bookmark(models.Model):      url = models.URLField()      tags = GenericRelation(Tag) -    def __unicode__(self): +    def __str__(self):          return 'Bookmark: %s' % self.url +@python_2_unicode_compatible  class Note(models.Model):      """      A textual note that may have multiple tags attached. @@ -37,7 +41,7 @@ class Note(models.Model):      text = models.TextField()      tags = GenericRelation(Tag) -    def __unicode__(self): +    def __str__(self):          return 'Note: %s' % self.text @@ -69,6 +73,35 @@ class TestGenericRelations(TestCase):          }          self.assertEqual(serializer.data, expected) +    def test_generic_nested_relation(self): +        """ +        Test saving a GenericRelation field via a nested serializer. +        """ + +        class TagSerializer(serializers.ModelSerializer): +            class Meta: +                model = Tag +                exclude = ('content_type', 'object_id') + +        class BookmarkSerializer(serializers.ModelSerializer): +            tags = TagSerializer() + +            class Meta: +                model = Bookmark +                exclude = ('id',) + +        data = { +            'url': 'https://docs.djangoproject.com/', +            'tags': [ +                {'tag': 'contenttypes'}, +                {'tag': 'genericrelations'}, +            ] +        } +        serializer = BookmarkSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        serializer.save() +        self.assertEqual(serializer.object.tags.count(), 2) +      def test_generic_fk(self):          """          Test a relationship that spans a GenericForeignKey field. diff --git a/rest_framework/tests/test_htmlrenderer.py b/rest_framework/tests/test_htmlrenderer.py index 6c570dfd..1cbca04c 100644 --- a/rest_framework/tests/test_htmlrenderer.py +++ b/rest_framework/tests/test_htmlrenderer.py @@ -50,7 +50,7 @@ class TemplateHTMLRendererTests(TestCase):          """          self.get_template = django.template.loader.get_template -        def get_template(template_name): +        def get_template(template_name, dirs=None):              if template_name == 'example.html':                  return Template("example: {{ object }}")              raise TemplateDoesNotExist(template_name) @@ -108,11 +108,13 @@ class TemplateHTMLRendererExceptionTests(TestCase):      def test_not_found_html_view_with_template(self):          response = self.client.get('/not_found')          self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) -        self.assertEqual(response.content, six.b("404: Not found")) +        self.assertTrue(response.content in ( +            six.b("404: Not found"), six.b("404 Not Found")))          self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')      def test_permission_denied_html_view_with_template(self):          response = self.client.get('/permission_denied')          self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) -        self.assertEqual(response.content, six.b("403: Permission denied")) +        self.assertTrue(response.content in ( +            six.b("403: Permission denied"), six.b("403 Forbidden")))          self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') diff --git a/rest_framework/tests/test_hyperlinkedserializers.py b/rest_framework/tests/test_hyperlinkedserializers.py index ea7f70f2..5fb1b47e 100644 --- a/rest_framework/tests/test_hyperlinkedserializers.py +++ b/rest_framework/tests/test_hyperlinkedserializers.py @@ -1,8 +1,9 @@  from __future__ import unicode_literals  import json -from django.conf.urls import patterns, url  from django.test import TestCase  from rest_framework import generics, status, serializers +from django.conf.urls import patterns, url +from rest_framework.settings import api_settings  from rest_framework.test import APIRequestFactory  from rest_framework.tests.models import (      Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, @@ -331,3 +332,48 @@ class TestOverriddenURLField(TestCase):              serializer.data,              {'title': 'New blog post', 'url': 'foo bar'}          ) + + +class TestURLFieldNameBySettings(TestCase): +    urls = 'rest_framework.tests.test_hyperlinkedserializers' + +    def setUp(self): +        self.saved_url_field_name = api_settings.URL_FIELD_NAME +        api_settings.URL_FIELD_NAME = 'global_url_field' + +        class Serializer(serializers.HyperlinkedModelSerializer): + +            class Meta: +                model = BlogPost +                fields = ('title', api_settings.URL_FIELD_NAME) + +        self.Serializer = Serializer +        self.obj = BlogPost.objects.create(title="New blog post") + +    def tearDown(self): +        api_settings.URL_FIELD_NAME = self.saved_url_field_name + +    def test_overridden_url_field_name(self): +        request = factory.get('/posts/') +        serializer = self.Serializer(self.obj, context={'request': request}) +        self.assertIn(api_settings.URL_FIELD_NAME, serializer.data) + + +class TestURLFieldNameByOptions(TestCase): +    urls = 'rest_framework.tests.test_hyperlinkedserializers' + +    def setUp(self): +        class Serializer(serializers.HyperlinkedModelSerializer): + +            class Meta: +                model = BlogPost +                fields = ('title', 'serializer_url_field') +                url_field_name = 'serializer_url_field' + +        self.Serializer = Serializer +        self.obj = BlogPost.objects.create(title="New blog post") + +    def test_overridden_url_field_name(self): +        request = factory.get('/posts/') +        serializer = self.Serializer(self.obj, context={'request': request}) +        self.assertIn(self.Serializer.Meta.url_field_name, serializer.data) diff --git a/rest_framework/tests/test_nullable_fields.py b/rest_framework/tests/test_nullable_fields.py new file mode 100644 index 00000000..4812530e --- /dev/null +++ b/rest_framework/tests/test_nullable_fields.py @@ -0,0 +1,30 @@ +from django.core.urlresolvers import reverse + +from django.conf.urls import patterns, url +from rest_framework.test import APITestCase +from rest_framework.tests.models import NullableForeignKeySource +from rest_framework.tests.serializers import NullableFKSourceSerializer +from rest_framework.tests.views import NullableFKSourceDetail + + +urlpatterns = patterns( +    '', +    url(r'^objects/(?P<pk>\d+)/$', NullableFKSourceDetail.as_view(), name='object-detail'), +) + + +class NullableForeignKeyTests(APITestCase): +    """ +    DRF should be able to handle nullable foreign keys when a test +    Client POST/PUT request is made with its own serialized object. +    """ +    urls = 'rest_framework.tests.test_nullable_fields' + +    def test_updating_object_with_null_fk(self): +        obj = NullableForeignKeySource(name='example', target=None) +        obj.save() +        serialized_data = NullableFKSourceSerializer(obj).data + +        response = self.client.put(reverse('object-detail', args=[obj.pk]), serialized_data) + +        self.assertEqual(response.data, serialized_data) diff --git a/rest_framework/tests/test_pagination.py b/rest_framework/tests/test_pagination.py index cadb515f..24c1ba39 100644 --- a/rest_framework/tests/test_pagination.py +++ b/rest_framework/tests/test_pagination.py @@ -9,14 +9,18 @@ from rest_framework import generics, status, pagination, filters, serializers  from rest_framework.compat import django_filters  from rest_framework.test import APIRequestFactory  from rest_framework.tests.models import BasicModel +from .models import FilterableItem  factory = APIRequestFactory() +# Helper function to split arguments out of an url +def split_arguments_from_url(url): +    if '?' not in url: +        return url -class FilterableItem(models.Model): -    text = models.CharField(max_length=100) -    decimal = models.DecimalField(max_digits=4, decimal_places=2) -    date = models.DateField() +    path, args = url.split('?') +    args = dict(r.split('=') for r in args.split('&')) +    return path, args  class RootView(generics.ListCreateAPIView): @@ -84,7 +88,7 @@ class IntegrationTestPagination(TestCase):          self.assertNotEqual(response.data['next'], None)          self.assertEqual(response.data['previous'], None) -        request = factory.get(response.data['next']) +        request = factory.get(*split_arguments_from_url(response.data['next']))          with self.assertNumQueries(2):              response = self.view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -93,7 +97,7 @@ class IntegrationTestPagination(TestCase):          self.assertNotEqual(response.data['next'], None)          self.assertNotEqual(response.data['previous'], None) -        request = factory.get(response.data['next']) +        request = factory.get(*split_arguments_from_url(response.data['next']))          with self.assertNumQueries(2):              response = self.view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -146,7 +150,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):          EXPECTED_NUM_QUERIES = 2 -        request = factory.get('/?decimal=15.20') +        request = factory.get('/', {'decimal': '15.20'})          with self.assertNumQueries(EXPECTED_NUM_QUERIES):              response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -155,7 +159,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):          self.assertNotEqual(response.data['next'], None)          self.assertEqual(response.data['previous'], None) -        request = factory.get(response.data['next']) +        request = factory.get(*split_arguments_from_url(response.data['next']))          with self.assertNumQueries(EXPECTED_NUM_QUERIES):              response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -164,7 +168,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):          self.assertEqual(response.data['next'], None)          self.assertNotEqual(response.data['previous'], None) -        request = factory.get(response.data['previous']) +        request = factory.get(*split_arguments_from_url(response.data['previous']))          with self.assertNumQueries(EXPECTED_NUM_QUERIES):              response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -191,7 +195,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):          view = BasicFilterFieldsRootView.as_view() -        request = factory.get('/?decimal=15.20') +        request = factory.get('/', {'decimal': '15.20'})          with self.assertNumQueries(2):              response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -200,7 +204,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):          self.assertNotEqual(response.data['next'], None)          self.assertEqual(response.data['previous'], None) -        request = factory.get(response.data['next']) +        request = factory.get(*split_arguments_from_url(response.data['next']))          with self.assertNumQueries(2):              response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -209,7 +213,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):          self.assertEqual(response.data['next'], None)          self.assertNotEqual(response.data['previous'], None) -        request = factory.get(response.data['previous']) +        request = factory.get(*split_arguments_from_url(response.data['previous']))          with self.assertNumQueries(2):              response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -317,7 +321,7 @@ class TestCustomPaginateByParam(TestCase):          """          If paginate_by_param is set, the new kwarg should limit per view requests.          """ -        request = factory.get('/?page_size=5') +        request = factory.get('/', {'page_size': 5})          response = self.view(request).render()          self.assertEqual(response.data['count'], 13)          self.assertEqual(response.data['results'], self.data[:5]) @@ -345,7 +349,7 @@ class TestMaxPaginateByParam(TestCase):          """          If max_paginate_by is set, it should limit page size for the view.          """ -        request = factory.get('/?page_size=10') +        request = factory.get('/', data={'page_size': 10})          response = self.view(request).render()          self.assertEqual(response.data['count'], 13)          self.assertEqual(response.data['results'], self.data[:5]) diff --git a/rest_framework/tests/test_relations.py b/rest_framework/tests/test_relations.py index d19219c9..37ac826b 100644 --- a/rest_framework/tests/test_relations.py +++ b/rest_framework/tests/test_relations.py @@ -2,8 +2,10 @@  General tests for relational fields.  """  from __future__ import unicode_literals +from django import get_version  from django.db import models  from django.test import TestCase +from django.utils import unittest  from rest_framework import serializers  from rest_framework.tests.models import BlogPost @@ -98,3 +100,45 @@ class RelatedFieldSourceTests(TestCase):          obj = ClassWithQuerysetMethod()          value = field.field_to_native(obj, 'field_name')          self.assertEqual(value, ['BlogPost object']) + +    # Regression for #1129 +    def test_exception_for_incorect_fk(self): +        """ +        Check that the exception message are correct if the source field +        doesn't exist. +        """ +        from rest_framework.tests.models import ManyToManySource +        class Meta: +            model = ManyToManySource +        attrs = { +            'name': serializers.SlugRelatedField( +                slug_field='name', source='banzai'), +            'Meta': Meta, +        } + +        TestSerializer = type(str('TestSerializer'), +            (serializers.ModelSerializer,), attrs) +        with self.assertRaises(AttributeError): +            TestSerializer(data={'name': 'foo'}) + +@unittest.skipIf(get_version() < '1.6.0', 'Upstream behaviour changed in v1.6') +class RelatedFieldChoicesTests(TestCase): +    """ +    Tests for #1408 "Web browseable API doesn't have blank option on drop down list box" +    https://github.com/tomchristie/django-rest-framework/issues/1408 +    """ +    def test_blank_option_is_added_to_choice_if_required_equals_false(self): +        """ + +        """ +        post = BlogPost(title="Checking blank option is added") +        post.save() + +        queryset = BlogPost.objects.all() +        field = serializers.RelatedField(required=False, queryset=queryset) + +        choice_count = BlogPost.objects.count() +        widget_count = len(field.widget.choices) + +        self.assertEqual(widget_count, choice_count + 1, 'BLANK_CHOICE_DASH option should have been added') + diff --git a/rest_framework/tests/test_relations_nested.py b/rest_framework/tests/test_relations_nested.py index d393b0c3..4d9da489 100644 --- a/rest_framework/tests/test_relations_nested.py +++ b/rest_framework/tests/test_relations_nested.py @@ -3,9 +3,7 @@ from django.db import models  from django.test import TestCase  from rest_framework import serializers - -class OneToOneTarget(models.Model): -    name = models.CharField(max_length=100) +from .models import OneToOneTarget  class OneToOneSource(models.Model): diff --git a/rest_framework/tests/test_renderers.py b/rest_framework/tests/test_renderers.py index 9cb68233..460c02a9 100644 --- a/rest_framework/tests/test_renderers.py +++ b/rest_framework/tests/test_renderers.py @@ -4,6 +4,7 @@ from __future__ import unicode_literals  from decimal import Decimal  from django.conf.urls import patterns, url, include  from django.core.cache import cache +from django.db import models  from django.test import TestCase  from django.utils import unittest  from django.utils.translation import ugettext_lazy as _ @@ -35,6 +36,10 @@ expected_results = [  ] +class DummyTestModel(models.Model): +    name = models.CharField(max_length=42, default='') + +  class BasicRendererTests(TestCase):      def test_expected_results(self):          for value, renderer_cls, expected in expected_results: @@ -252,6 +257,18 @@ class RendererEndToEndTests(TestCase):          self.assertEqual(resp.get('Content-Type', None), None)          self.assertEqual(resp.status_code, status.HTTP_204_NO_CONTENT) +    def test_contains_headers_of_api_response(self): +        """ +        Issue #1437 + +        Test we display the headers of the API response and not those from the +        HTML response +        """ +        resp = self.client.get('/html1') +        self.assertContains(resp, '>GET, HEAD, OPTIONS<') +        self.assertContains(resp, '>application/json<') +        self.assertNotContains(resp, '>text/html; charset=utf-8<') +  _flat_repr = '{"foo": ["bar", "baz"]}'  _indented_repr = '{\n  "foo": [\n    "bar",\n    "baz"\n  ]\n}' @@ -277,6 +294,20 @@ class JSONRendererTests(TestCase):          ret = JSONRenderer().render(_('test'))          self.assertEqual(ret, b'"test"') +    def test_render_queryset_values(self): +        o = DummyTestModel.objects.create(name='dummy') +        qs = DummyTestModel.objects.values('id', 'name') +        ret = JSONRenderer().render(qs) +        data = json.loads(ret.decode('utf-8')) +        self.assertEquals(data, [{'id': o.id, 'name': o.name}]) + +    def test_render_queryset_values_list(self): +        o = DummyTestModel.objects.create(name='dummy') +        qs = DummyTestModel.objects.values_list('id', 'name') +        ret = JSONRenderer().render(qs) +        data = json.loads(ret.decode('utf-8')) +        self.assertEquals(data, [[o.id, o.name]]) +      def test_render_dict_abc_obj(self):          class Dict(MutableMapping):              def __init__(self): @@ -583,6 +614,10 @@ class CacheRenderTest(TestCase):          method = getattr(self.client, http_method)          resp = method(url)          del resp.client, resp.request +        try: +            del resp.wsgi_request +        except AttributeError: +            pass          return resp      def test_obj_pickling(self): diff --git a/rest_framework/tests/test_request.py b/rest_framework/tests/test_request.py index f07c31a3..e0da5fd4 100644 --- a/rest_framework/tests/test_request.py +++ b/rest_framework/tests/test_request.py @@ -68,6 +68,9 @@ class TestMethodOverloading(TestCase):          request = Request(factory.post('/', {'foo': 'bar'}, HTTP_X_HTTP_METHOD_OVERRIDE='DELETE'))          self.assertEqual(request.method, 'DELETE') +        request = Request(factory.get('/', {'foo': 'bar'}, HTTP_X_HTTP_METHOD_OVERRIDE='DELETE')) +        self.assertEqual(request.method, 'DELETE') +  class TestContentParsing(TestCase):      def test_standard_behaviour_determines_no_content_GET(self): diff --git a/rest_framework/tests/test_serializer.py b/rest_framework/tests/test_serializer.py index 6d9b85ee..a09bf6f5 100644 --- a/rest_framework/tests/test_serializer.py +++ b/rest_framework/tests/test_serializer.py @@ -3,6 +3,7 @@ from __future__ import unicode_literals  from django.db import models  from django.db.models.fields import BLANK_CHOICE_DASH  from django.test import TestCase +from django.utils import unittest  from django.utils.datastructures import MultiValueDict  from django.utils.translation import ugettext_lazy as _  from rest_framework import serializers, fields, relations @@ -12,6 +13,31 @@ from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, Acti  from rest_framework.tests.models import BasicModelSerializer  import datetime  import pickle +try: +    import PIL +except: +    PIL = None + + +if PIL is not None: +    class AMOAFModel(RESTFrameworkModel): +        char_field = models.CharField(max_length=1024, blank=True) +        comma_separated_integer_field = models.CommaSeparatedIntegerField(max_length=1024, blank=True) +        decimal_field = models.DecimalField(max_digits=64, decimal_places=32, blank=True) +        email_field = models.EmailField(max_length=1024, blank=True) +        file_field = models.FileField(upload_to='test', max_length=1024, blank=True) +        image_field = models.ImageField(upload_to='test', max_length=1024, blank=True) +        slug_field = models.SlugField(max_length=1024, blank=True) +        url_field = models.URLField(max_length=1024, blank=True) + +    class DVOAFModel(RESTFrameworkModel): +        positive_integer_field = models.PositiveIntegerField(blank=True) +        positive_small_integer_field = models.PositiveSmallIntegerField(blank=True) +        email_field = models.EmailField(blank=True) +        file_field = models.FileField(upload_to='test', blank=True) +        image_field = models.ImageField(upload_to='test', blank=True) +        slug_field = models.SlugField(blank=True) +        url_field = models.URLField(blank=True)  class SubComment(object): @@ -71,6 +97,15 @@ class ActionItemSerializer(serializers.ModelSerializer):      class Meta:          model = ActionItem +class ActionItemSerializerOptionalFields(serializers.ModelSerializer): +    """ +    Intended to test that fields with `required=False` are excluded from validation. +    """ +    title = serializers.CharField(required=False) + +    class Meta: +        model = ActionItem +        fields = ('title',)  class ActionItemSerializerCustomRestore(serializers.ModelSerializer): @@ -132,7 +167,7 @@ class AlbumsSerializer(serializers.ModelSerializer):      class Meta:          model = Album -        fields = ['title']  # lists are also valid options +        fields = ['title', 'ref']  # lists are also valid options  class PositiveIntegerAsChoiceSerializer(serializers.ModelSerializer): @@ -288,7 +323,13 @@ class BasicTests(TestCase):          serializer.save()          self.assertIsNotNone(serializer.data.get('id',None), 'Model is saved. `id` should be set.') - +    def test_fields_marked_as_not_required_are_excluded_from_validation(self): +        """ +        Check that fields with `required=False` are included in list of exclusions. +        """ +        serializer = ActionItemSerializerOptionalFields(self.actionitem) +        exclusions = serializer.get_validation_exclusions() +        self.assertTrue('title' in exclusions, '`title` field was marked `required=False` and should be excluded')  class DictStyleSerializer(serializers.Serializer): @@ -467,6 +508,32 @@ class ValidationTests(TestCase):          )          self.assertEqual(serializer.is_valid(), True) +    def test_writable_star_source_on_nested_serializer_with_parent_object(self): +        class TitleSerializer(serializers.Serializer): +            title = serializers.WritableField(source='title') + +        class AlbumSerializer(serializers.ModelSerializer): +            nested = TitleSerializer(source='*') + +            class Meta: +                model = Album +                fields = ('nested',) + +        class PhotoSerializer(serializers.ModelSerializer): +            album = AlbumSerializer(source='album') + +            class Meta: +                model = Photo +                fields = ('album', ) + +        photo = Photo(album=Album()) + +        data = {'album': {'nested': {'title': 'test'}}} + +        serializer = PhotoSerializer(photo, data=data) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(serializer.data, data) +      def test_writable_star_source_with_inner_source_fields(self):          """          Tests that a serializer with source="*" correctly expands the @@ -576,12 +643,15 @@ class ModelValidationTests(TestCase):          """          Just check if serializers.ModelSerializer handles unique checks via .full_clean()          """ -        serializer = AlbumsSerializer(data={'title': 'a'}) +        serializer = AlbumsSerializer(data={'title': 'a', 'ref': '1'})          serializer.is_valid()          serializer.save()          second_serializer = AlbumsSerializer(data={'title': 'a'})          self.assertFalse(second_serializer.is_valid()) -        self.assertEqual(second_serializer.errors,  {'title': ['Album with this Title already exists.']}) +        self.assertEqual(second_serializer.errors,  {'title': ['Album with this Title already exists.'],}) +        third_serializer = AlbumsSerializer(data=[{'title': 'b', 'ref': '1'}, {'title': 'c'}]) +        self.assertFalse(third_serializer.is_valid()) +        self.assertEqual(third_serializer.errors,  [{'ref': ['Album with this Ref already exists.']}, {}])      def test_foreign_key_is_null_with_partial(self):          """ @@ -865,6 +935,58 @@ class DefaultValueTests(TestCase):          self.assertEqual(instance.text, 'overridden') +class WritableFieldDefaultValueTests(TestCase): + +    def setUp(self): +        self.expected = {'default': 'value'} +        self.create_field = fields.WritableField + +    def test_get_default_value_with_noncallable(self): +        field = self.create_field(default=self.expected) +        got = field.get_default_value() +        self.assertEqual(got, self.expected) + +    def test_get_default_value_with_callable(self): +        field = self.create_field(default=lambda : self.expected) +        got = field.get_default_value() +        self.assertEqual(got, self.expected) + +    def test_get_default_value_when_not_required(self): +        field = self.create_field(default=self.expected, required=False) +        got = field.get_default_value() +        self.assertEqual(got, self.expected) + +    def test_get_default_value_returns_None(self): +        field = self.create_field() +        got = field.get_default_value() +        self.assertIsNone(got) + +    def test_get_default_value_returns_non_True_values(self): +        values = [None, '', False, 0, [], (), {}] # values that assumed as 'False' in the 'if' clause +        for expected in values: +            field = self.create_field(default=expected) +            got = field.get_default_value() +            self.assertEqual(got, expected) + + +class RelatedFieldDefaultValueTests(WritableFieldDefaultValueTests): + +    def setUp(self): +        self.expected = {'foo': 'bar'} +        self.create_field = relations.RelatedField + +    def test_get_default_value_returns_empty_list(self): +        field = self.create_field(many=True) +        got = field.get_default_value() +        self.assertListEqual(got, []) + +    def test_get_default_value_returns_expected(self): +        expected = [1, 2, 3] +        field = self.create_field(many=True, default=expected) +        got = field.get_default_value() +        self.assertListEqual(got, expected) + +  class CallableDefaultValueTests(TestCase):      def setUp(self):          class CallableDefaultValueSerializer(serializers.ModelSerializer): @@ -1492,19 +1614,10 @@ class ManyFieldHelpTextTest(TestCase):          self.assertEqual('Some help text.', rel_field.help_text) +@unittest.skipUnless(PIL is not None, 'PIL is not installed')  class AttributeMappingOnAutogeneratedFieldsTests(TestCase):      def setUp(self): -        class AMOAFModel(RESTFrameworkModel): -            char_field = models.CharField(max_length=1024, blank=True) -            comma_separated_integer_field = models.CommaSeparatedIntegerField(max_length=1024, blank=True) -            decimal_field = models.DecimalField(max_digits=64, decimal_places=32, blank=True) -            email_field = models.EmailField(max_length=1024, blank=True) -            file_field = models.FileField(max_length=1024, blank=True) -            image_field = models.ImageField(max_length=1024, blank=True) -            slug_field = models.SlugField(max_length=1024, blank=True) -            url_field = models.URLField(max_length=1024, blank=True) -            nullable_char_field = models.CharField(max_length=1024, blank=True, null=True)          class AMOAFSerializer(serializers.ModelSerializer):              class Meta: @@ -1581,17 +1694,10 @@ class AttributeMappingOnAutogeneratedFieldsTests(TestCase):          self.field_test('nullable_char_field') +@unittest.skipUnless(PIL is not None, 'PIL is not installed')  class DefaultValuesOnAutogeneratedFieldsTests(TestCase):      def setUp(self): -        class DVOAFModel(RESTFrameworkModel): -            positive_integer_field = models.PositiveIntegerField(blank=True) -            positive_small_integer_field = models.PositiveSmallIntegerField(blank=True) -            email_field = models.EmailField(blank=True) -            file_field = models.FileField(blank=True) -            image_field = models.ImageField(blank=True) -            slug_field = models.SlugField(blank=True) -            url_field = models.URLField(blank=True)          class DVOAFSerializer(serializers.ModelSerializer):              class Meta: @@ -1830,14 +1936,14 @@ class SerializerDefaultTrueBoolean(TestCase):          self.assertEqual(serializer.data['cat'], False)          self.assertEqual(serializer.data['dog'], False) -         +  class BoolenFieldTypeTest(TestCase):      '''      Ensure the various Boolean based model fields are rendered as the proper      field type -     +      ''' -     +      def setUp(self):          '''          Setup an ActionItemSerializer for BooleanTesting @@ -1853,11 +1959,11 @@ class BoolenFieldTypeTest(TestCase):          '''          bfield = self.serializer.get_fields()['done']          self.assertEqual(type(bfield), fields.BooleanField) -     +      def test_nullbooleanfield_type(self):          ''' -        Test that BooleanField is infered from models.NullBooleanField  -         +        Test that BooleanField is infered from models.NullBooleanField +          https://groups.google.com/forum/#!topic/django-rest-framework/D9mXEftpuQ8          '''          bfield = self.serializer.get_fields()['started'] diff --git a/rest_framework/tests/test_serializer_import.py b/rest_framework/tests/test_serializer_import.py new file mode 100644 index 00000000..9f30a7ff --- /dev/null +++ b/rest_framework/tests/test_serializer_import.py @@ -0,0 +1,19 @@ +from django.test import TestCase + +from rest_framework import serializers +from rest_framework.tests.accounts.serializers import AccountSerializer + + +class ImportingModelSerializerTests(TestCase): +    """ +    In some situations like, GH #1225, it is possible, especially in +    testing, to import a serializer who's related models have not yet +    been resolved by Django. `AccountSerializer` is an example of such +    a serializer (imported at the top of this file). +    """ +    def test_import_model_serializer(self): +        """ +        The serializer at the top of this file should have been +        imported successfully, and we should be able to instantiate it. +        """ +        self.assertIsInstance(AccountSerializer(), serializers.ModelSerializer) diff --git a/rest_framework/tests/test_serializer_nested.py b/rest_framework/tests/test_serializer_nested.py index 7114a060..6d69ffbd 100644 --- a/rest_framework/tests/test_serializer_nested.py +++ b/rest_framework/tests/test_serializer_nested.py @@ -345,4 +345,3 @@ class NestedModelSerializerUpdateTests(TestCase):          result = deserialize.object          result.save()          self.assertEqual(result.id, john.id) - diff --git a/rest_framework/tests/test_serializers.py b/rest_framework/tests/test_serializers.py new file mode 100644 index 00000000..082a400c --- /dev/null +++ b/rest_framework/tests/test_serializers.py @@ -0,0 +1,28 @@ +from django.db import models +from django.test import TestCase + +from rest_framework.serializers import _resolve_model +from rest_framework.tests.models import BasicModel + + +class ResolveModelTests(TestCase): +    """ +    `_resolve_model` should return a Django model class given the +    provided argument is a Django model class itself, or a properly +    formatted string representation of one. +    """ +    def test_resolve_django_model(self): +        resolved_model = _resolve_model(BasicModel) +        self.assertEqual(resolved_model, BasicModel) + +    def test_resolve_string_representation(self): +        resolved_model = _resolve_model('tests.BasicModel') +        self.assertEqual(resolved_model, BasicModel) + +    def test_resolve_non_django_model(self): +        with self.assertRaises(ValueError): +            _resolve_model(TestCase) + +    def test_resolve_improper_string_representation(self): +        with self.assertRaises(ValueError): +            _resolve_model('BasicModel') diff --git a/rest_framework/tests/test_templatetags.py b/rest_framework/tests/test_templatetags.py new file mode 100644 index 00000000..d4da0c23 --- /dev/null +++ b/rest_framework/tests/test_templatetags.py @@ -0,0 +1,51 @@ +# encoding: utf-8 +from __future__ import unicode_literals +from django.test import TestCase +from rest_framework.test import APIRequestFactory +from rest_framework.templatetags.rest_framework import add_query_param, urlize_quoted_links + +factory = APIRequestFactory() + + +class TemplateTagTests(TestCase): + +    def test_add_query_param_with_non_latin_charactor(self): +        # Ensure we don't double-escape non-latin characters +        # that are present in the querystring. +        # See #1314. +        request = factory.get("/", {'q': '查询'}) +        json_url = add_query_param(request, "format", "json") +        self.assertIn("q=%E6%9F%A5%E8%AF%A2", json_url) +        self.assertIn("format=json", json_url) + + +class Issue1386Tests(TestCase): +    """ +    Covers #1386 +    """ + +    def test_issue_1386(self): +        """ +        Test function urlize_quoted_links with different args +        """ +        correct_urls = [ +            "asdf.com", +            "asdf.net", +            "www.as_df.org", +            "as.d8f.ghj8.gov", +        ] +        for i in correct_urls: +            res = urlize_quoted_links(i) +            self.assertNotEqual(res, i) +            self.assertIn(i, res) + +        incorrect_urls = [ +            "mailto://asdf@fdf.com", +            "asdf.netnet", +        ] +        for i in incorrect_urls: +            res = urlize_quoted_links(i) +            self.assertEqual(i, res) + +        # example from issue #1386, this shouldn't raise an exception +        _ = urlize_quoted_links("asdf:[/p]zxcv.com") diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py index c08dd493..83ae8148 100644 --- a/rest_framework/tests/test_testing.py +++ b/rest_framework/tests/test_testing.py @@ -2,6 +2,8 @@  from __future__ import unicode_literals  from django.conf.urls import patterns, url +from io import BytesIO +  from django.contrib.auth.models import User  from django.test import TestCase  from rest_framework.decorators import api_view @@ -143,3 +145,20 @@ class TestAPIRequestFactory(TestCase):          force_authenticate(request, user=user)          response = view(request)          self.assertEqual(response.data['user'], 'example') + +    def test_upload_file(self): +        # This is a 1x1 black png +        simple_png = BytesIO(b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x06\x00\x00\x00\x1f\x15\xc4\x89\x00\x00\x00\rIDATx\x9cc````\x00\x00\x00\x05\x00\x01\xa5\xf6E@\x00\x00\x00\x00IEND\xaeB`\x82') +        simple_png.name = 'test.png' +        factory = APIRequestFactory() +        factory.post('/', data={'image': simple_png}) + +    def test_request_factory_url_arguments(self): +        """ +        This is a non regression test against #1461 +        """ +        factory = APIRequestFactory() +        request = factory.get('/view/?demo=test') +        self.assertEqual(dict(request.GET), {'demo': ['test']}) +        request = factory.get('/view/', {'demo': 'test'}) +        self.assertEqual(dict(request.GET), {'demo': ['test']}) diff --git a/rest_framework/tests/test_validation.py b/rest_framework/tests/test_validation.py index 124c874d..e13e4078 100644 --- a/rest_framework/tests/test_validation.py +++ b/rest_framework/tests/test_validation.py @@ -1,4 +1,5 @@  from __future__ import unicode_literals +from django.core.validators import MaxValueValidator  from django.db import models  from django.test import TestCase  from rest_framework import generics, serializers, status @@ -102,3 +103,46 @@ class TestAvoidValidation(TestCase):          self.assertFalse(serializer.is_valid())          self.assertDictEqual(serializer.errors,                               {'non_field_errors': ['Invalid data']}) + + +# regression tests for issue: 1493 + +class ValidationMaxValueValidatorModel(models.Model): +    number_value = models.PositiveIntegerField(validators=[MaxValueValidator(100)]) + + +class ValidationMaxValueValidatorModelSerializer(serializers.ModelSerializer): +    class Meta: +        model = ValidationMaxValueValidatorModel + + +class UpdateMaxValueValidationModel(generics.RetrieveUpdateDestroyAPIView): +    model = ValidationMaxValueValidatorModel +    serializer_class = ValidationMaxValueValidatorModelSerializer + + +class TestMaxValueValidatorValidation(TestCase): + +    def test_max_value_validation_serializer_success(self): +        serializer = ValidationMaxValueValidatorModelSerializer(data={'number_value': 99}) +        self.assertTrue(serializer.is_valid()) + +    def test_max_value_validation_serializer_fails(self): +        serializer = ValidationMaxValueValidatorModelSerializer(data={'number_value': 101}) +        self.assertFalse(serializer.is_valid()) +        self.assertDictEqual({'number_value': ['Ensure this value is less than or equal to 100.']}, serializer.errors) + +    def test_max_value_validation_success(self): +        obj = ValidationMaxValueValidatorModel.objects.create(number_value=100) +        request = factory.patch('/{0}'.format(obj.pk), {'number_value': 98}, format='json') +        view = UpdateMaxValueValidationModel().as_view() +        response = view(request, pk=obj.pk).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) + +    def test_max_value_validation_fail(self): +        obj = ValidationMaxValueValidatorModel.objects.create(number_value=100) +        request = factory.patch('/{0}'.format(obj.pk), {'number_value': 101}, format='json') +        view = UpdateMaxValueValidationModel().as_view() +        response = view(request, pk=obj.pk).render() +        self.assertEqual(response.content, b'{"number_value": ["Ensure this value is less than or equal to 100."]}') +        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) diff --git a/rest_framework/tests/test_write_only_fields.py b/rest_framework/tests/test_write_only_fields.py new file mode 100644 index 00000000..aabb18d6 --- /dev/null +++ b/rest_framework/tests/test_write_only_fields.py @@ -0,0 +1,42 @@ +from django.db import models +from django.test import TestCase +from rest_framework import serializers + + +class ExampleModel(models.Model): +    email = models.EmailField(max_length=100) +    password = models.CharField(max_length=100) + + +class WriteOnlyFieldTests(TestCase): +    def test_write_only_fields(self): +        class ExampleSerializer(serializers.Serializer): +            email = serializers.EmailField() +            password = serializers.CharField(write_only=True) + +        data = { +            'email': 'foo@example.com', +            'password': '123' +        } +        serializer = ExampleSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        self.assertEquals(serializer.object, data) +        self.assertEquals(serializer.data, {'email': 'foo@example.com'}) + +    def test_write_only_fields_meta(self): +        class ExampleSerializer(serializers.ModelSerializer): +            class Meta: +                model = ExampleModel +                fields = ('email', 'password') +                write_only_fields = ('password',) + +        data = { +            'email': 'foo@example.com', +            'password': '123' +        } +        serializer = ExampleSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        self.assertTrue(isinstance(serializer.object, ExampleModel)) +        self.assertEquals(serializer.object.email, data['email']) +        self.assertEquals(serializer.object.password, data['password']) +        self.assertEquals(serializer.data, {'email': 'foo@example.com'}) diff --git a/rest_framework/tests/users/__init__.py b/rest_framework/tests/users/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/rest_framework/tests/users/__init__.py diff --git a/rest_framework/tests/users/models.py b/rest_framework/tests/users/models.py new file mode 100644 index 00000000..128bac90 --- /dev/null +++ b/rest_framework/tests/users/models.py @@ -0,0 +1,6 @@ +from django.db import models + + +class User(models.Model): +    account = models.ForeignKey('accounts.Account', blank=True, null=True, related_name='users') +    active_record = models.ForeignKey('records.Record', blank=True, null=True) diff --git a/rest_framework/tests/users/serializers.py b/rest_framework/tests/users/serializers.py new file mode 100644 index 00000000..da496554 --- /dev/null +++ b/rest_framework/tests/users/serializers.py @@ -0,0 +1,8 @@ +from rest_framework import serializers + +from rest_framework.tests.users.models import User + + +class UserSerializer(serializers.ModelSerializer): +    class Meta: +        model = User diff --git a/rest_framework/tests/utils.py b/rest_framework/tests/utils.py new file mode 100644 index 00000000..a8f2eb0b --- /dev/null +++ b/rest_framework/tests/utils.py @@ -0,0 +1,25 @@ +from contextlib import contextmanager +from rest_framework.compat import six +from rest_framework.settings import api_settings + + +@contextmanager +def temporary_setting(setting, value, module=None): +    """ +    Temporarily change value of setting for test. + +    Optionally reload given module, useful when module uses value of setting on +    import. +    """ +    original_value = getattr(api_settings, setting) +    setattr(api_settings, setting, value) + +    if module is not None: +        six.moves.reload_module(module) + +    yield + +    setattr(api_settings, setting, original_value) + +    if module is not None: +        six.moves.reload_module(module) diff --git a/rest_framework/tests/views.py b/rest_framework/tests/views.py new file mode 100644 index 00000000..3917b74a --- /dev/null +++ b/rest_framework/tests/views.py @@ -0,0 +1,8 @@ +from rest_framework import generics +from rest_framework.tests.models import NullableForeignKeySource +from rest_framework.tests.serializers import NullableFKSourceSerializer + + +class NullableFKSourceDetail(generics.RetrieveUpdateDestroyAPIView): +    model = NullableForeignKeySource +    model_serializer_class = NullableFKSourceSerializer | 
