aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
authorTom Christie2012-10-02 19:54:24 +0100
committerTom Christie2012-10-02 19:54:24 +0100
commite1518de68fa9ad4b2628894abf17924e386ccbda (patch)
treeb1d190f6bd08ce087cf0a2d3e8cd681b5e7e7d04 /rest_framework
parent31b06f1721f98730556dc56927b985e4032788c3 (diff)
parentd1b99f350aded62fe480f7dc4749cd63d52715d2 (diff)
downloaddjango-rest-framework-e1518de68fa9ad4b2628894abf17924e386ccbda.tar.bz2
Merge branch 'restframework2' of https://github.com/tomchristie/django-rest-framework into restframework2
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/fields.py32
-rw-r--r--rest_framework/generics.py55
-rw-r--r--rest_framework/mixins.py21
-rw-r--r--rest_framework/pagination.py80
-rw-r--r--rest_framework/response.py2
-rw-r--r--rest_framework/serializers.py40
-rw-r--r--rest_framework/settings.py8
-rw-r--r--rest_framework/templatetags/rest_framework.py2
-rw-r--r--rest_framework/tests/generics.py39
-rw-r--r--rest_framework/tests/pagination.py87
-rw-r--r--rest_framework/tests/renderers.py19
-rw-r--r--rest_framework/tests/response.py7
12 files changed, 345 insertions, 47 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index eab90617..85ee5430 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -139,7 +139,13 @@ class Field(object):
if hasattr(self, 'model_field'):
return self.to_native(self.model_field._get_val_from_obj(obj))
- return self.to_native(getattr(obj, self.source or field_name))
+ if self.source:
+ value = obj
+ for component in self.source.split('.'):
+ value = getattr(value, component)
+ else:
+ value = getattr(obj, field_name)
+ return self.to_native(value)
def to_native(self, value):
"""
@@ -152,6 +158,8 @@ class Field(object):
return value
elif hasattr(self, 'model_field'):
return self.model_field.value_to_string(self.obj)
+ elif hasattr(value, '__iter__') and not isinstance(value, (dict, basestring)):
+ return [self.to_native(item) for item in value]
return smart_unicode(value)
def attributes(self):
@@ -175,7 +183,7 @@ class RelatedField(Field):
"""
def field_to_native(self, obj, field_name):
- obj = getattr(obj, field_name)
+ obj = getattr(obj, self.source or field_name)
if obj.__class__.__name__ in ('RelatedManager', 'ManyRelatedManager'):
return [self.to_native(item) for item in obj.all()]
return self.to_native(obj)
@@ -215,10 +223,10 @@ class PrimaryKeyRelatedField(RelatedField):
def field_to_native(self, obj, field_name):
try:
- obj = obj.serializable_value(field_name)
+ obj = obj.serializable_value(self.source or field_name)
except AttributeError:
field = obj._meta.get_field_by_name(field_name)[0]
- obj = getattr(obj, field_name)
+ obj = getattr(obj, self.source or field_name)
if obj.__class__.__name__ == 'RelatedManager':
return [self.to_native(item.pk) for item in obj.all()]
elif isinstance(field, RelatedObject):
@@ -431,19 +439,3 @@ class FloatField(Field):
except (TypeError, ValueError):
msg = self.error_messages['invalid'] % value
raise ValidationError(msg)
-
-# field_mapping = {
-# models.AutoField: IntegerField,
-# models.BooleanField: BooleanField,
-# models.CharField: CharField,
-# models.DateTimeField: DateTimeField,
-# models.DateField: DateField,
-# models.BigIntegerField: IntegerField,
-# models.IntegerField: IntegerField,
-# models.PositiveIntegerField: IntegerField,
-# models.FloatField: FloatField
-# }
-
-
-# def modelfield_to_serializerfield(field):
-# return field_mapping.get(type(field), Field)
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index 4240e33e..8647ad42 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -2,7 +2,8 @@
Generic views that provide commmonly needed behaviour.
"""
-from rest_framework import views, mixins, serializers
+from rest_framework import views, mixins
+from rest_framework.settings import api_settings
from django.views.generic.detail import SingleObjectMixin
from django.views.generic.list import MultipleObjectMixin
@@ -14,22 +15,39 @@ class BaseView(views.APIView):
Base class for all other generic views.
"""
serializer_class = None
+ model_serializer_class = api_settings.MODEL_SERIALIZER
- def get_serializer(self, data=None, files=None, instance=None):
- # TODO: add support for files
- # TODO: add support for seperate serializer/deserializer
+ def get_serializer_context(self):
+ """
+ Extra context provided to the serializer class.
+ """
+ return {
+ 'request': self.request,
+ 'format': self.format,
+ 'view': self
+ }
+
+ def get_serializer_class(self):
+ """
+ Return the class to use for the serializer.
+ Use `self.serializer_class`, falling back to constructing a
+ model serializer class from `self.model_serializer_class`
+ """
serializer_class = self.serializer_class
if serializer_class is None:
- class DefaultSerializer(serializers.ModelSerializer):
+ class DefaultSerializer(self.model_serializer_class):
class Meta:
model = self.model
serializer_class = DefaultSerializer
- context = {
- 'request': self.request,
- 'format': self.kwargs.get('format', None)
- }
+ return serializer_class
+
+ def get_serializer(self, data=None, files=None, instance=None):
+ # TODO: add support for files
+ # TODO: add support for seperate serializer/deserializer
+ serializer_class = self.get_serializer_class()
+ context = self.get_serializer_context()
return serializer_class(data, instance=instance, context=context)
@@ -37,7 +55,24 @@ class MultipleObjectBaseView(MultipleObjectMixin, BaseView):
"""
Base class for generic views onto a queryset.
"""
- pass
+
+ pagination_serializer_class = api_settings.PAGINATION_SERIALIZER
+ paginate_by = api_settings.PAGINATE_BY
+
+ def get_pagination_serializer_class(self):
+ """
+ Return the class to use for the pagination serializer.
+ """
+ class SerializerClass(self.pagination_serializer_class):
+ class Meta:
+ object_serializer_class = self.get_serializer_class()
+
+ return SerializerClass
+
+ def get_pagination_serializer(self, page=None):
+ pagination_serializer_class = self.get_pagination_serializer_class()
+ context = self.get_serializer_context()
+ return pagination_serializer_class(instance=page, context=context)
class SingleObjectBaseView(SingleObjectMixin, BaseView):
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index fe12dc8f..167cd89a 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -7,6 +7,7 @@ which allows mixin classes to be composed in interesting ways.
Eg. Use mixins to build a Resource class, and have a Router class
perform the binding of http methods to actions for us.
"""
+from django.http import Http404
from rest_framework import status
from rest_framework.response import Response
@@ -30,9 +31,27 @@ class ListModelMixin(object):
List a queryset.
Should be mixed in with `MultipleObjectBaseView`.
"""
+ empty_error = u"Empty list and '%(class_name)s.allow_empty' is False."
+
def list(self, request, *args, **kwargs):
self.object_list = self.get_queryset()
- serializer = self.get_serializer(instance=self.object_list)
+
+ # Default is to allow empty querysets. This can be altered by setting
+ # `.allow_empty = False`, to raise 404 errors on empty querysets.
+ allow_empty = self.get_allow_empty()
+ if not allow_empty and len(self.object_list) == 0:
+ error_args = {'class_name': self.__class__.__name__}
+ raise Http404(self.empty_error % error_args)
+
+ # Pagination size is set by the `.paginate_by` attribute,
+ # which may be `None` to disable pagination.
+ page_size = self.get_paginate_by(self.object_list)
+ if page_size:
+ paginator, page, queryset, is_paginated = self.paginate_queryset(self.object_list, page_size)
+ serializer = self.get_pagination_serializer(page)
+ else:
+ serializer = self.get_serializer(instance=self.object_list)
+
return Response(serializer.data)
diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py
new file mode 100644
index 00000000..131718fd
--- /dev/null
+++ b/rest_framework/pagination.py
@@ -0,0 +1,80 @@
+from rest_framework import serializers
+
+# TODO: Support URLconf kwarg-style paging
+
+
+class NextPageField(serializers.Field):
+ """
+ Field that returns a link to the next page in paginated results.
+ """
+ def to_native(self, value):
+ if not value.has_next():
+ return None
+ page = value.next_page_number()
+ request = self.context.get('request')
+ relative_url = '?page=%d' % page
+ if request:
+ return request.build_absolute_uri(relative_url)
+ return relative_url
+
+
+class PreviousPageField(serializers.Field):
+ """
+ Field that returns a link to the previous page in paginated results.
+ """
+ def to_native(self, value):
+ if not value.has_previous():
+ return None
+ page = value.previous_page_number()
+ request = self.context.get('request')
+ relative_url = '?page=%d' % page
+ if request:
+ return request.build_absolute_uri('?page=%d' % page)
+ return relative_url
+
+
+class PaginationSerializerOptions(serializers.SerializerOptions):
+ """
+ An object that stores the options that may be provided to a
+ pagination serializer by using the inner `Meta` class.
+
+ Accessible on the instance as `serializer.opts`.
+ """
+ def __init__(self, meta):
+ super(PaginationSerializerOptions, self).__init__(meta)
+ self.object_serializer_class = getattr(meta, 'object_serializer_class',
+ serializers.Field)
+
+
+class BasePaginationSerializer(serializers.Serializer):
+ """
+ A base class for pagination serializers to inherit from,
+ to make implementing custom serializers more easy.
+ """
+ _options_class = PaginationSerializerOptions
+ results_field = 'results'
+
+ def __init__(self, *args, **kwargs):
+ """
+ Override init to add in the object serializer field on-the-fly.
+ """
+ super(BasePaginationSerializer, self).__init__(*args, **kwargs)
+ results_field = self.results_field
+ object_serializer = self.opts.object_serializer_class
+ self.fields[results_field] = object_serializer(source='object_list')
+
+ def to_native(self, obj):
+ """
+ Prevent default behaviour of iterating over elements, and serializing
+ each in turn.
+ """
+ return self.convert_object(obj)
+
+
+class PaginationSerializer(BasePaginationSerializer):
+ """
+ A default implementation of a pagination serializer.
+ """
+ count = serializers.Field(source='paginator.count')
+ next = NextPageField(source='*')
+ previous = PreviousPageField(source='*')
diff --git a/rest_framework/response.py b/rest_framework/response.py
index 90516837..db6bf3e2 100644
--- a/rest_framework/response.py
+++ b/rest_framework/response.py
@@ -33,6 +33,8 @@ class Response(SimpleTemplateResponse):
@property
def rendered_content(self):
+ assert self.renderer, "No renderer set on Response"
+
self['Content-Type'] = self.renderer.media_type
if self.data is None:
return self.renderer.render()
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 9cbdb9de..a2f211ab 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -3,6 +3,7 @@ import datetime
import types
from decimal import Decimal
from django.core.serializers.base import DeserializedObject
+from django.db import models
from django.utils.datastructures import SortedDict
from rest_framework.compat import get_concrete_model
from rest_framework.fields import *
@@ -70,7 +71,7 @@ class SerializerMetaclass(type):
class SerializerOptions(object):
"""
- Meta class options for ModelSerializer
+ Meta class options for Serializer
"""
def __init__(self, meta):
self.nested = getattr(meta, 'nested', False)
@@ -308,17 +309,31 @@ class ModelSerializer(RelatedField, Serializer):
fields += [field for field in opts.many_to_many if field.serialize]
ret = SortedDict()
+ is_pk = True # First field in the list is the pk
+
for model_field in fields:
- if model_field.rel and nested:
+ if is_pk:
+ field = self.get_pk_field(model_field)
+ is_pk = False
+ elif model_field.rel and nested:
field = self.get_nested_field(model_field)
elif model_field.rel:
field = self.get_related_field(model_field)
else:
field = self.get_field(model_field)
- field.initialize(parent=self, model_field=model_field)
- ret[model_field.name] = field
+
+ if field is not None:
+ field.initialize(parent=self, model_field=model_field)
+ ret[model_field.name] = field
+
return ret
+ def get_pk_field(self, model_field):
+ """
+ Returns a default instance of the pk field.
+ """
+ return Field(readonly=True)
+
def get_nested_field(self, model_field):
"""
Creates a default instance of a nested relational field.
@@ -333,9 +348,22 @@ class ModelSerializer(RelatedField, Serializer):
def get_field(self, model_field):
"""
- Creates a default instance of a basic field.
+ Creates a default instance of a basic non-relational field.
"""
- return Field()
+ field_mapping = dict([
+ [models.FloatField, FloatField],
+ [models.IntegerField, IntegerField],
+ [models.DateTimeField, DateTimeField],
+ [models.DateField, DateField],
+ [models.EmailField, EmailField],
+ [models.CharField, CharField],
+ [models.CommaSeparatedIntegerField, CharField],
+ [models.BooleanField, BooleanField]
+ ])
+ try:
+ return field_mapping[model_field.__class__]()
+ except KeyError:
+ return Field()
def restore_object(self, attrs, instance=None):
"""
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index cfc89fe1..2e50e05d 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -44,13 +44,17 @@ DEFAULTS = {
'anon': None,
},
+ 'MODEL_SERIALIZER': 'rest_framework.serializers.ModelSerializer',
+ 'PAGINATION_SERIALIZER': 'rest_framework.pagination.PaginationSerializer',
+ 'PAGINATE_BY': 20,
+
'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
'UNAUTHENTICATED_TOKEN': None,
'FORM_METHOD_OVERRIDE': '_method',
'FORM_CONTENT_OVERRIDE': '_content',
'FORM_CONTENTTYPE_OVERRIDE': '_content_type',
- 'URL_ACCEPT_OVERRIDE': '_accept',
+ 'URL_ACCEPT_OVERRIDE': 'accept',
'URL_FORMAT_OVERRIDE': 'format',
'FORMAT_SUFFIX_KWARG': 'format'
@@ -65,6 +69,8 @@ IMPORT_STRINGS = (
'DEFAULT_PERMISSIONS',
'DEFAULT_THROTTLES',
'DEFAULT_CONTENT_NEGOTIATION',
+ 'MODEL_SERIALIZER',
+ 'PAGINATION_SERIALIZER',
'UNAUTHENTICATED_USER',
'UNAUTHENTICATED_TOKEN',
)
diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py
index 377fd489..c9b6eb10 100644
--- a/rest_framework/templatetags/rest_framework.py
+++ b/rest_framework/templatetags/rest_framework.py
@@ -1,5 +1,5 @@
from django import template
-from django.core.urlresolvers import reverse, NoReverseMatch
+from django.core.urlresolvers import reverse
from django.http import QueryDict
from django.utils.encoding import force_unicode
from django.utils.html import escape
diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py
index fee6e3a6..76662373 100644
--- a/rest_framework/tests/generics.py
+++ b/rest_framework/tests/generics.py
@@ -13,6 +13,7 @@ class RootView(generics.RootAPIView):
Example description for OPTIONS.
"""
model = BasicModel
+ paginate_by = None
class InstanceView(generics.InstanceAPIView):
@@ -51,7 +52,8 @@ class TestRootView(TestCase):
POST requests to RootAPIView should create a new object.
"""
content = {'text': 'foobar'}
- request = factory.post('/', json.dumps(content), content_type='application/json')
+ request = factory.post('/', json.dumps(content),
+ content_type='application/json')
response = self.view(request).render()
self.assertEquals(response.status_code, status.HTTP_201_CREATED)
self.assertEquals(response.data, {'id': 4, 'text': u'foobar'})
@@ -63,7 +65,8 @@ class TestRootView(TestCase):
PUT requests to RootAPIView should not be allowed
"""
content = {'text': 'foobar'}
- request = factory.put('/', json.dumps(content), content_type='application/json')
+ request = factory.put('/', json.dumps(content),
+ content_type='application/json')
response = self.view(request).render()
self.assertEquals(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
self.assertEquals(response.data, {"detail": "Method 'PUT' not allowed."})
@@ -99,6 +102,19 @@ class TestRootView(TestCase):
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, expected)
+ def test_post_cannot_set_id(self):
+ """
+ POST requests to create a new object should not be able to set the id.
+ """
+ content = {'id': 999, 'text': 'foobar'}
+ request = factory.post('/', json.dumps(content),
+ content_type='application/json')
+ response = self.view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_201_CREATED)
+ self.assertEquals(response.data, {'id': 4, 'text': u'foobar'})
+ created = self.objects.get(id=4)
+ self.assertEquals(created.text, 'foobar')
+
class TestInstanceView(TestCase):
def setUp(self):
@@ -129,7 +145,8 @@ class TestInstanceView(TestCase):
POST requests to InstanceAPIView should not be allowed
"""
content = {'text': 'foobar'}
- request = factory.post('/', json.dumps(content), content_type='application/json')
+ request = factory.post('/', json.dumps(content),
+ content_type='application/json')
response = self.view(request).render()
self.assertEquals(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
self.assertEquals(response.data, {"detail": "Method 'POST' not allowed."})
@@ -139,7 +156,8 @@ class TestInstanceView(TestCase):
PUT requests to InstanceAPIView should update an object.
"""
content = {'text': 'foobar'}
- request = factory.put('/1', json.dumps(content), content_type='application/json')
+ request = factory.put('/1', json.dumps(content),
+ content_type='application/json')
response = self.view(request, pk=1).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, {'id': 1, 'text': 'foobar'})
@@ -178,3 +196,16 @@ class TestInstanceView(TestCase):
}
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, expected)
+
+ def test_put_cannot_set_id(self):
+ """
+ POST requests to create a new object should not be able to set the id.
+ """
+ content = {'id': 999, 'text': 'foobar'}
+ request = factory.put('/1', json.dumps(content),
+ content_type='application/json')
+ response = self.view(request, pk=1).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data, {'id': 1, 'text': 'foobar'})
+ updated = self.objects.get(id=1)
+ self.assertEquals(updated.text, 'foobar')
diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py
new file mode 100644
index 00000000..9e424cc5
--- /dev/null
+++ b/rest_framework/tests/pagination.py
@@ -0,0 +1,87 @@
+from django.core.paginator import Paginator
+from django.test import TestCase
+from django.test.client import RequestFactory
+from rest_framework import generics, status, pagination
+from rest_framework.tests.models import BasicModel
+
+factory = RequestFactory()
+
+
+class RootView(generics.RootAPIView):
+ """
+ Example description for OPTIONS.
+ """
+ model = BasicModel
+ paginate_by = 10
+
+
+class IntegrationTestPagination(TestCase):
+ """
+ Integration tests for paginated list views.
+ """
+
+ def setUp(self):
+ """
+ Create 26 BasicModel intances.
+ """
+ for char in 'abcdefghijklmnopqrstuvwxyz':
+ BasicModel(text=char * 3).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.view = RootView.as_view()
+
+ def test_get_paginated_root_view(self):
+ """
+ GET requests to paginated RootAPIView should return paginated results.
+ """
+ request = factory.get('/')
+ response = self.view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data['count'], 26)
+ self.assertEquals(response.data['results'], self.data[:10])
+ self.assertNotEquals(response.data['next'], None)
+ self.assertEquals(response.data['previous'], None)
+
+ request = factory.get(response.data['next'])
+ response = self.view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data['count'], 26)
+ self.assertEquals(response.data['results'], self.data[10:20])
+ self.assertNotEquals(response.data['next'], None)
+ self.assertNotEquals(response.data['previous'], None)
+
+ request = factory.get(response.data['next'])
+ response = self.view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data['count'], 26)
+ self.assertEquals(response.data['results'], self.data[20:])
+ self.assertEquals(response.data['next'], None)
+ self.assertNotEquals(response.data['previous'], None)
+
+
+class UnitTestPagination(TestCase):
+ """
+ Unit tests for pagination of primative objects.
+ """
+
+ def setUp(self):
+ self.objects = [char * 3 for char in 'abcdefghijklmnopqrstuvwxyz']
+ paginator = Paginator(self.objects, 10)
+ self.first_page = paginator.page(1)
+ self.last_page = paginator.page(3)
+
+ def test_native_pagination(self):
+ serializer = pagination.PaginationSerializer(instance=self.first_page)
+ self.assertEquals(serializer.data['count'], 26)
+ self.assertEquals(serializer.data['next'], '?page=2')
+ self.assertEquals(serializer.data['previous'], None)
+ self.assertEquals(serializer.data['results'], self.objects[:10])
+
+ serializer = pagination.PaginationSerializer(instance=self.last_page)
+ self.assertEquals(serializer.data['count'], 26)
+ self.assertEquals(serializer.data['next'], None)
+ self.assertEquals(serializer.data['previous'], '?page=2')
+ self.assertEquals(serializer.data['results'], self.objects[20:])
diff --git a/rest_framework/tests/renderers.py b/rest_framework/tests/renderers.py
index 751f548f..91d84848 100644
--- a/rest_framework/tests/renderers.py
+++ b/rest_framework/tests/renderers.py
@@ -11,6 +11,7 @@ from rest_framework.views import APIView
from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \
XMLRenderer, JSONPRenderer, DocumentingHTMLRenderer
from rest_framework.parsers import YAMLParser, XMLParser
+from rest_framework.settings import api_settings
from StringIO import StringIO
import datetime
@@ -164,7 +165,11 @@ class RendererEndToEndTests(TestCase):
def test_specified_renderer_serializes_content_on_accept_query(self):
"""The '_accept' query string should behave in the same way as the Accept header."""
- resp = self.client.get('/?_accept=%s' % RendererB.media_type)
+ param = '?%s=%s' % (
+ api_settings.URL_ACCEPT_OVERRIDE,
+ RendererB.media_type
+ )
+ resp = self.client.get('/' + param)
self.assertEquals(resp['Content-Type'], RendererB.media_type)
self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEquals(resp.status_code, DUMMYSTATUS)
@@ -177,7 +182,11 @@ class RendererEndToEndTests(TestCase):
def test_specified_renderer_serializes_content_on_format_query(self):
"""If a 'format' query is specified, the renderer with the matching
format attribute should serialize the response."""
- resp = self.client.get('/?format=%s' % RendererB.format)
+ param = '?%s=%s' % (
+ api_settings.URL_FORMAT_OVERRIDE,
+ RendererB.format
+ )
+ resp = self.client.get('/' + param)
self.assertEquals(resp['Content-Type'], RendererB.media_type)
self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEquals(resp.status_code, DUMMYSTATUS)
@@ -193,7 +202,11 @@ class RendererEndToEndTests(TestCase):
def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
"""If both a 'format' query and a matching Accept header specified,
the renderer with the matching format attribute should serialize the response."""
- resp = self.client.get('/?format=%s' % RendererB.format,
+ param = '?%s=%s' % (
+ api_settings.URL_FORMAT_OVERRIDE,
+ RendererB.format
+ )
+ resp = self.client.get('/' + param,
HTTP_ACCEPT=RendererB.media_type)
self.assertEquals(resp['Content-Type'], RendererB.media_type)
self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
diff --git a/rest_framework/tests/response.py b/rest_framework/tests/response.py
index af70a387..f74e54fc 100644
--- a/rest_framework/tests/response.py
+++ b/rest_framework/tests/response.py
@@ -11,6 +11,7 @@ from rest_framework.renderers import (
JSONRenderer,
DocumentingHTMLRenderer
)
+from rest_framework.settings import api_settings
class MockPickleRenderer(BaseRenderer):
@@ -121,7 +122,11 @@ class RendererIntegrationTests(TestCase):
def test_specified_renderer_serializes_content_on_accept_query(self):
"""The '_accept' query string should behave in the same way as the Accept header."""
- resp = self.client.get('/?_accept=%s' % RendererB.media_type)
+ param = '?%s=%s' % (
+ api_settings.URL_ACCEPT_OVERRIDE,
+ RendererB.media_type
+ )
+ resp = self.client.get('/' + param)
self.assertEquals(resp['Content-Type'], RendererB.media_type)
self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEquals(resp.status_code, DUMMYSTATUS)