aboutsummaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
authorTom Christie2015-02-06 14:35:06 +0000
committerTom Christie2015-02-06 14:35:06 +0000
commit3dff9a4fe2952cf632ca7f4cd9ecf4221059ca91 (patch)
tree0649d42b20b875e97cb551b987644b61e7860e84 /tests
parentc06a82d0531f4cb290baacee196829c770913eaa (diff)
parent1f996128458570a909d13f15c3d739fb12111984 (diff)
downloaddjango-rest-framework-model-serializer-caching.tar.bz2
Resolve merge conflictmodel-serializer-caching
Diffstat (limited to 'tests')
-rw-r--r--tests/browsable_api/auth_urls.py1
-rw-r--r--tests/test_fields.py129
-rw-r--r--tests/test_filters.py17
-rw-r--r--tests/test_generics.py8
-rw-r--r--tests/test_htmlrenderer.py6
-rw-r--r--tests/test_metadata.py60
-rw-r--r--tests/test_model_serializer.py54
-rw-r--r--tests/test_multitable_inheritance.py4
-rw-r--r--tests/test_pagination.py1048
-rw-r--r--tests/test_parsers.py4
-rw-r--r--tests/test_relations.py35
-rw-r--r--tests/test_relations_hyperlink.py7
-rw-r--r--tests/test_renderers.py107
-rw-r--r--tests/test_routers.py151
-rw-r--r--tests/test_serializer.py79
-rw-r--r--tests/test_serializer_bulk_update.py4
-rw-r--r--tests/test_versioning.py62
-rw-r--r--tests/utils.py35
18 files changed, 1115 insertions, 696 deletions
diff --git a/tests/browsable_api/auth_urls.py b/tests/browsable_api/auth_urls.py
index bce7dcf9..97bc1036 100644
--- a/tests/browsable_api/auth_urls.py
+++ b/tests/browsable_api/auth_urls.py
@@ -3,6 +3,7 @@ from django.conf.urls import patterns, url, include
from .views import MockView
+
urlpatterns = patterns(
'',
(r'^$', MockView.as_view()),
diff --git a/tests/test_fields.py b/tests/test_fields.py
index 04c721d3..48ada780 100644
--- a/tests/test_fields.py
+++ b/tests/test_fields.py
@@ -4,6 +4,7 @@ from rest_framework import serializers
import datetime
import django
import pytest
+import uuid
# Tests for field keyword arguments and core functionality.
@@ -223,8 +224,8 @@ class MockHTMLDict(dict):
getlist = None
-class TestCharHTMLInput:
- def test_empty_html_checkbox(self):
+class TestHTMLInput:
+ def test_empty_html_charfield(self):
class TestSerializer(serializers.Serializer):
message = serializers.CharField(default='happy')
@@ -232,23 +233,31 @@ class TestCharHTMLInput:
assert serializer.is_valid()
assert serializer.validated_data == {'message': 'happy'}
- def test_empty_html_checkbox_allow_null(self):
+ def test_empty_html_charfield_allow_null(self):
class TestSerializer(serializers.Serializer):
message = serializers.CharField(allow_null=True)
- serializer = TestSerializer(data=MockHTMLDict())
+ serializer = TestSerializer(data=MockHTMLDict({'message': ''}))
assert serializer.is_valid()
assert serializer.validated_data == {'message': None}
- def test_empty_html_checkbox_allow_null_allow_blank(self):
+ def test_empty_html_datefield_allow_null(self):
+ class TestSerializer(serializers.Serializer):
+ expiry = serializers.DateField(allow_null=True)
+
+ serializer = TestSerializer(data=MockHTMLDict({'expiry': ''}))
+ assert serializer.is_valid()
+ assert serializer.validated_data == {'expiry': None}
+
+ def test_empty_html_charfield_allow_null_allow_blank(self):
class TestSerializer(serializers.Serializer):
message = serializers.CharField(allow_null=True, allow_blank=True)
- serializer = TestSerializer(data=MockHTMLDict({}))
+ serializer = TestSerializer(data=MockHTMLDict({'message': ''}))
assert serializer.is_valid()
assert serializer.validated_data == {'message': ''}
- def test_empty_html_required_false(self):
+ def test_empty_html_charfield_required_false(self):
class TestSerializer(serializers.Serializer):
message = serializers.CharField(required=False)
@@ -338,7 +347,7 @@ class TestBooleanField(FieldValues):
False: False,
}
invalid_inputs = {
- 'foo': ['`foo` is not a valid boolean.'],
+ 'foo': ['"foo" is not a valid boolean.'],
None: ['This field may not be null.']
}
outputs = {
@@ -368,7 +377,7 @@ class TestNullBooleanField(FieldValues):
None: None
}
invalid_inputs = {
- 'foo': ['`foo` is not a valid boolean.'],
+ 'foo': ['"foo" is not a valid boolean.'],
}
outputs = {
'true': True,
@@ -439,7 +448,7 @@ class TestSlugField(FieldValues):
'slug-99': 'slug-99',
}
invalid_inputs = {
- 'slug 99': ["Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens."]
+ 'slug 99': ['Enter a valid "slug" consisting of letters, numbers, underscores or hyphens.']
}
outputs = {}
field = serializers.SlugField()
@@ -459,6 +468,23 @@ class TestURLField(FieldValues):
field = serializers.URLField()
+class TestUUIDField(FieldValues):
+ """
+ Valid and invalid values for `UUIDField`.
+ """
+ valid_inputs = {
+ '825d7aeb-05a9-45b5-a5b7-05df87923cda': uuid.UUID('825d7aeb-05a9-45b5-a5b7-05df87923cda'),
+ '825d7aeb05a945b5a5b705df87923cda': uuid.UUID('825d7aeb-05a9-45b5-a5b7-05df87923cda')
+ }
+ invalid_inputs = {
+ '825d7aeb-05a9-45b5-a5b7': ['"825d7aeb-05a9-45b5-a5b7" is not a valid UUID.']
+ }
+ outputs = {
+ uuid.UUID('825d7aeb-05a9-45b5-a5b7-05df87923cda'): '825d7aeb-05a9-45b5-a5b7-05df87923cda'
+ }
+ field = serializers.UUIDField()
+
+
# Number types...
class TestIntegerField(FieldValues):
@@ -640,8 +666,8 @@ class TestDateField(FieldValues):
datetime.date(2001, 1, 1): datetime.date(2001, 1, 1),
}
invalid_inputs = {
- 'abc': ['Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]'],
- '2001-99-99': ['Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]'],
+ 'abc': ['Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]].'],
+ '2001-99-99': ['Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]].'],
datetime.datetime(2001, 1, 1, 12, 00): ['Expected a date but got a datetime.'],
}
outputs = {
@@ -658,7 +684,7 @@ class TestCustomInputFormatDateField(FieldValues):
'1 Jan 2001': datetime.date(2001, 1, 1),
}
invalid_inputs = {
- '2001-01-01': ['Date has wrong format. Use one of these formats instead: DD [Jan-Dec] YYYY']
+ '2001-01-01': ['Date has wrong format. Use one of these formats instead: DD [Jan-Dec] YYYY.']
}
outputs = {}
field = serializers.DateField(input_formats=['%d %b %Y'])
@@ -702,8 +728,8 @@ class TestDateTimeField(FieldValues):
'2001-01-01T14:00+01:00' if (django.VERSION > (1, 4)) else '2001-01-01T13:00Z': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC())
}
invalid_inputs = {
- 'abc': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]'],
- '2001-99-99T99:00': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]'],
+ 'abc': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z].'],
+ '2001-99-99T99:00': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z].'],
datetime.date(2001, 1, 1): ['Expected a datetime but got a date.'],
}
outputs = {
@@ -721,7 +747,7 @@ class TestCustomInputFormatDateTimeField(FieldValues):
'1:35pm, 1 Jan 2001': datetime.datetime(2001, 1, 1, 13, 35, tzinfo=timezone.UTC()),
}
invalid_inputs = {
- '2001-01-01T20:50': ['Datetime has wrong format. Use one of these formats instead: hh:mm[AM|PM], DD [Jan-Dec] YYYY']
+ '2001-01-01T20:50': ['Datetime has wrong format. Use one of these formats instead: hh:mm[AM|PM], DD [Jan-Dec] YYYY.']
}
outputs = {}
field = serializers.DateTimeField(default_timezone=timezone.UTC(), input_formats=['%I:%M%p, %d %b %Y'])
@@ -773,8 +799,8 @@ class TestTimeField(FieldValues):
datetime.time(13, 00): datetime.time(13, 00),
}
invalid_inputs = {
- 'abc': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]]'],
- '99:99': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]]'],
+ 'abc': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]].'],
+ '99:99': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]].'],
}
outputs = {
datetime.time(13, 00): '13:00:00'
@@ -790,7 +816,7 @@ class TestCustomInputFormatTimeField(FieldValues):
'1:00pm': datetime.time(13, 00),
}
invalid_inputs = {
- '13:00': ['Time has wrong format. Use one of these formats instead: hh:mm[AM|PM]'],
+ '13:00': ['Time has wrong format. Use one of these formats instead: hh:mm[AM|PM].'],
}
outputs = {}
field = serializers.TimeField(input_formats=['%I:%M%p'])
@@ -832,7 +858,7 @@ class TestChoiceField(FieldValues):
'good': 'good',
}
invalid_inputs = {
- 'amazing': ['`amazing` is not a valid choice.']
+ 'amazing': ['"amazing" is not a valid choice.']
}
outputs = {
'good': 'good',
@@ -872,8 +898,8 @@ class TestChoiceFieldWithType(FieldValues):
3: 3,
}
invalid_inputs = {
- 5: ['`5` is not a valid choice.'],
- 'abc': ['`abc` is not a valid choice.']
+ 5: ['"5" is not a valid choice.'],
+ 'abc': ['"abc" is not a valid choice.']
}
outputs = {
'1': 1,
@@ -899,7 +925,7 @@ class TestChoiceFieldWithListChoices(FieldValues):
'good': 'good',
}
invalid_inputs = {
- 'awful': ['`awful` is not a valid choice.']
+ 'awful': ['"awful" is not a valid choice.']
}
outputs = {
'good': 'good'
@@ -917,8 +943,8 @@ class TestMultipleChoiceField(FieldValues):
('aircon', 'manual'): set(['aircon', 'manual']),
}
invalid_inputs = {
- 'abc': ['Expected a list of items but got type `str`.'],
- ('aircon', 'incorrect'): ['`incorrect` is not a valid choice.']
+ 'abc': ['Expected a list of items but got type "str".'],
+ ('aircon', 'incorrect'): ['"incorrect" is not a valid choice.']
}
outputs = [
(['aircon', 'manual'], set(['aircon', 'manual']))
@@ -1021,14 +1047,14 @@ class TestValidImageField(FieldValues):
class TestListField(FieldValues):
"""
- Values for `ListField`.
+ Values for `ListField` with IntegerField as child.
"""
valid_inputs = [
([1, 2, 3], [1, 2, 3]),
(['1', '2', '3'], [1, 2, 3])
]
invalid_inputs = [
- ('not a list', ['Expected a list of items but got type `str`']),
+ ('not a list', ['Expected a list of items but got type "str".']),
([1, 2, 'error'], ['A valid integer is required.'])
]
outputs = [
@@ -1038,6 +1064,55 @@ class TestListField(FieldValues):
field = serializers.ListField(child=serializers.IntegerField())
+class TestUnvalidatedListField(FieldValues):
+ """
+ Values for `ListField` with no `child` argument.
+ """
+ valid_inputs = [
+ ([1, '2', True, [4, 5, 6]], [1, '2', True, [4, 5, 6]]),
+ ]
+ invalid_inputs = [
+ ('not a list', ['Expected a list of items but got type "str".']),
+ ]
+ outputs = [
+ ([1, '2', True, [4, 5, 6]], [1, '2', True, [4, 5, 6]]),
+ ]
+ field = serializers.ListField()
+
+
+class TestDictField(FieldValues):
+ """
+ Values for `ListField` with CharField as child.
+ """
+ valid_inputs = [
+ ({'a': 1, 'b': '2', 3: 3}, {'a': '1', 'b': '2', '3': '3'}),
+ ]
+ invalid_inputs = [
+ ({'a': 1, 'b': None}, ['This field may not be null.']),
+ ('not a dict', ['Expected a dictionary of items but got type "str".']),
+ ]
+ outputs = [
+ ({'a': 1, 'b': '2', 3: 3}, {'a': '1', 'b': '2', '3': '3'}),
+ ]
+ field = serializers.DictField(child=serializers.CharField())
+
+
+class TestUnvalidatedDictField(FieldValues):
+ """
+ Values for `ListField` with no `child` argument.
+ """
+ valid_inputs = [
+ ({'a': 1, 'b': [4, 5, 6], 1: 123}, {'a': 1, 'b': [4, 5, 6], '1': 123}),
+ ]
+ invalid_inputs = [
+ ('not a dict', ['Expected a dictionary of items but got type "str".']),
+ ]
+ outputs = [
+ ({'a': 1, 'b': [4, 5, 6]}, {'a': 1, 'b': [4, 5, 6]}),
+ ]
+ field = serializers.DictField()
+
+
# Tests for FieldField.
# ---------------------
diff --git a/tests/test_filters.py b/tests/test_filters.py
index dc84dcbd..355f02ce 100644
--- a/tests/test_filters.py
+++ b/tests/test_filters.py
@@ -5,13 +5,15 @@ from django.db import models
from django.conf.urls import patterns, url
from django.core.urlresolvers import reverse
from django.test import TestCase
+from django.test.utils import override_settings
from django.utils import unittest
from django.utils.dateparse import parse_date
+from django.utils.six.moves import reload_module
from rest_framework import generics, serializers, status, filters
from rest_framework.compat import django_filters
from rest_framework.test import APIRequestFactory
from .models import BaseFilterableItem, FilterableItem, BasicModel
-from .utils import temporary_setting
+
factory = APIRequestFactory()
@@ -404,7 +406,9 @@ class SearchFilterTests(TestCase):
)
def test_search_with_nonstandard_search_param(self):
- with temporary_setting('SEARCH_PARAM', 'query', module=filters):
+ with override_settings(REST_FRAMEWORK={'SEARCH_PARAM': 'query'}):
+ reload_module(filters)
+
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
@@ -422,6 +426,8 @@ class SearchFilterTests(TestCase):
]
)
+ reload_module(filters)
+
class OrderingFilterModel(models.Model):
title = models.CharField(max_length=20)
@@ -467,6 +473,7 @@ class DjangoFilterOrderingTests(TestCase):
for d in data:
DjangoFilterOrderingModel.objects.create(**d)
+ @unittest.skipUnless(django_filters, 'django-filter not installed')
def test_default_ordering(self):
class DjangoFilterOrderingView(generics.ListAPIView):
serializer_class = DjangoFilterOrderingSerializer
@@ -641,7 +648,9 @@ class OrderingFilterTests(TestCase):
)
def test_ordering_with_nonstandard_ordering_param(self):
- with temporary_setting('ORDERING_PARAM', 'order', filters):
+ with override_settings(REST_FRAMEWORK={'ORDERING_PARAM': 'order'}):
+ reload_module(filters)
+
class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
@@ -661,6 +670,8 @@ class OrderingFilterTests(TestCase):
]
)
+ reload_module(filters)
+
class SensitiveOrderingFilterModel(models.Model):
username = models.CharField(max_length=20)
diff --git a/tests/test_generics.py b/tests/test_generics.py
index 94023c30..88e792ce 100644
--- a/tests/test_generics.py
+++ b/tests/test_generics.py
@@ -117,7 +117,7 @@ class TestRootView(TestCase):
with self.assertNumQueries(0):
response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
- self.assertEqual(response.data, {"detail": "Method 'PUT' not allowed."})
+ self.assertEqual(response.data, {"detail": 'Method "PUT" not allowed.'})
def test_delete_root_view(self):
"""
@@ -127,7 +127,7 @@ class TestRootView(TestCase):
with self.assertNumQueries(0):
response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
- self.assertEqual(response.data, {"detail": "Method 'DELETE' not allowed."})
+ self.assertEqual(response.data, {"detail": 'Method "DELETE" not allowed.'})
def test_post_cannot_set_id(self):
"""
@@ -181,7 +181,7 @@ class TestInstanceView(TestCase):
with self.assertNumQueries(0):
response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
- self.assertEqual(response.data, {"detail": "Method 'POST' not allowed."})
+ self.assertEqual(response.data, {"detail": 'Method "POST" not allowed.'})
def test_put_instance_view(self):
"""
@@ -483,7 +483,7 @@ class TestFilterBackendAppliedToViews(TestCase):
request = factory.get('/1')
response = instance_view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
- self.assertEqual(response.data, {'detail': 'Not found'})
+ self.assertEqual(response.data, {'detail': 'Not found.'})
def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self):
"""
diff --git a/tests/test_htmlrenderer.py b/tests/test_htmlrenderer.py
index 2edc6b4b..a33b832f 100644
--- a/tests/test_htmlrenderer.py
+++ b/tests/test_htmlrenderer.py
@@ -56,7 +56,13 @@ class TemplateHTMLRendererTests(TestCase):
return Template("example: {{ object }}")
raise TemplateDoesNotExist(template_name)
+ def select_template(template_name_list, dirs=None, using=None):
+ if template_name_list == ['example.html']:
+ return Template("example: {{ object }}")
+ raise TemplateDoesNotExist(template_name_list[0])
+
django.template.loader.get_template = get_template
+ django.template.loader.select_template = select_template
def tearDown(self):
"""
diff --git a/tests/test_metadata.py b/tests/test_metadata.py
index 5ff59c72..5031c0f3 100644
--- a/tests/test_metadata.py
+++ b/tests/test_metadata.py
@@ -1,9 +1,8 @@
from __future__ import unicode_literals
-
-from rest_framework import exceptions, serializers, views
+from rest_framework import exceptions, serializers, status, views, versioning
from rest_framework.request import Request
+from rest_framework.renderers import BrowsableAPIRenderer
from rest_framework.test import APIRequestFactory
-import pytest
request = Request(APIRequestFactory().options('/'))
@@ -17,7 +16,8 @@ class TestMetadata:
"""Example view."""
pass
- response = ExampleView().options(request=request)
+ view = ExampleView.as_view()
+ response = view(request=request)
expected = {
'name': 'Example',
'description': 'Example view.',
@@ -31,7 +31,7 @@ class TestMetadata:
'multipart/form-data'
]
}
- assert response.status_code == 200
+ assert response.status_code == status.HTTP_200_OK
assert response.data == expected
def test_none_metadata(self):
@@ -42,8 +42,10 @@ class TestMetadata:
class ExampleView(views.APIView):
metadata_class = None
- with pytest.raises(exceptions.MethodNotAllowed):
- ExampleView().options(request=request)
+ view = ExampleView.as_view()
+ response = view(request=request)
+ assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
+ assert response.data == {'detail': 'Method "OPTIONS" not allowed.'}
def test_actions(self):
"""
@@ -63,7 +65,8 @@ class TestMetadata:
def get_serializer(self):
return ExampleSerializer()
- response = ExampleView().options(request=request)
+ view = ExampleView.as_view()
+ response = view(request=request)
expected = {
'name': 'Example',
'description': 'Example view.',
@@ -104,7 +107,7 @@ class TestMetadata:
}
}
}
- assert response.status_code == 200
+ assert response.status_code == status.HTTP_200_OK
assert response.data == expected
def test_global_permissions(self):
@@ -132,8 +135,9 @@ class TestMetadata:
if request.method == 'POST':
raise exceptions.PermissionDenied()
- response = ExampleView().options(request=request)
- assert response.status_code == 200
+ view = ExampleView.as_view()
+ response = view(request=request)
+ assert response.status_code == status.HTTP_200_OK
assert list(response.data['actions'].keys()) == ['PUT']
def test_object_permissions(self):
@@ -161,6 +165,36 @@ class TestMetadata:
if self.request.method == 'PUT':
raise exceptions.PermissionDenied()
- response = ExampleView().options(request=request)
- assert response.status_code == 200
+ view = ExampleView.as_view()
+ response = view(request=request)
+ assert response.status_code == status.HTTP_200_OK
assert list(response.data['actions'].keys()) == ['POST']
+
+ def test_bug_2455_clone_request(self):
+ class ExampleView(views.APIView):
+ renderer_classes = (BrowsableAPIRenderer,)
+
+ def post(self, request):
+ pass
+
+ def get_serializer(self):
+ assert hasattr(self.request, 'version')
+ return serializers.Serializer()
+
+ view = ExampleView.as_view()
+ view(request=request)
+
+ def test_bug_2477_clone_request(self):
+ class ExampleView(views.APIView):
+ renderer_classes = (BrowsableAPIRenderer,)
+
+ def post(self, request):
+ pass
+
+ def get_serializer(self):
+ assert hasattr(self.request, 'versioning_scheme')
+ return serializers.Serializer()
+
+ scheme = versioning.QueryParameterVersioning
+ view = ExampleView.as_view(versioning_class=scheme)
+ view(request=request)
diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py
index 5c56c8db..bce2008a 100644
--- a/tests/test_model_serializer.py
+++ b/tests/test_model_serializer.py
@@ -5,11 +5,14 @@ shortcuts for automatically creating serializers based on a given model class.
These tests deal with ensuring that we correctly map the model fields onto
an appropriate set of serializer fields for each case.
"""
+from __future__ import unicode_literals
from django.core.exceptions import ImproperlyConfigured
from django.core.validators import MaxValueValidator, MinValueValidator, MinLengthValidator
from django.db import models
from django.test import TestCase
+from django.utils import six
from rest_framework import serializers
+from rest_framework.compat import unicode_repr
def dedent(blocktext):
@@ -119,12 +122,12 @@ class TestRegularFieldMappings(TestCase):
positive_small_integer_field = IntegerField()
slug_field = SlugField(max_length=100)
small_integer_field = IntegerField()
- text_field = CharField(style={'type': 'textarea'})
+ text_field = CharField(style={'base_template': 'textarea.html'})
time_field = TimeField()
url_field = URLField(max_length=100)
custom_field = ModelField(model_field=<tests.test_model_serializer.CustomField: custom_field>)
""")
- self.assertEqual(repr(TestSerializer()), expected)
+ self.assertEqual(unicode_repr(TestSerializer()), expected)
def test_field_options(self):
class TestSerializer(serializers.ModelSerializer):
@@ -142,7 +145,14 @@ class TestRegularFieldMappings(TestCase):
descriptive_field = IntegerField(help_text='Some help text', label='A label')
choices_field = ChoiceField(choices=[('red', 'Red'), ('blue', 'Blue'), ('green', 'Green')])
""")
- self.assertEqual(repr(TestSerializer()), expected)
+ if six.PY2:
+ # This particular case is too awkward to resolve fully across
+ # both py2 and py3.
+ expected = expected.replace(
+ "('red', 'Red'), ('blue', 'Blue'), ('green', 'Green')",
+ "(u'red', u'Red'), (u'blue', u'Blue'), (u'green', u'Green')"
+ )
+ self.assertEqual(unicode_repr(TestSerializer()), expected)
def test_method_field(self):
"""
@@ -206,7 +216,7 @@ class TestRegularFieldMappings(TestCase):
with self.assertRaises(ImproperlyConfigured) as excinfo:
TestSerializer().fields
- expected = 'Field name `invalid` is not valid for model `ModelBase`.'
+ expected = 'Field name `invalid` is not valid for model `RegularFieldsModel`.'
assert str(excinfo.exception) == expected
def test_missing_field(self):
@@ -229,6 +239,26 @@ class TestRegularFieldMappings(TestCase):
)
assert str(excinfo.exception) == expected
+ def test_missing_superclass_field(self):
+ """
+ Fields that have been declared on a parent of the serializer class may
+ be excluded from the `Meta.fields` option.
+ """
+ class TestSerializer(serializers.ModelSerializer):
+ missing = serializers.ReadOnlyField()
+
+ class Meta:
+ model = RegularFieldsModel
+
+ class ChildSerializer(TestSerializer):
+ missing = serializers.ReadOnlyField()
+
+ class Meta:
+ model = RegularFieldsModel
+ fields = ('auto_field',)
+
+ ChildSerializer().fields
+
# Tests for relational field mappings.
# ------------------------------------
@@ -276,7 +306,7 @@ class TestRelationalFieldMappings(TestCase):
many_to_many = PrimaryKeyRelatedField(many=True, queryset=ManyToManyTargetModel.objects.all())
through = PrimaryKeyRelatedField(many=True, read_only=True)
""")
- self.assertEqual(repr(TestSerializer()), expected)
+ self.assertEqual(unicode_repr(TestSerializer()), expected)
def test_nested_relations(self):
class TestSerializer(serializers.ModelSerializer):
@@ -300,7 +330,7 @@ class TestRelationalFieldMappings(TestCase):
id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100)
""")
- self.assertEqual(repr(TestSerializer()), expected)
+ self.assertEqual(unicode_repr(TestSerializer()), expected)
def test_hyperlinked_relations(self):
class TestSerializer(serializers.HyperlinkedModelSerializer):
@@ -315,7 +345,7 @@ class TestRelationalFieldMappings(TestCase):
many_to_many = HyperlinkedRelatedField(many=True, queryset=ManyToManyTargetModel.objects.all(), view_name='manytomanytargetmodel-detail')
through = HyperlinkedRelatedField(many=True, read_only=True, view_name='throughtargetmodel-detail')
""")
- self.assertEqual(repr(TestSerializer()), expected)
+ self.assertEqual(unicode_repr(TestSerializer()), expected)
def test_nested_hyperlinked_relations(self):
class TestSerializer(serializers.HyperlinkedModelSerializer):
@@ -339,7 +369,7 @@ class TestRelationalFieldMappings(TestCase):
url = HyperlinkedIdentityField(view_name='throughtargetmodel-detail')
name = CharField(max_length=100)
""")
- self.assertEqual(repr(TestSerializer()), expected)
+ self.assertEqual(unicode_repr(TestSerializer()), expected)
def test_pk_reverse_foreign_key(self):
class TestSerializer(serializers.ModelSerializer):
@@ -353,7 +383,7 @@ class TestRelationalFieldMappings(TestCase):
name = CharField(max_length=100)
reverse_foreign_key = PrimaryKeyRelatedField(many=True, queryset=RelationalModel.objects.all())
""")
- self.assertEqual(repr(TestSerializer()), expected)
+ self.assertEqual(unicode_repr(TestSerializer()), expected)
def test_pk_reverse_one_to_one(self):
class TestSerializer(serializers.ModelSerializer):
@@ -367,7 +397,7 @@ class TestRelationalFieldMappings(TestCase):
name = CharField(max_length=100)
reverse_one_to_one = PrimaryKeyRelatedField(queryset=RelationalModel.objects.all())
""")
- self.assertEqual(repr(TestSerializer()), expected)
+ self.assertEqual(unicode_repr(TestSerializer()), expected)
def test_pk_reverse_many_to_many(self):
class TestSerializer(serializers.ModelSerializer):
@@ -381,7 +411,7 @@ class TestRelationalFieldMappings(TestCase):
name = CharField(max_length=100)
reverse_many_to_many = PrimaryKeyRelatedField(many=True, queryset=RelationalModel.objects.all())
""")
- self.assertEqual(repr(TestSerializer()), expected)
+ self.assertEqual(unicode_repr(TestSerializer()), expected)
def test_pk_reverse_through(self):
class TestSerializer(serializers.ModelSerializer):
@@ -395,7 +425,7 @@ class TestRelationalFieldMappings(TestCase):
name = CharField(max_length=100)
reverse_through = PrimaryKeyRelatedField(many=True, read_only=True)
""")
- self.assertEqual(repr(TestSerializer()), expected)
+ self.assertEqual(unicode_repr(TestSerializer()), expected)
class TestIntegration(TestCase):
diff --git a/tests/test_multitable_inheritance.py b/tests/test_multitable_inheritance.py
index e1b40cc7..15627e1d 100644
--- a/tests/test_multitable_inheritance.py
+++ b/tests/test_multitable_inheritance.py
@@ -48,8 +48,8 @@ class InheritedModelSerializationTests(TestCase):
Assert that a model with a onetoone field that is the primary key is
not treated like a derived model
"""
- parent = ParentModel(name1='parent name')
- associate = AssociatedModel(name='hello', ref=parent)
+ parent = ParentModel.objects.create(name1='parent name')
+ associate = AssociatedModel.objects.create(name='hello', ref=parent)
serializer = AssociatedModelSerializer(associate)
self.assertEqual(set(serializer.data.keys()),
set(['name', 'ref']))
diff --git a/tests/test_pagination.py b/tests/test_pagination.py
index 1fd9cf9c..13bfb627 100644
--- a/tests/test_pagination.py
+++ b/tests/test_pagination.py
@@ -1,553 +1,671 @@
+# coding: utf-8
from __future__ import unicode_literals
-import datetime
-from decimal import Decimal
-from django.core.paginator import Paginator
-from django.test import TestCase
-from django.utils import unittest
-from rest_framework import generics, serializers, status, pagination, filters
-from rest_framework.compat import django_filters
+from rest_framework import exceptions, generics, pagination, serializers, status, filters
+from rest_framework.request import Request
+from rest_framework.pagination import PageLink, PAGE_BREAK
from rest_framework.test import APIRequestFactory
-from .models import BasicModel, FilterableItem
+import pytest
factory = APIRequestFactory()
-# Helper function to split arguments out of an url
-def split_arguments_from_url(url):
- if '?' not in url:
- return url
+class TestPaginationIntegration:
+ """
+ Integration tests.
+ """
- path, args = url.split('?')
- args = dict(r.split('=') for r in args.split('&'))
- return path, args
+ def setup(self):
+ class PassThroughSerializer(serializers.BaseSerializer):
+ def to_representation(self, item):
+ return item
+ class EvenItemsOnly(filters.BaseFilterBackend):
+ def filter_queryset(self, request, queryset, view):
+ return [item for item in queryset if item % 2 == 0]
+
+ class BasicPagination(pagination.PageNumberPagination):
+ paginate_by = 5
+ paginate_by_param = 'page_size'
+ max_paginate_by = 20
+
+ self.view = generics.ListAPIView.as_view(
+ serializer_class=PassThroughSerializer,
+ queryset=range(1, 101),
+ filter_backends=[EvenItemsOnly],
+ pagination_class=BasicPagination
+ )
-class BasicSerializer(serializers.ModelSerializer):
- class Meta:
- model = BasicModel
+ def test_filtered_items_are_paginated(self):
+ request = factory.get('/', {'page': 2})
+ response = self.view(request)
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == {
+ 'results': [12, 14, 16, 18, 20],
+ 'previous': 'http://testserver/',
+ 'next': 'http://testserver/?page=3',
+ 'count': 50
+ }
+ def test_setting_page_size(self):
+ """
+ When 'paginate_by_param' is set, the client may choose a page size.
+ """
+ request = factory.get('/', {'page_size': 10})
+ response = self.view(request)
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == {
+ 'results': [2, 4, 6, 8, 10, 12, 14, 16, 18, 20],
+ 'previous': None,
+ 'next': 'http://testserver/?page=2&page_size=10',
+ 'count': 50
+ }
-class FilterableItemSerializer(serializers.ModelSerializer):
- class Meta:
- model = FilterableItem
+ def test_setting_page_size_over_maximum(self):
+ """
+ When page_size parameter exceeds maxiumum allowable,
+ then it should be capped to the maxiumum.
+ """
+ request = factory.get('/', {'page_size': 1000})
+ response = self.view(request)
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == {
+ 'results': [
+ 2, 4, 6, 8, 10, 12, 14, 16, 18, 20,
+ 22, 24, 26, 28, 30, 32, 34, 36, 38, 40
+ ],
+ 'previous': None,
+ 'next': 'http://testserver/?page=2&page_size=1000',
+ 'count': 50
+ }
+ def test_setting_page_size_to_zero(self):
+ """
+ When page_size parameter is invalid it should return to the default.
+ """
+ request = factory.get('/', {'page_size': 0})
+ response = self.view(request)
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == {
+ 'results': [2, 4, 6, 8, 10],
+ 'previous': None,
+ 'next': 'http://testserver/?page=2&page_size=0',
+ 'count': 50
+ }
-class RootView(generics.ListCreateAPIView):
- """
- Example description for OPTIONS.
- """
- queryset = BasicModel.objects.all()
- serializer_class = BasicSerializer
- paginate_by = 10
+ def test_additional_query_params_are_preserved(self):
+ request = factory.get('/', {'page': 2, 'filter': 'even'})
+ response = self.view(request)
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == {
+ 'results': [12, 14, 16, 18, 20],
+ 'previous': 'http://testserver/?filter=even',
+ 'next': 'http://testserver/?filter=even&page=3',
+ 'count': 50
+ }
+ def test_404_not_found_for_zero_page(self):
+ request = factory.get('/', {'page': '0'})
+ response = self.view(request)
+ assert response.status_code == status.HTTP_404_NOT_FOUND
+ assert response.data == {
+ 'detail': 'Invalid page "0": That page number is less than 1.'
+ }
-class DefaultPageSizeKwargView(generics.ListAPIView):
- """
- View for testing default paginate_by_param usage
- """
- queryset = BasicModel.objects.all()
- serializer_class = BasicSerializer
+ def test_404_not_found_for_invalid_page(self):
+ request = factory.get('/', {'page': 'invalid'})
+ response = self.view(request)
+ assert response.status_code == status.HTTP_404_NOT_FOUND
+ assert response.data == {
+ 'detail': 'Invalid page "invalid": That page number is not an integer.'
+ }
-class PaginateByParamView(generics.ListAPIView):
+class TestPaginationDisabledIntegration:
"""
- View for testing custom paginate_by_param usage
+ Integration tests for disabled pagination.
"""
- queryset = BasicModel.objects.all()
- serializer_class = BasicSerializer
- paginate_by_param = 'page_size'
+ def setup(self):
+ class PassThroughSerializer(serializers.BaseSerializer):
+ def to_representation(self, item):
+ return item
-class MaxPaginateByView(generics.ListAPIView):
- """
- View for testing custom max_paginate_by usage
- """
- queryset = BasicModel.objects.all()
- serializer_class = BasicSerializer
- paginate_by = 3
- max_paginate_by = 5
- paginate_by_param = 'page_size'
+ self.view = generics.ListAPIView.as_view(
+ serializer_class=PassThroughSerializer,
+ queryset=range(1, 101),
+ pagination_class=None
+ )
+
+ def test_unpaginated_list(self):
+ request = factory.get('/', {'page': 2})
+ response = self.view(request)
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == list(range(1, 101))
-class IntegrationTestPagination(TestCase):
+class TestDeprecatedStylePagination:
"""
- Integration tests for paginated list views.
+ Integration tests for deprecated style of setting pagination
+ attributes on the view.
"""
- def setUp(self):
- """
- Create 26 BasicModel instances.
- """
- for char in 'abcdefghijklmnopqrstuvwxyz':
- BasicModel(text=char * 3).save()
- self.objects = BasicModel.objects
- self.data = [
- {'id': obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
- self.view = RootView.as_view()
-
- def test_get_paginated_root_view(self):
- """
- GET requests to paginated ListCreateAPIView should return paginated results.
- """
- request = factory.get('/')
- # Note: Database queries are a `SELECT COUNT`, and `SELECT <fields>`
- with self.assertNumQueries(2):
- response = self.view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 26)
- self.assertEqual(response.data['results'], self.data[:10])
- self.assertNotEqual(response.data['next'], None)
- self.assertEqual(response.data['previous'], None)
-
- request = factory.get(*split_arguments_from_url(response.data['next']))
- with self.assertNumQueries(2):
- response = self.view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 26)
- self.assertEqual(response.data['results'], self.data[10:20])
- self.assertNotEqual(response.data['next'], None)
- self.assertNotEqual(response.data['previous'], None)
-
- request = factory.get(*split_arguments_from_url(response.data['next']))
- with self.assertNumQueries(2):
- response = self.view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 26)
- self.assertEqual(response.data['results'], self.data[20:])
- self.assertEqual(response.data['next'], None)
- self.assertNotEqual(response.data['previous'], None)
-
-
-class IntegrationTestPaginationAndFiltering(TestCase):
-
- def setUp(self):
- """
- Create 50 FilterableItem instances.
- """
- base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
- for i in range(26):
- text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
- decimal = base_data[1] + i
- date = base_data[2] - datetime.timedelta(days=i * 2)
- FilterableItem(text=text, decimal=decimal, date=date).save()
-
- self.objects = FilterableItem.objects
- self.data = [
- {'id': obj.id, 'text': obj.text, 'decimal': str(obj.decimal), 'date': obj.date.isoformat()}
- for obj in self.objects.all()
- ]
-
- @unittest.skipUnless(django_filters, 'django-filter not installed')
- def test_get_django_filter_paginated_filtered_root_view(self):
- """
- GET requests to paginated filtered ListCreateAPIView should return
- paginated results. The next and previous links should preserve the
- filtered parameters.
- """
- class DecimalFilter(django_filters.FilterSet):
- decimal = django_filters.NumberFilter(lookup_type='lt')
-
- class Meta:
- model = FilterableItem
- fields = ['text', 'decimal', 'date']
-
- class FilterFieldsRootView(generics.ListCreateAPIView):
- queryset = FilterableItem.objects.all()
- serializer_class = FilterableItemSerializer
- paginate_by = 10
- filter_class = DecimalFilter
- filter_backends = (filters.DjangoFilterBackend,)
-
- view = FilterFieldsRootView.as_view()
-
- EXPECTED_NUM_QUERIES = 2
-
- request = factory.get('/', {'decimal': '15.20'})
- with self.assertNumQueries(EXPECTED_NUM_QUERIES):
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 15)
- self.assertEqual(response.data['results'], self.data[:10])
- self.assertNotEqual(response.data['next'], None)
- self.assertEqual(response.data['previous'], None)
-
- request = factory.get(*split_arguments_from_url(response.data['next']))
- with self.assertNumQueries(EXPECTED_NUM_QUERIES):
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 15)
- self.assertEqual(response.data['results'], self.data[10:15])
- self.assertEqual(response.data['next'], None)
- self.assertNotEqual(response.data['previous'], None)
-
- request = factory.get(*split_arguments_from_url(response.data['previous']))
- with self.assertNumQueries(EXPECTED_NUM_QUERIES):
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 15)
- self.assertEqual(response.data['results'], self.data[:10])
- self.assertNotEqual(response.data['next'], None)
- self.assertEqual(response.data['previous'], None)
-
- def test_get_basic_paginated_filtered_root_view(self):
- """
- Same as `test_get_django_filter_paginated_filtered_root_view`,
- except using a custom filter backend instead of the django-filter
- backend,
- """
+ def setup(self):
+ class PassThroughSerializer(serializers.BaseSerializer):
+ def to_representation(self, item):
+ return item
- class DecimalFilterBackend(filters.BaseFilterBackend):
- def filter_queryset(self, request, queryset, view):
- return queryset.filter(decimal__lt=Decimal(request.GET['decimal']))
-
- class BasicFilterFieldsRootView(generics.ListCreateAPIView):
- queryset = FilterableItem.objects.all()
- serializer_class = FilterableItemSerializer
- paginate_by = 10
- filter_backends = (DecimalFilterBackend,)
-
- view = BasicFilterFieldsRootView.as_view()
-
- request = factory.get('/', {'decimal': '15.20'})
- with self.assertNumQueries(2):
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 15)
- self.assertEqual(response.data['results'], self.data[:10])
- self.assertNotEqual(response.data['next'], None)
- self.assertEqual(response.data['previous'], None)
-
- request = factory.get(*split_arguments_from_url(response.data['next']))
- with self.assertNumQueries(2):
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 15)
- self.assertEqual(response.data['results'], self.data[10:15])
- self.assertEqual(response.data['next'], None)
- self.assertNotEqual(response.data['previous'], None)
-
- request = factory.get(*split_arguments_from_url(response.data['previous']))
- with self.assertNumQueries(2):
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 15)
- self.assertEqual(response.data['results'], self.data[:10])
- self.assertNotEqual(response.data['next'], None)
- self.assertEqual(response.data['previous'], None)
-
-
-class PassOnContextPaginationSerializer(pagination.PaginationSerializer):
- class Meta:
- object_serializer_class = serializers.Serializer
-
-
-class UnitTestPagination(TestCase):
- """
- Unit tests for pagination of primitive objects.
- """
+ class ExampleView(generics.ListAPIView):
+ serializer_class = PassThroughSerializer
+ queryset = range(1, 101)
+ pagination_class = pagination.PageNumberPagination
+ paginate_by = 20
+ page_query_param = 'page_number'
- def setUp(self):
- self.objects = [char * 3 for char in 'abcdefghijklmnopqrstuvwxyz']
- paginator = Paginator(self.objects, 10)
- self.first_page = paginator.page(1)
- self.last_page = paginator.page(3)
-
- def test_native_pagination(self):
- serializer = pagination.PaginationSerializer(self.first_page)
- self.assertEqual(serializer.data['count'], 26)
- self.assertEqual(serializer.data['next'], '?page=2')
- self.assertEqual(serializer.data['previous'], None)
- self.assertEqual(serializer.data['results'], self.objects[:10])
-
- serializer = pagination.PaginationSerializer(self.last_page)
- self.assertEqual(serializer.data['count'], 26)
- self.assertEqual(serializer.data['next'], None)
- self.assertEqual(serializer.data['previous'], '?page=2')
- self.assertEqual(serializer.data['results'], self.objects[20:])
-
- def test_context_available_in_result(self):
- """
- Ensure context gets passed through to the object serializer.
- """
- serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'})
- serializer.data
- results = serializer.fields[serializer.results_field]
- self.assertEqual(serializer.context, results.context)
+ self.view = ExampleView.as_view()
+
+ def test_paginate_by_attribute_on_view(self):
+ request = factory.get('/?page_number=2')
+ response = self.view(request)
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == {
+ 'results': [
+ 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
+ 31, 32, 33, 34, 35, 36, 37, 38, 39, 40
+ ],
+ 'previous': 'http://testserver/',
+ 'next': 'http://testserver/?page_number=3',
+ 'count': 100
+ }
-class TestUnpaginated(TestCase):
+class TestPageNumberPagination:
"""
- Tests for list views without pagination.
+ Unit tests for `pagination.PageNumberPagination`.
"""
- def setUp(self):
- """
- Create 13 BasicModel instances.
- """
- for i in range(13):
- BasicModel(text=i).save()
- self.objects = BasicModel.objects
- self.data = [
- {'id': obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
- self.view = DefaultPageSizeKwargView.as_view()
-
- def test_unpaginated(self):
- """
- Tests the default page size for this view.
- no page size --> no limit --> no meta data
- """
- request = factory.get('/')
- response = self.view(request)
- self.assertEqual(response.data, self.data)
+ def setup(self):
+ class ExamplePagination(pagination.PageNumberPagination):
+ paginate_by = 5
+ self.pagination = ExamplePagination()
+ self.queryset = range(1, 101)
+
+ def paginate_queryset(self, request):
+ return list(self.pagination.paginate_queryset(self.queryset, request))
+
+ def get_paginated_content(self, queryset):
+ response = self.pagination.get_paginated_response(queryset)
+ return response.data
+
+ def get_html_context(self):
+ return self.pagination.get_html_context()
+
+ def test_no_page_number(self):
+ request = Request(factory.get('/'))
+ queryset = self.paginate_queryset(request)
+ content = self.get_paginated_content(queryset)
+ context = self.get_html_context()
+ assert queryset == [1, 2, 3, 4, 5]
+ assert content == {
+ 'results': [1, 2, 3, 4, 5],
+ 'previous': None,
+ 'next': 'http://testserver/?page=2',
+ 'count': 100
+ }
+ assert context == {
+ 'previous_url': None,
+ 'next_url': 'http://testserver/?page=2',
+ 'page_links': [
+ PageLink('http://testserver/', 1, True, False),
+ PageLink('http://testserver/?page=2', 2, False, False),
+ PageLink('http://testserver/?page=3', 3, False, False),
+ PAGE_BREAK,
+ PageLink('http://testserver/?page=20', 20, False, False),
+ ]
+ }
+ assert self.pagination.display_page_controls
+ assert isinstance(self.pagination.to_html(), type(''))
+
+ def test_second_page(self):
+ request = Request(factory.get('/', {'page': 2}))
+ queryset = self.paginate_queryset(request)
+ content = self.get_paginated_content(queryset)
+ context = self.get_html_context()
+ assert queryset == [6, 7, 8, 9, 10]
+ assert content == {
+ 'results': [6, 7, 8, 9, 10],
+ 'previous': 'http://testserver/',
+ 'next': 'http://testserver/?page=3',
+ 'count': 100
+ }
+ assert context == {
+ 'previous_url': 'http://testserver/',
+ 'next_url': 'http://testserver/?page=3',
+ 'page_links': [
+ PageLink('http://testserver/', 1, False, False),
+ PageLink('http://testserver/?page=2', 2, True, False),
+ PageLink('http://testserver/?page=3', 3, False, False),
+ PAGE_BREAK,
+ PageLink('http://testserver/?page=20', 20, False, False),
+ ]
+ }
+
+ def test_last_page(self):
+ request = Request(factory.get('/', {'page': 'last'}))
+ queryset = self.paginate_queryset(request)
+ content = self.get_paginated_content(queryset)
+ context = self.get_html_context()
+ assert queryset == [96, 97, 98, 99, 100]
+ assert content == {
+ 'results': [96, 97, 98, 99, 100],
+ 'previous': 'http://testserver/?page=19',
+ 'next': None,
+ 'count': 100
+ }
+ assert context == {
+ 'previous_url': 'http://testserver/?page=19',
+ 'next_url': None,
+ 'page_links': [
+ PageLink('http://testserver/', 1, False, False),
+ PAGE_BREAK,
+ PageLink('http://testserver/?page=18', 18, False, False),
+ PageLink('http://testserver/?page=19', 19, False, False),
+ PageLink('http://testserver/?page=20', 20, True, False),
+ ]
+ }
+
+ def test_invalid_page(self):
+ request = Request(factory.get('/', {'page': 'invalid'}))
+ with pytest.raises(exceptions.NotFound):
+ self.paginate_queryset(request)
-class TestCustomPaginateByParam(TestCase):
+class TestLimitOffset:
"""
- Tests for list views with default page size kwarg
+ Unit tests for `pagination.LimitOffsetPagination`.
"""
- def setUp(self):
+ def setup(self):
+ class ExamplePagination(pagination.LimitOffsetPagination):
+ default_limit = 10
+ self.pagination = ExamplePagination()
+ self.queryset = range(1, 101)
+
+ def paginate_queryset(self, request):
+ return list(self.pagination.paginate_queryset(self.queryset, request))
+
+ def get_paginated_content(self, queryset):
+ response = self.pagination.get_paginated_response(queryset)
+ return response.data
+
+ def get_html_context(self):
+ return self.pagination.get_html_context()
+
+ def test_no_offset(self):
+ request = Request(factory.get('/', {'limit': 5}))
+ queryset = self.paginate_queryset(request)
+ content = self.get_paginated_content(queryset)
+ context = self.get_html_context()
+ assert queryset == [1, 2, 3, 4, 5]
+ assert content == {
+ 'results': [1, 2, 3, 4, 5],
+ 'previous': None,
+ 'next': 'http://testserver/?limit=5&offset=5',
+ 'count': 100
+ }
+ assert context == {
+ 'previous_url': None,
+ 'next_url': 'http://testserver/?limit=5&offset=5',
+ 'page_links': [
+ PageLink('http://testserver/?limit=5', 1, True, False),
+ PageLink('http://testserver/?limit=5&offset=5', 2, False, False),
+ PageLink('http://testserver/?limit=5&offset=10', 3, False, False),
+ PAGE_BREAK,
+ PageLink('http://testserver/?limit=5&offset=95', 20, False, False),
+ ]
+ }
+ assert self.pagination.display_page_controls
+ assert isinstance(self.pagination.to_html(), type(''))
+
+ def test_single_offset(self):
"""
- Create 13 BasicModel instances.
+ When the offset is not a multiple of the limit we get some edge cases:
+ * The first page should still be offset zero.
+ * We may end up displaying an extra page in the pagination control.
"""
- for i in range(13):
- BasicModel(text=i).save()
- self.objects = BasicModel.objects
- self.data = [
- {'id': obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
- self.view = PaginateByParamView.as_view()
-
- def test_default_page_size(self):
+ request = Request(factory.get('/', {'limit': 5, 'offset': 1}))
+ queryset = self.paginate_queryset(request)
+ content = self.get_paginated_content(queryset)
+ context = self.get_html_context()
+ assert queryset == [2, 3, 4, 5, 6]
+ assert content == {
+ 'results': [2, 3, 4, 5, 6],
+ 'previous': 'http://testserver/?limit=5',
+ 'next': 'http://testserver/?limit=5&offset=6',
+ 'count': 100
+ }
+ assert context == {
+ 'previous_url': 'http://testserver/?limit=5',
+ 'next_url': 'http://testserver/?limit=5&offset=6',
+ 'page_links': [
+ PageLink('http://testserver/?limit=5', 1, False, False),
+ PageLink('http://testserver/?limit=5&offset=1', 2, True, False),
+ PageLink('http://testserver/?limit=5&offset=6', 3, False, False),
+ PAGE_BREAK,
+ PageLink('http://testserver/?limit=5&offset=96', 21, False, False),
+ ]
+ }
+
+ def test_first_offset(self):
+ request = Request(factory.get('/', {'limit': 5, 'offset': 5}))
+ queryset = self.paginate_queryset(request)
+ content = self.get_paginated_content(queryset)
+ context = self.get_html_context()
+ assert queryset == [6, 7, 8, 9, 10]
+ assert content == {
+ 'results': [6, 7, 8, 9, 10],
+ 'previous': 'http://testserver/?limit=5',
+ 'next': 'http://testserver/?limit=5&offset=10',
+ 'count': 100
+ }
+ assert context == {
+ 'previous_url': 'http://testserver/?limit=5',
+ 'next_url': 'http://testserver/?limit=5&offset=10',
+ 'page_links': [
+ PageLink('http://testserver/?limit=5', 1, False, False),
+ PageLink('http://testserver/?limit=5&offset=5', 2, True, False),
+ PageLink('http://testserver/?limit=5&offset=10', 3, False, False),
+ PAGE_BREAK,
+ PageLink('http://testserver/?limit=5&offset=95', 20, False, False),
+ ]
+ }
+
+ def test_middle_offset(self):
+ request = Request(factory.get('/', {'limit': 5, 'offset': 10}))
+ queryset = self.paginate_queryset(request)
+ content = self.get_paginated_content(queryset)
+ context = self.get_html_context()
+ assert queryset == [11, 12, 13, 14, 15]
+ assert content == {
+ 'results': [11, 12, 13, 14, 15],
+ 'previous': 'http://testserver/?limit=5&offset=5',
+ 'next': 'http://testserver/?limit=5&offset=15',
+ 'count': 100
+ }
+ assert context == {
+ 'previous_url': 'http://testserver/?limit=5&offset=5',
+ 'next_url': 'http://testserver/?limit=5&offset=15',
+ 'page_links': [
+ PageLink('http://testserver/?limit=5', 1, False, False),
+ PageLink('http://testserver/?limit=5&offset=5', 2, False, False),
+ PageLink('http://testserver/?limit=5&offset=10', 3, True, False),
+ PageLink('http://testserver/?limit=5&offset=15', 4, False, False),
+ PAGE_BREAK,
+ PageLink('http://testserver/?limit=5&offset=95', 20, False, False),
+ ]
+ }
+
+ def test_ending_offset(self):
+ request = Request(factory.get('/', {'limit': 5, 'offset': 95}))
+ queryset = self.paginate_queryset(request)
+ content = self.get_paginated_content(queryset)
+ context = self.get_html_context()
+ assert queryset == [96, 97, 98, 99, 100]
+ assert content == {
+ 'results': [96, 97, 98, 99, 100],
+ 'previous': 'http://testserver/?limit=5&offset=90',
+ 'next': None,
+ 'count': 100
+ }
+ assert context == {
+ 'previous_url': 'http://testserver/?limit=5&offset=90',
+ 'next_url': None,
+ 'page_links': [
+ PageLink('http://testserver/?limit=5', 1, False, False),
+ PAGE_BREAK,
+ PageLink('http://testserver/?limit=5&offset=85', 18, False, False),
+ PageLink('http://testserver/?limit=5&offset=90', 19, False, False),
+ PageLink('http://testserver/?limit=5&offset=95', 20, True, False),
+ ]
+ }
+
+ def test_invalid_offset(self):
"""
- Tests the default page size for this view.
- no page size --> no limit --> no meta data
+ An invalid offset query param should be treated as 0.
"""
- request = factory.get('/')
- response = self.view(request).render()
- self.assertEqual(response.data, self.data)
+ request = Request(factory.get('/', {'limit': 5, 'offset': 'invalid'}))
+ queryset = self.paginate_queryset(request)
+ assert queryset == [1, 2, 3, 4, 5]
- def test_paginate_by_param(self):
+ def test_invalid_limit(self):
"""
- If paginate_by_param is set, the new kwarg should limit per view requests.
+ An invalid limit query param should be ignored in favor of the default.
"""
- request = factory.get('/', {'page_size': 5})
- response = self.view(request).render()
- self.assertEqual(response.data['count'], 13)
- self.assertEqual(response.data['results'], self.data[:5])
+ request = Request(factory.get('/', {'limit': 'invalid', 'offset': 0}))
+ queryset = self.paginate_queryset(request)
+ assert queryset == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
-class TestMaxPaginateByParam(TestCase):
+class TestCursorPagination:
"""
- Tests for list views with max_paginate_by kwarg
+ Unit tests for `pagination.CursorPagination`.
"""
- def setUp(self):
+ def setup(self):
+ class MockObject(object):
+ def __init__(self, idx):
+ self.created = idx
+
+ class MockQuerySet(object):
+ def __init__(self, items):
+ self.items = items
+
+ def filter(self, created__gt=None, created__lt=None):
+ if created__gt is not None:
+ return MockQuerySet([
+ item for item in self.items
+ if item.created > int(created__gt)
+ ])
+
+ assert created__lt is not None
+ return MockQuerySet([
+ item for item in self.items
+ if item.created < int(created__lt)
+ ])
+
+ def order_by(self, *ordering):
+ if ordering[0].startswith('-'):
+ return MockQuerySet(list(reversed(self.items)))
+ return self
+
+ def __getitem__(self, sliced):
+ return self.items[sliced]
+
+ class ExamplePagination(pagination.CursorPagination):
+ page_size = 5
+ ordering = 'created'
+
+ self.pagination = ExamplePagination()
+ self.queryset = MockQuerySet([
+ MockObject(idx) for idx in [
+ 1, 1, 1, 1, 1,
+ 1, 2, 3, 4, 4,
+ 4, 4, 5, 6, 7,
+ 7, 7, 7, 7, 7,
+ 7, 7, 7, 8, 9,
+ 9, 9, 9, 9, 9
+ ]
+ ])
+
+ def get_pages(self, url):
"""
- Create 13 BasicModel instances.
- """
- for i in range(13):
- BasicModel(text=i).save()
- self.objects = BasicModel.objects
- self.data = [
- {'id': obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
- self.view = MaxPaginateByView.as_view()
-
- def test_max_paginate_by(self):
- """
- If max_paginate_by is set, it should limit page size for the view.
- """
- request = factory.get('/', data={'page_size': 10})
- response = self.view(request).render()
- self.assertEqual(response.data['count'], 13)
- self.assertEqual(response.data['results'], self.data[:5])
+ Given a URL return a tuple of:
- def test_max_paginate_by_without_page_size_param(self):
+ (previous page, current page, next page, previous url, next url)
"""
- If max_paginate_by is set, but client does not specifiy page_size,
- standard `paginate_by` behavior should be used.
- """
- request = factory.get('/')
- response = self.view(request).render()
- self.assertEqual(response.data['results'], self.data[:3])
-
-
-# Tests for context in pagination serializers
+ request = Request(factory.get(url))
+ queryset = self.pagination.paginate_queryset(self.queryset, request)
+ current = [item.created for item in queryset]
-class CustomField(serializers.ReadOnlyField):
- def to_native(self, value):
- if 'view' not in self.context:
- raise RuntimeError("context isn't getting passed into custom field")
- return "value"
+ next_url = self.pagination.get_next_link()
+ previous_url = self.pagination.get_previous_link()
+ if next_url is not None:
+ request = Request(factory.get(next_url))
+ queryset = self.pagination.paginate_queryset(self.queryset, request)
+ next = [item.created for item in queryset]
+ else:
+ next = None
-class BasicModelSerializer(serializers.Serializer):
- text = CustomField()
-
- def to_native(self, value):
- if 'view' not in self.context:
- raise RuntimeError("context isn't getting passed into serializer")
- return super(BasicSerializer, self).to_native(value)
+ if previous_url is not None:
+ request = Request(factory.get(previous_url))
+ queryset = self.pagination.paginate_queryset(self.queryset, request)
+ previous = [item.created for item in queryset]
+ else:
+ previous = None
+ return (previous, current, next, previous_url, next_url)
-class TestContextPassedToCustomField(TestCase):
- def setUp(self):
- BasicModel.objects.create(text='ala ma kota')
+ def test_invalid_cursor(self):
+ request = Request(factory.get('/', {'cursor': '123'}))
+ with pytest.raises(exceptions.NotFound):
+ self.pagination.paginate_queryset(self.queryset, request)
- def test_with_pagination(self):
- class ListView(generics.ListCreateAPIView):
- queryset = BasicModel.objects.all()
- serializer_class = BasicModelSerializer
- paginate_by = 1
+ def test_use_with_ordering_filter(self):
+ class MockView:
+ filter_backends = (filters.OrderingFilter,)
+ ordering_fields = ['username', 'created']
+ ordering = 'created'
- self.view = ListView.as_view()
- request = factory.get('/')
- response = self.view(request).render()
+ request = Request(factory.get('/', {'ordering': 'username'}))
+ ordering = self.pagination.get_ordering(request, [], MockView())
+ assert ordering == ('username',)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
+ request = Request(factory.get('/', {'ordering': '-username'}))
+ ordering = self.pagination.get_ordering(request, [], MockView())
+ assert ordering == ('-username',)
+ request = Request(factory.get('/', {'ordering': 'invalid'}))
+ ordering = self.pagination.get_ordering(request, [], MockView())
+ assert ordering == ('created',)
-# Tests for custom pagination serializers
+ def test_cursor_pagination(self):
+ (previous, current, next, previous_url, next_url) = self.get_pages('/')
-class LinksSerializer(serializers.Serializer):
- next = pagination.NextPageField(source='*')
- prev = pagination.PreviousPageField(source='*')
+ assert previous is None
+ assert current == [1, 1, 1, 1, 1]
+ assert next == [1, 2, 3, 4, 4]
+ (previous, current, next, previous_url, next_url) = self.get_pages(next_url)
-class CustomPaginationSerializer(pagination.BasePaginationSerializer):
- links = LinksSerializer(source='*') # Takes the page object as the source
- total_results = serializers.ReadOnlyField(source='paginator.count')
+ assert previous == [1, 1, 1, 1, 1]
+ assert current == [1, 2, 3, 4, 4]
+ assert next == [4, 4, 5, 6, 7]
- results_field = 'objects'
+ (previous, current, next, previous_url, next_url) = self.get_pages(next_url)
+ assert previous == [1, 2, 3, 4, 4]
+ assert current == [4, 4, 5, 6, 7]
+ assert next == [7, 7, 7, 7, 7]
-class CustomFooSerializer(serializers.Serializer):
- foo = serializers.CharField()
+ (previous, current, next, previous_url, next_url) = self.get_pages(next_url)
+ assert previous == [4, 4, 4, 5, 6] # Paging artifact
+ assert current == [7, 7, 7, 7, 7]
+ assert next == [7, 7, 7, 8, 9]
-class CustomFooPaginationSerializer(pagination.PaginationSerializer):
- class Meta:
- object_serializer_class = CustomFooSerializer
+ (previous, current, next, previous_url, next_url) = self.get_pages(next_url)
+ assert previous == [7, 7, 7, 7, 7]
+ assert current == [7, 7, 7, 8, 9]
+ assert next == [9, 9, 9, 9, 9]
-class TestCustomPaginationSerializer(TestCase):
- def setUp(self):
- objects = ['john', 'paul', 'george', 'ringo']
- paginator = Paginator(objects, 2)
- self.page = paginator.page(1)
+ (previous, current, next, previous_url, next_url) = self.get_pages(next_url)
- def test_custom_pagination_serializer(self):
- request = APIRequestFactory().get('/foobar')
- serializer = CustomPaginationSerializer(
- instance=self.page,
- context={'request': request}
- )
- expected = {
- 'links': {
- 'next': 'http://testserver/foobar?page=2',
- 'prev': None
- },
- 'total_results': 4,
- 'objects': ['john', 'paul']
- }
- self.assertEqual(serializer.data, expected)
+ assert previous == [7, 7, 7, 8, 9]
+ assert current == [9, 9, 9, 9, 9]
+ assert next is None
- def test_custom_pagination_serializer_with_custom_object_serializer(self):
- objects = [
- {'foo': 'bar'},
- {'foo': 'spam'}
- ]
- paginator = Paginator(objects, 1)
- page = paginator.page(1)
- serializer = CustomFooPaginationSerializer(page)
- serializer.data
+ (previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
+ assert previous == [7, 7, 7, 7, 7]
+ assert current == [7, 7, 7, 8, 9]
+ assert next == [9, 9, 9, 9, 9]
-class NonIntegerPage(object):
+ (previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
- def __init__(self, paginator, object_list, prev_token, token, next_token):
- self.paginator = paginator
- self.object_list = object_list
- self.prev_token = prev_token
- self.token = token
- self.next_token = next_token
+ assert previous == [4, 4, 5, 6, 7]
+ assert current == [7, 7, 7, 7, 7]
+ assert next == [8, 9, 9, 9, 9] # Paging artifact
- def has_next(self):
- return not not self.next_token
+ (previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
- def next_page_number(self):
- return self.next_token
+ assert previous == [1, 2, 3, 4, 4]
+ assert current == [4, 4, 5, 6, 7]
+ assert next == [7, 7, 7, 7, 7]
- def has_previous(self):
- return not not self.prev_token
+ (previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
- def previous_page_number(self):
- return self.prev_token
+ assert previous == [1, 1, 1, 1, 1]
+ assert current == [1, 2, 3, 4, 4]
+ assert next == [4, 4, 5, 6, 7]
+ (previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
-class NonIntegerPaginator(object):
+ assert previous is None
+ assert current == [1, 1, 1, 1, 1]
+ assert next == [1, 2, 3, 4, 4]
- def __init__(self, object_list, per_page):
- self.object_list = object_list
- self.per_page = per_page
+ assert isinstance(self.pagination.to_html(), type(''))
- def count(self):
- # pretend like we don't know how many pages we have
- return None
- def page(self, token=None):
- if token:
- try:
- first = self.object_list.index(token)
- except ValueError:
- first = 0
- else:
- first = 0
- n = len(self.object_list)
- last = min(first + self.per_page, n)
- prev_token = self.object_list[last - (2 * self.per_page)] if first else None
- next_token = self.object_list[last] if last < n else None
- return NonIntegerPage(self, self.object_list[first:last], prev_token, token, next_token)
-
-
-class TestNonIntegerPagination(TestCase):
- def test_custom_pagination_serializer(self):
- objects = ['john', 'paul', 'george', 'ringo']
- paginator = NonIntegerPaginator(objects, 2)
-
- request = APIRequestFactory().get('/foobar')
- serializer = CustomPaginationSerializer(
- instance=paginator.page(),
- context={'request': request}
- )
- expected = {
- 'links': {
- 'next': 'http://testserver/foobar?page={0}'.format(objects[2]),
- 'prev': None
- },
- 'total_results': None,
- 'objects': objects[:2]
- }
- self.assertEqual(serializer.data, expected)
+def test_get_displayed_page_numbers():
+ """
+ Test our contextual page display function.
- request = APIRequestFactory().get('/foobar')
- serializer = CustomPaginationSerializer(
- instance=paginator.page('george'),
- context={'request': request}
- )
- expected = {
- 'links': {
- 'next': None,
- 'prev': 'http://testserver/foobar?page={0}'.format(objects[0]),
- },
- 'total_results': None,
- 'objects': objects[2:]
- }
- self.assertEqual(serializer.data, expected)
+ This determines which pages to display in a pagination control,
+ given the current page and the last page.
+ """
+ displayed_page_numbers = pagination._get_displayed_page_numbers
+
+ # At five pages or less, all pages are displayed, always.
+ assert displayed_page_numbers(1, 5) == [1, 2, 3, 4, 5]
+ assert displayed_page_numbers(2, 5) == [1, 2, 3, 4, 5]
+ assert displayed_page_numbers(3, 5) == [1, 2, 3, 4, 5]
+ assert displayed_page_numbers(4, 5) == [1, 2, 3, 4, 5]
+ assert displayed_page_numbers(5, 5) == [1, 2, 3, 4, 5]
+
+ # Between six and either pages we may have a single page break.
+ assert displayed_page_numbers(1, 6) == [1, 2, 3, None, 6]
+ assert displayed_page_numbers(2, 6) == [1, 2, 3, None, 6]
+ assert displayed_page_numbers(3, 6) == [1, 2, 3, 4, 5, 6]
+ assert displayed_page_numbers(4, 6) == [1, 2, 3, 4, 5, 6]
+ assert displayed_page_numbers(5, 6) == [1, None, 4, 5, 6]
+ assert displayed_page_numbers(6, 6) == [1, None, 4, 5, 6]
+
+ assert displayed_page_numbers(1, 7) == [1, 2, 3, None, 7]
+ assert displayed_page_numbers(2, 7) == [1, 2, 3, None, 7]
+ assert displayed_page_numbers(3, 7) == [1, 2, 3, 4, None, 7]
+ assert displayed_page_numbers(4, 7) == [1, 2, 3, 4, 5, 6, 7]
+ assert displayed_page_numbers(5, 7) == [1, None, 4, 5, 6, 7]
+ assert displayed_page_numbers(6, 7) == [1, None, 5, 6, 7]
+ assert displayed_page_numbers(7, 7) == [1, None, 5, 6, 7]
+
+ assert displayed_page_numbers(1, 8) == [1, 2, 3, None, 8]
+ assert displayed_page_numbers(2, 8) == [1, 2, 3, None, 8]
+ assert displayed_page_numbers(3, 8) == [1, 2, 3, 4, None, 8]
+ assert displayed_page_numbers(4, 8) == [1, 2, 3, 4, 5, None, 8]
+ assert displayed_page_numbers(5, 8) == [1, None, 4, 5, 6, 7, 8]
+ assert displayed_page_numbers(6, 8) == [1, None, 5, 6, 7, 8]
+ assert displayed_page_numbers(7, 8) == [1, None, 6, 7, 8]
+ assert displayed_page_numbers(8, 8) == [1, None, 6, 7, 8]
+
+ # At nine or more pages we may have two page breaks, one on each side.
+ assert displayed_page_numbers(1, 9) == [1, 2, 3, None, 9]
+ assert displayed_page_numbers(2, 9) == [1, 2, 3, None, 9]
+ assert displayed_page_numbers(3, 9) == [1, 2, 3, 4, None, 9]
+ assert displayed_page_numbers(4, 9) == [1, 2, 3, 4, 5, None, 9]
+ assert displayed_page_numbers(5, 9) == [1, None, 4, 5, 6, None, 9]
+ assert displayed_page_numbers(6, 9) == [1, None, 5, 6, 7, 8, 9]
+ assert displayed_page_numbers(7, 9) == [1, None, 6, 7, 8, 9]
+ assert displayed_page_numbers(8, 9) == [1, None, 7, 8, 9]
+ assert displayed_page_numbers(9, 9) == [1, None, 7, 8, 9]
diff --git a/tests/test_parsers.py b/tests/test_parsers.py
index 54455cf6..8816065a 100644
--- a/tests/test_parsers.py
+++ b/tests/test_parsers.py
@@ -101,7 +101,9 @@ class TestFileUploadParser(TestCase):
self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8--ÀĥƦ.txt')
filename = parser.get_filename(self.stream, None, self.parser_context)
- self.assertEqual(filename, 'fallback.txt')
+ # Malformed. Either None or 'fallback.txt' will be acceptable.
+ # See also https://code.djangoproject.com/ticket/24209
+ self.assertIn(filename, ('fallback.txt', None))
def __replace_content_disposition(self, disposition):
self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = disposition
diff --git a/tests/test_relations.py b/tests/test_relations.py
index 62353dc2..fbe176e2 100644
--- a/tests/test_relations.py
+++ b/tests/test_relations.py
@@ -1,6 +1,8 @@
from .utils import mock_reverse, fail_reverse, BadType, MockObject, MockQueryset
from django.core.exceptions import ImproperlyConfigured
+from django.utils.datastructures import MultiValueDict
from rest_framework import serializers
+from rest_framework.fields import empty
from rest_framework.test import APISimpleTestCase
import pytest
@@ -33,7 +35,7 @@ class TestPrimaryKeyRelatedField(APISimpleTestCase):
with pytest.raises(serializers.ValidationError) as excinfo:
self.field.to_internal_value(4)
msg = excinfo.value.detail[0]
- assert msg == "Invalid pk '4' - object does not exist."
+ assert msg == 'Invalid pk "4" - object does not exist.'
def test_pk_related_lookup_invalid_type(self):
with pytest.raises(serializers.ValidationError) as excinfo:
@@ -134,3 +136,34 @@ class TestSlugRelatedField(APISimpleTestCase):
def test_representation(self):
representation = self.field.to_representation(self.instance)
assert representation == self.instance.name
+
+
+class TestManyRelatedField(APISimpleTestCase):
+ def setUp(self):
+ self.instance = MockObject(pk=1, name='foo')
+ self.field = serializers.StringRelatedField(many=True)
+ self.field.field_name = 'foo'
+
+ def test_get_value_regular_dictionary_full(self):
+ assert 'bar' == self.field.get_value({'foo': 'bar'})
+ assert empty == self.field.get_value({'baz': 'bar'})
+
+ def test_get_value_regular_dictionary_partial(self):
+ setattr(self.field.root, 'partial', True)
+ assert 'bar' == self.field.get_value({'foo': 'bar'})
+ assert empty == self.field.get_value({'baz': 'bar'})
+
+ def test_get_value_multi_dictionary_full(self):
+ mvd = MultiValueDict({'foo': ['bar1', 'bar2']})
+ assert ['bar1', 'bar2'] == self.field.get_value(mvd)
+
+ mvd = MultiValueDict({'baz': ['bar1', 'bar2']})
+ assert [] == self.field.get_value(mvd)
+
+ def test_get_value_multi_dictionary_partial(self):
+ setattr(self.field.root, 'partial', True)
+ mvd = MultiValueDict({'foo': ['bar1', 'bar2']})
+ assert ['bar1', 'bar2'] == self.field.get_value(mvd)
+
+ mvd = MultiValueDict({'baz': ['bar1', 'bar2']})
+ assert empty == self.field.get_value(mvd)
diff --git a/tests/test_relations_hyperlink.py b/tests/test_relations_hyperlink.py
index f1b882ed..aede61d2 100644
--- a/tests/test_relations_hyperlink.py
+++ b/tests/test_relations_hyperlink.py
@@ -1,5 +1,5 @@
from __future__ import unicode_literals
-from django.conf.urls import patterns, url
+from django.conf.urls import url
from django.test import TestCase
from rest_framework import serializers
from rest_framework.test import APIRequestFactory
@@ -14,8 +14,7 @@ request = factory.get('/') # Just to ensure we have a request in the serializer
dummy_view = lambda request, pk: None
-urlpatterns = patterns(
- '',
+urlpatterns = [
url(r'^dummyurl/(?P<pk>[0-9]+)/$', dummy_view, name='dummy-url'),
url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'),
url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'),
@@ -24,7 +23,7 @@ urlpatterns = patterns(
url(r'^nullableforeignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableforeignkeysource-detail'),
url(r'^onetoonetarget/(?P<pk>[0-9]+)/$', dummy_view, name='onetoonetarget-detail'),
url(r'^nullableonetoonesource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableonetoonesource-detail'),
-)
+]
# ManyToMany
diff --git a/tests/test_renderers.py b/tests/test_renderers.py
index 7b78f7ba..f68405f0 100644
--- a/tests/test_renderers.py
+++ b/tests/test_renderers.py
@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
-
from django.conf.urls import patterns, url, include
from django.core.cache import cache
from django.db import models
@@ -8,6 +7,7 @@ from django.test import TestCase
from django.utils import six
from django.utils.translation import ugettext_lazy as _
from rest_framework import status, permissions
+from rest_framework.compat import OrderedDict
from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework.renderers import BaseRenderer, JSONRenderer, BrowsableAPIRenderer
@@ -15,7 +15,6 @@ from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory
from collections import MutableMapping
import json
-import pickle
import re
@@ -408,84 +407,46 @@ class CacheRenderTest(TestCase):
urls = 'tests.test_renderers'
- cache_key = 'just_a_cache_key'
-
- @classmethod
- def _get_pickling_errors(cls, obj, seen=None):
- """ Return any errors that would be raised if `obj' is pickled
- Courtesy of koffie @ http://stackoverflow.com/a/7218986/109897
- """
- if seen is None:
- seen = []
- try:
- state = obj.__getstate__()
- except AttributeError:
- return
- if state is None:
- return
- if isinstance(state, tuple):
- if not isinstance(state[0], dict):
- state = state[1]
- else:
- state = state[0].update(state[1])
- result = {}
- for i in state:
- try:
- pickle.dumps(state[i], protocol=2)
- except pickle.PicklingError:
- if not state[i] in seen:
- seen.append(state[i])
- result[i] = cls._get_pickling_errors(state[i], seen)
- return result
-
- def http_resp(self, http_method, url):
- """
- Simple wrapper for Client http requests
- Removes the `client' and `request' attributes from as they are
- added by django.test.client.Client and not part of caching
- responses outside of tests.
- """
- method = getattr(self.client, http_method)
- resp = method(url)
- resp._closable_objects = []
- del resp.client, resp.request
- try:
- del resp.wsgi_request
- except AttributeError:
- pass
- return resp
-
- def test_obj_pickling(self):
- """
- Test that responses are properly pickled
- """
- resp = self.http_resp('get', '/cache')
-
- # Make sure that no pickling errors occurred
- self.assertEqual(self._get_pickling_errors(resp), {})
-
- # Unfortunately LocMem backend doesn't raise PickleErrors but returns
- # None instead.
- cache.set(self.cache_key, resp)
- self.assertTrue(cache.get(self.cache_key) is not None)
-
def test_head_caching(self):
"""
Test caching of HEAD requests
"""
- resp = self.http_resp('head', '/cache')
- cache.set(self.cache_key, resp)
-
- cached_resp = cache.get(self.cache_key)
- self.assertIsInstance(cached_resp, Response)
+ response = self.client.head('/cache')
+ cache.set('key', response)
+ cached_response = cache.get('key')
+ assert isinstance(cached_response, Response)
+ assert cached_response.content == response.content
+ assert cached_response.status_code == response.status_code
def test_get_caching(self):
"""
Test caching of GET requests
"""
- resp = self.http_resp('get', '/cache')
- cache.set(self.cache_key, resp)
+ response = self.client.get('/cache')
+ cache.set('key', response)
+ cached_response = cache.get('key')
+ assert isinstance(cached_response, Response)
+ assert cached_response.content == response.content
+ assert cached_response.status_code == response.status_code
+
+
+class TestJSONIndentationStyles:
+ def test_indented(self):
+ renderer = JSONRenderer()
+ data = OrderedDict([('a', 1), ('b', 2)])
+ assert renderer.render(data) == b'{"a":1,"b":2}'
- cached_resp = cache.get(self.cache_key)
- self.assertIsInstance(cached_resp, Response)
- self.assertEqual(cached_resp.content, resp.content)
+ def test_compact(self):
+ renderer = JSONRenderer()
+ data = OrderedDict([('a', 1), ('b', 2)])
+ context = {'indent': 4}
+ assert (
+ renderer.render(data, renderer_context=context) ==
+ b'{\n "a": 1,\n "b": 2\n}'
+ )
+
+ def test_long_form(self):
+ renderer = JSONRenderer()
+ renderer.compact = False
+ data = OrderedDict([('a', 1), ('b', 2)])
+ assert renderer.render(data) == b'{"a": 1, "b": 2}'
diff --git a/tests/test_routers.py b/tests/test_routers.py
index 34306146..948c69bb 100644
--- a/tests/test_routers.py
+++ b/tests/test_routers.py
@@ -1,17 +1,53 @@
from __future__ import unicode_literals
-from django.conf.urls import patterns, url, include
+from django.conf.urls import url, include
from django.db import models
from django.test import TestCase
from django.core.exceptions import ImproperlyConfigured
-from rest_framework import serializers, viewsets, mixins, permissions
+from rest_framework import serializers, viewsets, permissions
from rest_framework.decorators import detail_route, list_route
from rest_framework.response import Response
from rest_framework.routers import SimpleRouter, DefaultRouter
from rest_framework.test import APIRequestFactory
+from collections import namedtuple
factory = APIRequestFactory()
-urlpatterns = patterns('',)
+
+class RouterTestModel(models.Model):
+ uuid = models.CharField(max_length=20)
+ text = models.CharField(max_length=200)
+
+
+class NoteSerializer(serializers.HyperlinkedModelSerializer):
+ url = serializers.HyperlinkedIdentityField(view_name='routertestmodel-detail', lookup_field='uuid')
+
+ class Meta:
+ model = RouterTestModel
+ fields = ('url', 'uuid', 'text')
+
+
+class NoteViewSet(viewsets.ModelViewSet):
+ queryset = RouterTestModel.objects.all()
+ serializer_class = NoteSerializer
+ lookup_field = 'uuid'
+
+
+class MockViewSet(viewsets.ModelViewSet):
+ queryset = None
+ serializer_class = None
+
+
+notes_router = SimpleRouter()
+notes_router.register(r'notes', NoteViewSet)
+
+namespaced_router = DefaultRouter()
+namespaced_router.register(r'example', MockViewSet, base_name='example')
+
+urlpatterns = [
+ url(r'^non-namespaced/', include(namespaced_router.urls)),
+ url(r'^namespaced/', include(namespaced_router.urls, namespace='example')),
+ url(r'^example/', include(notes_router.urls)),
+]
class BasicViewSet(viewsets.ViewSet):
@@ -63,9 +99,26 @@ class TestSimpleRouter(TestCase):
self.assertEqual(route.mapping[method], endpoint)
-class RouterTestModel(models.Model):
- uuid = models.CharField(max_length=20)
- text = models.CharField(max_length=200)
+class TestRootView(TestCase):
+ urls = 'tests.test_routers'
+
+ def test_retrieve_namespaced_root(self):
+ response = self.client.get('/namespaced/')
+ self.assertEqual(
+ response.data,
+ {
+ "example": "http://testserver/namespaced/example/",
+ }
+ )
+
+ def test_retrieve_non_namespaced_root(self):
+ response = self.client.get('/non-namespaced/')
+ self.assertEqual(
+ response.data,
+ {
+ "example": "http://testserver/non-namespaced/example/",
+ }
+ )
class TestCustomLookupFields(TestCase):
@@ -75,51 +128,29 @@ class TestCustomLookupFields(TestCase):
urls = 'tests.test_routers'
def setUp(self):
- class NoteSerializer(serializers.HyperlinkedModelSerializer):
- url = serializers.HyperlinkedIdentityField(view_name='routertestmodel-detail', lookup_field='uuid')
-
- class Meta:
- model = RouterTestModel
- fields = ('url', 'uuid', 'text')
-
- class NoteViewSet(viewsets.ModelViewSet):
- queryset = RouterTestModel.objects.all()
- serializer_class = NoteSerializer
- lookup_field = 'uuid'
-
- self.router = SimpleRouter()
- self.router.register(r'notes', NoteViewSet)
-
- from tests import test_routers
- urls = getattr(test_routers, 'urlpatterns')
- urls += patterns(
- '',
- url(r'^', include(self.router.urls)),
- )
-
RouterTestModel.objects.create(uuid='123', text='foo bar')
def test_custom_lookup_field_route(self):
- detail_route = self.router.urls[-1]
+ detail_route = notes_router.urls[-1]
detail_url_pattern = detail_route.regex.pattern
self.assertIn('<uuid>', detail_url_pattern)
def test_retrieve_lookup_field_list_view(self):
- response = self.client.get('/notes/')
+ response = self.client.get('/example/notes/')
self.assertEqual(
response.data,
[{
- "url": "http://testserver/notes/123/",
+ "url": "http://testserver/example/notes/123/",
"uuid": "123", "text": "foo bar"
}]
)
def test_retrieve_lookup_field_detail_view(self):
- response = self.client.get('/notes/123/')
+ response = self.client.get('/example/notes/123/')
self.assertEqual(
response.data,
{
- "url": "http://testserver/notes/123/",
+ "url": "http://testserver/example/notes/123/",
"uuid": "123", "text": "foo bar"
}
)
@@ -149,7 +180,7 @@ class TestLookupValueRegex(TestCase):
class TestTrailingSlashIncluded(TestCase):
def setUp(self):
class NoteViewSet(viewsets.ModelViewSet):
- model = RouterTestModel
+ queryset = RouterTestModel.objects.all()
self.router = SimpleRouter()
self.router.register(r'notes', NoteViewSet)
@@ -164,7 +195,7 @@ class TestTrailingSlashIncluded(TestCase):
class TestTrailingSlashRemoved(TestCase):
def setUp(self):
class NoteViewSet(viewsets.ModelViewSet):
- model = RouterTestModel
+ queryset = RouterTestModel.objects.all()
self.router = SimpleRouter(trailing_slash=False)
self.router.register(r'notes', NoteViewSet)
@@ -179,7 +210,8 @@ class TestTrailingSlashRemoved(TestCase):
class TestNameableRoot(TestCase):
def setUp(self):
class NoteViewSet(viewsets.ModelViewSet):
- model = RouterTestModel
+ queryset = RouterTestModel.objects.all()
+
self.router = DefaultRouter()
self.router.root_view_name = 'nameable-root'
self.router.register(r'notes', NoteViewSet)
@@ -261,6 +293,14 @@ class DynamicListAndDetailViewSet(viewsets.ViewSet):
def detail_route_get(self, request, *args, **kwargs):
return Response({'method': 'link2'})
+ @list_route(url_path="list_custom-route")
+ def list_custom_route_get(self, request, *args, **kwargs):
+ return Response({'method': 'link1'})
+
+ @detail_route(url_path="detail_custom-route")
+ def detail_custom_route_get(self, request, *args, **kwargs):
+ return Response({'method': 'link2'})
+
class TestDynamicListAndDetailRouter(TestCase):
def setUp(self):
@@ -269,35 +309,30 @@ class TestDynamicListAndDetailRouter(TestCase):
def test_list_and_detail_route_decorators(self):
routes = self.router.get_routes(DynamicListAndDetailViewSet)
decorator_routes = [r for r in routes if not (r.name.endswith('-list') or r.name.endswith('-detail'))]
+
+ MethodNamesMap = namedtuple('MethodNamesMap', 'method_name url_path')
# Make sure all these endpoints exist and none have been clobbered
- for i, endpoint in enumerate(['list_route_get', 'list_route_post', 'detail_route_get', 'detail_route_post']):
+ for i, endpoint in enumerate([MethodNamesMap('list_custom_route_get', 'list_custom-route'),
+ MethodNamesMap('list_route_get', 'list_route_get'),
+ MethodNamesMap('list_route_post', 'list_route_post'),
+ MethodNamesMap('detail_custom_route_get', 'detail_custom-route'),
+ MethodNamesMap('detail_route_get', 'detail_route_get'),
+ MethodNamesMap('detail_route_post', 'detail_route_post')
+ ]):
route = decorator_routes[i]
# check url listing
- if endpoint.startswith('list_'):
+ method_name = endpoint.method_name
+ url_path = endpoint.url_path
+
+ if method_name.startswith('list_'):
self.assertEqual(route.url,
- '^{{prefix}}/{0}{{trailing_slash}}$'.format(endpoint))
+ '^{{prefix}}/{0}{{trailing_slash}}$'.format(url_path))
else:
self.assertEqual(route.url,
- '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(endpoint))
+ '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(url_path))
# check method to function mapping
- if endpoint.endswith('_post'):
+ if method_name.endswith('_post'):
method_map = 'post'
else:
method_map = 'get'
- self.assertEqual(route.mapping[method_map], endpoint)
-
-
-class TestRootWithAListlessViewset(TestCase):
- def setUp(self):
- class NoteViewSet(mixins.RetrieveModelMixin,
- viewsets.GenericViewSet):
- model = RouterTestModel
-
- self.router = DefaultRouter()
- self.router.register(r'notes', NoteViewSet)
- self.view = self.router.urls[0].callback
-
- def test_api_root(self):
- request = factory.get('/')
- response = self.view(request)
- self.assertEqual(response.data, {})
+ self.assertEqual(route.mapping[method_map], method_name)
diff --git a/tests/test_serializer.py b/tests/test_serializer.py
index c17b6d8c..b7a0484b 100644
--- a/tests/test_serializer.py
+++ b/tests/test_serializer.py
@@ -1,7 +1,9 @@
# coding: utf-8
from __future__ import unicode_literals
+from .utils import MockObject
from rest_framework import serializers
from rest_framework.compat import unicode_repr
+import pickle
import pytest
@@ -216,3 +218,80 @@ class TestUnicodeRepr:
instance = ExampleObject()
serializer = ExampleSerializer(instance)
repr(serializer) # Should not error.
+
+
+class TestNotRequiredOutput:
+ def test_not_required_output_for_dict(self):
+ """
+ 'required=False' should allow a dictionary key to be missing in output.
+ """
+ class ExampleSerializer(serializers.Serializer):
+ omitted = serializers.CharField(required=False)
+ included = serializers.CharField()
+
+ serializer = ExampleSerializer(data={'included': 'abc'})
+ serializer.is_valid()
+ assert serializer.data == {'included': 'abc'}
+
+ def test_not_required_output_for_object(self):
+ """
+ 'required=False' should allow an object attribute to be missing in output.
+ """
+ class ExampleSerializer(serializers.Serializer):
+ omitted = serializers.CharField(required=False)
+ included = serializers.CharField()
+
+ def create(self, validated_data):
+ return MockObject(**validated_data)
+
+ serializer = ExampleSerializer(data={'included': 'abc'})
+ serializer.is_valid()
+ serializer.save()
+ assert serializer.data == {'included': 'abc'}
+
+ def test_default_required_output_for_dict(self):
+ """
+ 'default="something"' should require dictionary key.
+
+ We need to handle this as the field will have an implicit
+ 'required=False', but it should still have a value.
+ """
+ class ExampleSerializer(serializers.Serializer):
+ omitted = serializers.CharField(default='abc')
+ included = serializers.CharField()
+
+ serializer = ExampleSerializer({'included': 'abc'})
+ with pytest.raises(KeyError):
+ serializer.data
+
+ def test_default_required_output_for_object(self):
+ """
+ 'default="something"' should require object attribute.
+
+ We need to handle this as the field will have an implicit
+ 'required=False', but it should still have a value.
+ """
+ class ExampleSerializer(serializers.Serializer):
+ omitted = serializers.CharField(default='abc')
+ included = serializers.CharField()
+
+ instance = MockObject(included='abc')
+ serializer = ExampleSerializer(instance)
+ with pytest.raises(AttributeError):
+ serializer.data
+
+
+class TestCacheSerializerData:
+ def test_cache_serializer_data(self):
+ """
+ Caching serializer data with pickle will drop the serializer info,
+ but does preserve the data itself.
+ """
+ class ExampleSerializer(serializers.Serializer):
+ field1 = serializers.CharField()
+ field2 = serializers.CharField()
+
+ serializer = ExampleSerializer({'field1': 'a', 'field2': 'b'})
+ pickled = pickle.dumps(serializer.data)
+ data = pickle.loads(pickled)
+ assert data == {'field1': 'a', 'field2': 'b'}
diff --git a/tests/test_serializer_bulk_update.py b/tests/test_serializer_bulk_update.py
index fb881a75..bc955b2e 100644
--- a/tests/test_serializer_bulk_update.py
+++ b/tests/test_serializer_bulk_update.py
@@ -101,7 +101,7 @@ class BulkCreateSerializerTests(TestCase):
serializer = self.BookSerializer(data=data, many=True)
self.assertEqual(serializer.is_valid(), False)
- expected_errors = {'non_field_errors': ['Expected a list of items but got type `int`.']}
+ expected_errors = {'non_field_errors': ['Expected a list of items but got type "int".']}
self.assertEqual(serializer.errors, expected_errors)
@@ -118,6 +118,6 @@ class BulkCreateSerializerTests(TestCase):
serializer = self.BookSerializer(data=data, many=True)
self.assertEqual(serializer.is_valid(), False)
- expected_errors = {'non_field_errors': ['Expected a list of items but got type `dict`.']}
+ expected_errors = {'non_field_errors': ['Expected a list of items but got type "dict".']}
self.assertEqual(serializer.errors, expected_errors)
diff --git a/tests/test_versioning.py b/tests/test_versioning.py
index c44f727d..553463d1 100644
--- a/tests/test_versioning.py
+++ b/tests/test_versioning.py
@@ -1,9 +1,13 @@
+from .utils import UsingURLPatterns
from django.conf.urls import include, url
+from rest_framework import serializers
from rest_framework import status, versioning
from rest_framework.decorators import APIView
from rest_framework.response import Response
from rest_framework.reverse import reverse
from rest_framework.test import APIRequestFactory, APITestCase
+from rest_framework.versioning import NamespaceVersioning
+import pytest
class RequestVersionView(APIView):
@@ -28,17 +32,8 @@ class RequestInvalidVersionView(APIView):
factory = APIRequestFactory()
-mock_view = lambda request: None
-
-included_patterns = [
- url(r'^namespaced/$', mock_view, name='another'),
-]
-
-urlpatterns = [
- url(r'^v1/', include(included_patterns, namespace='v1')),
- url(r'^another/$', mock_view, name='another'),
- url(r'^(?P<version>[^/]+)/another/$', mock_view, name='another')
-]
+dummy_view = lambda request: None
+dummy_pk_view = lambda request, pk: None
class TestRequestVersion:
@@ -114,8 +109,17 @@ class TestRequestVersion:
assert response.data == {'version': None}
-class TestURLReversing(APITestCase):
- urls = 'tests.test_versioning'
+class TestURLReversing(UsingURLPatterns, APITestCase):
+ included = [
+ url(r'^namespaced/$', dummy_view, name='another'),
+ url(r'^example/(?P<pk>\d+)/$', dummy_pk_view, name='example-detail')
+ ]
+
+ urlpatterns = [
+ url(r'^v1/', include(included, namespace='v1')),
+ url(r'^another/$', dummy_view, name='another'),
+ url(r'^(?P<version>[^/]+)/another/$', dummy_view, name='another'),
+ ]
def test_reverse_unversioned(self):
view = ReverseView.as_view()
@@ -221,3 +225,35 @@ class TestInvalidVersion:
request.resolver_match = FakeResolverMatch
response = view(request, version='v3')
assert response.status_code == status.HTTP_404_NOT_FOUND
+
+
+class TestHyperlinkedRelatedField(UsingURLPatterns, APITestCase):
+ included = [
+ url(r'^namespaced/(?P<pk>\d+)/$', dummy_view, name='namespaced'),
+ ]
+
+ urlpatterns = [
+ url(r'^v1/', include(included, namespace='v1')),
+ url(r'^v2/', include(included, namespace='v2'))
+ ]
+
+ def setUp(self):
+ super(TestHyperlinkedRelatedField, self).setUp()
+
+ class MockQueryset(object):
+ def get(self, pk):
+ return 'object %s' % pk
+
+ self.field = serializers.HyperlinkedRelatedField(
+ view_name='namespaced',
+ queryset=MockQueryset()
+ )
+ request = factory.get('/')
+ request.versioning_scheme = NamespaceVersioning()
+ request.version = 'v1'
+ self.field._context = {'request': request}
+
+ def test_bug_2489(self):
+ assert self.field.to_internal_value('/v1/namespaced/3/') == 'object 3'
+ with pytest.raises(serializers.ValidationError):
+ self.field.to_internal_value('/v2/namespaced/3/')
diff --git a/tests/utils.py b/tests/utils.py
index 5e902ba9..b9034996 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -1,30 +1,29 @@
-from contextlib import contextmanager
from django.core.exceptions import ObjectDoesNotExist
from django.core.urlresolvers import NoReverseMatch
-from django.utils import six
-from rest_framework.settings import api_settings
-@contextmanager
-def temporary_setting(setting, value, module=None):
+class UsingURLPatterns(object):
"""
- Temporarily change value of setting for test.
+ Isolates URL patterns used during testing on the test class itself.
+ For example:
- Optionally reload given module, useful when module uses value of setting on
- import.
- """
- original_value = getattr(api_settings, setting)
- setattr(api_settings, setting, value)
-
- if module is not None:
- six.moves.reload_module(module)
+ class MyTestCase(UsingURLPatterns, TestCase):
+ urlpatterns = [
+ ...
+ ]
- yield
+ def test_something(self):
+ ...
+ """
+ urls = __name__
- setattr(api_settings, setting, original_value)
+ def setUp(self):
+ global urlpatterns
+ urlpatterns = self.urlpatterns
- if module is not None:
- six.moves.reload_module(module)
+ def tearDown(self):
+ global urlpatterns
+ urlpatterns = []
class MockObject(object):