From ea8c40520165fc33343fceb15221b770701bdedf Mon Sep 17 00:00:00 2001 From: tanwanirahul Date: Mon, 3 Nov 2014 14:44:47 +0100 Subject: Tests for validating custom_method_name router attribute --- tests/test_routers.py | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) (limited to 'tests') diff --git a/tests/test_routers.py b/tests/test_routers.py index f6f5a977..d426f832 100644 --- a/tests/test_routers.py +++ b/tests/test_routers.py @@ -8,6 +8,7 @@ from rest_framework.decorators import detail_route, list_route from rest_framework.response import Response from rest_framework.routers import SimpleRouter, DefaultRouter from rest_framework.test import APIRequestFactory +from collections import namedtuple factory = APIRequestFactory() @@ -260,6 +261,14 @@ class DynamicListAndDetailViewSet(viewsets.ViewSet): def detail_route_get(self, request, *args, **kwargs): return Response({'method': 'link2'}) + @list_route(custom_method_name="list_custom-route") + def list_custom_route_get(self, request, *args, **kwargs): + return Response({'method': 'link1'}) + + @detail_route(custom_method_name="detail_custom-route") + def detail_custom_route_get(self, request, *args, **kwargs): + return Response({'method': 'link2'}) + class TestDynamicListAndDetailRouter(TestCase): def setUp(self): @@ -268,22 +277,33 @@ class TestDynamicListAndDetailRouter(TestCase): def test_list_and_detail_route_decorators(self): routes = self.router.get_routes(DynamicListAndDetailViewSet) decorator_routes = [r for r in routes if not (r.name.endswith('-list') or r.name.endswith('-detail'))] + + MethodNamesMap = namedtuple('MethodNamesMap', 'method_name custom_method_name') # Make sure all these endpoints exist and none have been clobbered - for i, endpoint in enumerate(['list_route_get', 'list_route_post', 'detail_route_get', 'detail_route_post']): + for i, endpoint in enumerate([MethodNamesMap('list_custom_route_get', 'list_custom-route'), + MethodNamesMap('list_route_get', 'list_route_get'), + MethodNamesMap('list_route_post', 'list_route_post'), + MethodNamesMap('detail_custom_route_get', 'detail_custom-route'), + MethodNamesMap('detail_route_get', 'detail_route_get'), + MethodNamesMap('detail_route_post', 'detail_route_post') + ]): route = decorator_routes[i] # check url listing - if endpoint.startswith('list_'): + method_name = endpoint.method_name + custom_method_name = endpoint.custom_method_name + + if method_name.startswith('list_'): self.assertEqual(route.url, - '^{{prefix}}/{0}{{trailing_slash}}$'.format(endpoint)) + '^{{prefix}}/{0}{{trailing_slash}}$'.format(custom_method_name)) else: self.assertEqual(route.url, - '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(endpoint)) + '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(custom_method_name)) # check method to function mapping - if endpoint.endswith('_post'): + if method_name.endswith('_post'): method_map = 'post' else: method_map = 'get' - self.assertEqual(route.mapping[method_map], endpoint) + self.assertEqual(route.mapping[method_map], method_name) class TestRootWithAListlessViewset(TestCase): -- cgit v1.2.3 From 2448cc8e856369ca6fb99b848e10f8ff0105e925 Mon Sep 17 00:00:00 2001 From: tanwanirahul Date: Fri, 19 Dec 2014 19:53:48 +0530 Subject: Updated tests to use url_path attribute in list and detail decorators --- tests/test_routers.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'tests') diff --git a/tests/test_routers.py b/tests/test_routers.py index d426f832..73d10822 100644 --- a/tests/test_routers.py +++ b/tests/test_routers.py @@ -261,11 +261,11 @@ class DynamicListAndDetailViewSet(viewsets.ViewSet): def detail_route_get(self, request, *args, **kwargs): return Response({'method': 'link2'}) - @list_route(custom_method_name="list_custom-route") + @list_route(url_path="list_custom-route") def list_custom_route_get(self, request, *args, **kwargs): return Response({'method': 'link1'}) - @detail_route(custom_method_name="detail_custom-route") + @detail_route(url_path="detail_custom-route") def detail_custom_route_get(self, request, *args, **kwargs): return Response({'method': 'link2'}) @@ -278,7 +278,7 @@ class TestDynamicListAndDetailRouter(TestCase): routes = self.router.get_routes(DynamicListAndDetailViewSet) decorator_routes = [r for r in routes if not (r.name.endswith('-list') or r.name.endswith('-detail'))] - MethodNamesMap = namedtuple('MethodNamesMap', 'method_name custom_method_name') + MethodNamesMap = namedtuple('MethodNamesMap', 'method_name url_path') # Make sure all these endpoints exist and none have been clobbered for i, endpoint in enumerate([MethodNamesMap('list_custom_route_get', 'list_custom-route'), MethodNamesMap('list_route_get', 'list_route_get'), @@ -290,14 +290,14 @@ class TestDynamicListAndDetailRouter(TestCase): route = decorator_routes[i] # check url listing method_name = endpoint.method_name - custom_method_name = endpoint.custom_method_name + url_path = endpoint.url_path if method_name.startswith('list_'): self.assertEqual(route.url, - '^{{prefix}}/{0}{{trailing_slash}}$'.format(custom_method_name)) + '^{{prefix}}/{0}{{trailing_slash}}$'.format(url_path)) else: self.assertEqual(route.url, - '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(custom_method_name)) + '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(url_path)) # check method to function mapping if method_name.endswith('_post'): method_map = 'post' -- cgit v1.2.3 From 2a1485e00943b8280245d19e1e1f8514b1ef18ea Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 19 Dec 2014 21:32:43 +0000 Subject: Final bits of docs for ModelSerializer fields API --- tests/test_model_serializer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py index 5c56c8db..603faf47 100644 --- a/tests/test_model_serializer.py +++ b/tests/test_model_serializer.py @@ -206,7 +206,7 @@ class TestRegularFieldMappings(TestCase): with self.assertRaises(ImproperlyConfigured) as excinfo: TestSerializer().fields - expected = 'Field name `invalid` is not valid for model `ModelBase`.' + expected = 'Field name `invalid` is not valid for model `RegularFieldsModel`.' assert str(excinfo.exception) == expected def test_missing_field(self): -- cgit v1.2.3 From 77e3021fea3e30382b9770eac25371495e0b156b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sat, 20 Dec 2014 16:26:51 +0000 Subject: Better behaviour with null and '' for blank HTML fields. --- tests/test_fields.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) (limited to 'tests') diff --git a/tests/test_fields.py b/tests/test_fields.py index 04c721d3..775d4618 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -223,8 +223,8 @@ class MockHTMLDict(dict): getlist = None -class TestCharHTMLInput: - def test_empty_html_checkbox(self): +class TestHTMLInput: + def test_empty_html_charfield(self): class TestSerializer(serializers.Serializer): message = serializers.CharField(default='happy') @@ -232,23 +232,31 @@ class TestCharHTMLInput: assert serializer.is_valid() assert serializer.validated_data == {'message': 'happy'} - def test_empty_html_checkbox_allow_null(self): + def test_empty_html_charfield_allow_null(self): class TestSerializer(serializers.Serializer): message = serializers.CharField(allow_null=True) - serializer = TestSerializer(data=MockHTMLDict()) + serializer = TestSerializer(data=MockHTMLDict({'message': ''})) assert serializer.is_valid() assert serializer.validated_data == {'message': None} - def test_empty_html_checkbox_allow_null_allow_blank(self): + def test_empty_html_datefield_allow_null(self): + class TestSerializer(serializers.Serializer): + expiry = serializers.DateField(allow_null=True) + + serializer = TestSerializer(data=MockHTMLDict({'expiry': ''})) + assert serializer.is_valid() + assert serializer.validated_data == {'expiry': None} + + def test_empty_html_charfield_allow_null_allow_blank(self): class TestSerializer(serializers.Serializer): message = serializers.CharField(allow_null=True, allow_blank=True) - serializer = TestSerializer(data=MockHTMLDict({})) + serializer = TestSerializer(data=MockHTMLDict({'message': ''})) assert serializer.is_valid() assert serializer.validated_data == {'message': ''} - def test_empty_html_required_false(self): + def test_empty_html_charfield_required_false(self): class TestSerializer(serializers.Serializer): message = serializers.CharField(required=False) -- cgit v1.2.3 From b32ecdefbace063c5b9b465af608ac6404795dd4 Mon Sep 17 00:00:00 2001 From: Remi Paulmier Date: Wed, 24 Dec 2014 14:07:28 +0100 Subject: modified the tests accordingly --- tests/test_model_serializer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py index da79164a..ee556dbc 100644 --- a/tests/test_model_serializer.py +++ b/tests/test_model_serializer.py @@ -119,7 +119,7 @@ class TestRegularFieldMappings(TestCase): positive_small_integer_field = IntegerField() slug_field = SlugField(max_length=100) small_integer_field = IntegerField() - text_field = CharField(style={'type': 'textarea'}) + text_field = CharField(style={'base_template': 'textarea.html'}) time_field = TimeField() url_field = URLField(max_length=100) custom_field = ModelField(model_field=) -- cgit v1.2.3 From 7b42c5ed17a2430d66da88932ad4e81492d9b914 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sun, 28 Dec 2014 11:14:32 +0000 Subject: Remove broken test. Closes #2359. --- tests/test_routers.py | 16 ---------------- 1 file changed, 16 deletions(-) (limited to 'tests') diff --git a/tests/test_routers.py b/tests/test_routers.py index 06ab8103..2b6cd7d2 100644 --- a/tests/test_routers.py +++ b/tests/test_routers.py @@ -305,19 +305,3 @@ class TestDynamicListAndDetailRouter(TestCase): else: method_map = 'get' self.assertEqual(route.mapping[method_map], method_name) - - -class TestRootWithAListlessViewset(TestCase): - def setUp(self): - class NoteViewSet(mixins.RetrieveModelMixin, - viewsets.GenericViewSet): - model = RouterTestModel - - self.router = DefaultRouter() - self.router.register(r'notes', NoteViewSet) - self.view = self.router.urls[0].callback - - def test_api_root(self): - request = factory.get('/') - response = self.view(request) - self.assertEqual(response.data, {}) -- cgit v1.2.3 From 67fc002f91e5dc617dab45895ded32d6be6c2a40 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sun, 28 Dec 2014 11:26:38 +0000 Subject: Drop unused import --- tests/test_routers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/test_routers.py b/tests/test_routers.py index 2b6cd7d2..fc22a8d9 100644 --- a/tests/test_routers.py +++ b/tests/test_routers.py @@ -3,7 +3,7 @@ from django.conf.urls import patterns, url, include from django.db import models from django.test import TestCase from django.core.exceptions import ImproperlyConfigured -from rest_framework import serializers, viewsets, mixins, permissions +from rest_framework import serializers, viewsets, permissions from rest_framework.decorators import detail_route, list_route from rest_framework.response import Response from rest_framework.routers import SimpleRouter, DefaultRouter -- cgit v1.2.3 From efa5942ce1c5d2286fd91994b52fb73a5690426c Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sun, 28 Dec 2014 12:02:52 +0000 Subject: Support namespaced router URLs with DefaultRouter. --- tests/test_routers.py | 94 +++++++++++++++++++++++++++++++++------------------ 1 file changed, 62 insertions(+), 32 deletions(-) (limited to 'tests') diff --git a/tests/test_routers.py b/tests/test_routers.py index fc22a8d9..86113f5d 100644 --- a/tests/test_routers.py +++ b/tests/test_routers.py @@ -1,5 +1,5 @@ from __future__ import unicode_literals -from django.conf.urls import patterns, url, include +from django.conf.urls import url, include from django.db import models from django.test import TestCase from django.core.exceptions import ImproperlyConfigured @@ -12,7 +12,42 @@ from collections import namedtuple factory = APIRequestFactory() -urlpatterns = patterns('',) + +class RouterTestModel(models.Model): + uuid = models.CharField(max_length=20) + text = models.CharField(max_length=200) + + +class NoteSerializer(serializers.HyperlinkedModelSerializer): + url = serializers.HyperlinkedIdentityField(view_name='routertestmodel-detail', lookup_field='uuid') + + class Meta: + model = RouterTestModel + fields = ('url', 'uuid', 'text') + + +class NoteViewSet(viewsets.ModelViewSet): + queryset = RouterTestModel.objects.all() + serializer_class = NoteSerializer + lookup_field = 'uuid' + + +class MockViewSet(viewsets.ModelViewSet): + queryset = None + serializer_class = None + + +notes_router = SimpleRouter() +notes_router.register(r'notes', NoteViewSet) + +namespaced_router = DefaultRouter() +namespaced_router.register(r'example', MockViewSet, base_name='example') + +urlpatterns = [ + url(r'^non-namespaced/', include(namespaced_router.urls)), + url(r'^namespaced/', include(namespaced_router.urls, namespace='example')), + url(r'^example/', include(notes_router.urls)), +] class BasicViewSet(viewsets.ViewSet): @@ -64,9 +99,26 @@ class TestSimpleRouter(TestCase): self.assertEqual(route.mapping[method], endpoint) -class RouterTestModel(models.Model): - uuid = models.CharField(max_length=20) - text = models.CharField(max_length=200) +class TestRootView(TestCase): + urls = 'tests.test_routers' + + def test_retrieve_namespaced_root(self): + response = self.client.get('/namespaced/') + self.assertEqual( + response.data, + { + "example": "http://testserver/namespaced/example/", + } + ) + + def test_retrieve_non_namespaced_root(self): + response = self.client.get('/non-namespaced/') + self.assertEqual( + response.data, + { + "example": "http://testserver/non-namespaced/example/", + } + ) class TestCustomLookupFields(TestCase): @@ -76,51 +128,29 @@ class TestCustomLookupFields(TestCase): urls = 'tests.test_routers' def setUp(self): - class NoteSerializer(serializers.HyperlinkedModelSerializer): - url = serializers.HyperlinkedIdentityField(view_name='routertestmodel-detail', lookup_field='uuid') - - class Meta: - model = RouterTestModel - fields = ('url', 'uuid', 'text') - - class NoteViewSet(viewsets.ModelViewSet): - queryset = RouterTestModel.objects.all() - serializer_class = NoteSerializer - lookup_field = 'uuid' - - self.router = SimpleRouter() - self.router.register(r'notes', NoteViewSet) - - from tests import test_routers - urls = getattr(test_routers, 'urlpatterns') - urls += patterns( - '', - url(r'^', include(self.router.urls)), - ) - RouterTestModel.objects.create(uuid='123', text='foo bar') def test_custom_lookup_field_route(self): - detail_route = self.router.urls[-1] + detail_route = notes_router.urls[-1] detail_url_pattern = detail_route.regex.pattern self.assertIn('', detail_url_pattern) def test_retrieve_lookup_field_list_view(self): - response = self.client.get('/notes/') + response = self.client.get('/example/notes/') self.assertEqual( response.data, [{ - "url": "http://testserver/notes/123/", + "url": "http://testserver/example/notes/123/", "uuid": "123", "text": "foo bar" }] ) def test_retrieve_lookup_field_detail_view(self): - response = self.client.get('/notes/123/') + response = self.client.get('/example/notes/123/') self.assertEqual( response.data, { - "url": "http://testserver/notes/123/", + "url": "http://testserver/example/notes/123/", "uuid": "123", "text": "foo bar" } ) -- cgit v1.2.3 From 32506e20756c84677abb5ae49706446a0d250371 Mon Sep 17 00:00:00 2001 From: Craig Blaszczyk Date: Wed, 31 Dec 2014 13:14:09 +0000 Subject: update expected error messages in tests --- tests/test_fields.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) (limited to 'tests') diff --git a/tests/test_fields.py b/tests/test_fields.py index 04c721d3..61d39aff 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -640,8 +640,8 @@ class TestDateField(FieldValues): datetime.date(2001, 1, 1): datetime.date(2001, 1, 1), } invalid_inputs = { - 'abc': ['Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]'], - '2001-99-99': ['Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]'], + 'abc': ['Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]].'], + '2001-99-99': ['Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]].'], datetime.datetime(2001, 1, 1, 12, 00): ['Expected a date but got a datetime.'], } outputs = { @@ -658,7 +658,7 @@ class TestCustomInputFormatDateField(FieldValues): '1 Jan 2001': datetime.date(2001, 1, 1), } invalid_inputs = { - '2001-01-01': ['Date has wrong format. Use one of these formats instead: DD [Jan-Dec] YYYY'] + '2001-01-01': ['Date has wrong format. Use one of these formats instead: DD [Jan-Dec] YYYY.'] } outputs = {} field = serializers.DateField(input_formats=['%d %b %Y']) @@ -702,8 +702,8 @@ class TestDateTimeField(FieldValues): '2001-01-01T14:00+01:00' if (django.VERSION > (1, 4)) else '2001-01-01T13:00Z': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()) } invalid_inputs = { - 'abc': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]'], - '2001-99-99T99:00': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]'], + 'abc': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z].'], + '2001-99-99T99:00': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z].'], datetime.date(2001, 1, 1): ['Expected a datetime but got a date.'], } outputs = { @@ -721,7 +721,7 @@ class TestCustomInputFormatDateTimeField(FieldValues): '1:35pm, 1 Jan 2001': datetime.datetime(2001, 1, 1, 13, 35, tzinfo=timezone.UTC()), } invalid_inputs = { - '2001-01-01T20:50': ['Datetime has wrong format. Use one of these formats instead: hh:mm[AM|PM], DD [Jan-Dec] YYYY'] + '2001-01-01T20:50': ['Datetime has wrong format. Use one of these formats instead: hh:mm[AM|PM], DD [Jan-Dec] YYYY.'] } outputs = {} field = serializers.DateTimeField(default_timezone=timezone.UTC(), input_formats=['%I:%M%p, %d %b %Y']) @@ -773,8 +773,8 @@ class TestTimeField(FieldValues): datetime.time(13, 00): datetime.time(13, 00), } invalid_inputs = { - 'abc': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]]'], - '99:99': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]]'], + 'abc': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]].'], + '99:99': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]].'], } outputs = { datetime.time(13, 00): '13:00:00' @@ -790,7 +790,7 @@ class TestCustomInputFormatTimeField(FieldValues): '1:00pm': datetime.time(13, 00), } invalid_inputs = { - '13:00': ['Time has wrong format. Use one of these formats instead: hh:mm[AM|PM]'], + '13:00': ['Time has wrong format. Use one of these formats instead: hh:mm[AM|PM].'], } outputs = {} field = serializers.TimeField(input_formats=['%I:%M%p']) @@ -1028,7 +1028,7 @@ class TestListField(FieldValues): (['1', '2', '3'], [1, 2, 3]) ] invalid_inputs = [ - ('not a list', ['Expected a list of items but got type `str`']), + ('not a list', ['Expected a list of items but got type `str`.']), ([1, 2, 'error'], ['A valid integer is required.']) ] outputs = [ -- cgit v1.2.3 From b6ca7248ebcf95a95e1911aa0b130f653b8bf690 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 5 Jan 2015 14:32:12 +0000 Subject: required=False allows omission of value for output. Closes #2342 --- tests/test_serializer.py | 62 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) (limited to 'tests') diff --git a/tests/test_serializer.py b/tests/test_serializer.py index c17b6d8c..68bbbe98 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -1,5 +1,6 @@ # coding: utf-8 from __future__ import unicode_literals +from .utils import MockObject from rest_framework import serializers from rest_framework.compat import unicode_repr import pytest @@ -216,3 +217,64 @@ class TestUnicodeRepr: instance = ExampleObject() serializer = ExampleSerializer(instance) repr(serializer) # Should not error. + + +class TestNotRequiredOutput: + def test_not_required_output_for_dict(self): + """ + 'required=False' should allow a dictionary key to be missing in output. + """ + class ExampleSerializer(serializers.Serializer): + omitted = serializers.CharField(required=False) + included = serializers.CharField() + + serializer = ExampleSerializer(data={'included': 'abc'}) + serializer.is_valid() + assert serializer.data == {'included': 'abc'} + + def test_not_required_output_for_object(self): + """ + 'required=False' should allow an object attribute to be missing in output. + """ + class ExampleSerializer(serializers.Serializer): + omitted = serializers.CharField(required=False) + included = serializers.CharField() + + def create(self, validated_data): + return MockObject(**validated_data) + + serializer = ExampleSerializer(data={'included': 'abc'}) + serializer.is_valid() + serializer.save() + assert serializer.data == {'included': 'abc'} + + def test_default_required_output_for_dict(self): + """ + 'default="something"' should require dictionary key. + + We need to handle this as the field will have an implicit + 'required=False', but it should still have a value. + """ + class ExampleSerializer(serializers.Serializer): + omitted = serializers.CharField(default='abc') + included = serializers.CharField() + + serializer = ExampleSerializer({'included': 'abc'}) + with pytest.raises(KeyError): + serializer.data + + def test_default_required_output_for_object(self): + """ + 'default="something"' should require object attribute. + + We need to handle this as the field will have an implicit + 'required=False', but it should still have a value. + """ + class ExampleSerializer(serializers.Serializer): + omitted = serializers.CharField(default='abc') + included = serializers.CharField() + + instance = MockObject(included='abc') + serializer = ExampleSerializer(instance) + with pytest.raises(AttributeError): + serializer.data -- cgit v1.2.3 From 91e316f7810157474d6246cd0024bd7f7cc31ff7 Mon Sep 17 00:00:00 2001 From: Craig Blaszczyk Date: Wed, 7 Jan 2015 12:46:23 +0000 Subject: prefer single quotes in source and double quotes in user visible strings; add some missing full stops to user visible strings --- tests/test_fields.py | 2 +- tests/test_generics.py | 6 +++--- tests/test_relations.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) (limited to 'tests') diff --git a/tests/test_fields.py b/tests/test_fields.py index 61d39aff..5ecb9857 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -439,7 +439,7 @@ class TestSlugField(FieldValues): 'slug-99': 'slug-99', } invalid_inputs = { - 'slug 99': ["Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens."] + 'slug 99': ['Enter a valid "slug" consisting of letters, numbers, underscores or hyphens.'] } outputs = {} field = serializers.SlugField() diff --git a/tests/test_generics.py b/tests/test_generics.py index 94023c30..fba8718f 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -117,7 +117,7 @@ class TestRootView(TestCase): with self.assertNumQueries(0): response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) - self.assertEqual(response.data, {"detail": "Method 'PUT' not allowed."}) + self.assertEqual(response.data, {"detail": 'Method "PUT" not allowed.'}) def test_delete_root_view(self): """ @@ -127,7 +127,7 @@ class TestRootView(TestCase): with self.assertNumQueries(0): response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) - self.assertEqual(response.data, {"detail": "Method 'DELETE' not allowed."}) + self.assertEqual(response.data, {"detail": 'Method "DELETE" not allowed.'}) def test_post_cannot_set_id(self): """ @@ -181,7 +181,7 @@ class TestInstanceView(TestCase): with self.assertNumQueries(0): response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) - self.assertEqual(response.data, {"detail": "Method 'POST' not allowed."}) + self.assertEqual(response.data, {"detail": 'Method "POST" not allowed.'}) def test_put_instance_view(self): """ diff --git a/tests/test_relations.py b/tests/test_relations.py index 62353dc2..08c92242 100644 --- a/tests/test_relations.py +++ b/tests/test_relations.py @@ -33,7 +33,7 @@ class TestPrimaryKeyRelatedField(APISimpleTestCase): with pytest.raises(serializers.ValidationError) as excinfo: self.field.to_internal_value(4) msg = excinfo.value.detail[0] - assert msg == "Invalid pk '4' - object does not exist." + assert msg == 'Invalid pk "4" - object does not exist.' def test_pk_related_lookup_invalid_type(self): with pytest.raises(serializers.ValidationError) as excinfo: -- cgit v1.2.3 From 7f8d314101c4e6e059b00ac12658f0e1055da8f7 Mon Sep 17 00:00:00 2001 From: Craig Blaszczyk Date: Thu, 8 Jan 2015 17:16:47 +0000 Subject: update tests to expect new error messages --- tests/test_fields.py | 18 +++++++++--------- tests/test_serializer_bulk_update.py | 4 ++-- 2 files changed, 11 insertions(+), 11 deletions(-) (limited to 'tests') diff --git a/tests/test_fields.py b/tests/test_fields.py index 5ecb9857..240827ee 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -338,7 +338,7 @@ class TestBooleanField(FieldValues): False: False, } invalid_inputs = { - 'foo': ['`foo` is not a valid boolean.'], + 'foo': ['"foo" is not a valid boolean.'], None: ['This field may not be null.'] } outputs = { @@ -368,7 +368,7 @@ class TestNullBooleanField(FieldValues): None: None } invalid_inputs = { - 'foo': ['`foo` is not a valid boolean.'], + 'foo': ['"foo" is not a valid boolean.'], } outputs = { 'true': True, @@ -832,7 +832,7 @@ class TestChoiceField(FieldValues): 'good': 'good', } invalid_inputs = { - 'amazing': ['`amazing` is not a valid choice.'] + 'amazing': ['"amazing" is not a valid choice.'] } outputs = { 'good': 'good', @@ -872,8 +872,8 @@ class TestChoiceFieldWithType(FieldValues): 3: 3, } invalid_inputs = { - 5: ['`5` is not a valid choice.'], - 'abc': ['`abc` is not a valid choice.'] + 5: ['"5" is not a valid choice.'], + 'abc': ['"abc" is not a valid choice.'] } outputs = { '1': 1, @@ -899,7 +899,7 @@ class TestChoiceFieldWithListChoices(FieldValues): 'good': 'good', } invalid_inputs = { - 'awful': ['`awful` is not a valid choice.'] + 'awful': ['"awful" is not a valid choice.'] } outputs = { 'good': 'good' @@ -917,8 +917,8 @@ class TestMultipleChoiceField(FieldValues): ('aircon', 'manual'): set(['aircon', 'manual']), } invalid_inputs = { - 'abc': ['Expected a list of items but got type `str`.'], - ('aircon', 'incorrect'): ['`incorrect` is not a valid choice.'] + 'abc': ['Expected a list of items but got type "str".'], + ('aircon', 'incorrect'): ['"incorrect" is not a valid choice.'] } outputs = [ (['aircon', 'manual'], set(['aircon', 'manual'])) @@ -1028,7 +1028,7 @@ class TestListField(FieldValues): (['1', '2', '3'], [1, 2, 3]) ] invalid_inputs = [ - ('not a list', ['Expected a list of items but got type `str`.']), + ('not a list', ['Expected a list of items but got type "str".']), ([1, 2, 'error'], ['A valid integer is required.']) ] outputs = [ diff --git a/tests/test_serializer_bulk_update.py b/tests/test_serializer_bulk_update.py index fb881a75..bc955b2e 100644 --- a/tests/test_serializer_bulk_update.py +++ b/tests/test_serializer_bulk_update.py @@ -101,7 +101,7 @@ class BulkCreateSerializerTests(TestCase): serializer = self.BookSerializer(data=data, many=True) self.assertEqual(serializer.is_valid(), False) - expected_errors = {'non_field_errors': ['Expected a list of items but got type `int`.']} + expected_errors = {'non_field_errors': ['Expected a list of items but got type "int".']} self.assertEqual(serializer.errors, expected_errors) @@ -118,6 +118,6 @@ class BulkCreateSerializerTests(TestCase): serializer = self.BookSerializer(data=data, many=True) self.assertEqual(serializer.is_valid(), False) - expected_errors = {'non_field_errors': ['Expected a list of items but got type `dict`.']} + expected_errors = {'non_field_errors': ['Expected a list of items but got type "dict".']} self.assertEqual(serializer.errors, expected_errors) -- cgit v1.2.3 From 73feaf6299827607eab94ce96b77b73671880626 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 9 Jan 2015 15:30:36 +0000 Subject: First pass at 3.1 pagination API --- tests/test_pagination.py | 216 +---------------------------------------------- 1 file changed, 1 insertion(+), 215 deletions(-) (limited to 'tests') diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 1fd9cf9c..d410cd5e 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -1,10 +1,9 @@ from __future__ import unicode_literals import datetime from decimal import Decimal -from django.core.paginator import Paginator from django.test import TestCase from django.utils import unittest -from rest_framework import generics, serializers, status, pagination, filters +from rest_framework import generics, serializers, status, filters from rest_framework.compat import django_filters from rest_framework.test import APIRequestFactory from .models import BasicModel, FilterableItem @@ -238,45 +237,6 @@ class IntegrationTestPaginationAndFiltering(TestCase): self.assertEqual(response.data['previous'], None) -class PassOnContextPaginationSerializer(pagination.PaginationSerializer): - class Meta: - object_serializer_class = serializers.Serializer - - -class UnitTestPagination(TestCase): - """ - Unit tests for pagination of primitive objects. - """ - - def setUp(self): - self.objects = [char * 3 for char in 'abcdefghijklmnopqrstuvwxyz'] - paginator = Paginator(self.objects, 10) - self.first_page = paginator.page(1) - self.last_page = paginator.page(3) - - def test_native_pagination(self): - serializer = pagination.PaginationSerializer(self.first_page) - self.assertEqual(serializer.data['count'], 26) - self.assertEqual(serializer.data['next'], '?page=2') - self.assertEqual(serializer.data['previous'], None) - self.assertEqual(serializer.data['results'], self.objects[:10]) - - serializer = pagination.PaginationSerializer(self.last_page) - self.assertEqual(serializer.data['count'], 26) - self.assertEqual(serializer.data['next'], None) - self.assertEqual(serializer.data['previous'], '?page=2') - self.assertEqual(serializer.data['results'], self.objects[20:]) - - def test_context_available_in_result(self): - """ - Ensure context gets passed through to the object serializer. - """ - serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'}) - serializer.data - results = serializer.fields[serializer.results_field] - self.assertEqual(serializer.context, results.context) - - class TestUnpaginated(TestCase): """ Tests for list views without pagination. @@ -377,177 +337,3 @@ class TestMaxPaginateByParam(TestCase): request = factory.get('/') response = self.view(request).render() self.assertEqual(response.data['results'], self.data[:3]) - - -# Tests for context in pagination serializers - -class CustomField(serializers.ReadOnlyField): - def to_native(self, value): - if 'view' not in self.context: - raise RuntimeError("context isn't getting passed into custom field") - return "value" - - -class BasicModelSerializer(serializers.Serializer): - text = CustomField() - - def to_native(self, value): - if 'view' not in self.context: - raise RuntimeError("context isn't getting passed into serializer") - return super(BasicSerializer, self).to_native(value) - - -class TestContextPassedToCustomField(TestCase): - def setUp(self): - BasicModel.objects.create(text='ala ma kota') - - def test_with_pagination(self): - class ListView(generics.ListCreateAPIView): - queryset = BasicModel.objects.all() - serializer_class = BasicModelSerializer - paginate_by = 1 - - self.view = ListView.as_view() - request = factory.get('/') - response = self.view(request).render() - - self.assertEqual(response.status_code, status.HTTP_200_OK) - - -# Tests for custom pagination serializers - -class LinksSerializer(serializers.Serializer): - next = pagination.NextPageField(source='*') - prev = pagination.PreviousPageField(source='*') - - -class CustomPaginationSerializer(pagination.BasePaginationSerializer): - links = LinksSerializer(source='*') # Takes the page object as the source - total_results = serializers.ReadOnlyField(source='paginator.count') - - results_field = 'objects' - - -class CustomFooSerializer(serializers.Serializer): - foo = serializers.CharField() - - -class CustomFooPaginationSerializer(pagination.PaginationSerializer): - class Meta: - object_serializer_class = CustomFooSerializer - - -class TestCustomPaginationSerializer(TestCase): - def setUp(self): - objects = ['john', 'paul', 'george', 'ringo'] - paginator = Paginator(objects, 2) - self.page = paginator.page(1) - - def test_custom_pagination_serializer(self): - request = APIRequestFactory().get('/foobar') - serializer = CustomPaginationSerializer( - instance=self.page, - context={'request': request} - ) - expected = { - 'links': { - 'next': 'http://testserver/foobar?page=2', - 'prev': None - }, - 'total_results': 4, - 'objects': ['john', 'paul'] - } - self.assertEqual(serializer.data, expected) - - def test_custom_pagination_serializer_with_custom_object_serializer(self): - objects = [ - {'foo': 'bar'}, - {'foo': 'spam'} - ] - paginator = Paginator(objects, 1) - page = paginator.page(1) - serializer = CustomFooPaginationSerializer(page) - serializer.data - - -class NonIntegerPage(object): - - def __init__(self, paginator, object_list, prev_token, token, next_token): - self.paginator = paginator - self.object_list = object_list - self.prev_token = prev_token - self.token = token - self.next_token = next_token - - def has_next(self): - return not not self.next_token - - def next_page_number(self): - return self.next_token - - def has_previous(self): - return not not self.prev_token - - def previous_page_number(self): - return self.prev_token - - -class NonIntegerPaginator(object): - - def __init__(self, object_list, per_page): - self.object_list = object_list - self.per_page = per_page - - def count(self): - # pretend like we don't know how many pages we have - return None - - def page(self, token=None): - if token: - try: - first = self.object_list.index(token) - except ValueError: - first = 0 - else: - first = 0 - n = len(self.object_list) - last = min(first + self.per_page, n) - prev_token = self.object_list[last - (2 * self.per_page)] if first else None - next_token = self.object_list[last] if last < n else None - return NonIntegerPage(self, self.object_list[first:last], prev_token, token, next_token) - - -class TestNonIntegerPagination(TestCase): - def test_custom_pagination_serializer(self): - objects = ['john', 'paul', 'george', 'ringo'] - paginator = NonIntegerPaginator(objects, 2) - - request = APIRequestFactory().get('/foobar') - serializer = CustomPaginationSerializer( - instance=paginator.page(), - context={'request': request} - ) - expected = { - 'links': { - 'next': 'http://testserver/foobar?page={0}'.format(objects[2]), - 'prev': None - }, - 'total_results': None, - 'objects': objects[:2] - } - self.assertEqual(serializer.data, expected) - - request = APIRequestFactory().get('/foobar') - serializer = CustomPaginationSerializer( - instance=paginator.page('george'), - context={'request': request} - ) - expected = { - 'links': { - 'next': None, - 'prev': 'http://testserver/foobar?page={0}'.format(objects[0]), - }, - 'total_results': None, - 'objects': objects[2:] - } - self.assertEqual(serializer.data, expected) -- cgit v1.2.3 From 53edd37df5aa0ac29dbe7824db2e33da1d901f98 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 15 Jan 2015 21:07:05 +0000 Subject: Tests for LimitOffsetPagination --- tests/test_pagination.py | 117 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 116 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/test_pagination.py b/tests/test_pagination.py index d410cd5e..32fe7a66 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -3,8 +3,10 @@ import datetime from decimal import Decimal from django.test import TestCase from django.utils import unittest -from rest_framework import generics, serializers, status, filters +from rest_framework import generics, pagination, serializers, status, filters from rest_framework.compat import django_filters +from rest_framework.request import Request +from rest_framework.pagination import PageLink, PAGE_BREAK from rest_framework.test import APIRequestFactory from .models import BasicModel, FilterableItem @@ -337,3 +339,116 @@ class TestMaxPaginateByParam(TestCase): request = factory.get('/') response = self.view(request).render() self.assertEqual(response.data['results'], self.data[:3]) + + +class TestLimitOffset: + def setup(self): + self.pagination = pagination.LimitOffsetPagination() + self.queryset = range(1, 101) + + def paginate_queryset(self, request): + return self.pagination.paginate_queryset(self.queryset, request) + + def get_paginated_content(self, queryset): + response = self.pagination.get_paginated_response(queryset) + return response.data + + def get_html_context(self): + return self.pagination.get_html_context() + + def test_no_offset(self): + request = Request(factory.get('/', {'limit': 5})) + queryset = self.paginate_queryset(request) + content = self.get_paginated_content(queryset) + context = self.get_html_context() + assert queryset == [1, 2, 3, 4, 5] + assert content == { + 'results': [1, 2, 3, 4, 5], + 'previous': None, + 'next': 'http://testserver/?limit=5&offset=5', + 'count': 100 + } + assert context == { + 'previous_url': None, + 'next_url': 'http://testserver/?limit=5&offset=5', + 'page_links': [ + PageLink('http://testserver/?limit=5', 1, True, False), + PageLink('http://testserver/?limit=5&offset=5', 2, False, False), + PageLink('http://testserver/?limit=5&offset=10', 3, False, False), + PAGE_BREAK, + PageLink('http://testserver/?limit=5&offset=95', 20, False, False), + ] + } + + def test_first_offset(self): + request = Request(factory.get('/', {'limit': 5, 'offset': 5})) + queryset = self.paginate_queryset(request) + content = self.get_paginated_content(queryset) + context = self.get_html_context() + assert queryset == [6, 7, 8, 9, 10] + assert content == { + 'results': [6, 7, 8, 9, 10], + 'previous': 'http://testserver/?limit=5', + 'next': 'http://testserver/?limit=5&offset=10', + 'count': 100 + } + assert context == { + 'previous_url': 'http://testserver/?limit=5', + 'next_url': 'http://testserver/?limit=5&offset=10', + 'page_links': [ + PageLink('http://testserver/?limit=5', 1, False, False), + PageLink('http://testserver/?limit=5&offset=5', 2, True, False), + PageLink('http://testserver/?limit=5&offset=10', 3, False, False), + PAGE_BREAK, + PageLink('http://testserver/?limit=5&offset=95', 20, False, False), + ] + } + + def test_middle_offset(self): + request = Request(factory.get('/', {'limit': 5, 'offset': 10})) + queryset = self.paginate_queryset(request) + content = self.get_paginated_content(queryset) + context = self.get_html_context() + assert queryset == [11, 12, 13, 14, 15] + assert content == { + 'results': [11, 12, 13, 14, 15], + 'previous': 'http://testserver/?limit=5&offset=5', + 'next': 'http://testserver/?limit=5&offset=15', + 'count': 100 + } + assert context == { + 'previous_url': 'http://testserver/?limit=5&offset=5', + 'next_url': 'http://testserver/?limit=5&offset=15', + 'page_links': [ + PageLink('http://testserver/?limit=5', 1, False, False), + PageLink('http://testserver/?limit=5&offset=5', 2, False, False), + PageLink('http://testserver/?limit=5&offset=10', 3, True, False), + PageLink('http://testserver/?limit=5&offset=15', 4, False, False), + PAGE_BREAK, + PageLink('http://testserver/?limit=5&offset=95', 20, False, False), + ] + } + + def test_ending_offset(self): + request = Request(factory.get('/', {'limit': 5, 'offset': 95})) + queryset = self.paginate_queryset(request) + content = self.get_paginated_content(queryset) + context = self.get_html_context() + assert queryset == [96, 97, 98, 99, 100] + assert content == { + 'results': [96, 97, 98, 99, 100], + 'previous': 'http://testserver/?limit=5&offset=90', + 'next': None, + 'count': 100 + } + assert context == { + 'previous_url': 'http://testserver/?limit=5&offset=90', + 'next_url': None, + 'page_links': [ + PageLink('http://testserver/?limit=5', 1, False, False), + PAGE_BREAK, + PageLink('http://testserver/?limit=5&offset=85', 18, False, False), + PageLink('http://testserver/?limit=5&offset=90', 19, False, False), + PageLink('http://testserver/?limit=5&offset=95', 20, True, False), + ] + } -- cgit v1.2.3 From 50db8c092ab51a5eb94e2bb495c317097fceeb59 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 16 Jan 2015 16:55:28 +0000 Subject: Minor test cleanup --- tests/test_metadata.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) (limited to 'tests') diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 5ff59c72..972a896a 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -1,9 +1,7 @@ from __future__ import unicode_literals - -from rest_framework import exceptions, serializers, views +from rest_framework import exceptions, serializers, status, views from rest_framework.request import Request from rest_framework.test import APIRequestFactory -import pytest request = Request(APIRequestFactory().options('/')) @@ -17,7 +15,8 @@ class TestMetadata: """Example view.""" pass - response = ExampleView().options(request=request) + view = ExampleView.as_view() + response = view(request=request) expected = { 'name': 'Example', 'description': 'Example view.', @@ -31,7 +30,7 @@ class TestMetadata: 'multipart/form-data' ] } - assert response.status_code == 200 + assert response.status_code == status.HTTP_200_OK assert response.data == expected def test_none_metadata(self): @@ -42,8 +41,10 @@ class TestMetadata: class ExampleView(views.APIView): metadata_class = None - with pytest.raises(exceptions.MethodNotAllowed): - ExampleView().options(request=request) + view = ExampleView.as_view() + response = view(request=request) + assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED + assert response.data == {'detail': 'Method "OPTIONS" not allowed.'} def test_actions(self): """ @@ -63,7 +64,8 @@ class TestMetadata: def get_serializer(self): return ExampleSerializer() - response = ExampleView().options(request=request) + view = ExampleView.as_view() + response = view(request=request) expected = { 'name': 'Example', 'description': 'Example view.', @@ -104,7 +106,7 @@ class TestMetadata: } } } - assert response.status_code == 200 + assert response.status_code == status.HTTP_200_OK assert response.data == expected def test_global_permissions(self): @@ -132,8 +134,9 @@ class TestMetadata: if request.method == 'POST': raise exceptions.PermissionDenied() - response = ExampleView().options(request=request) - assert response.status_code == 200 + view = ExampleView.as_view() + response = view(request=request) + assert response.status_code == status.HTTP_200_OK assert list(response.data['actions'].keys()) == ['PUT'] def test_object_permissions(self): @@ -161,6 +164,7 @@ class TestMetadata: if self.request.method == 'PUT': raise exceptions.PermissionDenied() - response = ExampleView().options(request=request) - assert response.status_code == 200 + view = ExampleView.as_view() + response = view(request=request) + assert response.status_code == status.HTTP_200_OK assert list(response.data['actions'].keys()) == ['POST'] -- cgit v1.2.3 From 8b0f25aa0a91cb7b56f9ce4dde4330fe5daaad9b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 16 Jan 2015 16:55:46 +0000 Subject: More pagination tests & cleanup --- tests/test_pagination.py | 629 ++++++++++++++++++++++++----------------------- 1 file changed, 325 insertions(+), 304 deletions(-) (limited to 'tests') diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 32fe7a66..b3436b35 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -1,349 +1,270 @@ from __future__ import unicode_literals -import datetime -from decimal import Decimal -from django.test import TestCase -from django.utils import unittest -from rest_framework import generics, pagination, serializers, status, filters -from rest_framework.compat import django_filters +from rest_framework import exceptions, generics, pagination, serializers, status, filters from rest_framework.request import Request from rest_framework.pagination import PageLink, PAGE_BREAK from rest_framework.test import APIRequestFactory -from .models import BasicModel, FilterableItem +import pytest factory = APIRequestFactory() -# Helper function to split arguments out of an url -def split_arguments_from_url(url): - if '?' not in url: - return url - - path, args = url.split('?') - args = dict(r.split('=') for r in args.split('&')) - return path, args +class TestPaginationIntegration: + """ + Integration tests. + """ + def setup(self): + class PassThroughSerializer(serializers.BaseSerializer): + def to_representation(self, item): + return item -class BasicSerializer(serializers.ModelSerializer): - class Meta: - model = BasicModel + class EvenItemsOnly(filters.BaseFilterBackend): + def filter_queryset(self, request, queryset, view): + return [item for item in queryset if item % 2 == 0] + + class BasicPagination(pagination.PageNumberPagination): + paginate_by = 5 + paginate_by_param = 'page_size' + max_paginate_by = 20 + + self.view = generics.ListAPIView.as_view( + serializer_class=PassThroughSerializer, + queryset=range(1, 101), + filter_backends=[EvenItemsOnly], + pagination_class=BasicPagination + ) + + def test_filtered_items_are_paginated(self): + request = factory.get('/', {'page': 2}) + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == { + 'results': [12, 14, 16, 18, 20], + 'previous': 'http://testserver/', + 'next': 'http://testserver/?page=3', + 'count': 50 + } + def test_setting_page_size(self): + """ + When 'paginate_by_param' is set, the client may choose a page size. + """ + request = factory.get('/', {'page_size': 10}) + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == { + 'results': [2, 4, 6, 8, 10, 12, 14, 16, 18, 20], + 'previous': None, + 'next': 'http://testserver/?page=2&page_size=10', + 'count': 50 + } -class FilterableItemSerializer(serializers.ModelSerializer): - class Meta: - model = FilterableItem + def test_setting_page_size_over_maximum(self): + """ + When page_size parameter exceeds maxiumum allowable, + then it should be capped to the maxiumum. + """ + request = factory.get('/', {'page_size': 1000}) + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == { + 'results': [ + 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, + 22, 24, 26, 28, 30, 32, 34, 36, 38, 40 + ], + 'previous': None, + 'next': 'http://testserver/?page=2&page_size=1000', + 'count': 50 + } + def test_additional_query_params_are_preserved(self): + request = factory.get('/', {'page': 2, 'filter': 'even'}) + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == { + 'results': [12, 14, 16, 18, 20], + 'previous': 'http://testserver/?filter=even', + 'next': 'http://testserver/?filter=even&page=3', + 'count': 50 + } -class RootView(generics.ListCreateAPIView): - """ - Example description for OPTIONS. - """ - queryset = BasicModel.objects.all() - serializer_class = BasicSerializer - paginate_by = 10 + def test_404_not_found_for_invalid_page(self): + request = factory.get('/', {'page': 'invalid'}) + response = self.view(request) + assert response.status_code == status.HTTP_404_NOT_FOUND + assert response.data == { + 'detail': 'Invalid page "invalid": That page number is not an integer.' + } -class DefaultPageSizeKwargView(generics.ListAPIView): +class TestPaginationDisabledIntegration: """ - View for testing default paginate_by_param usage + Integration tests for disabled pagination. """ - queryset = BasicModel.objects.all() - serializer_class = BasicSerializer - -class PaginateByParamView(generics.ListAPIView): - """ - View for testing custom paginate_by_param usage - """ - queryset = BasicModel.objects.all() - serializer_class = BasicSerializer - paginate_by_param = 'page_size' + def setup(self): + class PassThroughSerializer(serializers.BaseSerializer): + def to_representation(self, item): + return item + + self.view = generics.ListAPIView.as_view( + serializer_class=PassThroughSerializer, + queryset=range(1, 101), + pagination_class=None + ) + + def test_unpaginated_list(self): + request = factory.get('/', {'page': 2}) + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == range(1, 101) -class MaxPaginateByView(generics.ListAPIView): +class TestDeprecatedStylePagination: """ - View for testing custom max_paginate_by usage + Integration tests for deprecated style of setting pagination + attributes on the view. """ - queryset = BasicModel.objects.all() - serializer_class = BasicSerializer - paginate_by = 3 - max_paginate_by = 5 - paginate_by_param = 'page_size' - -class IntegrationTestPagination(TestCase): - """ - Integration tests for paginated list views. - """ + def setup(self): + class PassThroughSerializer(serializers.BaseSerializer): + def to_representation(self, item): + return item - def setUp(self): - """ - Create 26 BasicModel instances. - """ - for char in 'abcdefghijklmnopqrstuvwxyz': - BasicModel(text=char * 3).save() - self.objects = BasicModel.objects - self.data = [ - {'id': obj.id, 'text': obj.text} - for obj in self.objects.all() - ] - self.view = RootView.as_view() - - def test_get_paginated_root_view(self): - """ - GET requests to paginated ListCreateAPIView should return paginated results. - """ - request = factory.get('/') - # Note: Database queries are a `SELECT COUNT`, and `SELECT ` - with self.assertNumQueries(2): - response = self.view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 26) - self.assertEqual(response.data['results'], self.data[:10]) - self.assertNotEqual(response.data['next'], None) - self.assertEqual(response.data['previous'], None) - - request = factory.get(*split_arguments_from_url(response.data['next'])) - with self.assertNumQueries(2): - response = self.view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 26) - self.assertEqual(response.data['results'], self.data[10:20]) - self.assertNotEqual(response.data['next'], None) - self.assertNotEqual(response.data['previous'], None) - - request = factory.get(*split_arguments_from_url(response.data['next'])) - with self.assertNumQueries(2): - response = self.view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 26) - self.assertEqual(response.data['results'], self.data[20:]) - self.assertEqual(response.data['next'], None) - self.assertNotEqual(response.data['previous'], None) - - -class IntegrationTestPaginationAndFiltering(TestCase): - - def setUp(self): - """ - Create 50 FilterableItem instances. - """ - base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8)) - for i in range(26): - text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc. - decimal = base_data[1] + i - date = base_data[2] - datetime.timedelta(days=i * 2) - FilterableItem(text=text, decimal=decimal, date=date).save() - - self.objects = FilterableItem.objects - self.data = [ - {'id': obj.id, 'text': obj.text, 'decimal': str(obj.decimal), 'date': obj.date.isoformat()} - for obj in self.objects.all() - ] - - @unittest.skipUnless(django_filters, 'django-filter not installed') - def test_get_django_filter_paginated_filtered_root_view(self): - """ - GET requests to paginated filtered ListCreateAPIView should return - paginated results. The next and previous links should preserve the - filtered parameters. - """ - class DecimalFilter(django_filters.FilterSet): - decimal = django_filters.NumberFilter(lookup_type='lt') - - class Meta: - model = FilterableItem - fields = ['text', 'decimal', 'date'] - - class FilterFieldsRootView(generics.ListCreateAPIView): - queryset = FilterableItem.objects.all() - serializer_class = FilterableItemSerializer - paginate_by = 10 - filter_class = DecimalFilter - filter_backends = (filters.DjangoFilterBackend,) - - view = FilterFieldsRootView.as_view() - - EXPECTED_NUM_QUERIES = 2 - - request = factory.get('/', {'decimal': '15.20'}) - with self.assertNumQueries(EXPECTED_NUM_QUERIES): - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 15) - self.assertEqual(response.data['results'], self.data[:10]) - self.assertNotEqual(response.data['next'], None) - self.assertEqual(response.data['previous'], None) - - request = factory.get(*split_arguments_from_url(response.data['next'])) - with self.assertNumQueries(EXPECTED_NUM_QUERIES): - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 15) - self.assertEqual(response.data['results'], self.data[10:15]) - self.assertEqual(response.data['next'], None) - self.assertNotEqual(response.data['previous'], None) - - request = factory.get(*split_arguments_from_url(response.data['previous'])) - with self.assertNumQueries(EXPECTED_NUM_QUERIES): - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 15) - self.assertEqual(response.data['results'], self.data[:10]) - self.assertNotEqual(response.data['next'], None) - self.assertEqual(response.data['previous'], None) - - def test_get_basic_paginated_filtered_root_view(self): - """ - Same as `test_get_django_filter_paginated_filtered_root_view`, - except using a custom filter backend instead of the django-filter - backend, - """ + class ExampleView(generics.ListAPIView): + serializer_class = PassThroughSerializer + queryset = range(1, 101) + pagination_class = pagination.PageNumberPagination + paginate_by = 20 + page_query_param = 'page_number' - class DecimalFilterBackend(filters.BaseFilterBackend): - def filter_queryset(self, request, queryset, view): - return queryset.filter(decimal__lt=Decimal(request.GET['decimal'])) - - class BasicFilterFieldsRootView(generics.ListCreateAPIView): - queryset = FilterableItem.objects.all() - serializer_class = FilterableItemSerializer - paginate_by = 10 - filter_backends = (DecimalFilterBackend,) - - view = BasicFilterFieldsRootView.as_view() - - request = factory.get('/', {'decimal': '15.20'}) - with self.assertNumQueries(2): - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 15) - self.assertEqual(response.data['results'], self.data[:10]) - self.assertNotEqual(response.data['next'], None) - self.assertEqual(response.data['previous'], None) - - request = factory.get(*split_arguments_from_url(response.data['next'])) - with self.assertNumQueries(2): - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 15) - self.assertEqual(response.data['results'], self.data[10:15]) - self.assertEqual(response.data['next'], None) - self.assertNotEqual(response.data['previous'], None) - - request = factory.get(*split_arguments_from_url(response.data['previous'])) - with self.assertNumQueries(2): - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 15) - self.assertEqual(response.data['results'], self.data[:10]) - self.assertNotEqual(response.data['next'], None) - self.assertEqual(response.data['previous'], None) - - -class TestUnpaginated(TestCase): - """ - Tests for list views without pagination. - """ + self.view = ExampleView.as_view() - def setUp(self): - """ - Create 13 BasicModel instances. - """ - for i in range(13): - BasicModel(text=i).save() - self.objects = BasicModel.objects - self.data = [ - {'id': obj.id, 'text': obj.text} - for obj in self.objects.all() - ] - self.view = DefaultPageSizeKwargView.as_view() - - def test_unpaginated(self): - """ - Tests the default page size for this view. - no page size --> no limit --> no meta data - """ - request = factory.get('/') + def test_paginate_by_attribute_on_view(self): + request = factory.get('/?page_number=2') response = self.view(request) - self.assertEqual(response.data, self.data) + assert response.status_code == status.HTTP_200_OK + assert response.data == { + 'results': [ + 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40 + ], + 'previous': 'http://testserver/', + 'next': 'http://testserver/?page_number=3', + 'count': 100 + } -class TestCustomPaginateByParam(TestCase): +class TestPageNumberPagination: """ - Tests for list views with default page size kwarg + Unit tests for `pagination.PageNumberPagination`. """ - def setUp(self): - """ - Create 13 BasicModel instances. - """ - for i in range(13): - BasicModel(text=i).save() - self.objects = BasicModel.objects - self.data = [ - {'id': obj.id, 'text': obj.text} - for obj in self.objects.all() - ] - self.view = PaginateByParamView.as_view() - - def test_default_page_size(self): - """ - Tests the default page size for this view. - no page size --> no limit --> no meta data - """ - request = factory.get('/') - response = self.view(request).render() - self.assertEqual(response.data, self.data) + def setup(self): + class ExamplePagination(pagination.PageNumberPagination): + paginate_by = 5 + self.pagination = ExamplePagination() + self.queryset = range(1, 101) - def test_paginate_by_param(self): - """ - If paginate_by_param is set, the new kwarg should limit per view requests. - """ - request = factory.get('/', {'page_size': 5}) - response = self.view(request).render() - self.assertEqual(response.data['count'], 13) - self.assertEqual(response.data['results'], self.data[:5]) + def paginate_queryset(self, request): + return list(self.pagination.paginate_queryset(self.queryset, request)) + def get_paginated_content(self, queryset): + response = self.pagination.get_paginated_response(queryset) + return response.data -class TestMaxPaginateByParam(TestCase): - """ - Tests for list views with max_paginate_by kwarg - """ + def get_html_context(self): + return self.pagination.get_html_context() - def setUp(self): - """ - Create 13 BasicModel instances. - """ - for i in range(13): - BasicModel(text=i).save() - self.objects = BasicModel.objects - self.data = [ - {'id': obj.id, 'text': obj.text} - for obj in self.objects.all() - ] - self.view = MaxPaginateByView.as_view() - - def test_max_paginate_by(self): - """ - If max_paginate_by is set, it should limit page size for the view. - """ - request = factory.get('/', data={'page_size': 10}) - response = self.view(request).render() - self.assertEqual(response.data['count'], 13) - self.assertEqual(response.data['results'], self.data[:5]) + def test_no_page_number(self): + request = Request(factory.get('/')) + queryset = self.paginate_queryset(request) + content = self.get_paginated_content(queryset) + context = self.get_html_context() + assert queryset == [1, 2, 3, 4, 5] + assert content == { + 'results': [1, 2, 3, 4, 5], + 'previous': None, + 'next': 'http://testserver/?page=2', + 'count': 100 + } + assert context == { + 'previous_url': None, + 'next_url': 'http://testserver/?page=2', + 'page_links': [ + PageLink('http://testserver/', 1, True, False), + PageLink('http://testserver/?page=2', 2, False, False), + PageLink('http://testserver/?page=3', 3, False, False), + PAGE_BREAK, + PageLink('http://testserver/?page=20', 20, False, False), + ] + } + assert self.pagination.display_page_controls + assert isinstance(self.pagination.to_html(), type('')) - def test_max_paginate_by_without_page_size_param(self): - """ - If max_paginate_by is set, but client does not specifiy page_size, - standard `paginate_by` behavior should be used. - """ - request = factory.get('/') - response = self.view(request).render() - self.assertEqual(response.data['results'], self.data[:3]) + def test_second_page(self): + request = Request(factory.get('/', {'page': 2})) + queryset = self.paginate_queryset(request) + content = self.get_paginated_content(queryset) + context = self.get_html_context() + assert queryset == [6, 7, 8, 9, 10] + assert content == { + 'results': [6, 7, 8, 9, 10], + 'previous': 'http://testserver/', + 'next': 'http://testserver/?page=3', + 'count': 100 + } + assert context == { + 'previous_url': 'http://testserver/', + 'next_url': 'http://testserver/?page=3', + 'page_links': [ + PageLink('http://testserver/', 1, False, False), + PageLink('http://testserver/?page=2', 2, True, False), + PageLink('http://testserver/?page=3', 3, False, False), + PAGE_BREAK, + PageLink('http://testserver/?page=20', 20, False, False), + ] + } + + def test_last_page(self): + request = Request(factory.get('/', {'page': 'last'})) + queryset = self.paginate_queryset(request) + content = self.get_paginated_content(queryset) + context = self.get_html_context() + assert queryset == [96, 97, 98, 99, 100] + assert content == { + 'results': [96, 97, 98, 99, 100], + 'previous': 'http://testserver/?page=19', + 'next': None, + 'count': 100 + } + assert context == { + 'previous_url': 'http://testserver/?page=19', + 'next_url': None, + 'page_links': [ + PageLink('http://testserver/', 1, False, False), + PAGE_BREAK, + PageLink('http://testserver/?page=18', 18, False, False), + PageLink('http://testserver/?page=19', 19, False, False), + PageLink('http://testserver/?page=20', 20, True, False), + ] + } + + def test_invalid_page(self): + request = Request(factory.get('/', {'page': 'invalid'})) + with pytest.raises(exceptions.NotFound): + self.paginate_queryset(request) class TestLimitOffset: + """ + Unit tests for `pagination.LimitOffsetPagination`. + """ + def setup(self): - self.pagination = pagination.LimitOffsetPagination() + class ExamplePagination(pagination.LimitOffsetPagination): + default_limit = 10 + self.pagination = ExamplePagination() self.queryset = range(1, 101) def paginate_queryset(self, request): @@ -379,6 +300,37 @@ class TestLimitOffset: PageLink('http://testserver/?limit=5&offset=95', 20, False, False), ] } + assert self.pagination.display_page_controls + assert isinstance(self.pagination.to_html(), type('')) + + def test_single_offset(self): + """ + When the offset is not a multiple of the limit we get some edge cases: + * The first page should still be offset zero. + * We may end up displaying an extra page in the pagination control. + """ + request = Request(factory.get('/', {'limit': 5, 'offset': 1})) + queryset = self.paginate_queryset(request) + content = self.get_paginated_content(queryset) + context = self.get_html_context() + assert queryset == [2, 3, 4, 5, 6] + assert content == { + 'results': [2, 3, 4, 5, 6], + 'previous': 'http://testserver/?limit=5', + 'next': 'http://testserver/?limit=5&offset=6', + 'count': 100 + } + assert context == { + 'previous_url': 'http://testserver/?limit=5', + 'next_url': 'http://testserver/?limit=5&offset=6', + 'page_links': [ + PageLink('http://testserver/?limit=5', 1, False, False), + PageLink('http://testserver/?limit=5&offset=1', 2, True, False), + PageLink('http://testserver/?limit=5&offset=6', 3, False, False), + PAGE_BREAK, + PageLink('http://testserver/?limit=5&offset=96', 21, False, False), + ] + } def test_first_offset(self): request = Request(factory.get('/', {'limit': 5, 'offset': 5})) @@ -452,3 +404,72 @@ class TestLimitOffset: PageLink('http://testserver/?limit=5&offset=95', 20, True, False), ] } + + def test_invalid_offset(self): + """ + An invalid offset query param should be treated as 0. + """ + request = Request(factory.get('/', {'limit': 5, 'offset': 'invalid'})) + queryset = self.paginate_queryset(request) + assert queryset == [1, 2, 3, 4, 5] + + def test_invalid_limit(self): + """ + An invalid limit query param should be ignored in favor of the default. + """ + request = Request(factory.get('/', {'limit': 'invalid', 'offset': 0})) + queryset = self.paginate_queryset(request) + assert queryset == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + + +def test_get_displayed_page_numbers(): + """ + Test our contextual page display function. + + This determines which pages to display in a pagination control, + given the current page and the last page. + """ + displayed_page_numbers = pagination._get_displayed_page_numbers + + # At five pages or less, all pages are displayed, always. + assert displayed_page_numbers(1, 5) == [1, 2, 3, 4, 5] + assert displayed_page_numbers(2, 5) == [1, 2, 3, 4, 5] + assert displayed_page_numbers(3, 5) == [1, 2, 3, 4, 5] + assert displayed_page_numbers(4, 5) == [1, 2, 3, 4, 5] + assert displayed_page_numbers(5, 5) == [1, 2, 3, 4, 5] + + # Between six and either pages we may have a single page break. + assert displayed_page_numbers(1, 6) == [1, 2, 3, None, 6] + assert displayed_page_numbers(2, 6) == [1, 2, 3, None, 6] + assert displayed_page_numbers(3, 6) == [1, 2, 3, 4, 5, 6] + assert displayed_page_numbers(4, 6) == [1, 2, 3, 4, 5, 6] + assert displayed_page_numbers(5, 6) == [1, None, 4, 5, 6] + assert displayed_page_numbers(6, 6) == [1, None, 4, 5, 6] + + assert displayed_page_numbers(1, 7) == [1, 2, 3, None, 7] + assert displayed_page_numbers(2, 7) == [1, 2, 3, None, 7] + assert displayed_page_numbers(3, 7) == [1, 2, 3, 4, None, 7] + assert displayed_page_numbers(4, 7) == [1, 2, 3, 4, 5, 6, 7] + assert displayed_page_numbers(5, 7) == [1, None, 4, 5, 6, 7] + assert displayed_page_numbers(6, 7) == [1, None, 5, 6, 7] + assert displayed_page_numbers(7, 7) == [1, None, 5, 6, 7] + + assert displayed_page_numbers(1, 8) == [1, 2, 3, None, 8] + assert displayed_page_numbers(2, 8) == [1, 2, 3, None, 8] + assert displayed_page_numbers(3, 8) == [1, 2, 3, 4, None, 8] + assert displayed_page_numbers(4, 8) == [1, 2, 3, 4, 5, None, 8] + assert displayed_page_numbers(5, 8) == [1, None, 4, 5, 6, 7, 8] + assert displayed_page_numbers(6, 8) == [1, None, 5, 6, 7, 8] + assert displayed_page_numbers(7, 8) == [1, None, 6, 7, 8] + assert displayed_page_numbers(8, 8) == [1, None, 6, 7, 8] + + # At nine or more pages we may have two page breaks, one on each side. + assert displayed_page_numbers(1, 9) == [1, 2, 3, None, 9] + assert displayed_page_numbers(2, 9) == [1, 2, 3, None, 9] + assert displayed_page_numbers(3, 9) == [1, 2, 3, 4, None, 9] + assert displayed_page_numbers(4, 9) == [1, 2, 3, 4, 5, None, 9] + assert displayed_page_numbers(5, 9) == [1, None, 4, 5, 6, None, 9] + assert displayed_page_numbers(6, 9) == [1, None, 5, 6, 7, 8, 9] + assert displayed_page_numbers(7, 9) == [1, None, 6, 7, 8, 9] + assert displayed_page_numbers(8, 9) == [1, None, 7, 8, 9] + assert displayed_page_numbers(9, 9) == [1, None, 7, 8, 9] -- cgit v1.2.3 From 86d2774cf30351fd4174e97501532056ed0d8f95 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 16 Jan 2015 20:30:46 +0000 Subject: Fix compat issues --- tests/test_pagination.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/test_pagination.py b/tests/test_pagination.py index b3436b35..7cc92347 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -117,7 +117,7 @@ class TestPaginationDisabledIntegration: request = factory.get('/', {'page': 2}) response = self.view(request) assert response.status_code == status.HTTP_200_OK - assert response.data == range(1, 101) + assert response.data == list(range(1, 101)) class TestDeprecatedStylePagination: @@ -268,7 +268,7 @@ class TestLimitOffset: self.queryset = range(1, 101) def paginate_queryset(self, request): - return self.pagination.paginate_queryset(self.queryset, request) + return list(self.pagination.paginate_queryset(self.queryset, request)) def get_paginated_content(self, queryset): response = self.pagination.get_paginated_response(queryset) -- cgit v1.2.3 From 4919492582547d227a22852ad2339fa73739cc94 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sat, 17 Jan 2015 00:10:43 +0000 Subject: First pass at cursor pagination --- tests/test_pagination.py | 88 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) (limited to 'tests') diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 7cc92347..7f18b446 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -422,6 +422,94 @@ class TestLimitOffset: assert queryset == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] +class TestCursorPagination: + """ + Unit tests for `pagination.CursorPagination`. + """ + + def setup(self): + class MockObject(object): + def __init__(self, idx): + self.created = idx + + class MockQuerySet(object): + def __init__(self, items): + self.items = items + + def filter(self, created__gt): + return [ + item for item in self.items + if item.created > int(created__gt) + ] + + def __getitem__(self, sliced): + return self.items[sliced] + + self.pagination = pagination.CursorPagination() + self.queryset = MockQuerySet( + [MockObject(idx) for idx in range(1, 21)] + ) + + def paginate_queryset(self, request): + return list(self.pagination.paginate_queryset(self.queryset, request)) + + # def get_paginated_content(self, queryset): + # response = self.pagination.get_paginated_response(queryset) + # return response.data + + # def get_html_context(self): + # return self.pagination.get_html_context() + + def test_following_cursor(self): + request = Request(factory.get('/')) + queryset = self.paginate_queryset(request) + assert [item.created for item in queryset] == [1, 2, 3, 4, 5] + + next_url = self.pagination.get_next_link() + assert next_url + + request = Request(factory.get(next_url)) + queryset = self.paginate_queryset(request) + assert [item.created for item in queryset] == [6, 7, 8, 9, 10] + + next_url = self.pagination.get_next_link() + assert next_url + + request = Request(factory.get(next_url)) + queryset = self.paginate_queryset(request) + assert [item.created for item in queryset] == [11, 12, 13, 14, 15] + + next_url = self.pagination.get_next_link() + assert next_url + + request = Request(factory.get(next_url)) + queryset = self.paginate_queryset(request) + assert [item.created for item in queryset] == [16, 17, 18, 19, 20] + + next_url = self.pagination.get_next_link() + assert next_url is None + + # assert content == { + # 'results': [1, 2, 3, 4, 5], + # 'previous': None, + # 'next': 'http://testserver/?limit=5&offset=5', + # 'count': 100 + # } + # assert context == { + # 'previous_url': None, + # 'next_url': 'http://testserver/?limit=5&offset=5', + # 'page_links': [ + # PageLink('http://testserver/?limit=5', 1, True, False), + # PageLink('http://testserver/?limit=5&offset=5', 2, False, False), + # PageLink('http://testserver/?limit=5&offset=10', 3, False, False), + # PAGE_BREAK, + # PageLink('http://testserver/?limit=5&offset=95', 20, False, False), + # ] + # } + # assert self.pagination.display_page_controls + # assert isinstance(self.pagination.to_html(), type('')) + + def test_get_displayed_page_numbers(): """ Test our contextual page display function. -- cgit v1.2.3 From b5128ca574d03ea590a198b04043142a3fc7163e Mon Sep 17 00:00:00 2001 From: David Muller Date: Sun, 18 Jan 2015 15:19:11 -0800 Subject: Save objects before assigning them in InheritedModelSerializationTests; Django 1.8 now throws an error when assigning unsaved objects to Foreign Key, GenericForeignKey, and OneToOneFields --- tests/test_multitable_inheritance.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/test_multitable_inheritance.py b/tests/test_multitable_inheritance.py index e1b40cc7..15627e1d 100644 --- a/tests/test_multitable_inheritance.py +++ b/tests/test_multitable_inheritance.py @@ -48,8 +48,8 @@ class InheritedModelSerializationTests(TestCase): Assert that a model with a onetoone field that is the primary key is not treated like a derived model """ - parent = ParentModel(name1='parent name') - associate = AssociatedModel(name='hello', ref=parent) + parent = ParentModel.objects.create(name1='parent name') + associate = AssociatedModel.objects.create(name='hello', ref=parent) serializer = AssociatedModelSerializer(associate) self.assertEqual(set(serializer.data.keys()), set(['name', 'ref'])) -- cgit v1.2.3 From dbb684117f6fe0f9c34f98d5e914fc106090cdbc Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 19 Jan 2015 09:24:42 +0000 Subject: Add offset support for cursor pagination --- tests/test_pagination.py | 64 +++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 61 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 7f18b446..f04079a7 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -447,7 +447,7 @@ class TestCursorPagination: self.pagination = pagination.CursorPagination() self.queryset = MockQuerySet( - [MockObject(idx) for idx in range(1, 21)] + [MockObject(idx) for idx in range(1, 16)] ) def paginate_queryset(self, request): @@ -479,16 +479,74 @@ class TestCursorPagination: queryset = self.paginate_queryset(request) assert [item.created for item in queryset] == [11, 12, 13, 14, 15] + next_url = self.pagination.get_next_link() + assert next_url is None + + +class TestCrazyCursorPagination: + """ + Unit tests for `pagination.CursorPagination`. + """ + + def setup(self): + class MockObject(object): + def __init__(self, idx): + self.created = idx + + class MockQuerySet(object): + def __init__(self, items): + self.items = items + + def filter(self, created__gt): + return [ + item for item in self.items + if item.created > int(created__gt) + ] + + def __getitem__(self, sliced): + return self.items[sliced] + + self.pagination = pagination.CursorPagination() + self.queryset = MockQuerySet([ + MockObject(idx) for idx in [ + 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, + 1, 1, 2, 3, 4, + 5, 6, 7, 8, 9 + ] + ]) + + def paginate_queryset(self, request): + return list(self.pagination.paginate_queryset(self.queryset, request)) + + def test_following_cursor_identical_items(self): + request = Request(factory.get('/')) + queryset = self.paginate_queryset(request) + assert [item.created for item in queryset] == [1, 1, 1, 1, 1] + next_url = self.pagination.get_next_link() assert next_url request = Request(factory.get(next_url)) queryset = self.paginate_queryset(request) - assert [item.created for item in queryset] == [16, 17, 18, 19, 20] + assert [item.created for item in queryset] == [1, 1, 1, 1, 1] next_url = self.pagination.get_next_link() - assert next_url is None + assert next_url + + request = Request(factory.get(next_url)) + queryset = self.paginate_queryset(request) + assert [item.created for item in queryset] == [1, 1, 2, 3, 4] + + next_url = self.pagination.get_next_link() + assert next_url + + request = Request(factory.get(next_url)) + queryset = self.paginate_queryset(request) + assert [item.created for item in queryset] == [5, 6, 7, 8, 9] + next_url = self.pagination.get_next_link() + assert next_url is None # assert content == { # 'results': [1, 2, 3, 4, 5], # 'previous': None, -- cgit v1.2.3 From 4f3c3a06cfc0ea2dfbf46da2d98546664343ce93 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 19 Jan 2015 14:41:10 +0000 Subject: Drop trailing whitespace on indented JSON output. Closes #2429. --- tests/test_renderers.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/test_renderers.py b/tests/test_renderers.py index 7b78f7ba..3e64d8fe 100644 --- a/tests/test_renderers.py +++ b/tests/test_renderers.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- from __future__ import unicode_literals - from django.conf.urls import patterns, url, include from django.core.cache import cache from django.db import models @@ -8,6 +7,7 @@ from django.test import TestCase from django.utils import six from django.utils.translation import ugettext_lazy as _ from rest_framework import status, permissions +from rest_framework.compat import OrderedDict from rest_framework.response import Response from rest_framework.views import APIView from rest_framework.renderers import BaseRenderer, JSONRenderer, BrowsableAPIRenderer @@ -489,3 +489,25 @@ class CacheRenderTest(TestCase): cached_resp = cache.get(self.cache_key) self.assertIsInstance(cached_resp, Response) self.assertEqual(cached_resp.content, resp.content) + + +class TestJSONIndentationStyles: + def test_indented(self): + renderer = JSONRenderer() + data = OrderedDict([('a', 1), ('b', 2)]) + assert renderer.render(data) == b'{"a":1,"b":2}' + + def test_compact(self): + renderer = JSONRenderer() + data = OrderedDict([('a', 1), ('b', 2)]) + context = {'indent': 4} + assert ( + renderer.render(data, renderer_context=context) == + b'{\n "a": 1,\n "b": 2\n}' + ) + + def test_long_form(self): + renderer = JSONRenderer() + renderer.compact = False + data = OrderedDict([('a', 1), ('b', 2)]) + assert renderer.render(data) == b'{"a": 1, "b": 2}' -- cgit v1.2.3 From da6ef3d0b0f3a8e688524bbd446d4350a74fd05a Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 21 Jan 2015 13:03:37 +0000 Subject: Allow missing fields option for inherited serializers. Closes #2388. --- tests/test_model_serializer.py | 52 +++++++++++++++++++++++++++++++++--------- 1 file changed, 41 insertions(+), 11 deletions(-) (limited to 'tests') diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py index ee556dbc..247b309a 100644 --- a/tests/test_model_serializer.py +++ b/tests/test_model_serializer.py @@ -5,11 +5,14 @@ shortcuts for automatically creating serializers based on a given model class. These tests deal with ensuring that we correctly map the model fields onto an appropriate set of serializer fields for each case. """ +from __future__ import unicode_literals from django.core.exceptions import ImproperlyConfigured from django.core.validators import MaxValueValidator, MinValueValidator, MinLengthValidator from django.db import models from django.test import TestCase +from django.utils import six from rest_framework import serializers +from rest_framework.compat import unicode_repr def dedent(blocktext): @@ -124,7 +127,7 @@ class TestRegularFieldMappings(TestCase): url_field = URLField(max_length=100) custom_field = ModelField(model_field=) """) - self.assertEqual(repr(TestSerializer()), expected) + self.assertEqual(unicode_repr(TestSerializer()), expected) def test_field_options(self): class TestSerializer(serializers.ModelSerializer): @@ -142,7 +145,14 @@ class TestRegularFieldMappings(TestCase): descriptive_field = IntegerField(help_text='Some help text', label='A label') choices_field = ChoiceField(choices=[('red', 'Red'), ('blue', 'Blue'), ('green', 'Green')]) """) - self.assertEqual(repr(TestSerializer()), expected) + if six.PY2: + # This particular case is too awkward to resolve fully across + # both py2 and py3. + expected = expected.replace( + "('red', 'Red'), ('blue', 'Blue'), ('green', 'Green')", + "(u'red', u'Red'), (u'blue', u'Blue'), (u'green', u'Green')" + ) + self.assertEqual(unicode_repr(TestSerializer()), expected) def test_method_field(self): """ @@ -221,7 +231,7 @@ class TestRegularFieldMappings(TestCase): model = RegularFieldsModel fields = ('auto_field',) - with self.assertRaises(ImproperlyConfigured) as excinfo: + with self.assertRaises(AssertionError) as excinfo: TestSerializer().fields expected = ( 'Field `missing` has been declared on serializer ' @@ -229,6 +239,26 @@ class TestRegularFieldMappings(TestCase): ) assert str(excinfo.exception) == expected + def test_missing_superclass_field(self): + """ + Fields that have been declared on a parent of the serializer class may + be excluded from the `Meta.fields` option. + """ + class TestSerializer(serializers.ModelSerializer): + missing = serializers.ReadOnlyField() + + class Meta: + model = RegularFieldsModel + + class ChildSerializer(TestSerializer): + missing = serializers.ReadOnlyField() + + class Meta: + model = RegularFieldsModel + fields = ('auto_field',) + + ChildSerializer().fields + # Tests for relational field mappings. # ------------------------------------ @@ -276,7 +306,7 @@ class TestRelationalFieldMappings(TestCase): many_to_many = PrimaryKeyRelatedField(many=True, queryset=ManyToManyTargetModel.objects.all()) through = PrimaryKeyRelatedField(many=True, read_only=True) """) - self.assertEqual(repr(TestSerializer()), expected) + self.assertEqual(unicode_repr(TestSerializer()), expected) def test_nested_relations(self): class TestSerializer(serializers.ModelSerializer): @@ -300,7 +330,7 @@ class TestRelationalFieldMappings(TestCase): id = IntegerField(label='ID', read_only=True) name = CharField(max_length=100) """) - self.assertEqual(repr(TestSerializer()), expected) + self.assertEqual(unicode_repr(TestSerializer()), expected) def test_hyperlinked_relations(self): class TestSerializer(serializers.HyperlinkedModelSerializer): @@ -315,7 +345,7 @@ class TestRelationalFieldMappings(TestCase): many_to_many = HyperlinkedRelatedField(many=True, queryset=ManyToManyTargetModel.objects.all(), view_name='manytomanytargetmodel-detail') through = HyperlinkedRelatedField(many=True, read_only=True, view_name='throughtargetmodel-detail') """) - self.assertEqual(repr(TestSerializer()), expected) + self.assertEqual(unicode_repr(TestSerializer()), expected) def test_nested_hyperlinked_relations(self): class TestSerializer(serializers.HyperlinkedModelSerializer): @@ -339,7 +369,7 @@ class TestRelationalFieldMappings(TestCase): url = HyperlinkedIdentityField(view_name='throughtargetmodel-detail') name = CharField(max_length=100) """) - self.assertEqual(repr(TestSerializer()), expected) + self.assertEqual(unicode_repr(TestSerializer()), expected) def test_pk_reverse_foreign_key(self): class TestSerializer(serializers.ModelSerializer): @@ -353,7 +383,7 @@ class TestRelationalFieldMappings(TestCase): name = CharField(max_length=100) reverse_foreign_key = PrimaryKeyRelatedField(many=True, queryset=RelationalModel.objects.all()) """) - self.assertEqual(repr(TestSerializer()), expected) + self.assertEqual(unicode_repr(TestSerializer()), expected) def test_pk_reverse_one_to_one(self): class TestSerializer(serializers.ModelSerializer): @@ -367,7 +397,7 @@ class TestRelationalFieldMappings(TestCase): name = CharField(max_length=100) reverse_one_to_one = PrimaryKeyRelatedField(queryset=RelationalModel.objects.all()) """) - self.assertEqual(repr(TestSerializer()), expected) + self.assertEqual(unicode_repr(TestSerializer()), expected) def test_pk_reverse_many_to_many(self): class TestSerializer(serializers.ModelSerializer): @@ -381,7 +411,7 @@ class TestRelationalFieldMappings(TestCase): name = CharField(max_length=100) reverse_many_to_many = PrimaryKeyRelatedField(many=True, queryset=RelationalModel.objects.all()) """) - self.assertEqual(repr(TestSerializer()), expected) + self.assertEqual(unicode_repr(TestSerializer()), expected) def test_pk_reverse_through(self): class TestSerializer(serializers.ModelSerializer): @@ -395,7 +425,7 @@ class TestRelationalFieldMappings(TestCase): name = CharField(max_length=100) reverse_through = PrimaryKeyRelatedField(many=True, read_only=True) """) - self.assertEqual(repr(TestSerializer()), expected) + self.assertEqual(unicode_repr(TestSerializer()), expected) class TestIntegration(TestCase): -- cgit v1.2.3 From e59b3d1718de549d0e165d03aeea1488ddfe20ee Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 21 Jan 2015 14:18:13 +0000 Subject: Make ReturnDict cachable. Closes #2360. --- tests/test_serializer.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) (limited to 'tests') diff --git a/tests/test_serializer.py b/tests/test_serializer.py index 68bbbe98..b7a0484b 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -3,6 +3,7 @@ from __future__ import unicode_literals from .utils import MockObject from rest_framework import serializers from rest_framework.compat import unicode_repr +import pickle import pytest @@ -278,3 +279,19 @@ class TestNotRequiredOutput: serializer = ExampleSerializer(instance) with pytest.raises(AttributeError): serializer.data + + +class TestCacheSerializerData: + def test_cache_serializer_data(self): + """ + Caching serializer data with pickle will drop the serializer info, + but does preserve the data itself. + """ + class ExampleSerializer(serializers.Serializer): + field1 = serializers.CharField() + field2 = serializers.CharField() + + serializer = ExampleSerializer({'field1': 'a', 'field2': 'b'}) + pickled = pickle.dumps(serializer.data) + data = pickle.loads(pickled) + assert data == {'field1': 'a', 'field2': 'b'} -- cgit v1.2.3 From cae9528c54ea13863ea056d40168e8d8df68b276 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 22 Jan 2015 10:28:19 +0000 Subject: Add support for reverse cursors --- tests/test_pagination.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'tests') diff --git a/tests/test_pagination.py b/tests/test_pagination.py index f04079a7..47019671 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -442,6 +442,9 @@ class TestCursorPagination: if item.created > int(created__gt) ] + def order_by(self, ordering): + return self + def __getitem__(self, sliced): return self.items[sliced] @@ -503,6 +506,9 @@ class TestCrazyCursorPagination: if item.created > int(created__gt) ] + def order_by(self, ordering): + return self + def __getitem__(self, sliced): return self.items[sliced] -- cgit v1.2.3 From f1af603fb05fce236a4258e18df8af8888043247 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 22 Jan 2015 10:51:04 +0000 Subject: Tests for reverse pagination --- tests/test_pagination.py | 98 +++++++++++++++++++++++++++++++++++------------- 1 file changed, 71 insertions(+), 27 deletions(-) (limited to 'tests') diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 47019671..4907a080 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -436,13 +436,22 @@ class TestCursorPagination: def __init__(self, items): self.items = items - def filter(self, created__gt): - return [ + def filter(self, created__gt=None, created__lt=None): + if created__gt is not None: + return MockQuerySet([ + item for item in self.items + if item.created > int(created__gt) + ]) + + assert created__lt is not None + return MockQuerySet([ item for item in self.items - if item.created > int(created__gt) - ] + if item.created < int(created__lt) + ]) def order_by(self, ordering): + if ordering.startswith('-'): + return MockQuerySet(reversed(self.items)) return self def __getitem__(self, sliced): @@ -485,6 +494,25 @@ class TestCursorPagination: next_url = self.pagination.get_next_link() assert next_url is None + # Now page back again + + previous_url = self.pagination.get_previous_link() + assert previous_url + + request = Request(factory.get(previous_url)) + queryset = self.paginate_queryset(request) + assert [item.created for item in queryset] == [6, 7, 8, 9, 10] + + previous_url = self.pagination.get_previous_link() + assert previous_url + + request = Request(factory.get(previous_url)) + queryset = self.paginate_queryset(request) + assert [item.created for item in queryset] == [1, 2, 3, 4, 5] + + previous_url = self.pagination.get_previous_link() + assert previous_url is None + class TestCrazyCursorPagination: """ @@ -500,13 +528,22 @@ class TestCrazyCursorPagination: def __init__(self, items): self.items = items - def filter(self, created__gt): - return [ + def filter(self, created__gt=None, created__lt=None): + if created__gt is not None: + return MockQuerySet([ + item for item in self.items + if item.created > int(created__gt) + ]) + + assert created__lt is not None + return MockQuerySet([ item for item in self.items - if item.created > int(created__gt) - ] + if item.created < int(created__lt) + ]) def order_by(self, ordering): + if ordering.startswith('-'): + return MockQuerySet(reversed(self.items)) return self def __getitem__(self, sliced): @@ -553,25 +590,32 @@ class TestCrazyCursorPagination: next_url = self.pagination.get_next_link() assert next_url is None - # assert content == { - # 'results': [1, 2, 3, 4, 5], - # 'previous': None, - # 'next': 'http://testserver/?limit=5&offset=5', - # 'count': 100 - # } - # assert context == { - # 'previous_url': None, - # 'next_url': 'http://testserver/?limit=5&offset=5', - # 'page_links': [ - # PageLink('http://testserver/?limit=5', 1, True, False), - # PageLink('http://testserver/?limit=5&offset=5', 2, False, False), - # PageLink('http://testserver/?limit=5&offset=10', 3, False, False), - # PAGE_BREAK, - # PageLink('http://testserver/?limit=5&offset=95', 20, False, False), - # ] - # } - # assert self.pagination.display_page_controls - # assert isinstance(self.pagination.to_html(), type('')) + + # Now page back again + + previous_url = self.pagination.get_previous_link() + assert previous_url + + request = Request(factory.get(previous_url)) + queryset = self.paginate_queryset(request) + assert [item.created for item in queryset] == [1, 1, 2, 3, 4] + + previous_url = self.pagination.get_previous_link() + assert previous_url + + request = Request(factory.get(previous_url)) + queryset = self.paginate_queryset(request) + assert [item.created for item in queryset] == [1, 1, 1, 1, 1] + + previous_url = self.pagination.get_previous_link() + assert previous_url + + request = Request(factory.get(previous_url)) + queryset = self.paginate_queryset(request) + assert [item.created for item in queryset] == [1, 1, 1, 1, 1] + + previous_url = self.pagination.get_previous_link() + assert previous_url is None def test_get_displayed_page_numbers(): -- cgit v1.2.3 From 94b5f7a86e401e46f14fb8982afaa7a8c61847c9 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 22 Jan 2015 12:14:52 +0000 Subject: Tidy up cursor tests and make more comprehensive --- tests/test_pagination.py | 212 +++++++++++++++++++---------------------------- 1 file changed, 84 insertions(+), 128 deletions(-) (limited to 'tests') diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 4907a080..e32dd028 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -451,171 +451,127 @@ class TestCursorPagination: def order_by(self, ordering): if ordering.startswith('-'): - return MockQuerySet(reversed(self.items)) + return MockQuerySet(list(reversed(self.items))) return self def __getitem__(self, sliced): return self.items[sliced] - self.pagination = pagination.CursorPagination() - self.queryset = MockQuerySet( - [MockObject(idx) for idx in range(1, 16)] - ) - - def paginate_queryset(self, request): - return list(self.pagination.paginate_queryset(self.queryset, request)) - - # def get_paginated_content(self, queryset): - # response = self.pagination.get_paginated_response(queryset) - # return response.data - - # def get_html_context(self): - # return self.pagination.get_html_context() - - def test_following_cursor(self): - request = Request(factory.get('/')) - queryset = self.paginate_queryset(request) - assert [item.created for item in queryset] == [1, 2, 3, 4, 5] - - next_url = self.pagination.get_next_link() - assert next_url + class ExamplePagination(pagination.CursorPagination): + page_size = 5 - request = Request(factory.get(next_url)) - queryset = self.paginate_queryset(request) - assert [item.created for item in queryset] == [6, 7, 8, 9, 10] + self.pagination = ExamplePagination() + self.queryset = MockQuerySet([ + MockObject(idx) for idx in [ + 1, 1, 1, 1, 1, + 1, 2, 3, 4, 4, + 4, 4, 5, 6, 7, + 7, 7, 7, 7, 7, + 7, 7, 7, 8, 9, + 9, 9, 9, 9, 9 + ] + ]) - next_url = self.pagination.get_next_link() - assert next_url + def get_pages(self, url): + """ + Given a URL return a tuple of: - request = Request(factory.get(next_url)) - queryset = self.paginate_queryset(request) - assert [item.created for item in queryset] == [11, 12, 13, 14, 15] + (previous page, current page, next page, previous url, next url) + """ + request = Request(factory.get(url)) + queryset = self.pagination.paginate_queryset(self.queryset, request) + current = [item.created for item in queryset] next_url = self.pagination.get_next_link() - assert next_url is None - - # Now page back again - previous_url = self.pagination.get_previous_link() - assert previous_url - request = Request(factory.get(previous_url)) - queryset = self.paginate_queryset(request) - assert [item.created for item in queryset] == [6, 7, 8, 9, 10] + if next_url is not None: + request = Request(factory.get(next_url)) + queryset = self.pagination.paginate_queryset(self.queryset, request) + next = [item.created for item in queryset] + else: + next = None - previous_url = self.pagination.get_previous_link() - assert previous_url + if previous_url is not None: + request = Request(factory.get(previous_url)) + queryset = self.pagination.paginate_queryset(self.queryset, request) + previous = [item.created for item in queryset] + else: + previous = None - request = Request(factory.get(previous_url)) - queryset = self.paginate_queryset(request) - assert [item.created for item in queryset] == [1, 2, 3, 4, 5] + return (previous, current, next, previous_url, next_url) - previous_url = self.pagination.get_previous_link() - assert previous_url is None - - -class TestCrazyCursorPagination: - """ - Unit tests for `pagination.CursorPagination`. - """ - - def setup(self): - class MockObject(object): - def __init__(self, idx): - self.created = idx - - class MockQuerySet(object): - def __init__(self, items): - self.items = items + def test_invalid_cursor(self): + request = Request(factory.get('/', {'cursor': '123'})) + with pytest.raises(exceptions.NotFound): + self.pagination.paginate_queryset(self.queryset, request) - def filter(self, created__gt=None, created__lt=None): - if created__gt is not None: - return MockQuerySet([ - item for item in self.items - if item.created > int(created__gt) - ]) + def test_cursor_pagination(self): + (previous, current, next, previous_url, next_url) = self.get_pages('/') - assert created__lt is not None - return MockQuerySet([ - item for item in self.items - if item.created < int(created__lt) - ]) + assert previous is None + assert current == [1, 1, 1, 1, 1] + assert next == [1, 2, 3, 4, 4] - def order_by(self, ordering): - if ordering.startswith('-'): - return MockQuerySet(reversed(self.items)) - return self + (previous, current, next, previous_url, next_url) = self.get_pages(next_url) - def __getitem__(self, sliced): - return self.items[sliced] + assert previous == [1, 1, 1, 1, 1] + assert current == [1, 2, 3, 4, 4] + assert next == [4, 4, 5, 6, 7] - self.pagination = pagination.CursorPagination() - self.queryset = MockQuerySet([ - MockObject(idx) for idx in [ - 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, - 1, 1, 2, 3, 4, - 5, 6, 7, 8, 9 - ] - ]) + (previous, current, next, previous_url, next_url) = self.get_pages(next_url) - def paginate_queryset(self, request): - return list(self.pagination.paginate_queryset(self.queryset, request)) + assert previous == [1, 2, 3, 4, 4] + assert current == [4, 4, 5, 6, 7] + assert next == [7, 7, 7, 7, 7] - def test_following_cursor_identical_items(self): - request = Request(factory.get('/')) - queryset = self.paginate_queryset(request) - assert [item.created for item in queryset] == [1, 1, 1, 1, 1] + (previous, current, next, previous_url, next_url) = self.get_pages(next_url) - next_url = self.pagination.get_next_link() - assert next_url + assert previous == [4, 4, 4, 5, 6] # Paging artifact + assert current == [7, 7, 7, 7, 7] + assert next == [7, 7, 7, 8, 9] - request = Request(factory.get(next_url)) - queryset = self.paginate_queryset(request) - assert [item.created for item in queryset] == [1, 1, 1, 1, 1] + (previous, current, next, previous_url, next_url) = self.get_pages(next_url) - next_url = self.pagination.get_next_link() - assert next_url + assert previous == [7, 7, 7, 7, 7] + assert current == [7, 7, 7, 8, 9] + assert next == [9, 9, 9, 9, 9] - request = Request(factory.get(next_url)) - queryset = self.paginate_queryset(request) - assert [item.created for item in queryset] == [1, 1, 2, 3, 4] + (previous, current, next, previous_url, next_url) = self.get_pages(next_url) - next_url = self.pagination.get_next_link() - assert next_url + assert previous == [7, 7, 7, 8, 9] + assert current == [9, 9, 9, 9, 9] + assert next is None - request = Request(factory.get(next_url)) - queryset = self.paginate_queryset(request) - assert [item.created for item in queryset] == [5, 6, 7, 8, 9] + (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) - next_url = self.pagination.get_next_link() - assert next_url is None + assert previous == [7, 7, 7, 7, 7] + assert current == [7, 7, 7, 8, 9] + assert next == [9, 9, 9, 9, 9] - # Now page back again + (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) - previous_url = self.pagination.get_previous_link() - assert previous_url + assert previous == [4, 4, 5, 6, 7] + assert current == [7, 7, 7, 7, 7] + assert next == [8, 9, 9, 9, 9] # Paging artifact - request = Request(factory.get(previous_url)) - queryset = self.paginate_queryset(request) - assert [item.created for item in queryset] == [1, 1, 2, 3, 4] + (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) - previous_url = self.pagination.get_previous_link() - assert previous_url + assert previous == [1, 2, 3, 4, 4] + assert current == [4, 4, 5, 6, 7] + assert next == [7, 7, 7, 7, 7] - request = Request(factory.get(previous_url)) - queryset = self.paginate_queryset(request) - assert [item.created for item in queryset] == [1, 1, 1, 1, 1] + (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) - previous_url = self.pagination.get_previous_link() - assert previous_url + assert previous == [1, 1, 1, 1, 1] + assert current == [1, 2, 3, 4, 4] + assert next == [4, 4, 5, 6, 7] - request = Request(factory.get(previous_url)) - queryset = self.paginate_queryset(request) - assert [item.created for item in queryset] == [1, 1, 1, 1, 1] + (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) - previous_url = self.pagination.get_previous_link() - assert previous_url is None + assert previous is None + assert current == [1, 1, 1, 1, 1] + assert next == [1, 2, 3, 4, 4] def test_get_displayed_page_numbers(): -- cgit v1.2.3 From 83a82b44a56a303d43a16dd675fae116e51b9d85 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 22 Jan 2015 15:07:01 +0000 Subject: Support for tuple ordering in cursor pagination --- tests/test_pagination.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/test_pagination.py b/tests/test_pagination.py index e32dd028..fffdcbe9 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -450,7 +450,7 @@ class TestCursorPagination: ]) def order_by(self, ordering): - if ordering.startswith('-'): + if ordering[0].startswith('-'): return MockQuerySet(list(reversed(self.items))) return self -- cgit v1.2.3 From 408261ee02b176732b7f840f7042e7c24f3ecd27 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 22 Jan 2015 15:15:52 +0000 Subject: Support ordering attribute either on view or on pagination class for CursorPagination --- tests/test_pagination.py | 1 + 1 file changed, 1 insertion(+) (limited to 'tests') diff --git a/tests/test_pagination.py b/tests/test_pagination.py index fffdcbe9..c05b4aba 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -459,6 +459,7 @@ class TestCursorPagination: class ExamplePagination(pagination.CursorPagination): page_size = 5 + ordering = 'created' self.pagination = ExamplePagination() self.queryset = MockQuerySet([ -- cgit v1.2.3 From 0822c9e55820f8e4737329e38abc2e21718af9e5 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 22 Jan 2015 16:12:05 +0000 Subject: Cursor pagination now works with OrderingFilter --- tests/test_pagination.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) (limited to 'tests') diff --git a/tests/test_pagination.py b/tests/test_pagination.py index c05b4aba..338be610 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -77,6 +77,20 @@ class TestPaginationIntegration: 'count': 50 } + def test_setting_page_size_to_zero(self): + """ + When page_size parameter is invalid it should return to the default. + """ + request = factory.get('/', {'page_size': 0}) + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == { + 'results': [2, 4, 6, 8, 10], + 'previous': None, + 'next': 'http://testserver/?page=2&page_size=0', + 'count': 50 + } + def test_additional_query_params_are_preserved(self): request = factory.get('/', {'page': 2, 'filter': 'even'}) response = self.view(request) @@ -88,6 +102,14 @@ class TestPaginationIntegration: 'count': 50 } + def test_404_not_found_for_zero_page(self): + request = factory.get('/', {'page': '0'}) + response = self.view(request) + assert response.status_code == status.HTTP_404_NOT_FOUND + assert response.data == { + 'detail': 'Invalid page "0": That page number is less than 1.' + } + def test_404_not_found_for_invalid_page(self): request = factory.get('/', {'page': 'invalid'}) response = self.view(request) @@ -507,6 +529,24 @@ class TestCursorPagination: with pytest.raises(exceptions.NotFound): self.pagination.paginate_queryset(self.queryset, request) + def test_use_with_ordering_filter(self): + class MockView: + filter_backends = (filters.OrderingFilter,) + ordering_fields = ['username', 'created'] + ordering = 'created' + + request = Request(factory.get('/', {'ordering': 'username'})) + ordering = self.pagination.get_ordering(request, [], MockView()) + assert ordering == ('username',) + + request = Request(factory.get('/', {'ordering': '-username'})) + ordering = self.pagination.get_ordering(request, [], MockView()) + assert ordering == ('-username',) + + request = Request(factory.get('/', {'ordering': 'invalid'})) + ordering = self.pagination.get_ordering(request, [], MockView()) + assert ordering == ('created',) + def test_cursor_pagination(self): (previous, current, next, previous_url, next_url) = self.get_pages('/') -- cgit v1.2.3 From 43d983fae82ab23ca94f52deb29e938eb2a40e88 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 22 Jan 2015 17:25:12 +0000 Subject: Add paging controls --- tests/test_pagination.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 338be610..13bfb627 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -1,3 +1,4 @@ +# coding: utf-8 from __future__ import unicode_literals from rest_framework import exceptions, generics, pagination, serializers, status, filters from rest_framework.request import Request @@ -471,7 +472,7 @@ class TestCursorPagination: if item.created < int(created__lt) ]) - def order_by(self, ordering): + def order_by(self, *ordering): if ordering[0].startswith('-'): return MockQuerySet(list(reversed(self.items))) return self @@ -614,6 +615,8 @@ class TestCursorPagination: assert current == [1, 1, 1, 1, 1] assert next == [1, 2, 3, 4, 4] + assert isinstance(self.pagination.to_html(), type('')) + def test_get_displayed_page_numbers(): """ -- cgit v1.2.3 From e988d578535fcc820d30dc7c59f1e24f5c911d3c Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 23 Jan 2015 11:47:01 +0000 Subject: Fix template loader monkey patching to also support 1.8 --- tests/test_htmlrenderer.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'tests') diff --git a/tests/test_htmlrenderer.py b/tests/test_htmlrenderer.py index 2edc6b4b..a33b832f 100644 --- a/tests/test_htmlrenderer.py +++ b/tests/test_htmlrenderer.py @@ -56,7 +56,13 @@ class TemplateHTMLRendererTests(TestCase): return Template("example: {{ object }}") raise TemplateDoesNotExist(template_name) + def select_template(template_name_list, dirs=None, using=None): + if template_name_list == ['example.html']: + return Template("example: {{ object }}") + raise TemplateDoesNotExist(template_name_list[0]) + django.template.loader.get_template = get_template + django.template.loader.select_template = select_template def tearDown(self): """ -- cgit v1.2.3 From 4cb164b66c0784ce79054925d4744deb5b18d8b2 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 23 Jan 2015 11:49:57 +0000 Subject: Add missing skipUnless(django_filters) --- tests/test_filters.py | 1 + 1 file changed, 1 insertion(+) (limited to 'tests') diff --git a/tests/test_filters.py b/tests/test_filters.py index dc84dcbd..5b1b6ca5 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -467,6 +467,7 @@ class DjangoFilterOrderingTests(TestCase): for d in data: DjangoFilterOrderingModel.objects.create(**d) + @unittest.skipUnless(django_filters, 'django-filter not installed') def test_default_ordering(self): class DjangoFilterOrderingView(generics.ListAPIView): serializer_class = DjangoFilterOrderingSerializer -- cgit v1.2.3 From f1ac9d3f9b6c306b7fa48381006d8259c1642a99 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 23 Jan 2015 12:26:44 +0000 Subject: More graceful handling of malformed Content-Disposition --- tests/test_parsers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/test_parsers.py b/tests/test_parsers.py index d28d8bd4..1d2054ac 100644 --- a/tests/test_parsers.py +++ b/tests/test_parsers.py @@ -161,7 +161,9 @@ class TestFileUploadParser(TestCase): self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8--ÀĥƦ.txt') filename = parser.get_filename(self.stream, None, self.parser_context) - self.assertEqual(filename, 'fallback.txt') + # Malformed. Either None or 'fallback.txt' will be acceptable. + # See also https://code.djangoproject.com/ticket/24209 + self.assertIn(filename, ('fallback.txt', None)) def __replace_content_disposition(self, disposition): self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = disposition -- cgit v1.2.3 From f3b6eedb8aeaa23f4b48551356814837973db31c Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 23 Jan 2015 12:56:55 +0000 Subject: More sensible response caching. --- tests/test_renderers.py | 85 +++++++------------------------------------------ 1 file changed, 12 insertions(+), 73 deletions(-) (limited to 'tests') diff --git a/tests/test_renderers.py b/tests/test_renderers.py index 00a24fb1..54eea8ce 100644 --- a/tests/test_renderers.py +++ b/tests/test_renderers.py @@ -22,7 +22,6 @@ from rest_framework.test import APIRequestFactory from collections import MutableMapping import datetime import json -import pickle import re @@ -618,84 +617,24 @@ class CacheRenderTest(TestCase): urls = 'tests.test_renderers' - cache_key = 'just_a_cache_key' - - @classmethod - def _get_pickling_errors(cls, obj, seen=None): - """ Return any errors that would be raised if `obj' is pickled - Courtesy of koffie @ http://stackoverflow.com/a/7218986/109897 - """ - if seen is None: - seen = [] - try: - state = obj.__getstate__() - except AttributeError: - return - if state is None: - return - if isinstance(state, tuple): - if not isinstance(state[0], dict): - state = state[1] - else: - state = state[0].update(state[1]) - result = {} - for i in state: - try: - pickle.dumps(state[i], protocol=2) - except pickle.PicklingError: - if not state[i] in seen: - seen.append(state[i]) - result[i] = cls._get_pickling_errors(state[i], seen) - return result - - def http_resp(self, http_method, url): - """ - Simple wrapper for Client http requests - Removes the `client' and `request' attributes from as they are - added by django.test.client.Client and not part of caching - responses outside of tests. - """ - method = getattr(self.client, http_method) - resp = method(url) - resp._closable_objects = [] - del resp.client, resp.request - try: - del resp.wsgi_request - except AttributeError: - pass - return resp - - def test_obj_pickling(self): - """ - Test that responses are properly pickled - """ - resp = self.http_resp('get', '/cache') - - # Make sure that no pickling errors occurred - self.assertEqual(self._get_pickling_errors(resp), {}) - - # Unfortunately LocMem backend doesn't raise PickleErrors but returns - # None instead. - cache.set(self.cache_key, resp) - self.assertTrue(cache.get(self.cache_key) is not None) - def test_head_caching(self): """ Test caching of HEAD requests """ - resp = self.http_resp('head', '/cache') - cache.set(self.cache_key, resp) - - cached_resp = cache.get(self.cache_key) - self.assertIsInstance(cached_resp, Response) + response = self.client.head('/cache') + cache.set('key', response) + cached_response = cache.get('key') + assert isinstance(cached_response, Response) + assert cached_response.content == response.content + assert cached_response.status_code == response.status_code def test_get_caching(self): """ Test caching of GET requests """ - resp = self.http_resp('get', '/cache') - cache.set(self.cache_key, resp) - - cached_resp = cache.get(self.cache_key) - self.assertIsInstance(cached_resp, Response) - self.assertEqual(cached_resp.content, resp.content) + response = self.client.get('/cache') + cache.set('key', response) + cached_response = cache.get('key') + assert isinstance(cached_response, Response) + assert cached_response.content == response.content + assert cached_response.status_code == response.status_code -- cgit v1.2.3 From e8db1834d3a3f6ba05276b64e5681288aa8f9820 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 23 Jan 2015 15:24:06 +0000 Subject: Added UUIDField. --- tests/test_fields.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) (limited to 'tests') diff --git a/tests/test_fields.py b/tests/test_fields.py index 775d4618..a46cc205 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -4,6 +4,7 @@ from rest_framework import serializers import datetime import django import pytest +import uuid # Tests for field keyword arguments and core functionality. @@ -467,6 +468,23 @@ class TestURLField(FieldValues): field = serializers.URLField() +class TestUUIDField(FieldValues): + """ + Valid and invalid values for `UUIDField`. + """ + valid_inputs = { + '825d7aeb-05a9-45b5-a5b7-05df87923cda': uuid.UUID('825d7aeb-05a9-45b5-a5b7-05df87923cda'), + '825d7aeb05a945b5a5b705df87923cda': uuid.UUID('825d7aeb-05a9-45b5-a5b7-05df87923cda') + } + invalid_inputs = { + '825d7aeb-05a9-45b5-a5b7': ['"825d7aeb-05a9-45b5-a5b7" is not a valid UUID.'] + } + outputs = { + uuid.UUID('825d7aeb-05a9-45b5-a5b7-05df87923cda'): '825d7aeb-05a9-45b5-a5b7-05df87923cda' + } + field = serializers.UUIDField() + + # Number types... class TestIntegerField(FieldValues): -- cgit v1.2.3 From 35f6a8246299d31ecce4f791f9527bf34cebe6e2 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 23 Jan 2015 16:27:23 +0000 Subject: Added DictField and support for HStoreField. --- tests/test_fields.py | 51 ++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/test_fields.py b/tests/test_fields.py index a46cc205..6744cf64 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -1047,7 +1047,7 @@ class TestValidImageField(FieldValues): class TestListField(FieldValues): """ - Values for `ListField`. + Values for `ListField` with IntegerField as child. """ valid_inputs = [ ([1, 2, 3], [1, 2, 3]), @@ -1064,6 +1064,55 @@ class TestListField(FieldValues): field = serializers.ListField(child=serializers.IntegerField()) +class TestUnvalidatedListField(FieldValues): + """ + Values for `ListField` with no `child` argument. + """ + valid_inputs = [ + ([1, '2', True, [4, 5, 6]], [1, '2', True, [4, 5, 6]]), + ] + invalid_inputs = [ + ('not a list', ['Expected a list of items but got type `str`']), + ] + outputs = [ + ([1, '2', True, [4, 5, 6]], [1, '2', True, [4, 5, 6]]), + ] + field = serializers.ListField() + + +class TestDictField(FieldValues): + """ + Values for `ListField` with CharField as child. + """ + valid_inputs = [ + ({'a': 1, 'b': '2', 3: 3}, {'a': '1', 'b': '2', '3': '3'}), + ] + invalid_inputs = [ + ({'a': 1, 'b': None}, ['This field may not be null.']), + ('not a dict', ['Expected a dictionary of items but got type `str`']), + ] + outputs = [ + ({'a': 1, 'b': '2', 3: 3}, {'a': '1', 'b': '2', '3': '3'}), + ] + field = serializers.DictField(child=serializers.CharField()) + + +class TestUnvalidatedDictField(FieldValues): + """ + Values for `ListField` with no `child` argument. + """ + valid_inputs = [ + ({'a': 1, 'b': [4, 5, 6], 1: 123}, {'a': 1, 'b': [4, 5, 6], '1': 123}), + ] + invalid_inputs = [ + ('not a dict', ['Expected a dictionary of items but got type `str`']), + ] + outputs = [ + ({'a': 1, 'b': [4, 5, 6]}, {'a': 1, 'b': [4, 5, 6]}), + ] + field = serializers.DictField() + + # Tests for FieldField. # --------------------- -- cgit v1.2.3 From b09ef28959fe63351f0dd24564b7d2d344b44fa3 Mon Sep 17 00:00:00 2001 From: Brandon Cazander Date: Sat, 24 Jan 2015 01:37:23 -0800 Subject: Add failing test for request.version AttributeError in BrowsableAPI. --- tests/browsable_api/auth_urls.py | 9 ++++++++- tests/browsable_api/test_browsable_api.py | 10 ++++++++++ tests/browsable_api/views.py | 27 +++++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/browsable_api/auth_urls.py b/tests/browsable_api/auth_urls.py index bce7dcf9..098a99ac 100644 --- a/tests/browsable_api/auth_urls.py +++ b/tests/browsable_api/auth_urls.py @@ -1,10 +1,17 @@ from __future__ import unicode_literals from django.conf.urls import patterns, url, include +from rest_framework import routers -from .views import MockView +from .views import MockView, FooViewSet, BarViewSet + +router = routers.SimpleRouter() +router.register(r'foo', FooViewSet) +router.register(r'bar', BarViewSet) urlpatterns = patterns( '', (r'^$', MockView.as_view()), + url(r'^', include(router.urls)), + url(r'^bar/(?P\d+)/$', BarViewSet, name='bar-list'), url(r'^auth/', include('rest_framework.urls', namespace='rest_framework')), ) diff --git a/tests/browsable_api/test_browsable_api.py b/tests/browsable_api/test_browsable_api.py index 5f264783..31907f84 100644 --- a/tests/browsable_api/test_browsable_api.py +++ b/tests/browsable_api/test_browsable_api.py @@ -3,6 +3,7 @@ from django.contrib.auth.models import User from django.test import TestCase from rest_framework.test import APIClient +from .models import Foo, Bar class DropdownWithAuthTests(TestCase): @@ -16,6 +17,8 @@ class DropdownWithAuthTests(TestCase): self.email = 'lennon@thebeatles.com' self.password = 'password' self.user = User.objects.create_user(self.username, self.email, self.password) + foo = Foo.objects.create(name='Foo') + Bar.objects.create(foo=foo) def tearDown(self): self.client.logout() @@ -25,6 +28,13 @@ class DropdownWithAuthTests(TestCase): response = self.client.get('/') self.assertContains(response, 'john') + def test_bug_2455_clone_request(self): + self.client.login(username=self.username, password=self.password) + json_response = self.client.get('/foo/1/?format=json') + self.assertEqual(json_response.status_code, 200) + browsable_api_response = self.client.get('/foo/1/') + self.assertEqual(browsable_api_response.status_code, 200) + def test_logout_shown_when_logged_in(self): self.client.login(username=self.username, password=self.password) response = self.client.get('/') diff --git a/tests/browsable_api/views.py b/tests/browsable_api/views.py index 000f4e80..f06f7c40 100644 --- a/tests/browsable_api/views.py +++ b/tests/browsable_api/views.py @@ -1,9 +1,14 @@ from __future__ import unicode_literals from rest_framework.views import APIView +from rest_framework.viewsets import ModelViewSet from rest_framework import authentication from rest_framework import renderers from rest_framework.response import Response +from rest_framework.renderers import BrowsableAPIRenderer, JSONRenderer +from rest_framework.versioning import NamespaceVersioning +from .models import Foo, Bar +from .serializers import FooSerializer, BarSerializer class MockView(APIView): @@ -13,3 +18,25 @@ class MockView(APIView): def get(self, request): return Response({'a': 1, 'b': 2, 'c': 3}) + + +class SerializerClassMixin(object): + def get_serializer_class(self): + # Get base name of serializer + self.request.version + return self.serializer_class + + +class FooViewSet(SerializerClassMixin, ModelViewSet): + versioning_class = NamespaceVersioning + model = Foo + queryset = Foo.objects.all() + serializer_class = FooSerializer + renderer_classes = (BrowsableAPIRenderer, JSONRenderer) + + +class BarViewSet(SerializerClassMixin, ModelViewSet): + model = Bar + queryset = Bar.objects.all() + serializer_class = BarSerializer + renderer_classes = (BrowsableAPIRenderer, ) -- cgit v1.2.3 From 0ee2edc0a14c4d14b8aa6e4b63ccbd0c2cc78024 Mon Sep 17 00:00:00 2001 From: Brandon Cazander Date: Sat, 24 Jan 2015 01:44:09 -0800 Subject: Add missed files for test. --- tests/browsable_api/models.py | 9 +++++++++ tests/browsable_api/serializers.py | 14 ++++++++++++++ 2 files changed, 23 insertions(+) create mode 100644 tests/browsable_api/models.py create mode 100644 tests/browsable_api/serializers.py (limited to 'tests') diff --git a/tests/browsable_api/models.py b/tests/browsable_api/models.py new file mode 100644 index 00000000..05c6c23b --- /dev/null +++ b/tests/browsable_api/models.py @@ -0,0 +1,9 @@ +from django.db import models + + +class Foo(models.Model): + name = models.CharField(max_length=30) + + +class Bar(models.Model): + foo = models.ForeignKey("Foo", editable=False) diff --git a/tests/browsable_api/serializers.py b/tests/browsable_api/serializers.py new file mode 100644 index 00000000..e8364540 --- /dev/null +++ b/tests/browsable_api/serializers.py @@ -0,0 +1,14 @@ +from .models import Foo, Bar +from rest_framework.serializers import HyperlinkedModelSerializer, HyperlinkedIdentityField + + +class FooSerializer(HyperlinkedModelSerializer): + bar = HyperlinkedIdentityField(view_name='bar-list') + + class Meta: + model = Foo + + +class BarSerializer(HyperlinkedModelSerializer): + class Meta: + model = Bar -- cgit v1.2.3 From 6c083b12a1162bf8e0f51e6c52ff13a1bd621cf2 Mon Sep 17 00:00:00 2001 From: Brandon Cazander Date: Sat, 24 Jan 2015 11:00:36 -0800 Subject: Streamline test for #2455 --- tests/browsable_api/auth_urls.py | 8 +------- tests/browsable_api/models.py | 9 --------- tests/browsable_api/serializers.py | 14 -------------- tests/browsable_api/test_browsable_api.py | 10 ---------- tests/browsable_api/views.py | 27 --------------------------- tests/test_metadata.py | 15 +++++++++++++++ 6 files changed, 16 insertions(+), 67 deletions(-) delete mode 100644 tests/browsable_api/models.py delete mode 100644 tests/browsable_api/serializers.py (limited to 'tests') diff --git a/tests/browsable_api/auth_urls.py b/tests/browsable_api/auth_urls.py index 098a99ac..97bc1036 100644 --- a/tests/browsable_api/auth_urls.py +++ b/tests/browsable_api/auth_urls.py @@ -1,17 +1,11 @@ from __future__ import unicode_literals from django.conf.urls import patterns, url, include -from rest_framework import routers -from .views import MockView, FooViewSet, BarViewSet +from .views import MockView -router = routers.SimpleRouter() -router.register(r'foo', FooViewSet) -router.register(r'bar', BarViewSet) urlpatterns = patterns( '', (r'^$', MockView.as_view()), - url(r'^', include(router.urls)), - url(r'^bar/(?P\d+)/$', BarViewSet, name='bar-list'), url(r'^auth/', include('rest_framework.urls', namespace='rest_framework')), ) diff --git a/tests/browsable_api/models.py b/tests/browsable_api/models.py deleted file mode 100644 index 05c6c23b..00000000 --- a/tests/browsable_api/models.py +++ /dev/null @@ -1,9 +0,0 @@ -from django.db import models - - -class Foo(models.Model): - name = models.CharField(max_length=30) - - -class Bar(models.Model): - foo = models.ForeignKey("Foo", editable=False) diff --git a/tests/browsable_api/serializers.py b/tests/browsable_api/serializers.py deleted file mode 100644 index e8364540..00000000 --- a/tests/browsable_api/serializers.py +++ /dev/null @@ -1,14 +0,0 @@ -from .models import Foo, Bar -from rest_framework.serializers import HyperlinkedModelSerializer, HyperlinkedIdentityField - - -class FooSerializer(HyperlinkedModelSerializer): - bar = HyperlinkedIdentityField(view_name='bar-list') - - class Meta: - model = Foo - - -class BarSerializer(HyperlinkedModelSerializer): - class Meta: - model = Bar diff --git a/tests/browsable_api/test_browsable_api.py b/tests/browsable_api/test_browsable_api.py index 31907f84..5f264783 100644 --- a/tests/browsable_api/test_browsable_api.py +++ b/tests/browsable_api/test_browsable_api.py @@ -3,7 +3,6 @@ from django.contrib.auth.models import User from django.test import TestCase from rest_framework.test import APIClient -from .models import Foo, Bar class DropdownWithAuthTests(TestCase): @@ -17,8 +16,6 @@ class DropdownWithAuthTests(TestCase): self.email = 'lennon@thebeatles.com' self.password = 'password' self.user = User.objects.create_user(self.username, self.email, self.password) - foo = Foo.objects.create(name='Foo') - Bar.objects.create(foo=foo) def tearDown(self): self.client.logout() @@ -28,13 +25,6 @@ class DropdownWithAuthTests(TestCase): response = self.client.get('/') self.assertContains(response, 'john') - def test_bug_2455_clone_request(self): - self.client.login(username=self.username, password=self.password) - json_response = self.client.get('/foo/1/?format=json') - self.assertEqual(json_response.status_code, 200) - browsable_api_response = self.client.get('/foo/1/') - self.assertEqual(browsable_api_response.status_code, 200) - def test_logout_shown_when_logged_in(self): self.client.login(username=self.username, password=self.password) response = self.client.get('/') diff --git a/tests/browsable_api/views.py b/tests/browsable_api/views.py index f06f7c40..000f4e80 100644 --- a/tests/browsable_api/views.py +++ b/tests/browsable_api/views.py @@ -1,14 +1,9 @@ from __future__ import unicode_literals from rest_framework.views import APIView -from rest_framework.viewsets import ModelViewSet from rest_framework import authentication from rest_framework import renderers from rest_framework.response import Response -from rest_framework.renderers import BrowsableAPIRenderer, JSONRenderer -from rest_framework.versioning import NamespaceVersioning -from .models import Foo, Bar -from .serializers import FooSerializer, BarSerializer class MockView(APIView): @@ -18,25 +13,3 @@ class MockView(APIView): def get(self, request): return Response({'a': 1, 'b': 2, 'c': 3}) - - -class SerializerClassMixin(object): - def get_serializer_class(self): - # Get base name of serializer - self.request.version - return self.serializer_class - - -class FooViewSet(SerializerClassMixin, ModelViewSet): - versioning_class = NamespaceVersioning - model = Foo - queryset = Foo.objects.all() - serializer_class = FooSerializer - renderer_classes = (BrowsableAPIRenderer, JSONRenderer) - - -class BarViewSet(SerializerClassMixin, ModelViewSet): - model = Bar - queryset = Bar.objects.all() - serializer_class = BarSerializer - renderer_classes = (BrowsableAPIRenderer, ) diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 972a896a..bdc84edf 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -1,6 +1,7 @@ from __future__ import unicode_literals from rest_framework import exceptions, serializers, status, views from rest_framework.request import Request +from rest_framework.renderers import BrowsableAPIRenderer from rest_framework.test import APIRequestFactory request = Request(APIRequestFactory().options('/')) @@ -168,3 +169,17 @@ class TestMetadata: response = view(request=request) assert response.status_code == status.HTTP_200_OK assert list(response.data['actions'].keys()) == ['POST'] + + def test_bug_2455_clone_request(self): + class ExampleView(views.APIView): + renderer_classes = (BrowsableAPIRenderer,) + + def post(self, request): + pass + + def get_serializer(self): + assert hasattr(self.request, 'version') + return serializers.Serializer() + + view = ExampleView.as_view() + view(request=request) -- cgit v1.2.3 From 65bca59ea548dc5e2222be06ca20b3d3fa151cf0 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 27 Jan 2015 13:51:30 +0000 Subject: Reload api_settings when using Django's 'override_settings' --- tests/test_filters.py | 16 +++++++++++++--- tests/utils.py | 25 ------------------------- 2 files changed, 13 insertions(+), 28 deletions(-) (limited to 'tests') diff --git a/tests/test_filters.py b/tests/test_filters.py index 5b1b6ca5..355f02ce 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -5,13 +5,15 @@ from django.db import models from django.conf.urls import patterns, url from django.core.urlresolvers import reverse from django.test import TestCase +from django.test.utils import override_settings from django.utils import unittest from django.utils.dateparse import parse_date +from django.utils.six.moves import reload_module from rest_framework import generics, serializers, status, filters from rest_framework.compat import django_filters from rest_framework.test import APIRequestFactory from .models import BaseFilterableItem, FilterableItem, BasicModel -from .utils import temporary_setting + factory = APIRequestFactory() @@ -404,7 +406,9 @@ class SearchFilterTests(TestCase): ) def test_search_with_nonstandard_search_param(self): - with temporary_setting('SEARCH_PARAM', 'query', module=filters): + with override_settings(REST_FRAMEWORK={'SEARCH_PARAM': 'query'}): + reload_module(filters) + class SearchListView(generics.ListAPIView): queryset = SearchFilterModel.objects.all() serializer_class = SearchFilterSerializer @@ -422,6 +426,8 @@ class SearchFilterTests(TestCase): ] ) + reload_module(filters) + class OrderingFilterModel(models.Model): title = models.CharField(max_length=20) @@ -642,7 +648,9 @@ class OrderingFilterTests(TestCase): ) def test_ordering_with_nonstandard_ordering_param(self): - with temporary_setting('ORDERING_PARAM', 'order', filters): + with override_settings(REST_FRAMEWORK={'ORDERING_PARAM': 'order'}): + reload_module(filters) + class OrderingListView(generics.ListAPIView): queryset = OrderingFilterModel.objects.all() serializer_class = OrderingFilterSerializer @@ -662,6 +670,8 @@ class OrderingFilterTests(TestCase): ] ) + reload_module(filters) + class SensitiveOrderingFilterModel(models.Model): username = models.CharField(max_length=20) diff --git a/tests/utils.py b/tests/utils.py index 5e902ba9..5b2d7586 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,30 +1,5 @@ -from contextlib import contextmanager from django.core.exceptions import ObjectDoesNotExist from django.core.urlresolvers import NoReverseMatch -from django.utils import six -from rest_framework.settings import api_settings - - -@contextmanager -def temporary_setting(setting, value, module=None): - """ - Temporarily change value of setting for test. - - Optionally reload given module, useful when module uses value of setting on - import. - """ - original_value = getattr(api_settings, setting) - setattr(api_settings, setting, value) - - if module is not None: - six.moves.reload_module(module) - - yield - - setattr(api_settings, setting, original_value) - - if module is not None: - six.moves.reload_module(module) class MockObject(object): -- cgit v1.2.3 From 8c3f82fb18a58b8e0983612ef3cc35b3c3950b66 Mon Sep 17 00:00:00 2001 From: Susan Dreher Date: Tue, 27 Jan 2015 16:18:51 -0500 Subject: :bug: ManyRelatedField get_value clearing field on partial update A PATCH to a serializer's non-related CharField was clearing an ancillary StringRelatedField(many=True) field. The issue appears to be in the ManyRelatedField's get_value method, which was returning a [] instead of empty when the request data was a MultiDict. This fix mirrors code in fields.py, class Field, get_value, Ln. 272, which explicitly returns empty on a partial update. Tests added to demonstrate the issue. --- tests/test_relations.py | 40 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 38 insertions(+), 2 deletions(-) (limited to 'tests') diff --git a/tests/test_relations.py b/tests/test_relations.py index 62353dc2..143e835c 100644 --- a/tests/test_relations.py +++ b/tests/test_relations.py @@ -1,8 +1,13 @@ -from .utils import mock_reverse, fail_reverse, BadType, MockObject, MockQueryset +import pytest + from django.core.exceptions import ImproperlyConfigured +from django.utils.datastructures import MultiValueDict + from rest_framework import serializers +from rest_framework.fields import empty from rest_framework.test import APISimpleTestCase -import pytest + +from .utils import mock_reverse, fail_reverse, BadType, MockObject, MockQueryset class TestStringRelatedField(APISimpleTestCase): @@ -134,3 +139,34 @@ class TestSlugRelatedField(APISimpleTestCase): def test_representation(self): representation = self.field.to_representation(self.instance) assert representation == self.instance.name + + +class TestManyRelatedField(APISimpleTestCase): + def setUp(self): + self.instance = MockObject(pk=1, name='foo') + self.field = serializers.StringRelatedField(many=True) + self.field.field_name = 'foo' + + def test_get_value_regular_dictionary_full(self): + assert 'bar' == self.field.get_value({'foo': 'bar'}) + assert empty == self.field.get_value({'baz': 'bar'}) + + def test_get_value_regular_dictionary_partial(self): + setattr(self.field.root, 'partial', True) + assert 'bar' == self.field.get_value({'foo': 'bar'}) + assert empty == self.field.get_value({'baz': 'bar'}) + + def test_get_value_multi_dictionary_full(self): + mvd = MultiValueDict({'foo': ['bar1', 'bar2']}) + assert ['bar1', 'bar2'] == self.field.get_value(mvd) + + mvd = MultiValueDict({'baz': ['bar1', 'bar2']}) + assert [] == self.field.get_value(mvd) + + def test_get_value_multi_dictionary_partial(self): + setattr(self.field.root, 'partial', True) + mvd = MultiValueDict({'foo': ['bar1', 'bar2']}) + assert ['bar1', 'bar2'] == self.field.get_value(mvd) + + mvd = MultiValueDict({'baz': ['bar1', 'bar2']}) + assert empty == self.field.get_value(mvd) -- cgit v1.2.3 From 1714ceae9f468bc1479f0d7a32b0bf26ae9cf15f Mon Sep 17 00:00:00 2001 From: Susan Dreher Date: Tue, 27 Jan 2015 16:31:25 -0500 Subject: reorganize imports --- tests/test_relations.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/test_relations.py b/tests/test_relations.py index 143e835c..67f49c6b 100644 --- a/tests/test_relations.py +++ b/tests/test_relations.py @@ -1,13 +1,13 @@ -import pytest +from .utils import mock_reverse, fail_reverse, BadType, MockObject, MockQueryset from django.core.exceptions import ImproperlyConfigured from django.utils.datastructures import MultiValueDict from rest_framework import serializers from rest_framework.fields import empty -from rest_framework.test import APISimpleTestCase -from .utils import mock_reverse, fail_reverse, BadType, MockObject, MockQueryset +from rest_framework.test import APISimpleTestCase +import pytest class TestStringRelatedField(APISimpleTestCase): -- cgit v1.2.3 From e7da266a866adddd5c37453fab33812ee412752b Mon Sep 17 00:00:00 2001 From: Susan Dreher Date: Tue, 27 Jan 2015 16:32:15 -0500 Subject: reorganize imports --- tests/test_relations.py | 3 --- 1 file changed, 3 deletions(-) (limited to 'tests') diff --git a/tests/test_relations.py b/tests/test_relations.py index 67f49c6b..d478d855 100644 --- a/tests/test_relations.py +++ b/tests/test_relations.py @@ -1,11 +1,8 @@ from .utils import mock_reverse, fail_reverse, BadType, MockObject, MockQueryset - from django.core.exceptions import ImproperlyConfigured from django.utils.datastructures import MultiValueDict - from rest_framework import serializers from rest_framework.fields import empty - from rest_framework.test import APISimpleTestCase import pytest -- cgit v1.2.3 From ac87490b91e3405d497da360afed10842a73dfd0 Mon Sep 17 00:00:00 2001 From: Brandon Cazander Date: Tue, 27 Jan 2015 17:10:17 -0800 Subject: Clone the versioning_scheme when necessary. Fixes #2477 --- tests/test_metadata.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/test_metadata.py b/tests/test_metadata.py index bdc84edf..5031c0f3 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -1,5 +1,5 @@ from __future__ import unicode_literals -from rest_framework import exceptions, serializers, status, views +from rest_framework import exceptions, serializers, status, views, versioning from rest_framework.request import Request from rest_framework.renderers import BrowsableAPIRenderer from rest_framework.test import APIRequestFactory @@ -183,3 +183,18 @@ class TestMetadata: view = ExampleView.as_view() view(request=request) + + def test_bug_2477_clone_request(self): + class ExampleView(views.APIView): + renderer_classes = (BrowsableAPIRenderer,) + + def post(self, request): + pass + + def get_serializer(self): + assert hasattr(self.request, 'versioning_scheme') + return serializers.Serializer() + + scheme = versioning.QueryParameterVersioning + view = ExampleView.as_view(versioning_class=scheme) + view(request=request) -- cgit v1.2.3 From ba7dca893cd55a1d5ee928c4b10878c92c44c4f5 Mon Sep 17 00:00:00 2001 From: Tymur Maryokhin Date: Thu, 29 Jan 2015 17:28:03 +0100 Subject: Removed router check for deprecated '.model' attribute --- tests/test_routers.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'tests') diff --git a/tests/test_routers.py b/tests/test_routers.py index 86113f5d..948c69bb 100644 --- a/tests/test_routers.py +++ b/tests/test_routers.py @@ -180,7 +180,7 @@ class TestLookupValueRegex(TestCase): class TestTrailingSlashIncluded(TestCase): def setUp(self): class NoteViewSet(viewsets.ModelViewSet): - model = RouterTestModel + queryset = RouterTestModel.objects.all() self.router = SimpleRouter() self.router.register(r'notes', NoteViewSet) @@ -195,7 +195,7 @@ class TestTrailingSlashIncluded(TestCase): class TestTrailingSlashRemoved(TestCase): def setUp(self): class NoteViewSet(viewsets.ModelViewSet): - model = RouterTestModel + queryset = RouterTestModel.objects.all() self.router = SimpleRouter(trailing_slash=False) self.router.register(r'notes', NoteViewSet) @@ -210,7 +210,8 @@ class TestTrailingSlashRemoved(TestCase): class TestNameableRoot(TestCase): def setUp(self): class NoteViewSet(viewsets.ModelViewSet): - model = RouterTestModel + queryset = RouterTestModel.objects.all() + self.router = DefaultRouter() self.router.root_view_name = 'nameable-root' self.router.register(r'notes', NoteViewSet) -- cgit v1.2.3 From 2cc4cb24652366c6622af08370a0c04b429aa4b8 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sat, 31 Jan 2015 08:53:40 +0000 Subject: Fix error text in test. --- tests/test_generics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/test_generics.py b/tests/test_generics.py index fba8718f..88e792ce 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -483,7 +483,7 @@ class TestFilterBackendAppliedToViews(TestCase): request = factory.get('/1') response = instance_view(request, pk=1).render() self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - self.assertEqual(response.data, {'detail': 'Not found'}) + self.assertEqual(response.data, {'detail': 'Not found.'}) def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self): """ -- cgit v1.2.3 From 77d061d234e03004f34058028707ecddfc730fae Mon Sep 17 00:00:00 2001 From: Brandon Cazander Date: Wed, 28 Jan 2015 17:08:34 -0800 Subject: Provide rest_framework.resolve. Fixes #2489 --- tests/test_relations.py | 50 +++++++++++++++++++++++++++++++++++++++++++++++-- tests/urls.py | 4 ++-- 2 files changed, 50 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/test_relations.py b/tests/test_relations.py index fbe176e2..b82a1f2a 100644 --- a/tests/test_relations.py +++ b/tests/test_relations.py @@ -1,11 +1,28 @@ from .utils import mock_reverse, fail_reverse, BadType, MockObject, MockQueryset -from django.core.exceptions import ImproperlyConfigured +from django.conf.urls import patterns, url, include +from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist from django.utils.datastructures import MultiValueDict from rest_framework import serializers from rest_framework.fields import empty -from rest_framework.test import APISimpleTestCase +from rest_framework.test import APISimpleTestCase, APIRequestFactory +from rest_framework.versioning import NamespaceVersioning import pytest +factory = APIRequestFactory() +request = factory.get('/') # Just to ensure we have a request in the serializer context + +dummy_view = lambda request, pk: None + +included_patterns = [ + url(r'^example/(?P\d+)/$', dummy_view, name='example-detail') +] + +urlpatterns = patterns( + '', + url(r'^v1/', include(included_patterns, namespace='v1')), + url(r'^example/(?P\d+)/$', dummy_view, name='example-detail') +) + class TestStringRelatedField(APISimpleTestCase): def setUp(self): @@ -48,6 +65,35 @@ class TestPrimaryKeyRelatedField(APISimpleTestCase): assert representation == self.instance.pk +class TestHyperlinkedRelatedField(APISimpleTestCase): + urls = 'tests.test_relations' + + def setUp(self): + class HyperlinkedMockQueryset(MockQueryset): + def get(self, **lookup): + for item in self.items: + if item.pk == int(lookup.get('pk', -1)): + return item + raise ObjectDoesNotExist() + + self.queryset = HyperlinkedMockQueryset([ + MockObject(pk=1, name='foo'), + MockObject(pk=2, name='bar'), + MockObject(pk=3, name='baz') + ]) + self.field = serializers.HyperlinkedRelatedField( + view_name='example-detail', + queryset=self.queryset + ) + request = factory.post('/') + request.versioning_scheme = NamespaceVersioning() + self.field._context = {'request': request} + + def test_bug_2489(self): + self.field.to_internal_value('/example/3/') + self.field.to_internal_value('/v1/example/3/') + + class TestHyperlinkedIdentityField(APISimpleTestCase): def setUp(self): self.instance = MockObject(pk=1, name='foo') diff --git a/tests/urls.py b/tests/urls.py index 41f527df..742e361d 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -1,6 +1,6 @@ """ Blank URLConf just to keep the test suite happy """ -from django.conf.urls import patterns +from tests import test_relations -urlpatterns = patterns('') +urlpatterns = test_relations.urlpatterns -- cgit v1.2.3 From f3067a7fabdd0edb5bc5f48cfdadd2850866c189 Mon Sep 17 00:00:00 2001 From: Brandon Cazander Date: Mon, 2 Feb 2015 20:41:06 -0800 Subject: Remove unnecessary APIRequestFactory get from tests. --- tests/test_relations.py | 1 - 1 file changed, 1 deletion(-) (limited to 'tests') diff --git a/tests/test_relations.py b/tests/test_relations.py index b82a1f2a..ff377d38 100644 --- a/tests/test_relations.py +++ b/tests/test_relations.py @@ -9,7 +9,6 @@ from rest_framework.versioning import NamespaceVersioning import pytest factory = APIRequestFactory() -request = factory.get('/') # Just to ensure we have a request in the serializer context dummy_view = lambda request, pk: None -- cgit v1.2.3 From 030f01afdbcd4018a288250ef1f4c12de28e63bb Mon Sep 17 00:00:00 2001 From: Brandon Cazander Date: Tue, 3 Feb 2015 02:14:38 -0800 Subject: Reorganize tests. --- tests/test_relations.py | 49 ++---------------------------------------------- tests/test_versioning.py | 41 ++++++++++++++++++++++++++++++++++++++-- tests/urls.py | 4 ++-- 3 files changed, 43 insertions(+), 51 deletions(-) (limited to 'tests') diff --git a/tests/test_relations.py b/tests/test_relations.py index ff377d38..fbe176e2 100644 --- a/tests/test_relations.py +++ b/tests/test_relations.py @@ -1,27 +1,11 @@ from .utils import mock_reverse, fail_reverse, BadType, MockObject, MockQueryset -from django.conf.urls import patterns, url, include -from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist +from django.core.exceptions import ImproperlyConfigured from django.utils.datastructures import MultiValueDict from rest_framework import serializers from rest_framework.fields import empty -from rest_framework.test import APISimpleTestCase, APIRequestFactory -from rest_framework.versioning import NamespaceVersioning +from rest_framework.test import APISimpleTestCase import pytest -factory = APIRequestFactory() - -dummy_view = lambda request, pk: None - -included_patterns = [ - url(r'^example/(?P\d+)/$', dummy_view, name='example-detail') -] - -urlpatterns = patterns( - '', - url(r'^v1/', include(included_patterns, namespace='v1')), - url(r'^example/(?P\d+)/$', dummy_view, name='example-detail') -) - class TestStringRelatedField(APISimpleTestCase): def setUp(self): @@ -64,35 +48,6 @@ class TestPrimaryKeyRelatedField(APISimpleTestCase): assert representation == self.instance.pk -class TestHyperlinkedRelatedField(APISimpleTestCase): - urls = 'tests.test_relations' - - def setUp(self): - class HyperlinkedMockQueryset(MockQueryset): - def get(self, **lookup): - for item in self.items: - if item.pk == int(lookup.get('pk', -1)): - return item - raise ObjectDoesNotExist() - - self.queryset = HyperlinkedMockQueryset([ - MockObject(pk=1, name='foo'), - MockObject(pk=2, name='bar'), - MockObject(pk=3, name='baz') - ]) - self.field = serializers.HyperlinkedRelatedField( - view_name='example-detail', - queryset=self.queryset - ) - request = factory.post('/') - request.versioning_scheme = NamespaceVersioning() - self.field._context = {'request': request} - - def test_bug_2489(self): - self.field.to_internal_value('/example/3/') - self.field.to_internal_value('/v1/example/3/') - - class TestHyperlinkedIdentityField(APISimpleTestCase): def setUp(self): self.instance = MockObject(pk=1, name='foo') diff --git a/tests/test_versioning.py b/tests/test_versioning.py index c44f727d..e7c8485e 100644 --- a/tests/test_versioning.py +++ b/tests/test_versioning.py @@ -1,9 +1,13 @@ +from .utils import MockObject, MockQueryset from django.conf.urls import include, url +from django.core.exceptions import ObjectDoesNotExist +from rest_framework import serializers from rest_framework import status, versioning from rest_framework.decorators import APIView from rest_framework.response import Response from rest_framework.reverse import reverse -from rest_framework.test import APIRequestFactory, APITestCase +from rest_framework.test import APIRequestFactory, APITestCase, APISimpleTestCase +from rest_framework.versioning import NamespaceVersioning class RequestVersionView(APIView): @@ -29,15 +33,18 @@ class RequestInvalidVersionView(APIView): factory = APIRequestFactory() mock_view = lambda request: None +dummy_view = lambda request, pk: None included_patterns = [ url(r'^namespaced/$', mock_view, name='another'), + url(r'^example/(?P\d+)/$', dummy_view, name='example-detail') ] urlpatterns = [ url(r'^v1/', include(included_patterns, namespace='v1')), url(r'^another/$', mock_view, name='another'), - url(r'^(?P[^/]+)/another/$', mock_view, name='another') + url(r'^(?P[^/]+)/another/$', mock_view, name='another'), + url(r'^example/(?P\d+)/$', dummy_view, name='example-detail') ] @@ -221,3 +228,33 @@ class TestInvalidVersion: request.resolver_match = FakeResolverMatch response = view(request, version='v3') assert response.status_code == status.HTTP_404_NOT_FOUND + + +class TestHyperlinkedRelatedField(APISimpleTestCase): + urls = 'tests.test_versioning' + + def setUp(self): + + class HyperlinkedMockQueryset(MockQueryset): + def get(self, **lookup): + for item in self.items: + if item.pk == int(lookup.get('pk', -1)): + return item + raise ObjectDoesNotExist() + + self.queryset = HyperlinkedMockQueryset([ + MockObject(pk=1, name='foo'), + MockObject(pk=2, name='bar'), + MockObject(pk=3, name='baz') + ]) + self.field = serializers.HyperlinkedRelatedField( + view_name='example-detail', + queryset=self.queryset + ) + request = factory.post('/', urlconf='tests.test_versioning') + request.versioning_scheme = NamespaceVersioning() + self.field._context = {'request': request} + + def test_bug_2489(self): + self.field.to_internal_value('/example/3/') + self.field.to_internal_value('/v1/example/3/') diff --git a/tests/urls.py b/tests/urls.py index 742e361d..41f527df 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -1,6 +1,6 @@ """ Blank URLConf just to keep the test suite happy """ -from tests import test_relations +from django.conf.urls import patterns -urlpatterns = test_relations.urlpatterns +urlpatterns = patterns('') -- cgit v1.2.3 From e1c45133126e0c47b8470b4cf7a43c6a7f4fca43 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 5 Feb 2015 00:58:09 +0000 Subject: Fix NamespaceVersioning with hyperlinked serializer fields --- tests/test_relations_hyperlink.py | 7 +++--- tests/test_versioning.py | 50 +++++++++++++++++++++++---------------- tests/utils.py | 24 +++++++++++++++++++ 3 files changed, 56 insertions(+), 25 deletions(-) (limited to 'tests') diff --git a/tests/test_relations_hyperlink.py b/tests/test_relations_hyperlink.py index f1b882ed..aede61d2 100644 --- a/tests/test_relations_hyperlink.py +++ b/tests/test_relations_hyperlink.py @@ -1,5 +1,5 @@ from __future__ import unicode_literals -from django.conf.urls import patterns, url +from django.conf.urls import url from django.test import TestCase from rest_framework import serializers from rest_framework.test import APIRequestFactory @@ -14,8 +14,7 @@ request = factory.get('/') # Just to ensure we have a request in the serializer dummy_view = lambda request, pk: None -urlpatterns = patterns( - '', +urlpatterns = [ url(r'^dummyurl/(?P[0-9]+)/$', dummy_view, name='dummy-url'), url(r'^manytomanysource/(?P[0-9]+)/$', dummy_view, name='manytomanysource-detail'), url(r'^manytomanytarget/(?P[0-9]+)/$', dummy_view, name='manytomanytarget-detail'), @@ -24,7 +23,7 @@ urlpatterns = patterns( url(r'^nullableforeignkeysource/(?P[0-9]+)/$', dummy_view, name='nullableforeignkeysource-detail'), url(r'^onetoonetarget/(?P[0-9]+)/$', dummy_view, name='onetoonetarget-detail'), url(r'^nullableonetoonesource/(?P[0-9]+)/$', dummy_view, name='nullableonetoonesource-detail'), -) +] # ManyToMany diff --git a/tests/test_versioning.py b/tests/test_versioning.py index e7c8485e..cdd10065 100644 --- a/tests/test_versioning.py +++ b/tests/test_versioning.py @@ -1,4 +1,4 @@ -from .utils import MockObject, MockQueryset +from .utils import MockObject, MockQueryset, UsingURLPatterns from django.conf.urls import include, url from django.core.exceptions import ObjectDoesNotExist from rest_framework import serializers @@ -6,8 +6,9 @@ from rest_framework import status, versioning from rest_framework.decorators import APIView from rest_framework.response import Response from rest_framework.reverse import reverse -from rest_framework.test import APIRequestFactory, APITestCase, APISimpleTestCase +from rest_framework.test import APIRequestFactory, APITestCase from rest_framework.versioning import NamespaceVersioning +import pytest class RequestVersionView(APIView): @@ -35,18 +36,6 @@ factory = APIRequestFactory() mock_view = lambda request: None dummy_view = lambda request, pk: None -included_patterns = [ - url(r'^namespaced/$', mock_view, name='another'), - url(r'^example/(?P\d+)/$', dummy_view, name='example-detail') -] - -urlpatterns = [ - url(r'^v1/', include(included_patterns, namespace='v1')), - url(r'^another/$', mock_view, name='another'), - url(r'^(?P[^/]+)/another/$', mock_view, name='another'), - url(r'^example/(?P\d+)/$', dummy_view, name='example-detail') -] - class TestRequestVersion: def test_unversioned(self): @@ -121,8 +110,17 @@ class TestRequestVersion: assert response.data == {'version': None} -class TestURLReversing(APITestCase): - urls = 'tests.test_versioning' +class TestURLReversing(UsingURLPatterns, APITestCase): + included = [ + url(r'^namespaced/$', mock_view, name='another'), + url(r'^example/(?P\d+)/$', dummy_view, name='example-detail') + ] + + urlpatterns = [ + url(r'^v1/', include(included, namespace='v1')), + url(r'^another/$', mock_view, name='another'), + url(r'^(?P[^/]+)/another/$', mock_view, name='another'), + ] def test_reverse_unversioned(self): view = ReverseView.as_view() @@ -230,10 +228,18 @@ class TestInvalidVersion: assert response.status_code == status.HTTP_404_NOT_FOUND -class TestHyperlinkedRelatedField(APISimpleTestCase): - urls = 'tests.test_versioning' +class TestHyperlinkedRelatedField(UsingURLPatterns, APITestCase): + included = [ + url(r'^namespaced/(?P\d+)/$', mock_view, name='namespaced'), + ] + + urlpatterns = [ + url(r'^v1/', include(included, namespace='v1')), + url(r'^v2/', include(included, namespace='v2')) + ] def setUp(self): + super(TestHyperlinkedRelatedField, self).setUp() class HyperlinkedMockQueryset(MockQueryset): def get(self, **lookup): @@ -248,13 +254,15 @@ class TestHyperlinkedRelatedField(APISimpleTestCase): MockObject(pk=3, name='baz') ]) self.field = serializers.HyperlinkedRelatedField( - view_name='example-detail', + view_name='namespaced', queryset=self.queryset ) request = factory.post('/', urlconf='tests.test_versioning') request.versioning_scheme = NamespaceVersioning() + request.version = 'v1' self.field._context = {'request': request} def test_bug_2489(self): - self.field.to_internal_value('/example/3/') - self.field.to_internal_value('/v1/example/3/') + self.field.to_internal_value('/v1/namespaced/3/') + with pytest.raises(serializers.ValidationError): + self.field.to_internal_value('/v2/namespaced/3/') diff --git a/tests/utils.py b/tests/utils.py index 5b2d7586..b9034996 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,6 +2,30 @@ from django.core.exceptions import ObjectDoesNotExist from django.core.urlresolvers import NoReverseMatch +class UsingURLPatterns(object): + """ + Isolates URL patterns used during testing on the test class itself. + For example: + + class MyTestCase(UsingURLPatterns, TestCase): + urlpatterns = [ + ... + ] + + def test_something(self): + ... + """ + urls = __name__ + + def setUp(self): + global urlpatterns + urlpatterns = self.urlpatterns + + def tearDown(self): + global urlpatterns + urlpatterns = [] + + class MockObject(object): def __init__(self, **kwargs): self._kwargs = kwargs -- cgit v1.2.3 From f98f842827c6e79bbaa196482e3c3c549e8999c8 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 5 Feb 2015 01:24:55 +0000 Subject: Minor bits of test cleanup --- tests/test_versioning.py | 39 +++++++++++++++------------------------ 1 file changed, 15 insertions(+), 24 deletions(-) (limited to 'tests') diff --git a/tests/test_versioning.py b/tests/test_versioning.py index cdd10065..553463d1 100644 --- a/tests/test_versioning.py +++ b/tests/test_versioning.py @@ -1,6 +1,5 @@ -from .utils import MockObject, MockQueryset, UsingURLPatterns +from .utils import UsingURLPatterns from django.conf.urls import include, url -from django.core.exceptions import ObjectDoesNotExist from rest_framework import serializers from rest_framework import status, versioning from rest_framework.decorators import APIView @@ -33,8 +32,8 @@ class RequestInvalidVersionView(APIView): factory = APIRequestFactory() -mock_view = lambda request: None -dummy_view = lambda request, pk: None +dummy_view = lambda request: None +dummy_pk_view = lambda request, pk: None class TestRequestVersion: @@ -112,14 +111,14 @@ class TestRequestVersion: class TestURLReversing(UsingURLPatterns, APITestCase): included = [ - url(r'^namespaced/$', mock_view, name='another'), - url(r'^example/(?P\d+)/$', dummy_view, name='example-detail') + url(r'^namespaced/$', dummy_view, name='another'), + url(r'^example/(?P\d+)/$', dummy_pk_view, name='example-detail') ] urlpatterns = [ url(r'^v1/', include(included, namespace='v1')), - url(r'^another/$', mock_view, name='another'), - url(r'^(?P[^/]+)/another/$', mock_view, name='another'), + url(r'^another/$', dummy_view, name='another'), + url(r'^(?P[^/]+)/another/$', dummy_view, name='another'), ] def test_reverse_unversioned(self): @@ -230,7 +229,7 @@ class TestInvalidVersion: class TestHyperlinkedRelatedField(UsingURLPatterns, APITestCase): included = [ - url(r'^namespaced/(?P\d+)/$', mock_view, name='namespaced'), + url(r'^namespaced/(?P\d+)/$', dummy_view, name='namespaced'), ] urlpatterns = [ @@ -241,28 +240,20 @@ class TestHyperlinkedRelatedField(UsingURLPatterns, APITestCase): def setUp(self): super(TestHyperlinkedRelatedField, self).setUp() - class HyperlinkedMockQueryset(MockQueryset): - def get(self, **lookup): - for item in self.items: - if item.pk == int(lookup.get('pk', -1)): - return item - raise ObjectDoesNotExist() - - self.queryset = HyperlinkedMockQueryset([ - MockObject(pk=1, name='foo'), - MockObject(pk=2, name='bar'), - MockObject(pk=3, name='baz') - ]) + class MockQueryset(object): + def get(self, pk): + return 'object %s' % pk + self.field = serializers.HyperlinkedRelatedField( view_name='namespaced', - queryset=self.queryset + queryset=MockQueryset() ) - request = factory.post('/', urlconf='tests.test_versioning') + request = factory.get('/') request.versioning_scheme = NamespaceVersioning() request.version = 'v1' self.field._context = {'request': request} def test_bug_2489(self): - self.field.to_internal_value('/v1/namespaced/3/') + assert self.field.to_internal_value('/v1/namespaced/3/') == 'object 3' with pytest.raises(serializers.ValidationError): self.field.to_internal_value('/v2/namespaced/3/') -- cgit v1.2.3