aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/tests
diff options
context:
space:
mode:
authorRyan Kaskel2013-05-18 14:17:50 +0100
committerRyan Kaskel2013-05-18 14:17:50 +0100
commit22874e441dd71101296a656e753bfc17907b5cca (patch)
tree6ebf7971e5bf8d40c6d60fa857cbe0c04fc91372 /rest_framework/tests
parentb5640bb77843c50f42a649982b9b9592113c6f59 (diff)
parenta0e3c44c99a61a6dc878308bdf0890fbb10c41e4 (diff)
downloaddjango-rest-framework-22874e441dd71101296a656e753bfc17907b5cca.tar.bz2
Merge latest changes from master.
Diffstat (limited to 'rest_framework/tests')
-rw-r--r--rest_framework/tests/authentication.py44
-rw-r--r--rest_framework/tests/description.py63
-rw-r--r--rest_framework/tests/fields.py208
-rw-r--r--rest_framework/tests/filters.py474
-rw-r--r--rest_framework/tests/filterset.py169
-rw-r--r--rest_framework/tests/generics.py105
-rw-r--r--rest_framework/tests/hyperlinkedserializers.py40
-rw-r--r--rest_framework/tests/models.py7
-rw-r--r--rest_framework/tests/pagination.py24
-rw-r--r--rest_framework/tests/parsers.py33
-rw-r--r--rest_framework/tests/relations.py55
-rw-r--r--rest_framework/tests/relations_hyperlink.py87
-rw-r--r--rest_framework/tests/relations_nested.py24
-rw-r--r--rest_framework/tests/relations_pk.py74
-rw-r--r--rest_framework/tests/routers.py55
-rw-r--r--rest_framework/tests/serializer.py224
-rw-r--r--rest_framework/tests/serializer_bulk_update.py34
-rw-r--r--rest_framework/tests/serializer_nested.py4
18 files changed, 1368 insertions, 356 deletions
diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py
index b663ca48..8e6d3e51 100644
--- a/rest_framework/tests/authentication.py
+++ b/rest_framework/tests/authentication.py
@@ -466,17 +466,13 @@ class OAuth2Tests(TestCase):
def _create_authorization_header(self, token=None):
return "Bearer {0}".format(token or self.access_token.token)
- def _client_credentials_params(self):
- return {'client_id': self.CLIENT_ID, 'client_secret': self.CLIENT_SECRET}
-
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_get_form_with_wrong_authorization_header_token_type_failing(self):
"""Ensure that a wrong token type lead to the correct HTTP error status code"""
auth = "Wrong token-type-obsviously"
response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 401)
- params = self._client_credentials_params()
- response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 401)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
@@ -485,8 +481,7 @@ class OAuth2Tests(TestCase):
auth = "Bearer wrong token format"
response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 401)
- params = self._client_credentials_params()
- response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 401)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
@@ -495,33 +490,21 @@ class OAuth2Tests(TestCase):
auth = "Bearer wrong-token"
response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 401)
- params = self._client_credentials_params()
- response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 401)
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_get_form_with_wrong_client_data_failing_auth(self):
- """Ensure GETing form over OAuth with incorrect client credentials fails"""
- auth = self._create_authorization_header()
- params = self._client_credentials_params()
- params['client_id'] += 'a'
- response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 401)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_get_form_passing_auth(self):
"""Ensure GETing form over OAuth with correct client credentials succeed"""
auth = self._create_authorization_header()
- params = self._client_credentials_params()
- response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_post_form_passing_auth(self):
"""Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
auth = self._create_authorization_header()
- params = self._client_credentials_params()
- response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
@@ -529,16 +512,14 @@ class OAuth2Tests(TestCase):
"""Ensure POSTing when there is no OAuth access token in db fails"""
self.access_token.delete()
auth = self._create_authorization_header()
- params = self._client_credentials_params()
- response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_post_form_with_refresh_token_failing_auth(self):
"""Ensure POSTing with refresh token instead of access token fails"""
auth = self._create_authorization_header(token=self.refresh_token.token)
- params = self._client_credentials_params()
- response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
@@ -547,8 +528,7 @@ class OAuth2Tests(TestCase):
self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10) # 10 seconds late
self.access_token.save()
auth = self._create_authorization_header()
- params = self._client_credentials_params()
- response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
self.assertIn('Invalid token', response.content)
@@ -559,10 +539,9 @@ class OAuth2Tests(TestCase):
read_only_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['read']
read_only_access_token.save()
auth = self._create_authorization_header(token=read_only_access_token.token)
- params = self._client_credentials_params()
- response = self.csrf_client.get('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.get('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
- response = self.csrf_client.post('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
@@ -572,6 +551,5 @@ class OAuth2Tests(TestCase):
read_write_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['write']
read_write_access_token.save()
auth = self._create_authorization_header(token=read_write_access_token.token)
- params = self._client_credentials_params()
- response = self.csrf_client.post('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
diff --git a/rest_framework/tests/description.py b/rest_framework/tests/description.py
index 5b3315bc..52c1a34c 100644
--- a/rest_framework/tests/description.py
+++ b/rest_framework/tests/description.py
@@ -4,6 +4,7 @@ from __future__ import unicode_literals
from django.test import TestCase
from rest_framework.views import APIView
from rest_framework.compat import apply_markdown
+from rest_framework.utils.formatting import get_view_name, get_view_description
# We check that docstrings get nicely un-indented.
DESCRIPTION = """an example docstring
@@ -49,22 +50,16 @@ MARKED_DOWN_gte_21 = """<h2 id="an-example-docstring">an example docstring</h2>
class TestViewNamesAndDescriptions(TestCase):
- def test_resource_name_uses_classname_by_default(self):
- """Ensure Resource names are based on the classname by default."""
+ def test_view_name_uses_class_name(self):
+ """
+ Ensure view names are based on the class name.
+ """
class MockView(APIView):
pass
- self.assertEqual(MockView().get_name(), 'Mock')
+ self.assertEqual(get_view_name(MockView), 'Mock')
- def test_resource_name_can_be_set_explicitly(self):
- """Ensure Resource names can be set using the 'get_name' method."""
- example = 'Some Other Name'
- class MockView(APIView):
- def get_name(self):
- return example
- self.assertEqual(MockView().get_name(), example)
-
- def test_resource_description_uses_docstring_by_default(self):
- """Ensure Resource names are based on the docstring by default."""
+ def test_view_description_uses_docstring(self):
+ """Ensure view descriptions are based on the docstring."""
class MockView(APIView):
"""an example docstring
====================
@@ -81,44 +76,32 @@ class TestViewNamesAndDescriptions(TestCase):
# hash style header #"""
- self.assertEqual(MockView().get_description(), DESCRIPTION)
-
- def test_resource_description_can_be_set_explicitly(self):
- """Ensure Resource descriptions can be set using the 'get_description' method."""
- example = 'Some other description'
-
- class MockView(APIView):
- """docstring"""
- def get_description(self):
- return example
- self.assertEqual(MockView().get_description(), example)
+ self.assertEqual(get_view_description(MockView), DESCRIPTION)
- def test_resource_description_supports_unicode(self):
+ def test_view_description_supports_unicode(self):
+ """
+ Unicode in docstrings should be respected.
+ """
class MockView(APIView):
"""Проверка"""
pass
- self.assertEqual(MockView().get_description(), "Проверка")
-
-
- def test_resource_description_does_not_require_docstring(self):
- """Ensure that empty docstrings do not affect the Resource's description if it has been set using the 'get_description' method."""
- example = 'Some other description'
-
- class MockView(APIView):
- def get_description(self):
- return example
- self.assertEqual(MockView().get_description(), example)
+ self.assertEqual(get_view_description(MockView), "Проверка")
- def test_resource_description_can_be_empty(self):
- """Ensure that if a resource has no doctring or 'description' class attribute, then it's description is the empty string."""
+ def test_view_description_can_be_empty(self):
+ """
+ Ensure that if a view has no docstring,
+ then it's description is the empty string.
+ """
class MockView(APIView):
pass
- self.assertEqual(MockView().get_description(), '')
+ self.assertEqual(get_view_description(MockView), '')
def test_markdown(self):
- """Ensure markdown to HTML works as expected"""
+ """
+ Ensure markdown to HTML works as expected.
+ """
if apply_markdown:
gte_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_gte_21
lt_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_lt_21
diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py
index 19c663d8..6b1cdfc7 100644
--- a/rest_framework/tests/fields.py
+++ b/rest_framework/tests/fields.py
@@ -2,13 +2,14 @@
General serializer field tests.
"""
from __future__ import unicode_literals
+from django.utils.datastructures import SortedDict
import datetime
-
+from decimal import Decimal
from django.db import models
from django.test import TestCase
from django.core import validators
-
from rest_framework import serializers
+from rest_framework.serializers import Serializer
class TimestampedModel(models.Model):
@@ -61,6 +62,20 @@ class BasicFieldTests(TestCase):
serializer = CharPrimaryKeyModelSerializer()
self.assertEqual(serializer.fields['id'].read_only, False)
+ def test_dict_field_ordering(self):
+ """
+ Field should preserve dictionary ordering, if it exists.
+ See: https://github.com/tomchristie/django-rest-framework/issues/832
+ """
+ ret = SortedDict()
+ ret['c'] = 1
+ ret['b'] = 1
+ ret['a'] = 1
+ ret['z'] = 1
+ field = serializers.Field()
+ keys = list(field.to_native(ret).keys())
+ self.assertEqual(keys, ['c', 'b', 'a', 'z'])
+
class DateFieldTest(TestCase):
"""
@@ -481,3 +496,192 @@ class TimeFieldTest(TestCase):
self.assertEqual('04 - 00 [000000]', result_1)
self.assertEqual('04 - 59 [000000]', result_2)
self.assertEqual('04 - 59 [000200]', result_3)
+
+
+class DecimalFieldTest(TestCase):
+ """
+ Tests for the DecimalField from_native() and to_native() behavior
+ """
+
+ def test_from_native_string(self):
+ """
+ Make sure from_native() accepts string values
+ """
+ f = serializers.DecimalField()
+ result_1 = f.from_native('9000')
+ result_2 = f.from_native('1.00000001')
+
+ self.assertEqual(Decimal('9000'), result_1)
+ self.assertEqual(Decimal('1.00000001'), result_2)
+
+ def test_from_native_invalid_string(self):
+ """
+ Make sure from_native() raises ValidationError on passing invalid string
+ """
+ f = serializers.DecimalField()
+
+ try:
+ f.from_native('123.45.6')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Enter a number."])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_integer(self):
+ """
+ Make sure from_native() accepts integer values
+ """
+ f = serializers.DecimalField()
+ result = f.from_native(9000)
+
+ self.assertEqual(Decimal('9000'), result)
+
+ def test_from_native_float(self):
+ """
+ Make sure from_native() accepts float values
+ """
+ f = serializers.DecimalField()
+ result = f.from_native(1.00000001)
+
+ self.assertEqual(Decimal('1.00000001'), result)
+
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns None on empty param.
+ """
+ f = serializers.DecimalField()
+ result = f.from_native('')
+
+ self.assertEqual(result, None)
+
+ def test_from_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DecimalField()
+ result = f.from_native(None)
+
+ self.assertEqual(result, None)
+
+ def test_to_native(self):
+ """
+ Make sure to_native() returns Decimal as string.
+ """
+ f = serializers.DecimalField()
+
+ result_1 = f.to_native(Decimal('9000'))
+ result_2 = f.to_native(Decimal('1.00000001'))
+
+ self.assertEqual(Decimal('9000'), result_1)
+ self.assertEqual(Decimal('1.00000001'), result_2)
+
+ def test_to_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DecimalField(required=False)
+ self.assertEqual(None, f.to_native(None))
+
+ def test_valid_serialization(self):
+ """
+ Make sure the serializer works correctly
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(max_value=9010,
+ min_value=9000,
+ max_digits=6,
+ decimal_places=2)
+
+ self.assertTrue(DecimalSerializer(data={'decimal_field': '9001'}).is_valid())
+ self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.2'}).is_valid())
+ self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.23'}).is_valid())
+
+ self.assertFalse(DecimalSerializer(data={'decimal_field': '8000'}).is_valid())
+ self.assertFalse(DecimalSerializer(data={'decimal_field': '9900'}).is_valid())
+ self.assertFalse(DecimalSerializer(data={'decimal_field': '9001.234'}).is_valid())
+
+ def test_raise_max_value(self):
+ """
+ Make sure max_value violations raises ValidationError
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(max_value=100)
+
+ s = DecimalSerializer(data={'decimal_field': '123'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is less than or equal to 100.']})
+
+ def test_raise_min_value(self):
+ """
+ Make sure min_value violations raises ValidationError
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(min_value=100)
+
+ s = DecimalSerializer(data={'decimal_field': '99'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is greater than or equal to 100.']})
+
+ def test_raise_max_digits(self):
+ """
+ Make sure max_digits violations raises ValidationError
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(max_digits=5)
+
+ s = DecimalSerializer(data={'decimal_field': '123.456'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 5 digits in total.']})
+
+ def test_raise_max_decimal_places(self):
+ """
+ Make sure max_decimal_places violations raises ValidationError
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(decimal_places=3)
+
+ s = DecimalSerializer(data={'decimal_field': '123.4567'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 3 decimal places.']})
+
+ def test_raise_max_whole_digits(self):
+ """
+ Make sure max_whole_digits violations raises ValidationError
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(max_digits=4, decimal_places=3)
+
+ s = DecimalSerializer(data={'decimal_field': '12345.6'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']})
+
+
+class ChoiceFieldTests(TestCase):
+ """
+ Tests for the ChoiceField options generator
+ """
+
+ SAMPLE_CHOICES = [
+ ('red', 'Red'),
+ ('green', 'Green'),
+ ('blue', 'Blue'),
+ ]
+
+ def test_choices_required(self):
+ """
+ Make sure proper choices are rendered if field is required
+ """
+ f = serializers.ChoiceField(required=True, choices=self.SAMPLE_CHOICES)
+ self.assertEqual(f.choices, self.SAMPLE_CHOICES)
+
+ def test_choices_not_required(self):
+ """
+ Make sure proper choices (plus blank) are rendered if the field isn't required
+ """
+ f = serializers.ChoiceField(required=False, choices=self.SAMPLE_CHOICES)
+ self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + self.SAMPLE_CHOICES)
diff --git a/rest_framework/tests/filters.py b/rest_framework/tests/filters.py
new file mode 100644
index 00000000..8ae6d530
--- /dev/null
+++ b/rest_framework/tests/filters.py
@@ -0,0 +1,474 @@
+from __future__ import unicode_literals
+import datetime
+from decimal import Decimal
+from django.db import models
+from django.core.urlresolvers import reverse
+from django.test import TestCase
+from django.test.client import RequestFactory
+from django.utils import unittest
+from rest_framework import generics, serializers, status, filters
+from rest_framework.compat import django_filters, patterns, url
+from rest_framework.tests.models import BasicModel
+
+factory = RequestFactory()
+
+
+class FilterableItem(models.Model):
+ text = models.CharField(max_length=100)
+ decimal = models.DecimalField(max_digits=4, decimal_places=2)
+ date = models.DateField()
+
+
+if django_filters:
+ # Basic filter on a list view.
+ class FilterFieldsRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_fields = ['decimal', 'date']
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ # These class are used to test a filter class.
+ class SeveralFieldsFilter(django_filters.FilterSet):
+ text = django_filters.CharFilter(lookup_type='icontains')
+ decimal = django_filters.NumberFilter(lookup_type='lt')
+ date = django_filters.DateFilter(lookup_type='gt')
+
+ class Meta:
+ model = FilterableItem
+ fields = ['text', 'decimal', 'date']
+
+ class FilterClassRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_class = SeveralFieldsFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ # These classes are used to test a misconfigured filter class.
+ class MisconfiguredFilter(django_filters.FilterSet):
+ text = django_filters.CharFilter(lookup_type='icontains')
+
+ class Meta:
+ model = BasicModel
+ fields = ['text']
+
+ class IncorrectlyConfiguredRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_class = MisconfiguredFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ class FilterClassDetailView(generics.RetrieveAPIView):
+ model = FilterableItem
+ filter_class = SeveralFieldsFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ # Regression test for #814
+ class FilterableItemSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = FilterableItem
+
+ class FilterFieldsQuerysetView(generics.ListCreateAPIView):
+ queryset = FilterableItem.objects.all()
+ serializer_class = FilterableItemSerializer
+ filter_fields = ['decimal', 'date']
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ class GetQuerysetView(generics.ListCreateAPIView):
+ serializer_class = FilterableItemSerializer
+ filter_class = SeveralFieldsFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ def get_queryset(self):
+ return FilterableItem.objects.all()
+
+ urlpatterns = patterns('',
+ url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'),
+ url(r'^$', FilterClassRootView.as_view(), name='root-view'),
+ url(r'^get-queryset/$', GetQuerysetView.as_view(),
+ name='get-queryset-view'),
+ )
+
+
+class CommonFilteringTestCase(TestCase):
+ def _serialize_object(self, obj):
+ return {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
+
+ def setUp(self):
+ """
+ Create 10 FilterableItem instances.
+ """
+ base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
+ for i in range(10):
+ text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
+ decimal = base_data[1] + i
+ date = base_data[2] - datetime.timedelta(days=i * 2)
+ FilterableItem(text=text, decimal=decimal, date=date).save()
+
+ self.objects = FilterableItem.objects
+ self.data = [
+ self._serialize_object(obj)
+ for obj in self.objects.all()
+ ]
+
+
+class IntegrationTestFiltering(CommonFilteringTestCase):
+ """
+ Integration tests for filtered list views.
+ """
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_get_filtered_fields_root_view(self):
+ """
+ GET requests to paginated ListCreateAPIView should return paginated results.
+ """
+ view = FilterFieldsRootView.as_view()
+
+ # Basic test with no filter.
+ request = factory.get('/')
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
+
+ # Tests that the decimal filter works.
+ search_decimal = Decimal('2.25')
+ request = factory.get('/?decimal=%s' % search_decimal)
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['decimal'] == search_decimal]
+ self.assertEqual(response.data, expected_data)
+
+ # Tests that the date filter works.
+ search_date = datetime.date(2012, 9, 22)
+ request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22'
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['date'] == search_date]
+ self.assertEqual(response.data, expected_data)
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_filter_with_queryset(self):
+ """
+ Regression test for #814.
+ """
+ view = FilterFieldsQuerysetView.as_view()
+
+ # Tests that the decimal filter works.
+ search_decimal = Decimal('2.25')
+ request = factory.get('/?decimal=%s' % search_decimal)
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['decimal'] == search_decimal]
+ self.assertEqual(response.data, expected_data)
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_filter_with_get_queryset_only(self):
+ """
+ Regression test for #834.
+ """
+ view = GetQuerysetView.as_view()
+ request = factory.get('/get-queryset/')
+ view(request).render()
+ # Used to raise "issubclass() arg 2 must be a class or tuple of classes"
+ # here when neither `model' nor `queryset' was specified.
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_get_filtered_class_root_view(self):
+ """
+ GET requests to filtered ListCreateAPIView that have a filter_class set
+ should return filtered results.
+ """
+ view = FilterClassRootView.as_view()
+
+ # Basic test with no filter.
+ request = factory.get('/')
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
+
+ # Tests that the decimal filter set with 'lt' in the filter class works.
+ search_decimal = Decimal('4.25')
+ request = factory.get('/?decimal=%s' % search_decimal)
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['decimal'] < search_decimal]
+ self.assertEqual(response.data, expected_data)
+
+ # Tests that the date filter set with 'gt' in the filter class works.
+ search_date = datetime.date(2012, 10, 2)
+ request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02'
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['date'] > search_date]
+ self.assertEqual(response.data, expected_data)
+
+ # Tests that the text filter set with 'icontains' in the filter class works.
+ search_text = 'ff'
+ request = factory.get('/?text=%s' % search_text)
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if search_text in f['text'].lower()]
+ self.assertEqual(response.data, expected_data)
+
+ # Tests that multiple filters works.
+ search_decimal = Decimal('5.25')
+ search_date = datetime.date(2012, 10, 2)
+ request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date))
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['date'] > search_date and
+ f['decimal'] < search_decimal]
+ self.assertEqual(response.data, expected_data)
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_incorrectly_configured_filter(self):
+ """
+ An error should be displayed when the filter class is misconfigured.
+ """
+ view = IncorrectlyConfiguredRootView.as_view()
+
+ request = factory.get('/')
+ self.assertRaises(AssertionError, view, request)
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_unknown_filter(self):
+ """
+ GET requests with filters that aren't configured should return 200.
+ """
+ view = FilterFieldsRootView.as_view()
+
+ search_integer = 10
+ request = factory.get('/?integer=%s' % search_integer)
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+
+class IntegrationTestDetailFiltering(CommonFilteringTestCase):
+ """
+ Integration tests for filtered detail views.
+ """
+ urls = 'rest_framework.tests.filters'
+
+ def _get_url(self, item):
+ return reverse('detail-view', kwargs=dict(pk=item.pk))
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_get_filtered_detail_view(self):
+ """
+ GET requests to filtered RetrieveAPIView that have a filter_class set
+ should return filtered results.
+ """
+ item = self.objects.all()[0]
+ data = self._serialize_object(item)
+
+ # Basic test with no filter.
+ response = self.client.get(self._get_url(item))
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, data)
+
+ # Tests that the decimal filter set that should fail.
+ search_decimal = Decimal('4.25')
+ high_item = self.objects.filter(decimal__gt=search_decimal)[0]
+ response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(high_item), param=search_decimal))
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
+ # Tests that the decimal filter set that should succeed.
+ search_decimal = Decimal('4.25')
+ low_item = self.objects.filter(decimal__lt=search_decimal)[0]
+ low_item_data = self._serialize_object(low_item)
+ response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal))
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, low_item_data)
+
+ # Tests that multiple filters works.
+ search_decimal = Decimal('5.25')
+ search_date = datetime.date(2012, 10, 2)
+ valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0]
+ valid_item_data = self._serialize_object(valid_item)
+ response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date))
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, valid_item_data)
+
+
+class SearchFilterModel(models.Model):
+ title = models.CharField(max_length=20)
+ text = models.CharField(max_length=100)
+
+
+class SearchFilterTests(TestCase):
+ def setUp(self):
+ # Sequence of title/text is:
+ #
+ # z abc
+ # zz bcd
+ # zzz cde
+ # ...
+ for idx in range(10):
+ title = 'z' * (idx + 1)
+ text = (
+ chr(idx + ord('a')) +
+ chr(idx + ord('b')) +
+ chr(idx + ord('c'))
+ )
+ SearchFilterModel(title=title, text=text).save()
+
+ def test_search(self):
+ class SearchListView(generics.ListAPIView):
+ model = SearchFilterModel
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('title', 'text')
+
+ view = SearchListView.as_view()
+ request = factory.get('?search=b')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 1, 'title': 'z', 'text': 'abc'},
+ {'id': 2, 'title': 'zz', 'text': 'bcd'}
+ ]
+ )
+
+ def test_exact_search(self):
+ class SearchListView(generics.ListAPIView):
+ model = SearchFilterModel
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('=title', 'text')
+
+ view = SearchListView.as_view()
+ request = factory.get('?search=zzz')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'zzz', 'text': 'cde'}
+ ]
+ )
+
+ def test_startswith_search(self):
+ class SearchListView(generics.ListAPIView):
+ model = SearchFilterModel
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('title', '^text')
+
+ view = SearchListView.as_view()
+ request = factory.get('?search=b')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 2, 'title': 'zz', 'text': 'bcd'}
+ ]
+ )
+
+
+class OrdringFilterModel(models.Model):
+ title = models.CharField(max_length=20)
+ text = models.CharField(max_length=100)
+
+
+class OrderingFilterTests(TestCase):
+ def setUp(self):
+ # Sequence of title/text is:
+ #
+ # zyx abc
+ # yxw bcd
+ # xwv cde
+ for idx in range(3):
+ title = (
+ chr(ord('z') - idx) +
+ chr(ord('y') - idx) +
+ chr(ord('x') - idx)
+ )
+ text = (
+ chr(idx + ord('a')) +
+ chr(idx + ord('b')) +
+ chr(idx + ord('c'))
+ )
+ OrdringFilterModel(title=title, text=text).save()
+
+ def test_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=text')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ ]
+ )
+
+ def test_reverse_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=-text')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
+
+ def test_incorrectfield_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=foobar')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
+
+ def test_default_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
+
+ def test_default_ordering_using_string(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = 'title'
+
+ view = OrderingListView.as_view()
+ request = factory.get('')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py
deleted file mode 100644
index 238da56e..00000000
--- a/rest_framework/tests/filterset.py
+++ /dev/null
@@ -1,169 +0,0 @@
-from __future__ import unicode_literals
-import datetime
-from decimal import Decimal
-from django.test import TestCase
-from django.test.client import RequestFactory
-from django.utils import unittest
-from rest_framework import generics, status, filters
-from rest_framework.compat import django_filters
-from rest_framework.tests.models import FilterableItem, BasicModel
-
-factory = RequestFactory()
-
-
-if django_filters:
- # Basic filter on a list view.
- class FilterFieldsRootView(generics.ListCreateAPIView):
- model = FilterableItem
- filter_fields = ['decimal', 'date']
- filter_backend = filters.DjangoFilterBackend
-
- # These class are used to test a filter class.
- class SeveralFieldsFilter(django_filters.FilterSet):
- text = django_filters.CharFilter(lookup_type='icontains')
- decimal = django_filters.NumberFilter(lookup_type='lt')
- date = django_filters.DateFilter(lookup_type='gt')
-
- class Meta:
- model = FilterableItem
- fields = ['text', 'decimal', 'date']
-
- class FilterClassRootView(generics.ListCreateAPIView):
- model = FilterableItem
- filter_class = SeveralFieldsFilter
- filter_backend = filters.DjangoFilterBackend
-
- # These classes are used to test a misconfigured filter class.
- class MisconfiguredFilter(django_filters.FilterSet):
- text = django_filters.CharFilter(lookup_type='icontains')
-
- class Meta:
- model = BasicModel
- fields = ['text']
-
- class IncorrectlyConfiguredRootView(generics.ListCreateAPIView):
- model = FilterableItem
- filter_class = MisconfiguredFilter
- filter_backend = filters.DjangoFilterBackend
-
-
-class IntegrationTestFiltering(TestCase):
- """
- Integration tests for filtered list views.
- """
-
- def setUp(self):
- """
- Create 10 FilterableItem instances.
- """
- base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
- for i in range(10):
- text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
- decimal = base_data[1] + i
- date = base_data[2] - datetime.timedelta(days=i * 2)
- FilterableItem(text=text, decimal=decimal, date=date).save()
-
- self.objects = FilterableItem.objects
- self.data = [
- {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
- for obj in self.objects.all()
- ]
-
- @unittest.skipUnless(django_filters, 'django-filters not installed')
- def test_get_filtered_fields_root_view(self):
- """
- GET requests to paginated ListCreateAPIView should return paginated results.
- """
- view = FilterFieldsRootView.as_view()
-
- # Basic test with no filter.
- request = factory.get('/')
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, self.data)
-
- # Tests that the decimal filter works.
- search_decimal = Decimal('2.25')
- request = factory.get('/?decimal=%s' % search_decimal)
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['decimal'] == search_decimal]
- self.assertEqual(response.data, expected_data)
-
- # Tests that the date filter works.
- search_date = datetime.date(2012, 9, 22)
- request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22'
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['date'] == search_date]
- self.assertEqual(response.data, expected_data)
-
- @unittest.skipUnless(django_filters, 'django-filters not installed')
- def test_get_filtered_class_root_view(self):
- """
- GET requests to filtered ListCreateAPIView that have a filter_class set
- should return filtered results.
- """
- view = FilterClassRootView.as_view()
-
- # Basic test with no filter.
- request = factory.get('/')
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, self.data)
-
- # Tests that the decimal filter set with 'lt' in the filter class works.
- search_decimal = Decimal('4.25')
- request = factory.get('/?decimal=%s' % search_decimal)
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['decimal'] < search_decimal]
- self.assertEqual(response.data, expected_data)
-
- # Tests that the date filter set with 'gt' in the filter class works.
- search_date = datetime.date(2012, 10, 2)
- request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02'
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['date'] > search_date]
- self.assertEqual(response.data, expected_data)
-
- # Tests that the text filter set with 'icontains' in the filter class works.
- search_text = 'ff'
- request = factory.get('/?text=%s' % search_text)
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if search_text in f['text'].lower()]
- self.assertEqual(response.data, expected_data)
-
- # Tests that multiple filters works.
- search_decimal = Decimal('5.25')
- search_date = datetime.date(2012, 10, 2)
- request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date))
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['date'] > search_date and
- f['decimal'] < search_decimal]
- self.assertEqual(response.data, expected_data)
-
- @unittest.skipUnless(django_filters, 'django-filters not installed')
- def test_incorrectly_configured_filter(self):
- """
- An error should be displayed when the filter class is misconfigured.
- """
- view = IncorrectlyConfiguredRootView.as_view()
-
- request = factory.get('/')
- self.assertRaises(AssertionError, view, request)
-
- @unittest.skipUnless(django_filters, 'django-filters not installed')
- def test_unknown_filter(self):
- """
- GET requests with filters that aren't configured should return 200.
- """
- view = FilterFieldsRootView.as_view()
-
- search_integer = 10
- request = factory.get('/?integer=%s' % search_integer)
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py
index f564890c..15d87e86 100644
--- a/rest_framework/tests/generics.py
+++ b/rest_framework/tests/generics.py
@@ -1,7 +1,8 @@
from __future__ import unicode_literals
from django.db import models
+from django.shortcuts import get_object_or_404
from django.test import TestCase
-from rest_framework import generics, serializers, status
+from rest_framework import generics, renderers, serializers, status
from rest_framework.tests.utils import RequestFactory
from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel
from rest_framework.compat import six
@@ -38,6 +39,7 @@ class SlugBasedInstanceView(InstanceView):
"""
model = SlugBasedModel
serializer_class = SlugSerializer
+ lookup_field = 'slug'
class TestRootView(TestCase):
@@ -302,6 +304,47 @@ class TestInstanceView(TestCase):
self.assertEqual(new_obj.text, 'foobar')
+class TestOverriddenGetObject(TestCase):
+ """
+ Test cases for a RetrieveUpdateDestroyAPIView that does NOT use the
+ queryset/model mechanism but instead overrides get_object()
+ """
+ def setUp(self):
+ """
+ Create 3 BasicModel intances.
+ """
+ items = ['foo', 'bar', 'baz']
+ for item in items:
+ BasicModel(text=item).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+
+ class OverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView):
+ """
+ Example detail view for override of get_object().
+ """
+ model = BasicModel
+
+ def get_object(self):
+ pk = int(self.kwargs['pk'])
+ return get_object_or_404(BasicModel.objects.all(), id=pk)
+
+ self.view = OverriddenGetObjectView.as_view()
+
+ def test_overridden_get_object_view(self):
+ """
+ GET requests to RetrieveUpdateDestroyAPIView should return a single object.
+ """
+ request = factory.get('/1')
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data[0])
+
+
# Regression test for #285
class CommentSerializer(serializers.ModelSerializer):
@@ -335,7 +378,7 @@ class TestCreateModelWithAutoNowAddField(TestCase):
self.assertEqual(created.content, 'foobar')
-# Test for particularly ugly regression with m2m in browseable API
+# Test for particularly ugly regression with m2m in browsable API
class ClassB(models.Model):
name = models.CharField(max_length=255)
@@ -360,7 +403,7 @@ class ExampleView(generics.ListCreateAPIView):
class TestM2MBrowseableAPI(TestCase):
def test_m2m_in_browseable_api(self):
"""
- Test for particularly ugly regression with m2m in browseable API
+ Test for particularly ugly regression with m2m in browsable API
"""
request = factory.get('/', HTTP_ACCEPT='text/html')
view = ExampleView().as_view()
@@ -392,22 +435,14 @@ class TestFilterBackendAppliedToViews(TestCase):
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
- self.root_view = RootView.as_view()
- self.instance_view = InstanceView.as_view()
- self.original_root_backend = getattr(RootView, 'filter_backend')
- self.original_instance_backend = getattr(InstanceView, 'filter_backend')
-
- def tearDown(self):
- setattr(RootView, 'filter_backend', self.original_root_backend)
- setattr(InstanceView, 'filter_backend', self.original_instance_backend)
def test_get_root_view_filters_by_name_with_filter_backend(self):
"""
GET requests to ListCreateAPIView should return filtered list.
"""
- setattr(RootView, 'filter_backend', InclusiveFilterBackend)
+ root_view = RootView.as_view(filter_backends=(InclusiveFilterBackend,))
request = factory.get('/')
- response = self.root_view(request).render()
+ response = root_view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data), 1)
self.assertEqual(response.data, [{'id': 1, 'text': 'foo'}])
@@ -416,9 +451,9 @@ class TestFilterBackendAppliedToViews(TestCase):
"""
GET requests to ListCreateAPIView should return empty list when all models are filtered out.
"""
- setattr(RootView, 'filter_backend', ExclusiveFilterBackend)
+ root_view = RootView.as_view(filter_backends=(ExclusiveFilterBackend,))
request = factory.get('/')
- response = self.root_view(request).render()
+ response = root_view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, [])
@@ -426,9 +461,9 @@ class TestFilterBackendAppliedToViews(TestCase):
"""
GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out.
"""
- setattr(InstanceView, 'filter_backend', ExclusiveFilterBackend)
+ instance_view = InstanceView.as_view(filter_backends=(ExclusiveFilterBackend,))
request = factory.get('/1')
- response = self.instance_view(request, pk=1).render()
+ response = instance_view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(response.data, {'detail': 'Not found'})
@@ -436,8 +471,40 @@ class TestFilterBackendAppliedToViews(TestCase):
"""
GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded
"""
- setattr(InstanceView, 'filter_backend', InclusiveFilterBackend)
+ instance_view = InstanceView.as_view(filter_backends=(InclusiveFilterBackend,))
request = factory.get('/1')
- response = self.instance_view(request, pk=1).render()
+ response = instance_view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, {'id': 1, 'text': 'foo'})
+
+
+class TwoFieldModel(models.Model):
+ field_a = models.CharField(max_length=100)
+ field_b = models.CharField(max_length=100)
+
+
+class DynamicSerializerView(generics.ListCreateAPIView):
+ model = TwoFieldModel
+ renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer)
+
+ def get_serializer_class(self):
+ if self.request.method == 'POST':
+ class DynamicSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = TwoFieldModel
+ fields = ('field_b',)
+ return DynamicSerializer
+ return super(DynamicSerializerView, self).get_serializer_class()
+
+
+class TestFilterBackendAppliedToViews(TestCase):
+
+ def test_dynamic_serializer_form_in_browsable_api(self):
+ """
+ GET requests to ListCreateAPIView should return filtered list.
+ """
+ view = DynamicSerializerView.as_view()
+ request = factory.get('/')
+ response = view(request).render()
+ self.assertContains(response, 'field_b')
+ self.assertNotContains(response, 'field_a')
diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py
index 9a61f299..8fc6ba77 100644
--- a/rest_framework/tests/hyperlinkedserializers.py
+++ b/rest_framework/tests/hyperlinkedserializers.py
@@ -27,6 +27,14 @@ class PhotoSerializer(serializers.Serializer):
return Photo(**attrs)
+class AlbumSerializer(serializers.ModelSerializer):
+ url = serializers.HyperlinkedIdentityField(view_name='album-detail', lookup_field='title')
+
+ class Meta:
+ model = Album
+ fields = ('title', 'url')
+
+
class BasicList(generics.ListCreateAPIView):
model = BasicModel
model_serializer_class = serializers.HyperlinkedModelSerializer
@@ -73,6 +81,8 @@ class PhotoListCreate(generics.ListCreateAPIView):
class AlbumDetail(generics.RetrieveAPIView):
model = Album
+ serializer_class = AlbumSerializer
+ lookup_field = 'title'
class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView):
@@ -180,6 +190,36 @@ class TestManyToManyHyperlinkedView(TestCase):
self.assertEqual(response.data, self.data[0])
+class TestHyperlinkedIdentityFieldLookup(TestCase):
+ urls = 'rest_framework.tests.hyperlinkedserializers'
+
+ def setUp(self):
+ """
+ Create 3 Album instances.
+ """
+ titles = ['foo', 'bar', 'baz']
+ for title in titles:
+ album = Album(title=title)
+ album.save()
+ self.detail_view = AlbumDetail.as_view()
+ self.data = {
+ 'foo': {'title': 'foo', 'url': 'http://testserver/albums/foo/'},
+ 'bar': {'title': 'bar', 'url': 'http://testserver/albums/bar/'},
+ 'baz': {'title': 'baz', 'url': 'http://testserver/albums/baz/'}
+ }
+
+ def test_lookup_field(self):
+ """
+ GET requests to AlbumDetail view should return serialized Albums
+ with a url field keyed by `title`.
+ """
+ for album in Album.objects.all():
+ request = factory.get('/albums/{0}/'.format(album.title))
+ response = self.detail_view(request, title=album.title)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data[album.title])
+
+
class TestCreateWithForeignKeys(TestCase):
urls = 'rest_framework.tests.hyperlinkedserializers'
diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py
index f2117538..40e41a64 100644
--- a/rest_framework/tests/models.py
+++ b/rest_framework/tests/models.py
@@ -58,13 +58,6 @@ class ReadOnlyManyToManyModel(RESTFrameworkModel):
rel = models.ManyToManyField(Anchor)
-# Model to test filtering.
-class FilterableItem(RESTFrameworkModel):
- text = models.CharField(max_length=100)
- decimal = models.DecimalField(max_digits=4, decimal_places=2)
- date = models.DateField()
-
-
# Model for regression test for #285
class Comment(RESTFrameworkModel):
diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py
index d2c9b051..e538a78e 100644
--- a/rest_framework/tests/pagination.py
+++ b/rest_framework/tests/pagination.py
@@ -1,18 +1,24 @@
from __future__ import unicode_literals
import datetime
from decimal import Decimal
-import django
+from django.db import models
from django.core.paginator import Paginator
from django.test import TestCase
from django.test.client import RequestFactory
from django.utils import unittest
from rest_framework import generics, status, pagination, filters, serializers
from rest_framework.compat import django_filters
-from rest_framework.tests.models import BasicModel, FilterableItem
+from rest_framework.tests.models import BasicModel
factory = RequestFactory()
+class FilterableItem(models.Model):
+ text = models.CharField(max_length=100)
+ decimal = models.DecimalField(max_digits=4, decimal_places=2)
+ date = models.DateField()
+
+
class RootView(generics.ListCreateAPIView):
"""
Example description for OPTIONS.
@@ -124,21 +130,11 @@ class IntegrationTestPaginationAndFiltering(TestCase):
model = FilterableItem
paginate_by = 10
filter_class = DecimalFilter
- filter_backend = filters.DjangoFilterBackend
+ filter_backends = (filters.DjangoFilterBackend,)
view = FilterFieldsRootView.as_view()
EXPECTED_NUM_QUERIES = 2
- if django.VERSION < (1, 4):
- # On Django 1.3 we need to use django-filter 0.5.4
- #
- # The filter objects there don't expose a `.count()` method,
- # which means we only make a single query *but* it's a single
- # query across *all* of the queryset, instead of a COUNT and then
- # a SELECT with a LIMIT.
- #
- # Although this is fewer queries, it's actually a regression.
- EXPECTED_NUM_QUERIES = 1
request = factory.get('/?decimal=15.20')
with self.assertNumQueries(EXPECTED_NUM_QUERIES):
@@ -181,7 +177,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
class BasicFilterFieldsRootView(generics.ListCreateAPIView):
model = FilterableItem
paginate_by = 10
- filter_backend = DecimalFilterBackend
+ filter_backends = (DecimalFilterBackend,)
view = BasicFilterFieldsRootView.as_view()
diff --git a/rest_framework/tests/parsers.py b/rest_framework/tests/parsers.py
index 539c5b44..7699e10c 100644
--- a/rest_framework/tests/parsers.py
+++ b/rest_framework/tests/parsers.py
@@ -1,10 +1,11 @@
from __future__ import unicode_literals
from rest_framework.compat import StringIO
from django import forms
+from django.core.files.uploadhandler import MemoryFileUploadHandler
from django.test import TestCase
from django.utils import unittest
from rest_framework.compat import etree
-from rest_framework.parsers import FormParser
+from rest_framework.parsers import FormParser, FileUploadParser
from rest_framework.parsers import XMLParser
import datetime
@@ -82,3 +83,33 @@ class TestXMLParser(TestCase):
parser = XMLParser()
data = parser.parse(self._complex_data_input)
self.assertEqual(data, self._complex_data)
+
+
+class TestFileUploadParser(TestCase):
+ def setUp(self):
+ class MockRequest(object):
+ pass
+ from io import BytesIO
+ self.stream = BytesIO(
+ "Test text file".encode('utf-8')
+ )
+ request = MockRequest()
+ request.upload_handlers = (MemoryFileUploadHandler(),)
+ request.META = {
+ 'HTTP_CONTENT_DISPOSITION': 'Content-Disposition: inline; filename=file.txt'.encode('utf-8'),
+ 'HTTP_CONTENT_LENGTH': 14,
+ }
+ self.parser_context = {'request': request, 'kwargs': {}}
+
+ def test_parse(self):
+ """ Make sure the `QueryDict` works OK """
+ parser = FileUploadParser()
+ self.stream.seek(0)
+ data_and_files = parser.parse(self.stream, None, self.parser_context)
+ file_obj = data_and_files.files['file']
+ self.assertEqual(file_obj._size, 14)
+
+ def test_get_filename(self):
+ parser = FileUploadParser()
+ filename = parser.get_filename(self.stream, None, self.parser_context)
+ self.assertEqual(filename, 'file.txt'.encode('utf-8'))
diff --git a/rest_framework/tests/relations.py b/rest_framework/tests/relations.py
index cbf93c65..d19219c9 100644
--- a/rest_framework/tests/relations.py
+++ b/rest_framework/tests/relations.py
@@ -5,6 +5,7 @@ from __future__ import unicode_literals
from django.db import models
from django.test import TestCase
from rest_framework import serializers
+from rest_framework.tests.models import BlogPost
class NullModel(models.Model):
@@ -33,7 +34,7 @@ class FieldTests(TestCase):
self.assertRaises(serializers.ValidationError, field.from_native, [])
-class TestManyRelateMixin(TestCase):
+class TestManyRelatedMixin(TestCase):
def test_missing_many_to_many_related_field(self):
'''
Regression test for #632
@@ -45,3 +46,55 @@ class TestManyRelateMixin(TestCase):
into = {}
field.field_from_native({}, None, 'field_name', into)
self.assertEqual(into['field_name'], [])
+
+
+# Regression tests for #694 (`source` attribute on related fields)
+
+class RelatedFieldSourceTests(TestCase):
+ def test_related_manager_source(self):
+ """
+ Relational fields should be able to use manager-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.RelatedField(many=True, source='get_blogposts_manager')
+
+ class ClassWithManagerMethod(object):
+ def get_blogposts_manager(self):
+ return BlogPost.objects
+
+ obj = ClassWithManagerMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['BlogPost object'])
+
+ def test_related_queryset_source(self):
+ """
+ Relational fields should be able to use queryset-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.RelatedField(many=True, source='get_blogposts_queryset')
+
+ class ClassWithQuerysetMethod(object):
+ def get_blogposts_queryset(self):
+ return BlogPost.objects.all()
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['BlogPost object'])
+
+ def test_dotted_source(self):
+ """
+ Source argument should support dotted.source notation.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.RelatedField(many=True, source='a.b.c')
+
+ class ClassWithQuerysetMethod(object):
+ a = {
+ 'b': {
+ 'c': BlogPost.objects.all()
+ }
+ }
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['BlogPost object'])
diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/relations_hyperlink.py
index b5702a48..b3efbf52 100644
--- a/rest_framework/tests/relations_hyperlink.py
+++ b/rest_framework/tests/relations_hyperlink.py
@@ -4,6 +4,7 @@ from django.test.client import RequestFactory
from rest_framework import serializers
from rest_framework.compat import patterns, url
from rest_framework.tests.models import (
+ BlogPost,
ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource,
NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
)
@@ -16,6 +17,7 @@ def dummy_view(request, pk):
pass
urlpatterns = patterns('',
+ url(r'^dummyurl/(?P<pk>[0-9]+)/$', dummy_view, name='dummy-url'),
url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'),
url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'),
url(r'^foreignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeysource-detail'),
@@ -26,42 +28,44 @@ urlpatterns = patterns('',
)
+# ManyToMany
class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer):
- sources = serializers.HyperlinkedRelatedField(many=True, view_name='manytomanysource-detail')
-
class Meta:
model = ManyToManyTarget
+ fields = ('url', 'name', 'sources')
class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = ManyToManySource
+ fields = ('url', 'name', 'targets')
+# ForeignKey
class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer):
- sources = serializers.HyperlinkedRelatedField(many=True, view_name='foreignkeysource-detail')
-
class Meta:
model = ForeignKeyTarget
+ fields = ('url', 'name', 'sources')
class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = ForeignKeySource
+ fields = ('url', 'name', 'target')
# Nullable ForeignKey
class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = NullableForeignKeySource
+ fields = ('url', 'name', 'target')
-# OneToOne
+# Nullable OneToOne
class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer):
- nullable_source = serializers.HyperlinkedRelatedField(view_name='nullableonetoonesource-detail')
-
class Meta:
model = OneToOneTarget
+ fields = ('url', 'name', 'nullable_source')
# TODO: Add test that .data cannot be accessed prior to .is_valid
@@ -449,3 +453,72 @@ class HyperlinkedNullableOneToOneTests(TestCase):
{'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None},
]
self.assertEqual(serializer.data, expected)
+
+
+# Regression tests for #694 (`source` attribute on related fields)
+
+class HyperlinkedRelatedFieldSourceTests(TestCase):
+ urls = 'rest_framework.tests.relations_hyperlink'
+
+ def test_related_manager_source(self):
+ """
+ Relational fields should be able to use manager-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.HyperlinkedRelatedField(
+ many=True,
+ source='get_blogposts_manager',
+ view_name='dummy-url',
+ )
+ field.context = {'request': request}
+
+ class ClassWithManagerMethod(object):
+ def get_blogposts_manager(self):
+ return BlogPost.objects
+
+ obj = ClassWithManagerMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['http://testserver/dummyurl/1/'])
+
+ def test_related_queryset_source(self):
+ """
+ Relational fields should be able to use queryset-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.HyperlinkedRelatedField(
+ many=True,
+ source='get_blogposts_queryset',
+ view_name='dummy-url',
+ )
+ field.context = {'request': request}
+
+ class ClassWithQuerysetMethod(object):
+ def get_blogposts_queryset(self):
+ return BlogPost.objects.all()
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['http://testserver/dummyurl/1/'])
+
+ def test_dotted_source(self):
+ """
+ Source argument should support dotted.source notation.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.HyperlinkedRelatedField(
+ many=True,
+ source='a.b.c',
+ view_name='dummy-url',
+ )
+ field.context = {'request': request}
+
+ class ClassWithQuerysetMethod(object):
+ a = {
+ 'b': {
+ 'c': BlogPost.objects.all()
+ }
+ }
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['http://testserver/dummyurl/1/'])
diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py
index a125ba65..f6d006b3 100644
--- a/rest_framework/tests/relations_nested.py
+++ b/rest_framework/tests/relations_nested.py
@@ -6,38 +6,30 @@ from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, Null
class ForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
- depth = 1
- model = ForeignKeySource
-
-
-class FlatForeignKeySourceSerializer(serializers.ModelSerializer):
- class Meta:
model = ForeignKeySource
+ fields = ('id', 'name', 'target')
+ depth = 1
class ForeignKeyTargetSerializer(serializers.ModelSerializer):
- sources = FlatForeignKeySourceSerializer(many=True)
-
class Meta:
model = ForeignKeyTarget
+ fields = ('id', 'name', 'sources')
+ depth = 1
class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
- depth = 1
model = NullableForeignKeySource
-
-
-class NullableOneToOneSourceSerializer(serializers.ModelSerializer):
- class Meta:
- model = NullableOneToOneSource
+ fields = ('id', 'name', 'target')
+ depth = 1
class NullableOneToOneTargetSerializer(serializers.ModelSerializer):
- nullable_source = NullableOneToOneSourceSerializer()
-
class Meta:
model = OneToOneTarget
+ fields = ('id', 'name', 'nullable_source')
+ depth = 1
class ReverseForeignKeyTests(TestCase):
diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/relations_pk.py
index f08e1808..0f8c5247 100644
--- a/rest_framework/tests/relations_pk.py
+++ b/rest_framework/tests/relations_pk.py
@@ -1,45 +1,51 @@
from __future__ import unicode_literals
from django.test import TestCase
from rest_framework import serializers
-from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
+from rest_framework.tests.models import (
+ BlogPost, ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource,
+ NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource,
+)
from rest_framework.compat import six
+# ManyToMany
class ManyToManyTargetSerializer(serializers.ModelSerializer):
- sources = serializers.PrimaryKeyRelatedField(many=True)
-
class Meta:
model = ManyToManyTarget
+ fields = ('id', 'name', 'sources')
class ManyToManySourceSerializer(serializers.ModelSerializer):
class Meta:
model = ManyToManySource
+ fields = ('id', 'name', 'targets')
+# ForeignKey
class ForeignKeyTargetSerializer(serializers.ModelSerializer):
- sources = serializers.PrimaryKeyRelatedField(many=True)
-
class Meta:
model = ForeignKeyTarget
+ fields = ('id', 'name', 'sources')
class ForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
model = ForeignKeySource
+ fields = ('id', 'name', 'target')
+# Nullable ForeignKey
class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
model = NullableForeignKeySource
+ fields = ('id', 'name', 'target')
-# OneToOne
+# Nullable OneToOne
class NullableOneToOneTargetSerializer(serializers.ModelSerializer):
- nullable_source = serializers.PrimaryKeyRelatedField()
-
class Meta:
model = OneToOneTarget
+ fields = ('id', 'name', 'nullable_source')
# TODO: Add test that .data cannot be accessed prior to .is_valid
@@ -418,3 +424,55 @@ class PKNullableOneToOneTests(TestCase):
{'id': 2, 'name': 'target-2', 'nullable_source': 1},
]
self.assertEqual(serializer.data, expected)
+
+
+# Regression tests for #694 (`source` attribute on related fields)
+
+class PrimaryKeyRelatedFieldSourceTests(TestCase):
+ def test_related_manager_source(self):
+ """
+ Relational fields should be able to use manager-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.PrimaryKeyRelatedField(many=True, source='get_blogposts_manager')
+
+ class ClassWithManagerMethod(object):
+ def get_blogposts_manager(self):
+ return BlogPost.objects
+
+ obj = ClassWithManagerMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, [1])
+
+ def test_related_queryset_source(self):
+ """
+ Relational fields should be able to use queryset-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.PrimaryKeyRelatedField(many=True, source='get_blogposts_queryset')
+
+ class ClassWithQuerysetMethod(object):
+ def get_blogposts_queryset(self):
+ return BlogPost.objects.all()
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, [1])
+
+ def test_dotted_source(self):
+ """
+ Source argument should support dotted.source notation.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.PrimaryKeyRelatedField(many=True, source='a.b.c')
+
+ class ClassWithQuerysetMethod(object):
+ a = {
+ 'b': {
+ 'c': BlogPost.objects.all()
+ }
+ }
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, [1])
diff --git a/rest_framework/tests/routers.py b/rest_framework/tests/routers.py
new file mode 100644
index 00000000..4e4765cb
--- /dev/null
+++ b/rest_framework/tests/routers.py
@@ -0,0 +1,55 @@
+from __future__ import unicode_literals
+from django.test import TestCase
+from django.test.client import RequestFactory
+from rest_framework import status
+from rest_framework.response import Response
+from rest_framework import viewsets
+from rest_framework.decorators import link, action
+from rest_framework.routers import SimpleRouter
+import copy
+
+factory = RequestFactory()
+
+
+class BasicViewSet(viewsets.ViewSet):
+ def list(self, request, *args, **kwargs):
+ return Response({'method': 'list'})
+
+ @action()
+ def action1(self, request, *args, **kwargs):
+ return Response({'method': 'action1'})
+
+ @action()
+ def action2(self, request, *args, **kwargs):
+ return Response({'method': 'action2'})
+
+ @link()
+ def link1(self, request, *args, **kwargs):
+ return Response({'method': 'link1'})
+
+ @link()
+ def link2(self, request, *args, **kwargs):
+ return Response({'method': 'link2'})
+
+
+class TestSimpleRouter(TestCase):
+ def setUp(self):
+ self.router = SimpleRouter()
+
+ def test_link_and_action_decorator(self):
+ routes = self.router.get_routes(BasicViewSet)
+ decorator_routes = routes[2:]
+ # Make sure all these endpoints exist and none have been clobbered
+ for i, endpoint in enumerate(['action1', 'action2', 'link1', 'link2']):
+ route = decorator_routes[i]
+ # check url listing
+ self.assertEqual(route.url,
+ '^{{prefix}}/{{lookup}}/{0}/$'.format(endpoint))
+ # check method to function mapping
+ if endpoint.startswith('action'):
+ method_map = 'post'
+ else:
+ method_map = 'get'
+ self.assertEqual(route.mapping[method_map], endpoint)
+
+
diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py
index 0386ca76..d978dc87 100644
--- a/rest_framework/tests/serializer.py
+++ b/rest_framework/tests/serializer.py
@@ -1,9 +1,11 @@
from __future__ import unicode_literals
+from django.db import models
+from django.db.models.fields import BLANK_CHOICE_DASH
from django.utils.datastructures import MultiValueDict
from django.test import TestCase
from rest_framework import serializers
from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel,
- BlankFieldModel, BlogPost, Book, CallableDefaultValueModel, DefaultValueModel,
+ BlankFieldModel, BlogPost, BlogPostComment, Book, CallableDefaultValueModel, DefaultValueModel,
ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo)
import datetime
import pickle
@@ -43,6 +45,17 @@ class CommentSerializer(serializers.Serializer):
return instance
+class NamesSerializer(serializers.Serializer):
+ first = serializers.CharField()
+ last = serializers.CharField(required=False, default='')
+ initials = serializers.CharField(required=False, default='')
+
+
+class PersonIdentifierSerializer(serializers.Serializer):
+ ssn = serializers.CharField()
+ names = NamesSerializer(source='names', required=False)
+
+
class BookSerializer(serializers.ModelSerializer):
isbn = serializers.RegexField(regex=r'^[0-9]{13}$', error_messages={'invalid': 'isbn has to be exact 13 numbers'})
@@ -78,6 +91,18 @@ class PersonSerializer(serializers.ModelSerializer):
read_only_fields = ('age',)
+class PersonSerializerInvalidReadOnly(serializers.ModelSerializer):
+ """
+ Testing for #652.
+ """
+ info = serializers.Field(source='info')
+
+ class Meta:
+ model = Person
+ fields = ('name', 'age', 'info')
+ read_only_fields = ('age', 'info')
+
+
class AlbumsSerializer(serializers.ModelSerializer):
class Meta:
@@ -141,6 +166,42 @@ class BasicTests(TestCase):
self.assertFalse(serializer.object is expected)
self.assertEqual(serializer.data['sub_comment'], 'And Merry Christmas!')
+ def test_create_nested(self):
+ """Test a serializer with nested data."""
+ names = {'first': 'John', 'last': 'Doe', 'initials': 'jd'}
+ data = {'ssn': '1234567890', 'names': names}
+ serializer = PersonIdentifierSerializer(data=data)
+
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, data)
+ self.assertFalse(serializer.object is data)
+ self.assertEqual(serializer.data['names'], names)
+
+ def test_create_partial_nested(self):
+ """Test a serializer with nested data which has missing fields."""
+ names = {'first': 'John'}
+ data = {'ssn': '1234567890', 'names': names}
+ serializer = PersonIdentifierSerializer(data=data)
+
+ expected_names = {'first': 'John', 'last': '', 'initials': ''}
+ data['names'] = expected_names
+
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, data)
+ self.assertFalse(serializer.object is expected_names)
+ self.assertEqual(serializer.data['names'], expected_names)
+
+ def test_null_nested(self):
+ """Test a serializer with a nonexistent nested field"""
+ data = {'ssn': '1234567890'}
+ serializer = PersonIdentifierSerializer(data=data)
+
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, data)
+ self.assertFalse(serializer.object is data)
+ expected = {'ssn': '1234567890', 'names': None}
+ self.assertEqual(serializer.data, expected)
+
def test_update(self):
serializer = CommentSerializer(self.comment, data=self.data)
expected = self.comment
@@ -189,6 +250,12 @@ class BasicTests(TestCase):
# Assert age is unchanged (35)
self.assertEqual(instance.age, self.person_data['age'])
+ def test_invalid_read_only_fields(self):
+ """
+ Regression test for #652.
+ """
+ self.assertRaises(AssertionError, PersonSerializerInvalidReadOnly, [])
+
class DictStyleSerializer(serializers.Serializer):
"""
@@ -357,7 +424,6 @@ class CustomValidationTests(TestCase):
def validate_email(self, attrs, source):
value = attrs[source]
-
return attrs
def validate_content(self, attrs, source):
@@ -738,6 +804,43 @@ class ManyRelatedTests(TestCase):
self.assertEqual(serializer.data, expected)
+ def test_include_reverse_relations(self):
+ post = BlogPost.objects.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I hate this blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ class BlogPostSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlogPost
+ fields = ('id', 'title', 'blogpostcomment_set')
+
+ serializer = BlogPostSerializer(instance=post)
+ expected = {
+ 'id': 1, 'title': 'Test blog post', 'blogpostcomment_set': [1, 2]
+ }
+ self.assertEqual(serializer.data, expected)
+
+ def test_depth_include_reverse_relations(self):
+ post = BlogPost.objects.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I hate this blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ class BlogPostSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlogPost
+ fields = ('id', 'title', 'blogpostcomment_set')
+ depth = 1
+
+ serializer = BlogPostSerializer(instance=post)
+ expected = {
+ 'id': 1, 'title': 'Test blog post',
+ 'blogpostcomment_set': [
+ {'id': 1, 'text': 'I hate this blog post', 'blog_post': 1},
+ {'id': 2, 'text': 'I love this blog post', 'blog_post': 1}
+ ]
+ }
+ self.assertEqual(serializer.data, expected)
+
def test_callable_source(self):
post = BlogPost.objects.create(title="Test blog post")
post.blogpostcomment_set.create(text="I love this blog post")
@@ -767,8 +870,6 @@ class RelatedTraversalTest(TestCase):
post = BlogPost.objects.create(title="Test blog post", writer=user)
post.blogpostcomment_set.create(text="I love this blog post")
- from rest_framework.tests.models import BlogPostComment
-
class PersonSerializer(serializers.ModelSerializer):
class Meta:
model = Person
@@ -819,23 +920,6 @@ class RelatedTraversalTest(TestCase):
self.assertEqual(serializer.data, expected)
- def test_queryset_nested_traversal(self):
- """
- Relational fields should be able to use methods as their source.
- """
- BlogPost.objects.create(title='blah')
-
- class QuerysetMethodSerializer(serializers.Serializer):
- blogposts = serializers.RelatedField(many=True, source='get_all_blogposts')
-
- class ClassWithQuerysetMethod(object):
- def get_all_blogposts(self):
- return BlogPost.objects
-
- obj = ClassWithQuerysetMethod()
- serializer = QuerysetMethodSerializer(obj)
- self.assertEqual(serializer.data, {'blogposts': ['BlogPost object']})
-
class SerializerMethodFieldTests(TestCase):
def setUp(self):
@@ -966,25 +1050,95 @@ class SerializerPickleTests(TestCase):
repr(pickle.loads(pickle.dumps(data, 0)))
+# test for issue #725
+class SeveralChoicesModel(models.Model):
+ color = models.CharField(
+ max_length=10,
+ choices=[('red', 'Red'), ('green', 'Green'), ('blue', 'Blue')],
+ blank=False
+ )
+ drink = models.CharField(
+ max_length=10,
+ choices=[('beer', 'Beer'), ('wine', 'Wine'), ('cider', 'Cider')],
+ blank=False,
+ default='beer'
+ )
+ os = models.CharField(
+ max_length=10,
+ choices=[('linux', 'Linux'), ('osx', 'OSX'), ('windows', 'Windows')],
+ blank=True
+ )
+ music_genre = models.CharField(
+ max_length=10,
+ choices=[('rock', 'Rock'), ('metal', 'Metal'), ('grunge', 'Grunge')],
+ blank=True,
+ default='metal'
+ )
+
+
+class SerializerChoiceFields(TestCase):
+
+ def setUp(self):
+ super(SerializerChoiceFields, self).setUp()
+
+ class SeveralChoicesSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = SeveralChoicesModel
+ fields = ('color', 'drink', 'os', 'music_genre')
+
+ self.several_choices_serializer = SeveralChoicesSerializer
+
+ def test_choices_blank_false_not_default(self):
+ serializer = self.several_choices_serializer()
+ self.assertEqual(
+ serializer.fields['color'].choices,
+ [('red', 'Red'), ('green', 'Green'), ('blue', 'Blue')]
+ )
+
+ def test_choices_blank_false_with_default(self):
+ serializer = self.several_choices_serializer()
+ self.assertEqual(
+ serializer.fields['drink'].choices,
+ [('beer', 'Beer'), ('wine', 'Wine'), ('cider', 'Cider')]
+ )
+
+ def test_choices_blank_true_not_default(self):
+ serializer = self.several_choices_serializer()
+ self.assertEqual(
+ serializer.fields['os'].choices,
+ BLANK_CHOICE_DASH + [('linux', 'Linux'), ('osx', 'OSX'), ('windows', 'Windows')]
+ )
+
+ def test_choices_blank_true_with_default(self):
+ serializer = self.several_choices_serializer()
+ self.assertEqual(
+ serializer.fields['music_genre'].choices,
+ BLANK_CHOICE_DASH + [('rock', 'Rock'), ('metal', 'Metal'), ('grunge', 'Grunge')]
+ )
+
+
class DepthTest(TestCase):
def test_implicit_nesting(self):
+
writer = Person.objects.create(name="django", age=1)
post = BlogPost.objects.create(title="Test blog post", writer=writer)
+ comment = BlogPostComment.objects.create(text="Test blog post comment", blog_post=post)
- class BlogPostSerializer(serializers.ModelSerializer):
+ class BlogPostCommentSerializer(serializers.ModelSerializer):
class Meta:
- model = BlogPost
- depth = 1
+ model = BlogPostComment
+ depth = 2
- serializer = BlogPostSerializer(instance=post)
- expected = {'id': 1, 'title': 'Test blog post',
- 'writer': {'id': 1, 'name': 'django', 'age': 1}}
+ serializer = BlogPostCommentSerializer(instance=comment)
+ expected = {'id': 1, 'text': 'Test blog post comment', 'blog_post': {'id': 1, 'title': 'Test blog post',
+ 'writer': {'id': 1, 'name': 'django', 'age': 1}}}
self.assertEqual(serializer.data, expected)
def test_explicit_nesting(self):
writer = Person.objects.create(name="django", age=1)
post = BlogPost.objects.create(title="Test blog post", writer=writer)
+ comment = BlogPostComment.objects.create(text="Test blog post comment", blog_post=post)
class PersonSerializer(serializers.ModelSerializer):
class Meta:
@@ -996,9 +1150,15 @@ class DepthTest(TestCase):
class Meta:
model = BlogPost
- serializer = BlogPostSerializer(instance=post)
- expected = {'id': 1, 'title': 'Test blog post',
- 'writer': {'id': 1, 'name': 'django', 'age': 1}}
+ class BlogPostCommentSerializer(serializers.ModelSerializer):
+ blog_post = BlogPostSerializer()
+
+ class Meta:
+ model = BlogPostComment
+
+ serializer = BlogPostCommentSerializer(instance=comment)
+ expected = {'id': 1, 'text': 'Test blog post comment', 'blog_post': {'id': 1, 'title': 'Test blog post',
+ 'writer': {'id': 1, 'name': 'django', 'age': 1}}}
self.assertEqual(serializer.data, expected)
@@ -1066,7 +1226,7 @@ class DeserializeListTestCase(TestCase):
def test_no_errors(self):
data = [self.data.copy() for x in range(0, 3)]
- serializer = CommentSerializer(data=data)
+ serializer = CommentSerializer(data=data, many=True)
self.assertTrue(serializer.is_valid())
self.assertTrue(isinstance(serializer.object, list))
self.assertTrue(
@@ -1078,7 +1238,7 @@ class DeserializeListTestCase(TestCase):
invalid_item['email'] = ''
data = [self.data.copy(), invalid_item, self.data.copy()]
- serializer = CommentSerializer(data=data)
+ serializer = CommentSerializer(data=data, many=True)
self.assertFalse(serializer.is_valid())
expected = [{}, {'email': ['This field is required.']}, {}]
self.assertEqual(serializer.errors, expected)
diff --git a/rest_framework/tests/serializer_bulk_update.py b/rest_framework/tests/serializer_bulk_update.py
index afc1a1a9..8b0ded1a 100644
--- a/rest_framework/tests/serializer_bulk_update.py
+++ b/rest_framework/tests/serializer_bulk_update.py
@@ -98,7 +98,7 @@ class BulkCreateSerializerTests(TestCase):
serializer = self.BookSerializer(data=data, many=True)
self.assertEqual(serializer.is_valid(), False)
- expected_errors = {'non_field_errors': ['Expected a list of items']}
+ expected_errors = {'non_field_errors': ['Expected a list of items.']}
self.assertEqual(serializer.errors, expected_errors)
@@ -115,7 +115,7 @@ class BulkCreateSerializerTests(TestCase):
serializer = self.BookSerializer(data=data, many=True)
self.assertEqual(serializer.is_valid(), False)
- expected_errors = {'non_field_errors': ['Expected a list of items']}
+ expected_errors = {'non_field_errors': ['Expected a list of items.']}
self.assertEqual(serializer.errors, expected_errors)
@@ -201,11 +201,12 @@ class BulkUpdateSerializerTests(TestCase):
'author': 'Haruki Murakami'
}
]
- serializer = self.BookSerializer(self.books(), data=data, many=True, allow_delete=True)
+ serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
self.assertEqual(serializer.is_valid(), True)
self.assertEqual(serializer.data, data)
serializer.save()
new_data = self.BookSerializer(self.books(), many=True).data
+
self.assertEqual(data, new_data)
def test_bulk_update_and_create(self):
@@ -223,13 +224,36 @@ class BulkUpdateSerializerTests(TestCase):
'author': 'Haruki Murakami'
}
]
- serializer = self.BookSerializer(self.books(), data=data, many=True, allow_delete=True)
+ serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
self.assertEqual(serializer.is_valid(), True)
self.assertEqual(serializer.data, data)
serializer.save()
new_data = self.BookSerializer(self.books(), many=True).data
self.assertEqual(data, new_data)
+ def test_bulk_update_invalid_create(self):
+ """
+ Bulk update serialization without allow_add_remove may not create items.
+ """
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 3,
+ 'title': 'Kafka on the shore',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+ expected_errors = [
+ {},
+ {'non_field_errors': ['Cannot create a new item, only existing items may be updated.']}
+ ]
+ serializer = self.BookSerializer(self.books(), data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, expected_errors)
+
def test_bulk_update_error(self):
"""
Incorrect bulk update serialization should return error data.
@@ -249,6 +273,6 @@ class BulkUpdateSerializerTests(TestCase):
{},
{'id': ['Enter a whole number.']}
]
- serializer = self.BookSerializer(self.books(), data=data, many=True, allow_delete=True)
+ serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
self.assertEqual(serializer.is_valid(), False)
self.assertEqual(serializer.errors, expected_errors)
diff --git a/rest_framework/tests/serializer_nested.py b/rest_framework/tests/serializer_nested.py
index 6a29c652..71d0e24b 100644
--- a/rest_framework/tests/serializer_nested.py
+++ b/rest_framework/tests/serializer_nested.py
@@ -109,7 +109,7 @@ class WritableNestedSerializerBasicTests(TestCase):
}
]
- serializer = self.AlbumSerializer(data=data)
+ serializer = self.AlbumSerializer(data=data, many=True)
self.assertEqual(serializer.is_valid(), False)
self.assertEqual(serializer.errors, expected_errors)
@@ -241,6 +241,6 @@ class WritableNestedSerializerObjectTests(TestCase):
)
]
- serializer = self.AlbumSerializer(data=data)
+ serializer = self.AlbumSerializer(data=data, many=True)
self.assertEqual(serializer.is_valid(), True)
self.assertEqual(serializer.object, expected_object)