diff options
Diffstat (limited to 'rest_framework/tests')
| -rw-r--r-- | rest_framework/tests/test_description.py | 9 | ||||
| -rw-r--r-- | rest_framework/tests/test_fields.py | 17 | ||||
| -rw-r--r-- | rest_framework/tests/test_files.py | 37 | ||||
| -rw-r--r-- | rest_framework/tests/test_generics.py | 53 | ||||
| -rw-r--r-- | rest_framework/tests/test_pagination.py | 47 | ||||
| -rw-r--r-- | rest_framework/tests/test_relations_nested.py | 351 | ||||
| -rw-r--r-- | rest_framework/tests/test_relations_pk.py | 9 | ||||
| -rw-r--r-- | rest_framework/tests/test_routers.py | 2 | ||||
| -rw-r--r-- | rest_framework/tests/test_testing.py | 30 | ||||
| -rw-r--r-- | rest_framework/tests/test_views.py | 41 |
10 files changed, 522 insertions, 74 deletions
diff --git a/rest_framework/tests/test_description.py b/rest_framework/tests/test_description.py index 8019f5ec..4c03c1de 100644 --- a/rest_framework/tests/test_description.py +++ b/rest_framework/tests/test_description.py @@ -6,7 +6,6 @@ from rest_framework.compat import apply_markdown, smart_text from rest_framework.views import APIView from rest_framework.tests.description import ViewWithNonASCIICharactersInDocstring from rest_framework.tests.description import UTF8_TEST_DOCSTRING -from rest_framework.utils.formatting import get_view_name, get_view_description # We check that docstrings get nicely un-indented. DESCRIPTION = """an example docstring @@ -58,7 +57,7 @@ class TestViewNamesAndDescriptions(TestCase): """ class MockView(APIView): pass - self.assertEqual(get_view_name(MockView), 'Mock') + self.assertEqual(MockView().get_view_name(), 'Mock') def test_view_description_uses_docstring(self): """Ensure view descriptions are based on the docstring.""" @@ -78,7 +77,7 @@ class TestViewNamesAndDescriptions(TestCase): # hash style header #""" - self.assertEqual(get_view_description(MockView), DESCRIPTION) + self.assertEqual(MockView().get_view_description(), DESCRIPTION) def test_view_description_supports_unicode(self): """ @@ -86,7 +85,7 @@ class TestViewNamesAndDescriptions(TestCase): """ self.assertEqual( - get_view_description(ViewWithNonASCIICharactersInDocstring), + ViewWithNonASCIICharactersInDocstring().get_view_description(), smart_text(UTF8_TEST_DOCSTRING) ) @@ -97,7 +96,7 @@ class TestViewNamesAndDescriptions(TestCase): """ class MockView(APIView): pass - self.assertEqual(get_view_description(MockView), '') + self.assertEqual(MockView().get_view_description(), '') def test_markdown(self): """ diff --git a/rest_framework/tests/test_fields.py b/rest_framework/tests/test_fields.py index 6836ec86..34fbab9c 100644 --- a/rest_framework/tests/test_fields.py +++ b/rest_framework/tests/test_fields.py @@ -688,6 +688,14 @@ class ChoiceFieldTests(TestCase): f = serializers.ChoiceField(required=False, choices=self.SAMPLE_CHOICES) self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + self.SAMPLE_CHOICES) + def test_from_native_empty(self): + """ + Make sure from_native() returns None on empty param. + """ + f = serializers.ChoiceField(choices=self.SAMPLE_CHOICES) + result = f.from_native('') + self.assertEqual(result, None) + class EmailFieldTests(TestCase): """ @@ -896,3 +904,12 @@ class CustomIntegerField(TestCase): self.assertFalse(serializer.is_valid()) +class BooleanField(TestCase): + """ + Tests for BooleanField + """ + def test_boolean_required(self): + class BooleanRequiredSerializer(serializers.Serializer): + bool_field = serializers.BooleanField(required=True) + + self.assertFalse(BooleanRequiredSerializer(data={}).is_valid()) diff --git a/rest_framework/tests/test_files.py b/rest_framework/tests/test_files.py index 487046ac..c13c38b8 100644 --- a/rest_framework/tests/test_files.py +++ b/rest_framework/tests/test_files.py @@ -7,13 +7,13 @@ import datetime class UploadedFile(object): - def __init__(self, file, created=None): + def __init__(self, file=None, created=None): self.file = file self.created = created or datetime.datetime.now() class UploadedFileSerializer(serializers.Serializer): - file = serializers.FileField() + file = serializers.FileField(required=False) created = serializers.DateTimeField() def restore_object(self, attrs, instance=None): @@ -47,5 +47,36 @@ class FileSerializerTests(TestCase): now = datetime.datetime.now() serializer = UploadedFileSerializer(data={'created': now}) + self.assertTrue(serializer.is_valid()) + self.assertEqual(serializer.object.created, now) + self.assertIsNone(serializer.object.file) + + def test_remove_with_empty_string(self): + """ + Passing empty string as data should cause file to be removed + + Test for: + https://github.com/tomchristie/django-rest-framework/issues/937 + """ + now = datetime.datetime.now() + file = BytesIO(six.b('stuff')) + file.name = 'stuff.txt' + file.size = len(file.getvalue()) + + uploaded_file = UploadedFile(file=file, created=now) + + serializer = UploadedFileSerializer(instance=uploaded_file, data={'created': now, 'file': ''}) + self.assertTrue(serializer.is_valid()) + self.assertEqual(serializer.object.created, uploaded_file.created) + self.assertIsNone(serializer.object.file) + + def test_validation_error_with_non_file(self): + """ + Passing non-files should raise a validation error. + """ + now = datetime.datetime.now() + errmsg = 'No file was submitted. Check the encoding type on the form.' + + serializer = UploadedFileSerializer(data={'created': now, 'file': 'abc'}) self.assertFalse(serializer.is_valid()) - self.assertIn('file', serializer.errors) + self.assertEqual(serializer.errors, {'file': [errmsg]}) diff --git a/rest_framework/tests/test_generics.py b/rest_framework/tests/test_generics.py index 1550880b..79cd99ac 100644 --- a/rest_framework/tests/test_generics.py +++ b/rest_framework/tests/test_generics.py @@ -272,6 +272,48 @@ class TestInstanceView(TestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data, expected) + def test_options_before_instance_create(self): + """ + OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata + before the instance has been created + """ + request = factory.options('/999') + with self.assertNumQueries(1): + response = self.view(request, pk=999).render() + expected = { + 'parses': [ + 'application/json', + 'application/x-www-form-urlencoded', + 'multipart/form-data' + ], + 'renders': [ + 'application/json', + 'text/html' + ], + 'name': 'Instance', + 'description': 'Example description for OPTIONS.', + 'actions': { + 'PUT': { + 'text': { + 'max_length': 100, + 'read_only': False, + 'required': True, + 'type': 'string', + 'label': 'Text comes here', + 'help_text': 'Text description.' + }, + 'id': { + 'read_only': True, + 'required': False, + 'type': 'integer', + 'label': 'ID', + }, + } + } + } + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, expected) + def test_get_instance_view_incorrect_arg(self): """ GET requests with an incorrect pk type, should raise 404, not 500. @@ -338,6 +380,17 @@ class TestInstanceView(TestCase): new_obj = SlugBasedModel.objects.get(slug='test_slug') self.assertEqual(new_obj.text, 'foobar') + def test_patch_cannot_create_an_object(self): + """ + PATCH requests should not be able to create objects. + """ + data = {'text': 'foobar'} + request = factory.patch('/999', data, format='json') + with self.assertNumQueries(1): + response = self.view(request, pk=999).render() + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + self.assertFalse(self.objects.filter(id=999).exists()) + class TestOverriddenGetObject(TestCase): """ diff --git a/rest_framework/tests/test_pagination.py b/rest_framework/tests/test_pagination.py index 85d4640e..4170d4b6 100644 --- a/rest_framework/tests/test_pagination.py +++ b/rest_framework/tests/test_pagination.py @@ -42,6 +42,16 @@ class PaginateByParamView(generics.ListAPIView): paginate_by_param = 'page_size' +class MaxPaginateByView(generics.ListAPIView): + """ + View for testing custom max_paginate_by usage + """ + model = BasicModel + paginate_by = 3 + max_paginate_by = 5 + paginate_by_param = 'page_size' + + class IntegrationTestPagination(TestCase): """ Integration tests for paginated list views. @@ -313,6 +323,43 @@ class TestCustomPaginateByParam(TestCase): self.assertEqual(response.data['results'], self.data[:5]) +class TestMaxPaginateByParam(TestCase): + """ + Tests for list views with max_paginate_by kwarg + """ + + def setUp(self): + """ + Create 13 BasicModel instances. + """ + for i in range(13): + BasicModel(text=i).save() + self.objects = BasicModel.objects + self.data = [ + {'id': obj.id, 'text': obj.text} + for obj in self.objects.all() + ] + self.view = MaxPaginateByView.as_view() + + def test_max_paginate_by(self): + """ + If max_paginate_by is set, it should limit page size for the view. + """ + request = factory.get('/?page_size=10') + response = self.view(request).render() + self.assertEqual(response.data['count'], 13) + self.assertEqual(response.data['results'], self.data[:5]) + + def test_max_paginate_by_without_page_size_param(self): + """ + If max_paginate_by is set, but client does not specifiy page_size, + standard `paginate_by` behavior should be used. + """ + request = factory.get('/') + response = self.view(request).render() + self.assertEqual(response.data['results'], self.data[:3]) + + ### Tests for context in pagination serializers class CustomField(serializers.Field): diff --git a/rest_framework/tests/test_relations_nested.py b/rest_framework/tests/test_relations_nested.py index f6d006b3..d393b0c3 100644 --- a/rest_framework/tests/test_relations_nested.py +++ b/rest_framework/tests/test_relations_nested.py @@ -1,107 +1,328 @@ from __future__ import unicode_literals +from django.db import models from django.test import TestCase from rest_framework import serializers -from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource -class ForeignKeySourceSerializer(serializers.ModelSerializer): - class Meta: - model = ForeignKeySource - fields = ('id', 'name', 'target') - depth = 1 +class OneToOneTarget(models.Model): + name = models.CharField(max_length=100) -class ForeignKeyTargetSerializer(serializers.ModelSerializer): - class Meta: - model = ForeignKeyTarget - fields = ('id', 'name', 'sources') - depth = 1 +class OneToOneSource(models.Model): + name = models.CharField(max_length=100) + target = models.OneToOneField(OneToOneTarget, related_name='source', + null=True, blank=True) -class NullableForeignKeySourceSerializer(serializers.ModelSerializer): - class Meta: - model = NullableForeignKeySource - fields = ('id', 'name', 'target') - depth = 1 +class OneToManyTarget(models.Model): + name = models.CharField(max_length=100) -class NullableOneToOneTargetSerializer(serializers.ModelSerializer): - class Meta: - model = OneToOneTarget - fields = ('id', 'name', 'nullable_source') - depth = 1 +class OneToManySource(models.Model): + name = models.CharField(max_length=100) + target = models.ForeignKey(OneToManyTarget, related_name='sources') -class ReverseForeignKeyTests(TestCase): +class ReverseNestedOneToOneTests(TestCase): def setUp(self): - target = ForeignKeyTarget(name='target-1') - target.save() - new_target = ForeignKeyTarget(name='target-2') - new_target.save() + class OneToOneSourceSerializer(serializers.ModelSerializer): + class Meta: + model = OneToOneSource + fields = ('id', 'name') + + class OneToOneTargetSerializer(serializers.ModelSerializer): + source = OneToOneSourceSerializer() + + class Meta: + model = OneToOneTarget + fields = ('id', 'name', 'source') + + self.Serializer = OneToOneTargetSerializer + for idx in range(1, 4): - source = ForeignKeySource(name='source-%d' % idx, target=target) + target = OneToOneTarget(name='target-%d' % idx) + target.save() + source = OneToOneSource(name='source-%d' % idx, target=target) source.save() - def test_foreign_key_retrieve(self): - queryset = ForeignKeySource.objects.all() - serializer = ForeignKeySourceSerializer(queryset, many=True) + def test_one_to_one_retrieve(self): + queryset = OneToOneTarget.objects.all() + serializer = self.Serializer(queryset, many=True) expected = [ - {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, - {'id': 2, 'name': 'source-2', 'target': {'id': 1, 'name': 'target-1'}}, - {'id': 3, 'name': 'source-3', 'target': {'id': 1, 'name': 'target-1'}}, + {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, + {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, + {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}} ] self.assertEqual(serializer.data, expected) - def test_reverse_foreign_key_retrieve(self): - queryset = ForeignKeyTarget.objects.all() - serializer = ForeignKeyTargetSerializer(queryset, many=True) + def test_one_to_one_create(self): + data = {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}} + serializer = self.Serializer(data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'target-4') + + # Ensure (target 4, target_source 4, source 4) are added, and + # everything else is as expected. + queryset = OneToOneTarget.objects.all() + serializer = self.Serializer(queryset, many=True) expected = [ - {'id': 1, 'name': 'target-1', 'sources': [ - {'id': 1, 'name': 'source-1', 'target': 1}, - {'id': 2, 'name': 'source-2', 'target': 1}, - {'id': 3, 'name': 'source-3', 'target': 1}, - ]}, - {'id': 2, 'name': 'target-2', 'sources': [ - ]} + {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, + {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, + {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}}, + {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}} ] self.assertEqual(serializer.data, expected) + def test_one_to_one_create_with_invalid_data(self): + data = {'id': 4, 'name': 'target-4', 'source': {'id': 4}} + serializer = self.Serializer(data=data) + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'source': [{'name': ['This field is required.']}]}) -class NestedNullableForeignKeyTests(TestCase): + def test_one_to_one_update(self): + data = {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}} + instance = OneToOneTarget.objects.get(pk=3) + serializer = self.Serializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'target-3-updated') + + # Ensure (target 3, target_source 3, source 3) are updated, + # and everything else is as expected. + queryset = OneToOneTarget.objects.all() + serializer = self.Serializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, + {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, + {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}} + ] + self.assertEqual(serializer.data, expected) + + +class ForwardNestedOneToOneTests(TestCase): def setUp(self): - target = ForeignKeyTarget(name='target-1') - target.save() + class OneToOneTargetSerializer(serializers.ModelSerializer): + class Meta: + model = OneToOneTarget + fields = ('id', 'name') + + class OneToOneSourceSerializer(serializers.ModelSerializer): + target = OneToOneTargetSerializer() + + class Meta: + model = OneToOneSource + fields = ('id', 'name', 'target') + + self.Serializer = OneToOneSourceSerializer + for idx in range(1, 4): - if idx == 3: - target = None - source = NullableForeignKeySource(name='source-%d' % idx, target=target) + target = OneToOneTarget(name='target-%d' % idx) + target.save() + source = OneToOneSource(name='source-%d' % idx, target=target) source.save() - def test_foreign_key_retrieve_with_null(self): - queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset, many=True) + def test_one_to_one_retrieve(self): + queryset = OneToOneSource.objects.all() + serializer = self.Serializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, + {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, + {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}} + ] + self.assertEqual(serializer.data, expected) + + def test_one_to_one_create(self): + data = {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}} + serializer = self.Serializer(data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'source-4') + + # Ensure (target 4, target_source 4, source 4) are added, and + # everything else is as expected. + queryset = OneToOneSource.objects.all() + serializer = self.Serializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, + {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, + {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}}, + {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}} + ] + self.assertEqual(serializer.data, expected) + + def test_one_to_one_create_with_invalid_data(self): + data = {'id': 4, 'name': 'source-4', 'target': {'id': 4}} + serializer = self.Serializer(data=data) + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'target': [{'name': ['This field is required.']}]}) + + def test_one_to_one_update(self): + data = {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}} + instance = OneToOneSource.objects.get(pk=3) + serializer = self.Serializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'source-3-updated') + + # Ensure (target 3, target_source 3, source 3) are updated, + # and everything else is as expected. + queryset = OneToOneSource.objects.all() + serializer = self.Serializer(queryset, many=True) expected = [ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, - {'id': 2, 'name': 'source-2', 'target': {'id': 1, 'name': 'target-1'}}, - {'id': 3, 'name': 'source-3', 'target': None}, + {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, + {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}} ] self.assertEqual(serializer.data, expected) + def test_one_to_one_update_to_null(self): + data = {'id': 3, 'name': 'source-3-updated', 'target': None} + instance = OneToOneSource.objects.get(pk=3) + serializer = self.Serializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() -class NestedNullableOneToOneTests(TestCase): + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'source-3-updated') + self.assertEqual(obj.target, None) + + queryset = OneToOneSource.objects.all() + serializer = self.Serializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, + {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, + {'id': 3, 'name': 'source-3-updated', 'target': None} + ] + self.assertEqual(serializer.data, expected) + + # TODO: Nullable 1-1 tests + # def test_one_to_one_delete(self): + # data = {'id': 3, 'name': 'target-3', 'target_source': None} + # instance = OneToOneTarget.objects.get(pk=3) + # serializer = self.Serializer(instance, data=data) + # self.assertTrue(serializer.is_valid()) + # serializer.save() + + # # Ensure (target_source 3, source 3) are deleted, + # # and everything else is as expected. + # queryset = OneToOneTarget.objects.all() + # serializer = self.Serializer(queryset) + # expected = [ + # {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, + # {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, + # {'id': 3, 'name': 'target-3', 'source': None} + # ] + # self.assertEqual(serializer.data, expected) + + +class ReverseNestedOneToManyTests(TestCase): def setUp(self): - target = OneToOneTarget(name='target-1') + class OneToManySourceSerializer(serializers.ModelSerializer): + class Meta: + model = OneToManySource + fields = ('id', 'name') + + class OneToManyTargetSerializer(serializers.ModelSerializer): + sources = OneToManySourceSerializer(many=True, allow_add_remove=True) + + class Meta: + model = OneToManyTarget + fields = ('id', 'name', 'sources') + + self.Serializer = OneToManyTargetSerializer + + target = OneToManyTarget(name='target-1') target.save() - new_target = OneToOneTarget(name='target-2') - new_target.save() - source = NullableOneToOneSource(name='source-1', target=target) - source.save() + for idx in range(1, 4): + source = OneToManySource(name='source-%d' % idx, target=target) + source.save() - def test_reverse_foreign_key_retrieve_with_null(self): - queryset = OneToOneTarget.objects.all() - serializer = NullableOneToOneTargetSerializer(queryset, many=True) + def test_one_to_many_retrieve(self): + queryset = OneToManyTarget.objects.all() + serializer = self.Serializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, + {'id': 2, 'name': 'source-2'}, + {'id': 3, 'name': 'source-3'}]}, + ] + self.assertEqual(serializer.data, expected) + + def test_one_to_many_create(self): + data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, + {'id': 2, 'name': 'source-2'}, + {'id': 3, 'name': 'source-3'}, + {'id': 4, 'name': 'source-4'}]} + instance = OneToManyTarget.objects.get(pk=1) + serializer = self.Serializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'target-1') + + # Ensure source 4 is added, and everything else is as + # expected. + queryset = OneToManyTarget.objects.all() + serializer = self.Serializer(queryset, many=True) expected = [ - {'id': 1, 'name': 'target-1', 'nullable_source': {'id': 1, 'name': 'source-1', 'target': 1}}, - {'id': 2, 'name': 'target-2', 'nullable_source': None}, + {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, + {'id': 2, 'name': 'source-2'}, + {'id': 3, 'name': 'source-3'}, + {'id': 4, 'name': 'source-4'}]} + ] + self.assertEqual(serializer.data, expected) + + def test_one_to_many_create_with_invalid_data(self): + data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, + {'id': 2, 'name': 'source-2'}, + {'id': 3, 'name': 'source-3'}, + {'id': 4}]} + serializer = self.Serializer(data=data) + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'sources': [{}, {}, {}, {'name': ['This field is required.']}]}) + + def test_one_to_many_update(self): + data = {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'}, + {'id': 2, 'name': 'source-2'}, + {'id': 3, 'name': 'source-3'}]} + instance = OneToManyTarget.objects.get(pk=1) + serializer = self.Serializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'target-1-updated') + + # Ensure (target 1, source 1) are updated, + # and everything else is as expected. + queryset = OneToManyTarget.objects.all() + serializer = self.Serializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'}, + {'id': 2, 'name': 'source-2'}, + {'id': 3, 'name': 'source-3'}]} + + ] + self.assertEqual(serializer.data, expected) + + def test_one_to_many_delete(self): + data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, + {'id': 3, 'name': 'source-3'}]} + instance = OneToManyTarget.objects.get(pk=1) + serializer = self.Serializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + serializer.save() + + # Ensure source 2 is deleted, and everything else is as + # expected. + queryset = OneToManyTarget.objects.all() + serializer = self.Serializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, + {'id': 3, 'name': 'source-3'}]} + ] self.assertEqual(serializer.data, expected) diff --git a/rest_framework/tests/test_relations_pk.py b/rest_framework/tests/test_relations_pk.py index e2a1b815..3815afdd 100644 --- a/rest_framework/tests/test_relations_pk.py +++ b/rest_framework/tests/test_relations_pk.py @@ -283,6 +283,15 @@ class PKForeignKeyTests(TestCase): self.assertFalse(serializer.is_valid()) self.assertEqual(serializer.errors, {'target': ['This field is required.']}) + def test_foreign_key_with_empty(self): + """ + Regression test for #1072 + + https://github.com/tomchristie/django-rest-framework/issues/1072 + """ + serializer = NullableForeignKeySourceSerializer() + self.assertEqual(serializer.data['target'], None) + class PKNullableForeignKeyTests(TestCase): def setUp(self): diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py index 5fcccb74..e723f7d4 100644 --- a/rest_framework/tests/test_routers.py +++ b/rest_framework/tests/test_routers.py @@ -146,7 +146,7 @@ class TestTrailingSlashRemoved(TestCase): self.urls = self.router.urls def test_urls_can_have_trailing_slash_removed(self): - expected = ['^notes$', '^notes/(?P<pk>[^/]+)$'] + expected = ['^notes$', '^notes/(?P<pk>[^/.]+)$'] for idx in range(len(expected)): self.assertEqual(expected[idx], self.urls[idx].regex.pattern) diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py index 49d45fc2..48b8956b 100644 --- a/rest_framework/tests/test_testing.py +++ b/rest_framework/tests/test_testing.py @@ -17,8 +17,18 @@ def view(request): }) +@api_view(['GET', 'POST']) +def session_view(request): + active_session = request.session.get('active_session', False) + request.session['active_session'] = True + return Response({ + 'active_session': active_session + }) + + urlpatterns = patterns('', url(r'^view/$', view), + url(r'^session-view/$', session_view), ) @@ -46,6 +56,26 @@ class TestAPITestClient(TestCase): response = self.client.get('/view/') self.assertEqual(response.data['user'], 'example') + def test_force_authenticate_with_sessions(self): + """ + Setting `.force_authenticate()` forcibly authenticates each request. + """ + user = User.objects.create_user('example', 'example@example.com') + self.client.force_authenticate(user) + + # First request does not yet have an active session + response = self.client.get('/session-view/') + self.assertEqual(response.data['active_session'], False) + + # Subsequant requests have an active session + response = self.client.get('/session-view/') + self.assertEqual(response.data['active_session'], True) + + # Force authenticating as `None` should also logout the user session. + self.client.force_authenticate(None) + response = self.client.get('/session-view/') + self.assertEqual(response.data['active_session'], False) + def test_csrf_exempt_by_default(self): """ By default, the test client is CSRF exempt. diff --git a/rest_framework/tests/test_views.py b/rest_framework/tests/test_views.py index c0bec5ae..65c7e50e 100644 --- a/rest_framework/tests/test_views.py +++ b/rest_framework/tests/test_views.py @@ -32,6 +32,16 @@ def basic_view(request): return {'method': 'PATCH', 'data': request.DATA} +class ErrorView(APIView): + def get(self, request, *args, **kwargs): + raise Exception + + +@api_view(['GET']) +def error_view(request): + raise Exception + + def sanitise_json_error(error_dict): """ Exact contents of JSON error messages depend on the installed version @@ -99,3 +109,34 @@ class FunctionBasedViewIntegrationTests(TestCase): } self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(sanitise_json_error(response.data), expected) + + +class TestCustomExceptionHandler(TestCase): + def setUp(self): + self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER + + def exception_handler(exc): + return Response('Error!', status=status.HTTP_400_BAD_REQUEST) + + api_settings.EXCEPTION_HANDLER = exception_handler + + def tearDown(self): + api_settings.EXCEPTION_HANDLER = self.DEFAULT_HANDLER + + def test_class_based_view_exception_handler(self): + view = ErrorView.as_view() + + request = factory.get('/', content_type='application/json') + response = view(request) + expected = 'Error!' + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.data, expected) + + def test_function_based_view_exception_handler(self): + view = error_view + + request = factory.get('/', content_type='application/json') + response = view(request) + expected = 'Error!' + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.data, expected) |
