aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/tests
diff options
context:
space:
mode:
authorPhilip Douglas2013-09-10 13:09:25 +0100
committerPhilip Douglas2013-09-10 13:09:25 +0100
commit39e13a0d1341c0a0e694acb1522a99470c4037be (patch)
tree27b498f3cbf81faa1ff587d0730e07706c7551a8 /rest_framework/tests
parentef7ce344865938bea285a408a7cc415a7b90a83c (diff)
parentf5c34926d6a4b4b29fb083d25b99b10d7431eee4 (diff)
downloaddjango-rest-framework-39e13a0d1341c0a0e694acb1522a99470c4037be.tar.bz2
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'rest_framework/tests')
-rw-r--r--rest_framework/tests/test_description.py9
-rw-r--r--rest_framework/tests/test_fields.py17
-rw-r--r--rest_framework/tests/test_files.py37
-rw-r--r--rest_framework/tests/test_generics.py53
-rw-r--r--rest_framework/tests/test_pagination.py47
-rw-r--r--rest_framework/tests/test_relations_nested.py351
-rw-r--r--rest_framework/tests/test_relations_pk.py9
-rw-r--r--rest_framework/tests/test_routers.py2
-rw-r--r--rest_framework/tests/test_testing.py30
-rw-r--r--rest_framework/tests/test_views.py41
10 files changed, 522 insertions, 74 deletions
diff --git a/rest_framework/tests/test_description.py b/rest_framework/tests/test_description.py
index 8019f5ec..4c03c1de 100644
--- a/rest_framework/tests/test_description.py
+++ b/rest_framework/tests/test_description.py
@@ -6,7 +6,6 @@ from rest_framework.compat import apply_markdown, smart_text
from rest_framework.views import APIView
from rest_framework.tests.description import ViewWithNonASCIICharactersInDocstring
from rest_framework.tests.description import UTF8_TEST_DOCSTRING
-from rest_framework.utils.formatting import get_view_name, get_view_description
# We check that docstrings get nicely un-indented.
DESCRIPTION = """an example docstring
@@ -58,7 +57,7 @@ class TestViewNamesAndDescriptions(TestCase):
"""
class MockView(APIView):
pass
- self.assertEqual(get_view_name(MockView), 'Mock')
+ self.assertEqual(MockView().get_view_name(), 'Mock')
def test_view_description_uses_docstring(self):
"""Ensure view descriptions are based on the docstring."""
@@ -78,7 +77,7 @@ class TestViewNamesAndDescriptions(TestCase):
# hash style header #"""
- self.assertEqual(get_view_description(MockView), DESCRIPTION)
+ self.assertEqual(MockView().get_view_description(), DESCRIPTION)
def test_view_description_supports_unicode(self):
"""
@@ -86,7 +85,7 @@ class TestViewNamesAndDescriptions(TestCase):
"""
self.assertEqual(
- get_view_description(ViewWithNonASCIICharactersInDocstring),
+ ViewWithNonASCIICharactersInDocstring().get_view_description(),
smart_text(UTF8_TEST_DOCSTRING)
)
@@ -97,7 +96,7 @@ class TestViewNamesAndDescriptions(TestCase):
"""
class MockView(APIView):
pass
- self.assertEqual(get_view_description(MockView), '')
+ self.assertEqual(MockView().get_view_description(), '')
def test_markdown(self):
"""
diff --git a/rest_framework/tests/test_fields.py b/rest_framework/tests/test_fields.py
index 6836ec86..34fbab9c 100644
--- a/rest_framework/tests/test_fields.py
+++ b/rest_framework/tests/test_fields.py
@@ -688,6 +688,14 @@ class ChoiceFieldTests(TestCase):
f = serializers.ChoiceField(required=False, choices=self.SAMPLE_CHOICES)
self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + self.SAMPLE_CHOICES)
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns None on empty param.
+ """
+ f = serializers.ChoiceField(choices=self.SAMPLE_CHOICES)
+ result = f.from_native('')
+ self.assertEqual(result, None)
+
class EmailFieldTests(TestCase):
"""
@@ -896,3 +904,12 @@ class CustomIntegerField(TestCase):
self.assertFalse(serializer.is_valid())
+class BooleanField(TestCase):
+ """
+ Tests for BooleanField
+ """
+ def test_boolean_required(self):
+ class BooleanRequiredSerializer(serializers.Serializer):
+ bool_field = serializers.BooleanField(required=True)
+
+ self.assertFalse(BooleanRequiredSerializer(data={}).is_valid())
diff --git a/rest_framework/tests/test_files.py b/rest_framework/tests/test_files.py
index 487046ac..c13c38b8 100644
--- a/rest_framework/tests/test_files.py
+++ b/rest_framework/tests/test_files.py
@@ -7,13 +7,13 @@ import datetime
class UploadedFile(object):
- def __init__(self, file, created=None):
+ def __init__(self, file=None, created=None):
self.file = file
self.created = created or datetime.datetime.now()
class UploadedFileSerializer(serializers.Serializer):
- file = serializers.FileField()
+ file = serializers.FileField(required=False)
created = serializers.DateTimeField()
def restore_object(self, attrs, instance=None):
@@ -47,5 +47,36 @@ class FileSerializerTests(TestCase):
now = datetime.datetime.now()
serializer = UploadedFileSerializer(data={'created': now})
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.object.created, now)
+ self.assertIsNone(serializer.object.file)
+
+ def test_remove_with_empty_string(self):
+ """
+ Passing empty string as data should cause file to be removed
+
+ Test for:
+ https://github.com/tomchristie/django-rest-framework/issues/937
+ """
+ now = datetime.datetime.now()
+ file = BytesIO(six.b('stuff'))
+ file.name = 'stuff.txt'
+ file.size = len(file.getvalue())
+
+ uploaded_file = UploadedFile(file=file, created=now)
+
+ serializer = UploadedFileSerializer(instance=uploaded_file, data={'created': now, 'file': ''})
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.object.created, uploaded_file.created)
+ self.assertIsNone(serializer.object.file)
+
+ def test_validation_error_with_non_file(self):
+ """
+ Passing non-files should raise a validation error.
+ """
+ now = datetime.datetime.now()
+ errmsg = 'No file was submitted. Check the encoding type on the form.'
+
+ serializer = UploadedFileSerializer(data={'created': now, 'file': 'abc'})
self.assertFalse(serializer.is_valid())
- self.assertIn('file', serializer.errors)
+ self.assertEqual(serializer.errors, {'file': [errmsg]})
diff --git a/rest_framework/tests/test_generics.py b/rest_framework/tests/test_generics.py
index 1550880b..79cd99ac 100644
--- a/rest_framework/tests/test_generics.py
+++ b/rest_framework/tests/test_generics.py
@@ -272,6 +272,48 @@ class TestInstanceView(TestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, expected)
+ def test_options_before_instance_create(self):
+ """
+ OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata
+ before the instance has been created
+ """
+ request = factory.options('/999')
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=999).render()
+ expected = {
+ 'parses': [
+ 'application/json',
+ 'application/x-www-form-urlencoded',
+ 'multipart/form-data'
+ ],
+ 'renders': [
+ 'application/json',
+ 'text/html'
+ ],
+ 'name': 'Instance',
+ 'description': 'Example description for OPTIONS.',
+ 'actions': {
+ 'PUT': {
+ 'text': {
+ 'max_length': 100,
+ 'read_only': False,
+ 'required': True,
+ 'type': 'string',
+ 'label': 'Text comes here',
+ 'help_text': 'Text description.'
+ },
+ 'id': {
+ 'read_only': True,
+ 'required': False,
+ 'type': 'integer',
+ 'label': 'ID',
+ },
+ }
+ }
+ }
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, expected)
+
def test_get_instance_view_incorrect_arg(self):
"""
GET requests with an incorrect pk type, should raise 404, not 500.
@@ -338,6 +380,17 @@ class TestInstanceView(TestCase):
new_obj = SlugBasedModel.objects.get(slug='test_slug')
self.assertEqual(new_obj.text, 'foobar')
+ def test_patch_cannot_create_an_object(self):
+ """
+ PATCH requests should not be able to create objects.
+ """
+ data = {'text': 'foobar'}
+ request = factory.patch('/999', data, format='json')
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=999).render()
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+ self.assertFalse(self.objects.filter(id=999).exists())
+
class TestOverriddenGetObject(TestCase):
"""
diff --git a/rest_framework/tests/test_pagination.py b/rest_framework/tests/test_pagination.py
index 85d4640e..4170d4b6 100644
--- a/rest_framework/tests/test_pagination.py
+++ b/rest_framework/tests/test_pagination.py
@@ -42,6 +42,16 @@ class PaginateByParamView(generics.ListAPIView):
paginate_by_param = 'page_size'
+class MaxPaginateByView(generics.ListAPIView):
+ """
+ View for testing custom max_paginate_by usage
+ """
+ model = BasicModel
+ paginate_by = 3
+ max_paginate_by = 5
+ paginate_by_param = 'page_size'
+
+
class IntegrationTestPagination(TestCase):
"""
Integration tests for paginated list views.
@@ -313,6 +323,43 @@ class TestCustomPaginateByParam(TestCase):
self.assertEqual(response.data['results'], self.data[:5])
+class TestMaxPaginateByParam(TestCase):
+ """
+ Tests for list views with max_paginate_by kwarg
+ """
+
+ def setUp(self):
+ """
+ Create 13 BasicModel instances.
+ """
+ for i in range(13):
+ BasicModel(text=i).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.view = MaxPaginateByView.as_view()
+
+ def test_max_paginate_by(self):
+ """
+ If max_paginate_by is set, it should limit page size for the view.
+ """
+ request = factory.get('/?page_size=10')
+ response = self.view(request).render()
+ self.assertEqual(response.data['count'], 13)
+ self.assertEqual(response.data['results'], self.data[:5])
+
+ def test_max_paginate_by_without_page_size_param(self):
+ """
+ If max_paginate_by is set, but client does not specifiy page_size,
+ standard `paginate_by` behavior should be used.
+ """
+ request = factory.get('/')
+ response = self.view(request).render()
+ self.assertEqual(response.data['results'], self.data[:3])
+
+
### Tests for context in pagination serializers
class CustomField(serializers.Field):
diff --git a/rest_framework/tests/test_relations_nested.py b/rest_framework/tests/test_relations_nested.py
index f6d006b3..d393b0c3 100644
--- a/rest_framework/tests/test_relations_nested.py
+++ b/rest_framework/tests/test_relations_nested.py
@@ -1,107 +1,328 @@
from __future__ import unicode_literals
+from django.db import models
from django.test import TestCase
from rest_framework import serializers
-from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
-class ForeignKeySourceSerializer(serializers.ModelSerializer):
- class Meta:
- model = ForeignKeySource
- fields = ('id', 'name', 'target')
- depth = 1
+class OneToOneTarget(models.Model):
+ name = models.CharField(max_length=100)
-class ForeignKeyTargetSerializer(serializers.ModelSerializer):
- class Meta:
- model = ForeignKeyTarget
- fields = ('id', 'name', 'sources')
- depth = 1
+class OneToOneSource(models.Model):
+ name = models.CharField(max_length=100)
+ target = models.OneToOneField(OneToOneTarget, related_name='source',
+ null=True, blank=True)
-class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
- class Meta:
- model = NullableForeignKeySource
- fields = ('id', 'name', 'target')
- depth = 1
+class OneToManyTarget(models.Model):
+ name = models.CharField(max_length=100)
-class NullableOneToOneTargetSerializer(serializers.ModelSerializer):
- class Meta:
- model = OneToOneTarget
- fields = ('id', 'name', 'nullable_source')
- depth = 1
+class OneToManySource(models.Model):
+ name = models.CharField(max_length=100)
+ target = models.ForeignKey(OneToManyTarget, related_name='sources')
-class ReverseForeignKeyTests(TestCase):
+class ReverseNestedOneToOneTests(TestCase):
def setUp(self):
- target = ForeignKeyTarget(name='target-1')
- target.save()
- new_target = ForeignKeyTarget(name='target-2')
- new_target.save()
+ class OneToOneSourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = OneToOneSource
+ fields = ('id', 'name')
+
+ class OneToOneTargetSerializer(serializers.ModelSerializer):
+ source = OneToOneSourceSerializer()
+
+ class Meta:
+ model = OneToOneTarget
+ fields = ('id', 'name', 'source')
+
+ self.Serializer = OneToOneTargetSerializer
+
for idx in range(1, 4):
- source = ForeignKeySource(name='source-%d' % idx, target=target)
+ target = OneToOneTarget(name='target-%d' % idx)
+ target.save()
+ source = OneToOneSource(name='source-%d' % idx, target=target)
source.save()
- def test_foreign_key_retrieve(self):
- queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset, many=True)
+ def test_one_to_one_retrieve(self):
+ queryset = OneToOneTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
- {'id': 2, 'name': 'source-2', 'target': {'id': 1, 'name': 'target-1'}},
- {'id': 3, 'name': 'source-3', 'target': {'id': 1, 'name': 'target-1'}},
+ {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
+ {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
+ {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}}
]
self.assertEqual(serializer.data, expected)
- def test_reverse_foreign_key_retrieve(self):
- queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ def test_one_to_one_create(self):
+ data = {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}}
+ serializer = self.Serializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-4')
+
+ # Ensure (target 4, target_source 4, source 4) are added, and
+ # everything else is as expected.
+ queryset = OneToOneTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'target-1', 'sources': [
- {'id': 1, 'name': 'source-1', 'target': 1},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': 1},
- ]},
- {'id': 2, 'name': 'target-2', 'sources': [
- ]}
+ {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
+ {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
+ {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}},
+ {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}}
]
self.assertEqual(serializer.data, expected)
+ def test_one_to_one_create_with_invalid_data(self):
+ data = {'id': 4, 'name': 'target-4', 'source': {'id': 4}}
+ serializer = self.Serializer(data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'source': [{'name': ['This field is required.']}]})
-class NestedNullableForeignKeyTests(TestCase):
+ def test_one_to_one_update(self):
+ data = {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}
+ instance = OneToOneTarget.objects.get(pk=3)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-3-updated')
+
+ # Ensure (target 3, target_source 3, source 3) are updated,
+ # and everything else is as expected.
+ queryset = OneToOneTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
+ {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
+ {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+
+class ForwardNestedOneToOneTests(TestCase):
def setUp(self):
- target = ForeignKeyTarget(name='target-1')
- target.save()
+ class OneToOneTargetSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = OneToOneTarget
+ fields = ('id', 'name')
+
+ class OneToOneSourceSerializer(serializers.ModelSerializer):
+ target = OneToOneTargetSerializer()
+
+ class Meta:
+ model = OneToOneSource
+ fields = ('id', 'name', 'target')
+
+ self.Serializer = OneToOneSourceSerializer
+
for idx in range(1, 4):
- if idx == 3:
- target = None
- source = NullableForeignKeySource(name='source-%d' % idx, target=target)
+ target = OneToOneTarget(name='target-%d' % idx)
+ target.save()
+ source = OneToOneSource(name='source-%d' % idx, target=target)
source.save()
- def test_foreign_key_retrieve_with_null(self):
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ def test_one_to_one_retrieve(self):
+ queryset = OneToOneSource.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
+ {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
+ {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_one_create(self):
+ data = {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}}
+ serializer = self.Serializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure (target 4, target_source 4, source 4) are added, and
+ # everything else is as expected.
+ queryset = OneToOneSource.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
+ {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
+ {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}},
+ {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_one_create_with_invalid_data(self):
+ data = {'id': 4, 'name': 'source-4', 'target': {'id': 4}}
+ serializer = self.Serializer(data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'target': [{'name': ['This field is required.']}]})
+
+ def test_one_to_one_update(self):
+ data = {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}}
+ instance = OneToOneSource.objects.get(pk=3)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-3-updated')
+
+ # Ensure (target 3, target_source 3, source 3) are updated,
+ # and everything else is as expected.
+ queryset = OneToOneSource.objects.all()
+ serializer = self.Serializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
- {'id': 2, 'name': 'source-2', 'target': {'id': 1, 'name': 'target-1'}},
- {'id': 3, 'name': 'source-3', 'target': None},
+ {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
+ {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}}
]
self.assertEqual(serializer.data, expected)
+ def test_one_to_one_update_to_null(self):
+ data = {'id': 3, 'name': 'source-3-updated', 'target': None}
+ instance = OneToOneSource.objects.get(pk=3)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
-class NestedNullableOneToOneTests(TestCase):
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-3-updated')
+ self.assertEqual(obj.target, None)
+
+ queryset = OneToOneSource.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
+ {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
+ {'id': 3, 'name': 'source-3-updated', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ # TODO: Nullable 1-1 tests
+ # def test_one_to_one_delete(self):
+ # data = {'id': 3, 'name': 'target-3', 'target_source': None}
+ # instance = OneToOneTarget.objects.get(pk=3)
+ # serializer = self.Serializer(instance, data=data)
+ # self.assertTrue(serializer.is_valid())
+ # serializer.save()
+
+ # # Ensure (target_source 3, source 3) are deleted,
+ # # and everything else is as expected.
+ # queryset = OneToOneTarget.objects.all()
+ # serializer = self.Serializer(queryset)
+ # expected = [
+ # {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
+ # {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
+ # {'id': 3, 'name': 'target-3', 'source': None}
+ # ]
+ # self.assertEqual(serializer.data, expected)
+
+
+class ReverseNestedOneToManyTests(TestCase):
def setUp(self):
- target = OneToOneTarget(name='target-1')
+ class OneToManySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = OneToManySource
+ fields = ('id', 'name')
+
+ class OneToManyTargetSerializer(serializers.ModelSerializer):
+ sources = OneToManySourceSerializer(many=True, allow_add_remove=True)
+
+ class Meta:
+ model = OneToManyTarget
+ fields = ('id', 'name', 'sources')
+
+ self.Serializer = OneToManyTargetSerializer
+
+ target = OneToManyTarget(name='target-1')
target.save()
- new_target = OneToOneTarget(name='target-2')
- new_target.save()
- source = NullableOneToOneSource(name='source-1', target=target)
- source.save()
+ for idx in range(1, 4):
+ source = OneToManySource(name='source-%d' % idx, target=target)
+ source.save()
- def test_reverse_foreign_key_retrieve_with_null(self):
- queryset = OneToOneTarget.objects.all()
- serializer = NullableOneToOneTargetSerializer(queryset, many=True)
+ def test_one_to_many_retrieve(self):
+ queryset = OneToManyTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'}]},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_many_create(self):
+ data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'},
+ {'id': 4, 'name': 'source-4'}]}
+ instance = OneToManyTarget.objects.get(pk=1)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-1')
+
+ # Ensure source 4 is added, and everything else is as
+ # expected.
+ queryset = OneToManyTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'target-1', 'nullable_source': {'id': 1, 'name': 'source-1', 'target': 1}},
- {'id': 2, 'name': 'target-2', 'nullable_source': None},
+ {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'},
+ {'id': 4, 'name': 'source-4'}]}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_many_create_with_invalid_data(self):
+ data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'},
+ {'id': 4}]}
+ serializer = self.Serializer(data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'sources': [{}, {}, {}, {'name': ['This field is required.']}]})
+
+ def test_one_to_many_update(self):
+ data = {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'}]}
+ instance = OneToManyTarget.objects.get(pk=1)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-1-updated')
+
+ # Ensure (target 1, source 1) are updated,
+ # and everything else is as expected.
+ queryset = OneToManyTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'}]}
+
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_many_delete(self):
+ data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 3, 'name': 'source-3'}]}
+ instance = OneToManyTarget.objects.get(pk=1)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+
+ # Ensure source 2 is deleted, and everything else is as
+ # expected.
+ queryset = OneToManyTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 3, 'name': 'source-3'}]}
+
]
self.assertEqual(serializer.data, expected)
diff --git a/rest_framework/tests/test_relations_pk.py b/rest_framework/tests/test_relations_pk.py
index e2a1b815..3815afdd 100644
--- a/rest_framework/tests/test_relations_pk.py
+++ b/rest_framework/tests/test_relations_pk.py
@@ -283,6 +283,15 @@ class PKForeignKeyTests(TestCase):
self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'target': ['This field is required.']})
+ def test_foreign_key_with_empty(self):
+ """
+ Regression test for #1072
+
+ https://github.com/tomchristie/django-rest-framework/issues/1072
+ """
+ serializer = NullableForeignKeySourceSerializer()
+ self.assertEqual(serializer.data['target'], None)
+
class PKNullableForeignKeyTests(TestCase):
def setUp(self):
diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py
index 5fcccb74..e723f7d4 100644
--- a/rest_framework/tests/test_routers.py
+++ b/rest_framework/tests/test_routers.py
@@ -146,7 +146,7 @@ class TestTrailingSlashRemoved(TestCase):
self.urls = self.router.urls
def test_urls_can_have_trailing_slash_removed(self):
- expected = ['^notes$', '^notes/(?P<pk>[^/]+)$']
+ expected = ['^notes$', '^notes/(?P<pk>[^/.]+)$']
for idx in range(len(expected)):
self.assertEqual(expected[idx], self.urls[idx].regex.pattern)
diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py
index 49d45fc2..48b8956b 100644
--- a/rest_framework/tests/test_testing.py
+++ b/rest_framework/tests/test_testing.py
@@ -17,8 +17,18 @@ def view(request):
})
+@api_view(['GET', 'POST'])
+def session_view(request):
+ active_session = request.session.get('active_session', False)
+ request.session['active_session'] = True
+ return Response({
+ 'active_session': active_session
+ })
+
+
urlpatterns = patterns('',
url(r'^view/$', view),
+ url(r'^session-view/$', session_view),
)
@@ -46,6 +56,26 @@ class TestAPITestClient(TestCase):
response = self.client.get('/view/')
self.assertEqual(response.data['user'], 'example')
+ def test_force_authenticate_with_sessions(self):
+ """
+ Setting `.force_authenticate()` forcibly authenticates each request.
+ """
+ user = User.objects.create_user('example', 'example@example.com')
+ self.client.force_authenticate(user)
+
+ # First request does not yet have an active session
+ response = self.client.get('/session-view/')
+ self.assertEqual(response.data['active_session'], False)
+
+ # Subsequant requests have an active session
+ response = self.client.get('/session-view/')
+ self.assertEqual(response.data['active_session'], True)
+
+ # Force authenticating as `None` should also logout the user session.
+ self.client.force_authenticate(None)
+ response = self.client.get('/session-view/')
+ self.assertEqual(response.data['active_session'], False)
+
def test_csrf_exempt_by_default(self):
"""
By default, the test client is CSRF exempt.
diff --git a/rest_framework/tests/test_views.py b/rest_framework/tests/test_views.py
index c0bec5ae..65c7e50e 100644
--- a/rest_framework/tests/test_views.py
+++ b/rest_framework/tests/test_views.py
@@ -32,6 +32,16 @@ def basic_view(request):
return {'method': 'PATCH', 'data': request.DATA}
+class ErrorView(APIView):
+ def get(self, request, *args, **kwargs):
+ raise Exception
+
+
+@api_view(['GET'])
+def error_view(request):
+ raise Exception
+
+
def sanitise_json_error(error_dict):
"""
Exact contents of JSON error messages depend on the installed version
@@ -99,3 +109,34 @@ class FunctionBasedViewIntegrationTests(TestCase):
}
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(sanitise_json_error(response.data), expected)
+
+
+class TestCustomExceptionHandler(TestCase):
+ def setUp(self):
+ self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER
+
+ def exception_handler(exc):
+ return Response('Error!', status=status.HTTP_400_BAD_REQUEST)
+
+ api_settings.EXCEPTION_HANDLER = exception_handler
+
+ def tearDown(self):
+ api_settings.EXCEPTION_HANDLER = self.DEFAULT_HANDLER
+
+ def test_class_based_view_exception_handler(self):
+ view = ErrorView.as_view()
+
+ request = factory.get('/', content_type='application/json')
+ response = view(request)
+ expected = 'Error!'
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEqual(response.data, expected)
+
+ def test_function_based_view_exception_handler(self):
+ view = error_view
+
+ request = factory.get('/', content_type='application/json')
+ response = view(request)
+ expected = 'Error!'
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEqual(response.data, expected)