aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
authorPhilip Douglas2013-09-10 13:09:25 +0100
committerPhilip Douglas2013-09-10 13:09:25 +0100
commit39e13a0d1341c0a0e694acb1522a99470c4037be (patch)
tree27b498f3cbf81faa1ff587d0730e07706c7551a8 /rest_framework
parentef7ce344865938bea285a408a7cc415a7b90a83c (diff)
parentf5c34926d6a4b4b29fb083d25b99b10d7431eee4 (diff)
downloaddjango-rest-framework-39e13a0d1341c0a0e694acb1522a99470c4037be.tar.bz2
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/fields.py24
-rw-r--r--rest_framework/generics.py31
-rw-r--r--rest_framework/mixins.py15
-rw-r--r--rest_framework/parsers.py10
-rw-r--r--rest_framework/relations.py13
-rw-r--r--rest_framework/renderers.py326
-rw-r--r--rest_framework/request.py23
-rw-r--r--rest_framework/routers.py6
-rw-r--r--rest_framework/serializers.py70
-rw-r--r--rest_framework/settings.py12
-rw-r--r--rest_framework/static/rest_framework/js/default.js45
-rw-r--r--rest_framework/templates/rest_framework/base.html18
-rw-r--r--rest_framework/test.py2
-rw-r--r--rest_framework/tests/test_description.py9
-rw-r--r--rest_framework/tests/test_fields.py17
-rw-r--r--rest_framework/tests/test_files.py37
-rw-r--r--rest_framework/tests/test_generics.py53
-rw-r--r--rest_framework/tests/test_pagination.py47
-rw-r--r--rest_framework/tests/test_relations_nested.py351
-rw-r--r--rest_framework/tests/test_relations_pk.py9
-rw-r--r--rest_framework/tests/test_routers.py2
-rw-r--r--rest_framework/tests/test_testing.py30
-rw-r--r--rest_framework/tests/test_views.py41
-rw-r--r--rest_framework/throttling.py11
-rw-r--r--rest_framework/utils/breadcrumbs.py6
-rw-r--r--rest_framework/utils/formatting.py42
-rw-r--r--rest_framework/views.py122
27 files changed, 1052 insertions, 320 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index add9d224..210c2537 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -16,6 +16,7 @@ from django.core import validators
from django.core.exceptions import ValidationError
from django.conf import settings
from django.db.models.fields import BLANK_CHOICE_DASH
+from django.http import QueryDict
from django.forms import widgets
from django.utils.encoding import is_protected_type
from django.utils.translation import ugettext_lazy as _
@@ -307,7 +308,10 @@ class WritableField(Field):
try:
if self.use_files:
files = files or {}
- native = files[field_name]
+ try:
+ native = files[field_name]
+ except KeyError:
+ native = data[field_name]
else:
native = data[field_name]
except KeyError:
@@ -399,10 +403,15 @@ class BooleanField(WritableField):
}
empty = False
- # Note: we set default to `False` in order to fill in missing value not
- # supplied by html form. TODO: Fix so that only html form input gets
- # this behavior.
- default = False
+ def field_from_native(self, data, files, field_name, into):
+ # HTML checkboxes do not explicitly represent unchecked as `False`
+ # we deal with that here...
+ if isinstance(data, QueryDict):
+ self.default = False
+
+ return super(BooleanField, self).field_from_native(
+ data, files, field_name, into
+ )
def from_native(self, value):
if value in ('true', 't', 'True', '1'):
@@ -505,6 +514,11 @@ class ChoiceField(WritableField):
return True
return False
+ def from_native(self, value):
+ if value in validators.EMPTY_VALUES:
+ return None
+ return super(ChoiceField, self).from_native(value)
+
class EmailField(CharField):
type_name = 'EmailField'
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index 99e9782e..7d1bf794 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -14,6 +14,17 @@ from rest_framework.settings import api_settings
import warnings
+def strict_positive_int(integer_string, cutoff=None):
+ """
+ Cast a string to a strictly positive integer.
+ """
+ ret = int(integer_string)
+ if ret <= 0:
+ raise ValueError()
+ if cutoff:
+ ret = min(ret, cutoff)
+ return ret
+
def get_object_or_404(queryset, **filter_kwargs):
"""
Same as Django's standard shortcut, but make sure to raise 404
@@ -47,6 +58,7 @@ class GenericAPIView(views.APIView):
# Pagination settings
paginate_by = api_settings.PAGINATE_BY
paginate_by_param = api_settings.PAGINATE_BY_PARAM
+ max_paginate_by = api_settings.MAX_PAGINATE_BY
pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS
page_kwarg = 'page'
@@ -135,7 +147,7 @@ class GenericAPIView(views.APIView):
page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg)
page = page_kwarg or page_query_param or 1
try:
- page_number = int(page)
+ page_number = strict_positive_int(page)
except ValueError:
if page == 'last':
page_number = paginator.num_pages
@@ -196,9 +208,11 @@ class GenericAPIView(views.APIView):
PendingDeprecationWarning, stacklevel=2)
if self.paginate_by_param:
- query_params = self.request.QUERY_PARAMS
try:
- return int(query_params[self.paginate_by_param])
+ return strict_positive_int(
+ self.request.QUERY_PARAMS[self.paginate_by_param],
+ cutoff=self.max_paginate_by
+ )
except (KeyError, ValueError):
pass
@@ -342,8 +356,15 @@ class GenericAPIView(views.APIView):
self.check_permissions(cloned_request)
# Test object permissions
if method == 'PUT':
- self.get_object()
- except (exceptions.APIException, PermissionDenied, Http404):
+ try:
+ self.get_object()
+ except Http404:
+ # Http404 should be acceptable and the serializer
+ # metadata should be populated. Except this so the
+ # outer "else" clause of the try-except-else block
+ # will be executed.
+ pass
+ except (exceptions.APIException, PermissionDenied):
pass
else:
# If user has appropriate permissions for the view, include
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index f11def6d..426865ff 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -142,11 +142,16 @@ class UpdateModelMixin(object):
try:
return self.get_object()
except Http404:
- # If this is a PUT-as-create operation, we need to ensure that
- # we have relevant permissions, as if this was a POST request.
- # This will either raise a PermissionDenied exception,
- # or simply return None
- self.check_permissions(clone_request(self.request, 'POST'))
+ if self.request.method == 'PUT':
+ # For PUT-as-create operation, we need to ensure that we have
+ # relevant permissions, as if this was a POST request. This
+ # will either raise a PermissionDenied exception, or simply
+ # return None.
+ self.check_permissions(clone_request(self.request, 'POST'))
+ else:
+ # PATCH requests where the object does not exist should still
+ # return a 404 response.
+ raise
def pre_save(self, obj):
"""
diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py
index 96bfac84..98fc0341 100644
--- a/rest_framework/parsers.py
+++ b/rest_framework/parsers.py
@@ -10,9 +10,9 @@ from django.core.files.uploadhandler import StopFutureHandlers
from django.http import QueryDict
from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser
from django.http.multipartparser import MultiPartParserError, parse_header, ChunkIter
-from rest_framework.compat import yaml, etree
+from rest_framework.compat import etree, six, yaml
from rest_framework.exceptions import ParseError
-from rest_framework.compat import six
+from rest_framework import renderers
import json
import datetime
import decimal
@@ -47,6 +47,7 @@ class JSONParser(BaseParser):
"""
media_type = 'application/json'
+ renderer_class = renderers.UnicodeJSONRenderer
def parse(self, stream, media_type=None, parser_context=None):
"""
@@ -121,7 +122,8 @@ class MultiPartParser(BaseParser):
parser_context = parser_context or {}
request = parser_context['request']
encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
- meta = request.META
+ meta = request.META.copy()
+ meta['CONTENT_TYPE'] = media_type
upload_handlers = request.upload_handlers
try:
@@ -129,7 +131,7 @@ class MultiPartParser(BaseParser):
data, files = parser.parse()
return DataAndFiles(data, files)
except MultiPartParserError as exc:
- raise ParseError('Multipart form parse error - %s' % six.u(exc))
+ raise ParseError('Multipart form parse error - %s' % str(exc))
class XMLParser(BaseParser):
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index edaf76d6..35c00bf1 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -134,9 +134,9 @@ class RelatedField(WritableField):
value = obj
for component in source.split('.'):
- value = get_component(value, component)
if value is None:
break
+ value = get_component(value, component)
except ObjectDoesNotExist:
return None
@@ -244,6 +244,8 @@ class PrimaryKeyRelatedField(RelatedField):
source = self.source or field_name
queryset = obj
for component in source.split('.'):
+ if queryset is None:
+ return []
queryset = get_component(queryset, component)
# Forward relationship
@@ -262,7 +264,7 @@ class PrimaryKeyRelatedField(RelatedField):
# RelatedObject (reverse relationship)
try:
pk = getattr(obj, self.source or field_name).pk
- except ObjectDoesNotExist:
+ except (ObjectDoesNotExist, AttributeError):
return None
# Forward relationship
@@ -567,8 +569,13 @@ class HyperlinkedIdentityField(Field):
May raise a `NoReverseMatch` if the `view_name` and `lookup_field`
attributes are not configured to correctly match the URL conf.
"""
- lookup_field = getattr(obj, self.lookup_field)
+ lookup_field = getattr(obj, self.lookup_field, None)
kwargs = {self.lookup_field: lookup_field}
+
+ # Handle unsaved object case
+ if lookup_field is None:
+ return None
+
try:
return reverse(view_name, kwargs=kwargs, request=request, format=format)
except NoReverseMatch:
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index 3a03ca33..fca67eee 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -21,11 +21,10 @@ from rest_framework.compat import six
from rest_framework.compat import smart_text
from rest_framework.compat import yaml
from rest_framework.settings import api_settings
-from rest_framework.request import clone_request
+from rest_framework.request import is_form_media_type, override_method
from rest_framework.utils import encoders
from rest_framework.utils.breadcrumbs import get_breadcrumbs
-from rest_framework.utils.formatting import get_view_name, get_view_description
-from rest_framework import exceptions, parsers, status, VERSION
+from rest_framework import exceptions, status, VERSION
class BaseRenderer(object):
@@ -37,6 +36,7 @@ class BaseRenderer(object):
media_type = None
format = None
charset = 'utf-8'
+ render_style = 'text'
def render(self, data, accepted_media_type=None, renderer_context=None):
raise NotImplemented('Renderer class requires .render() to be implemented')
@@ -52,16 +52,17 @@ class JSONRenderer(BaseRenderer):
format = 'json'
encoder_class = encoders.JSONEncoder
ensure_ascii = True
- charset = 'utf-8'
- # Note that JSON encodings must be utf-8, utf-16 or utf-32.
+ charset = None
+ # JSON is a binary encoding, that can be encoded as utf-8, utf-16 or utf-32.
# See: http://www.ietf.org/rfc/rfc4627.txt
+ # Also: http://lucumr.pocoo.org/2013/7/19/application-mimetypes-and-encodings/
def render(self, data, accepted_media_type=None, renderer_context=None):
"""
Render `data` into JSON.
"""
if data is None:
- return ''
+ return bytes()
# If 'indent' is provided in the context, then pretty print the result.
# E.g. If we're being called by the BrowsableAPIRenderer.
@@ -86,13 +87,12 @@ class JSONRenderer(BaseRenderer):
# and may (or may not) be unicode.
# On python 3.x json.dumps() returns unicode strings.
if isinstance(ret, six.text_type):
- return bytes(ret.encode(self.charset))
+ return bytes(ret.encode('utf-8'))
return ret
class UnicodeJSONRenderer(JSONRenderer):
ensure_ascii = False
- charset = 'utf-8'
"""
Renderer which serializes to JSON.
Does *not* apply JSON's character escaping for non-ascii characters.
@@ -109,6 +109,7 @@ class JSONPRenderer(JSONRenderer):
format = 'jsonp'
callback_parameter = 'callback'
default_callback = 'callback'
+ charset = 'utf-8'
def get_callback(self, renderer_context):
"""
@@ -317,6 +318,90 @@ class StaticHTMLRenderer(TemplateHTMLRenderer):
return data
+class HTMLFormRenderer(BaseRenderer):
+ """
+ Renderers serializer data into an HTML form.
+
+ If the serializer was instantiated without an object then this will
+ return an HTML form not bound to any object,
+ otherwise it will return an HTML form with the appropriate initial data
+ populated from the object.
+
+ Note that rendering of field and form errors is not currently supported.
+ """
+ media_type = 'text/html'
+ format = 'form'
+ template = 'rest_framework/form.html'
+ charset = 'utf-8'
+
+ def data_to_form_fields(self, data):
+ fields = {}
+ for key, val in data.fields.items():
+ if getattr(val, 'read_only', True):
+ # Don't include read-only fields.
+ continue
+
+ if getattr(val, 'fields', None):
+ # Nested data not supported by HTML forms.
+ continue
+
+ kwargs = {}
+ kwargs['required'] = val.required
+
+ #if getattr(v, 'queryset', None):
+ # kwargs['queryset'] = v.queryset
+
+ if getattr(val, 'choices', None) is not None:
+ kwargs['choices'] = val.choices
+
+ if getattr(val, 'regex', None) is not None:
+ kwargs['regex'] = val.regex
+
+ if getattr(val, 'widget', None):
+ widget = copy.deepcopy(val.widget)
+ kwargs['widget'] = widget
+
+ if getattr(val, 'default', None) is not None:
+ kwargs['initial'] = val.default
+
+ if getattr(val, 'label', None) is not None:
+ kwargs['label'] = val.label
+
+ if getattr(val, 'help_text', None) is not None:
+ kwargs['help_text'] = val.help_text
+
+ fields[key] = val.form_field_class(**kwargs)
+
+ return fields
+
+ def render(self, data, accepted_media_type=None, renderer_context=None):
+ """
+ Render serializer data and return an HTML form, as a string.
+ """
+ # The HTMLFormRenderer currently uses something of a hack to render
+ # the content, by translating each of the serializer fields into
+ # an html form field, creating a dynamic form using those fields,
+ # and then rendering that form.
+
+ # This isn't strictly neccessary, as we could render the serilizer
+ # fields to HTML directly. The implementation is historical and will
+ # likely change at some point.
+
+ self.renderer_context = renderer_context or {}
+ request = renderer_context['request']
+
+ # Creating an on the fly form see:
+ # http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python
+ fields = self.data_to_form_fields(data)
+ DynamicForm = type(str('DynamicForm'), (forms.Form,), fields)
+ data = None if data.empty else data
+
+ template = loader.get_template(self.template)
+ context = RequestContext(request, {'form': DynamicForm(data)})
+
+ return template.render(context)
+
+
class BrowsableAPIRenderer(BaseRenderer):
"""
HTML renderer used to self-document the API.
@@ -325,6 +410,7 @@ class BrowsableAPIRenderer(BaseRenderer):
format = 'api'
template = 'rest_framework/api.html'
charset = 'utf-8'
+ form_renderer_class = HTMLFormRenderer
def get_default_renderer(self, view):
"""
@@ -349,7 +435,10 @@ class BrowsableAPIRenderer(BaseRenderer):
renderer_context['indent'] = 4
content = renderer.render(data, accepted_media_type, renderer_context)
- if renderer.charset is None:
+ render_style = getattr(renderer, 'render_style', 'text')
+ assert render_style in ['text', 'binary'], 'Expected .render_style ' \
+ '"text" or "binary", but got "%s"' % render_style
+ if render_style == 'binary':
return '[%d bytes of binary content]' % len(content)
return content
@@ -372,136 +461,105 @@ class BrowsableAPIRenderer(BaseRenderer):
return False # Doesn't have permissions
return True
- def serializer_to_form_fields(self, serializer):
- fields = {}
- for k, v in serializer.get_fields().items():
- if getattr(v, 'read_only', True):
- continue
-
- kwargs = {}
- kwargs['required'] = v.required
-
- #if getattr(v, 'queryset', None):
- # kwargs['queryset'] = v.queryset
-
- if getattr(v, 'choices', None) is not None:
- kwargs['choices'] = v.choices
-
- if getattr(v, 'regex', None) is not None:
- kwargs['regex'] = v.regex
-
- if getattr(v, 'widget', None):
- widget = copy.deepcopy(v.widget)
- kwargs['widget'] = widget
-
- if getattr(v, 'default', None) is not None:
- kwargs['initial'] = v.default
-
- if getattr(v, 'label', None) is not None:
- kwargs['label'] = v.label
-
- if getattr(v, 'help_text', None) is not None:
- kwargs['help_text'] = v.help_text
-
- fields[k] = v.form_field_class(**kwargs)
-
- return fields
-
- def _get_form(self, view, method, request):
- # We need to impersonate a request with the correct method,
- # so that eg. any dynamic get_serializer_class methods return the
- # correct form for each method.
- restore = view.request
- request = clone_request(request, method)
- view.request = request
- try:
- return self.get_form(view, method, request)
- finally:
- view.request = restore
-
- def _get_raw_data_form(self, view, method, request, media_types):
- # We need to impersonate a request with the correct method,
- # so that eg. any dynamic get_serializer_class methods return the
- # correct form for each method.
- restore = view.request
- request = clone_request(request, method)
- view.request = request
- try:
- return self.get_raw_data_form(view, method, request, media_types)
- finally:
- view.request = restore
-
- def get_form(self, view, method, request):
+ def get_rendered_html_form(self, view, method, request):
"""
- Get a form, possibly bound to either the input or output data.
- In the absence on of the Resource having an associated form then
- provide a form that can be used to submit arbitrary content.
+ Return a string representing a rendered HTML form, possibly bound to
+ either the input or output data.
+
+ In the absence of the View having an associated form then return None.
"""
- obj = getattr(view, 'object', None)
- if not self.show_form_for_method(view, method, request, obj):
- return
+ with override_method(view, request, method) as request:
+ obj = getattr(view, 'object', None)
+ if not self.show_form_for_method(view, method, request, obj):
+ return
- if method in ('DELETE', 'OPTIONS'):
- return True # Don't actually need to return a form
+ if method in ('DELETE', 'OPTIONS'):
+ return True # Don't actually need to return a form
- if not getattr(view, 'get_serializer', None) or not parsers.FormParser in view.parser_classes:
- return
+ if (not getattr(view, 'get_serializer', None)
+ or not any(is_form_media_type(parser.media_type) for parser in view.parser_classes)):
+ return
- serializer = view.get_serializer(instance=obj)
- fields = self.serializer_to_form_fields(serializer)
+ serializer = view.get_serializer(instance=obj)
- # Creating an on the fly form see:
- # http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python
- OnTheFlyForm = type(str("OnTheFlyForm"), (forms.Form,), fields)
- data = (obj is not None) and serializer.data or None
- form_instance = OnTheFlyForm(data)
- return form_instance
+ data = serializer.data
+ form_renderer = self.form_renderer_class()
+ return form_renderer.render(data, self.accepted_media_type, self.renderer_context)
- def get_raw_data_form(self, view, method, request, media_types):
+ def get_raw_data_form(self, view, method, request):
"""
Returns a form that allows for arbitrary content types to be tunneled
via standard HTML forms.
(Which are typically application/x-www-form-urlencoded)
"""
-
- # If we're not using content overloading there's no point in supplying a generic form,
- # as the view won't treat the form's value as the content of the request.
- if not (api_settings.FORM_CONTENT_OVERRIDE
- and api_settings.FORM_CONTENTTYPE_OVERRIDE):
- return None
-
- # Check permissions
- obj = getattr(view, 'object', None)
- if not self.show_form_for_method(view, method, request, obj):
- return
-
- content_type_field = api_settings.FORM_CONTENTTYPE_OVERRIDE
- content_field = api_settings.FORM_CONTENT_OVERRIDE
- choices = [(media_type, media_type) for media_type in media_types]
- initial = media_types[0]
-
- # NB. http://jacobian.org/writing/dynamic-form-generation/
- class GenericContentForm(forms.Form):
- def __init__(self):
- super(GenericContentForm, self).__init__()
-
- self.fields[content_type_field] = forms.ChoiceField(
- label='Media type',
- choices=choices,
- initial=initial
- )
- self.fields[content_field] = forms.CharField(
- label='Content',
- widget=forms.Textarea
- )
-
- return GenericContentForm()
+ with override_method(view, request, method) as request:
+ # If we're not using content overloading there's no point in
+ # supplying a generic form, as the view won't treat the form's
+ # value as the content of the request.
+ if not (api_settings.FORM_CONTENT_OVERRIDE
+ and api_settings.FORM_CONTENTTYPE_OVERRIDE):
+ return None
+
+ # Check permissions
+ obj = getattr(view, 'object', None)
+ if not self.show_form_for_method(view, method, request, obj):
+ return
+
+ # If possible, serialize the initial content for the generic form
+ default_parser = view.parser_classes[0]
+ renderer_class = getattr(default_parser, 'renderer_class', None)
+ if (hasattr(view, 'get_serializer') and renderer_class):
+ # View has a serializer defined and parser class has a
+ # corresponding renderer that can be used to render the data.
+
+ # Get a read-only version of the serializer
+ serializer = view.get_serializer(instance=obj)
+ if obj is None:
+ for name, field in serializer.fields.items():
+ if getattr(field, 'read_only', None):
+ del serializer.fields[name]
+
+ # Render the raw data content
+ renderer = renderer_class()
+ accepted = self.accepted_media_type
+ context = self.renderer_context.copy()
+ context['indent'] = 4
+ content = renderer.render(serializer.data, accepted, context)
+ else:
+ content = None
+
+ # Generate a generic form that includes a content type field,
+ # and a content field.
+ content_type_field = api_settings.FORM_CONTENTTYPE_OVERRIDE
+ content_field = api_settings.FORM_CONTENT_OVERRIDE
+
+ media_types = [parser.media_type for parser in view.parser_classes]
+ choices = [(media_type, media_type) for media_type in media_types]
+ initial = media_types[0]
+
+ # NB. http://jacobian.org/writing/dynamic-form-generation/
+ class GenericContentForm(forms.Form):
+ def __init__(self):
+ super(GenericContentForm, self).__init__()
+
+ self.fields[content_type_field] = forms.ChoiceField(
+ label='Media type',
+ choices=choices,
+ initial=initial
+ )
+ self.fields[content_field] = forms.CharField(
+ label='Content',
+ widget=forms.Textarea,
+ initial=content
+ )
+
+ return GenericContentForm()
def get_name(self, view):
- return get_view_name(view.__class__, getattr(view, 'suffix', None))
+ return view.get_view_name()
def get_description(self, view):
- return get_view_description(view.__class__, html=True)
+ return view.get_view_description(html=True)
def get_breadcrumbs(self, request):
return get_breadcrumbs(request.path)
@@ -510,26 +568,25 @@ class BrowsableAPIRenderer(BaseRenderer):
"""
Render the HTML for the browsable API representation.
"""
- accepted_media_type = accepted_media_type or ''
- renderer_context = renderer_context or {}
+ self.accepted_media_type = accepted_media_type or ''
+ self.renderer_context = renderer_context or {}
view = renderer_context['view']
request = renderer_context['request']
response = renderer_context['response']
- media_types = [parser.media_type for parser in view.parser_classes]
renderer = self.get_default_renderer(view)
content = self.get_content(renderer, data, accepted_media_type, renderer_context)
- put_form = self._get_form(view, 'PUT', request)
- post_form = self._get_form(view, 'POST', request)
- patch_form = self._get_form(view, 'PATCH', request)
- delete_form = self._get_form(view, 'DELETE', request)
- options_form = self._get_form(view, 'OPTIONS', request)
+ put_form = self.get_rendered_html_form(view, 'PUT', request)
+ post_form = self.get_rendered_html_form(view, 'POST', request)
+ patch_form = self.get_rendered_html_form(view, 'PATCH', request)
+ delete_form = self.get_rendered_html_form(view, 'DELETE', request)
+ options_form = self.get_rendered_html_form(view, 'OPTIONS', request)
- raw_data_put_form = self._get_raw_data_form(view, 'PUT', request, media_types)
- raw_data_post_form = self._get_raw_data_form(view, 'POST', request, media_types)
- raw_data_patch_form = self._get_raw_data_form(view, 'PATCH', request, media_types)
+ raw_data_put_form = self.get_raw_data_form(view, 'PUT', request)
+ raw_data_post_form = self.get_raw_data_form(view, 'POST', request)
+ raw_data_patch_form = self.get_raw_data_form(view, 'PATCH', request)
raw_data_put_or_patch_form = raw_data_put_form or raw_data_patch_form
name = self.get_name(view)
@@ -582,3 +639,4 @@ class MultiPartRenderer(BaseRenderer):
def render(self, data, accepted_media_type=None, renderer_context=None):
return encode_multipart(self.BOUNDARY, data)
+
diff --git a/rest_framework/request.py b/rest_framework/request.py
index 919716f4..977d4d96 100644
--- a/rest_framework/request.py
+++ b/rest_framework/request.py
@@ -28,6 +28,29 @@ def is_form_media_type(media_type):
base_media_type == 'multipart/form-data')
+class override_method(object):
+ """
+ A context manager that temporarily overrides the method on a request,
+ additionally setting the `view.request` attribute.
+
+ Usage:
+
+ with override_method(view, request, 'POST') as request:
+ ... # Do stuff with `view` and `request`
+ """
+ def __init__(self, view, request, method):
+ self.view = view
+ self.request = request
+ self.method = method
+
+ def __enter__(self):
+ self.view.request = clone_request(self.request, self.method)
+ return self.view.request
+
+ def __exit__(self, *args, **kwarg):
+ self.view.request = self.request
+
+
class Empty(object):
"""
Placeholder for unset attributes.
diff --git a/rest_framework/routers.py b/rest_framework/routers.py
index 930011d3..3fee1e49 100644
--- a/rest_framework/routers.py
+++ b/rest_framework/routers.py
@@ -189,7 +189,11 @@ class SimpleRouter(BaseRouter):
Given a viewset, return the portion of URL regex that is used
to match against a single instance.
"""
- base_regex = '(?P<{lookup_field}>[^/]+)'
+ if self.trailing_slash:
+ base_regex = '(?P<{lookup_field}>[^/]+)'
+ else:
+ # Don't consume `.json` style suffixes
+ base_regex = '(?P<{lookup_field}>[^/.]+)'
lookup_field = getattr(viewset, 'lookup_field', 'pk')
return base_regex.format(lookup_field=lookup_field)
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 31cfa344..a63c7f6c 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -32,6 +32,9 @@ from rest_framework.relations import *
from rest_framework.fields import *
+class RelationsList(list):
+ _deleted = []
+
class NestedValidationError(ValidationError):
"""
The default ValidationError behavior is to stringify each item in the list
@@ -161,7 +164,6 @@ class BaseSerializer(WritableField):
self._data = None
self._files = None
self._errors = None
- self._deleted = None
if many and instance is not None and not hasattr(instance, '__iter__'):
raise ValueError('instance should be a queryset or other iterable with many=True')
@@ -298,7 +300,8 @@ class BaseSerializer(WritableField):
Serialize objects -> primitives.
"""
ret = self._dict_class()
- ret.fields = {}
+ ret.fields = self._dict_class()
+ ret.empty = obj is None
for field_name, field in self.fields.items():
field.initialize(parent=self, field_name=field_name)
@@ -331,14 +334,15 @@ class BaseSerializer(WritableField):
if self.source == '*':
return self.to_native(obj)
+ # Get the raw field value
try:
source = self.source or field_name
value = obj
for component in source.split('.'):
- value = get_component(value, component)
if value is None:
break
+ value = get_component(value, component)
except ObjectDoesNotExist:
return None
@@ -378,6 +382,7 @@ class BaseSerializer(WritableField):
# Set the serializer object if it exists
obj = getattr(self.parent.object, field_name) if self.parent.object else None
+ obj = obj.all() if is_simple_callable(getattr(obj, 'all', None)) else obj
if self.source == '*':
if value:
@@ -391,7 +396,8 @@ class BaseSerializer(WritableField):
'data': value,
'context': self.context,
'partial': self.partial,
- 'many': self.many
+ 'many': self.many,
+ 'allow_add_remove': self.allow_add_remove
}
serializer = self.__class__(**kwargs)
@@ -434,7 +440,7 @@ class BaseSerializer(WritableField):
DeprecationWarning, stacklevel=3)
if many:
- ret = []
+ ret = RelationsList()
errors = []
update = self.object is not None
@@ -461,8 +467,8 @@ class BaseSerializer(WritableField):
ret.append(self.from_native(item, None))
errors.append(self._errors)
- if update:
- self._deleted = identity_to_objects.values()
+ if update and self.allow_add_remove:
+ ret._deleted = identity_to_objects.values()
self._errors = any(errors) and errors or []
else:
@@ -514,12 +520,12 @@ class BaseSerializer(WritableField):
"""
if isinstance(self.object, list):
[self.save_object(item, **kwargs) for item in self.object]
+
+ if self.object._deleted:
+ [self.delete_object(item) for item in self.object._deleted]
else:
self.save_object(self.object, **kwargs)
- if self.allow_add_remove and self._deleted:
- [self.delete_object(item) for item in self._deleted]
-
return self.object
def metadata(self):
@@ -795,9 +801,12 @@ class ModelSerializer(Serializer):
cls = self.opts.model
opts = get_concrete_model(cls)._meta
exclusions = [field.name for field in opts.fields + opts.many_to_many]
+
for field_name, field in self.fields.items():
field_name = field.source or field_name
- if field_name in exclusions and not field.read_only:
+ if field_name in exclusions \
+ and not field.read_only \
+ and not isinstance(field, Serializer):
exclusions.remove(field_name)
return exclusions
@@ -823,6 +832,7 @@ class ModelSerializer(Serializer):
"""
m2m_data = {}
related_data = {}
+ nested_forward_relations = {}
meta = self.opts.model._meta
# Reverse fk or one-to-one relations
@@ -842,6 +852,12 @@ class ModelSerializer(Serializer):
if field.name in attrs:
m2m_data[field.name] = attrs.pop(field.name)
+ # Nested forward relations - These need to be marked so we can save
+ # them before saving the parent model instance.
+ for field_name in attrs.keys():
+ if isinstance(self.fields.get(field_name, None), Serializer):
+ nested_forward_relations[field_name] = attrs[field_name]
+
# Update an existing instance...
if instance is not None:
for key, val in attrs.items():
@@ -857,6 +873,7 @@ class ModelSerializer(Serializer):
# at the point of save.
instance._related_data = related_data
instance._m2m_data = m2m_data
+ instance._nested_forward_relations = nested_forward_relations
return instance
@@ -872,6 +889,14 @@ class ModelSerializer(Serializer):
"""
Save the deserialized object and return it.
"""
+ if getattr(obj, '_nested_forward_relations', None):
+ # Nested relationships need to be saved before we can save the
+ # parent instance.
+ for field_name, sub_object in obj._nested_forward_relations.items():
+ if sub_object:
+ self.save_object(sub_object)
+ setattr(obj, field_name, sub_object)
+
obj.save(**kwargs)
if getattr(obj, '_m2m_data', None):
@@ -881,7 +906,25 @@ class ModelSerializer(Serializer):
if getattr(obj, '_related_data', None):
for accessor_name, related in obj._related_data.items():
- setattr(obj, accessor_name, related)
+ if isinstance(related, RelationsList):
+ # Nested reverse fk relationship
+ for related_item in related:
+ fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name
+ setattr(related_item, fk_field, obj)
+ self.save_object(related_item)
+
+ # Delete any removed objects
+ if related._deleted:
+ [self.delete_object(item) for item in related._deleted]
+
+ elif isinstance(related, models.Model):
+ # Nested reverse one-one relationship
+ fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name
+ setattr(related, fk_field, obj)
+ self.save_object(related)
+ else:
+ # Reverse FK or reverse one-one
+ setattr(obj, accessor_name, related)
del(obj._related_data)
@@ -903,6 +946,7 @@ class HyperlinkedModelSerializer(ModelSerializer):
_options_class = HyperlinkedModelSerializerOptions
_default_view_name = '%(model_name)s-detail'
_hyperlink_field_class = HyperlinkedRelatedField
+ _hyperlink_identify_field_class = HyperlinkedIdentityField
def get_default_fields(self):
fields = super(HyperlinkedModelSerializer, self).get_default_fields()
@@ -911,7 +955,7 @@ class HyperlinkedModelSerializer(ModelSerializer):
self.opts.view_name = self._get_default_view_name(self.opts.model)
if 'url' not in fields:
- url_field = HyperlinkedIdentityField(
+ url_field = self._hyperlink_identify_field_class(
view_name=self.opts.view_name,
lookup_field=self.opts.lookup_field
)
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index 8fd177d5..8abaf140 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -48,7 +48,6 @@ DEFAULTS = {
),
'DEFAULT_THROTTLE_CLASSES': (
),
-
'DEFAULT_CONTENT_NEGOTIATION_CLASS':
'rest_framework.negotiation.DefaultContentNegotiation',
@@ -68,11 +67,19 @@ DEFAULTS = {
# Pagination
'PAGINATE_BY': None,
'PAGINATE_BY_PARAM': None,
+ 'MAX_PAGINATE_BY': None,
# Authentication
'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
'UNAUTHENTICATED_TOKEN': None,
+ # View configuration
+ 'VIEW_NAME_FUNCTION': 'rest_framework.views.get_view_name',
+ 'VIEW_DESCRIPTION_FUNCTION': 'rest_framework.views.get_view_description',
+
+ # Exception handling
+ 'EXCEPTION_HANDLER': 'rest_framework.views.exception_handler',
+
# Testing
'TEST_REQUEST_RENDERER_CLASSES': (
'rest_framework.renderers.MultiPartRenderer',
@@ -121,10 +128,13 @@ IMPORT_STRINGS = (
'DEFAULT_MODEL_SERIALIZER_CLASS',
'DEFAULT_PAGINATION_SERIALIZER_CLASS',
'DEFAULT_FILTER_BACKENDS',
+ 'EXCEPTION_HANDLER',
'FILTER_BACKEND',
'TEST_REQUEST_RENDERER_CLASSES',
'UNAUTHENTICATED_USER',
'UNAUTHENTICATED_TOKEN',
+ 'VIEW_NAME_FUNCTION',
+ 'VIEW_DESCRIPTION_FUNCTION'
)
diff --git a/rest_framework/static/rest_framework/js/default.js b/rest_framework/static/rest_framework/js/default.js
index c74829d7..bcb1964d 100644
--- a/rest_framework/static/rest_framework/js/default.js
+++ b/rest_framework/static/rest_framework/js/default.js
@@ -1,13 +1,56 @@
+function getCookie(c_name)
+{
+ // From http://www.w3schools.com/js/js_cookies.asp
+ var c_value = document.cookie;
+ var c_start = c_value.indexOf(" " + c_name + "=");
+ if (c_start == -1) {
+ c_start = c_value.indexOf(c_name + "=");
+ }
+ if (c_start == -1) {
+ c_value = null;
+ } else {
+ c_start = c_value.indexOf("=", c_start) + 1;
+ var c_end = c_value.indexOf(";", c_start);
+ if (c_end == -1) {
+ c_end = c_value.length;
+ }
+ c_value = unescape(c_value.substring(c_start,c_end));
+ }
+ return c_value;
+}
+
+// JSON highlighting.
prettyPrint();
+// Bootstrap tooltips.
$('.js-tooltip').tooltip({
delay: 1000
});
+// Deal with rounded tab styling after tab clicks.
$('a[data-toggle="tab"]:first').on('shown', function (e) {
$(e.target).parents('.tabbable').addClass('first-tab-active');
});
$('a[data-toggle="tab"]:not(:first)').on('shown', function (e) {
$(e.target).parents('.tabbable').removeClass('first-tab-active');
});
-$('.form-switcher a:first').tab('show');
+
+$('a[data-toggle="tab"]').click(function(){
+ document.cookie="tabstyle=" + this.name + "; path=/";
+});
+
+// Store tab preference in cookies & display appropriate tab on load.
+var selectedTab = null;
+var selectedTabName = getCookie('tabstyle');
+
+if (selectedTabName) {
+ selectedTab = $('.form-switcher a[name=' + selectedTabName + ']');
+}
+
+if (selectedTab && selectedTab.length > 0) {
+ // Display whichever tab is selected.
+ selectedTab.tab('show');
+} else {
+ // If no tab selected, display rightmost tab.
+ $('.form-switcher a:first').tab('show');
+}
diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html
index 51f9c291..aa90e90c 100644
--- a/rest_framework/templates/rest_framework/base.html
+++ b/rest_framework/templates/rest_framework/base.html
@@ -128,17 +128,17 @@
<div {% if post_form %}class="tabbable"{% endif %}>
{% if post_form %}
<ul class="nav nav-tabs form-switcher">
- <li><a href="#object-form" data-toggle="tab">HTML form</a></li>
- <li><a href="#generic-content-form" data-toggle="tab">Raw data</a></li>
+ <li><a name='html-tab' href="#object-form" data-toggle="tab">HTML form</a></li>
+ <li><a name='raw-tab' href="#generic-content-form" data-toggle="tab">Raw data</a></li>
</ul>
{% endif %}
<div class="well tab-content">
{% if post_form %}
<div class="tab-pane" id="object-form">
{% with form=post_form %}
- <form action="{{ request.get_full_path }}" method="POST" {% if form.is_multipart %}enctype="multipart/form-data"{% endif %} class="form-horizontal">
+ <form action="{{ request.get_full_path }}" method="POST" enctype="multipart/form-data" class="form-horizontal">
<fieldset>
- {% include "rest_framework/form.html" %}
+ {{ post_form }}
<div class="form-actions">
<button class="btn btn-primary" title="Make a POST request on the {{ name }} resource">POST</button>
</div>
@@ -167,23 +167,21 @@
<div {% if put_form %}class="tabbable"{% endif %}>
{% if put_form %}
<ul class="nav nav-tabs form-switcher">
- <li><a href="#object-form" data-toggle="tab">HTML form</a></li>
- <li><a href="#generic-content-form" data-toggle="tab">Raw data</a></li>
+ <li><a name='html-tab' href="#object-form" data-toggle="tab">HTML form</a></li>
+ <li><a name='raw-tab' href="#generic-content-form" data-toggle="tab">Raw data</a></li>
</ul>
{% endif %}
<div class="well tab-content">
{% if put_form %}
<div class="tab-pane" id="object-form">
- {% with form=put_form %}
- <form action="{{ request.get_full_path }}" method="POST" {% if form.is_multipart %}enctype="multipart/form-data"{% endif %} class="form-horizontal">
+ <form action="{{ request.get_full_path }}" method="POST" enctype="multipart/form-data" class="form-horizontal">
<fieldset>
- {% include "rest_framework/form.html" %}
+ {{ put_form }}
<div class="form-actions">
<button class="btn btn-primary js-tooltip" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="PUT" title="Make a PUT request on the {{ name }} resource">PUT</button>
</div>
</fieldset>
</form>
- {% endwith %}
</div>
{% endif %}
<div {% if put_form %}class="tab-pane"{% endif %} id="generic-content-form">
diff --git a/rest_framework/test.py b/rest_framework/test.py
index a18f5a29..234d10a4 100644
--- a/rest_framework/test.py
+++ b/rest_framework/test.py
@@ -134,6 +134,8 @@ class APIClient(APIRequestFactory, DjangoClient):
"""
self.handler._force_user = user
self.handler._force_token = token
+ if user is None:
+ self.logout() # Also clear any possible session info if required
def request(self, **kwargs):
# Ensure that any credentials set get added to every request.
diff --git a/rest_framework/tests/test_description.py b/rest_framework/tests/test_description.py
index 8019f5ec..4c03c1de 100644
--- a/rest_framework/tests/test_description.py
+++ b/rest_framework/tests/test_description.py
@@ -6,7 +6,6 @@ from rest_framework.compat import apply_markdown, smart_text
from rest_framework.views import APIView
from rest_framework.tests.description import ViewWithNonASCIICharactersInDocstring
from rest_framework.tests.description import UTF8_TEST_DOCSTRING
-from rest_framework.utils.formatting import get_view_name, get_view_description
# We check that docstrings get nicely un-indented.
DESCRIPTION = """an example docstring
@@ -58,7 +57,7 @@ class TestViewNamesAndDescriptions(TestCase):
"""
class MockView(APIView):
pass
- self.assertEqual(get_view_name(MockView), 'Mock')
+ self.assertEqual(MockView().get_view_name(), 'Mock')
def test_view_description_uses_docstring(self):
"""Ensure view descriptions are based on the docstring."""
@@ -78,7 +77,7 @@ class TestViewNamesAndDescriptions(TestCase):
# hash style header #"""
- self.assertEqual(get_view_description(MockView), DESCRIPTION)
+ self.assertEqual(MockView().get_view_description(), DESCRIPTION)
def test_view_description_supports_unicode(self):
"""
@@ -86,7 +85,7 @@ class TestViewNamesAndDescriptions(TestCase):
"""
self.assertEqual(
- get_view_description(ViewWithNonASCIICharactersInDocstring),
+ ViewWithNonASCIICharactersInDocstring().get_view_description(),
smart_text(UTF8_TEST_DOCSTRING)
)
@@ -97,7 +96,7 @@ class TestViewNamesAndDescriptions(TestCase):
"""
class MockView(APIView):
pass
- self.assertEqual(get_view_description(MockView), '')
+ self.assertEqual(MockView().get_view_description(), '')
def test_markdown(self):
"""
diff --git a/rest_framework/tests/test_fields.py b/rest_framework/tests/test_fields.py
index 6836ec86..34fbab9c 100644
--- a/rest_framework/tests/test_fields.py
+++ b/rest_framework/tests/test_fields.py
@@ -688,6 +688,14 @@ class ChoiceFieldTests(TestCase):
f = serializers.ChoiceField(required=False, choices=self.SAMPLE_CHOICES)
self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + self.SAMPLE_CHOICES)
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns None on empty param.
+ """
+ f = serializers.ChoiceField(choices=self.SAMPLE_CHOICES)
+ result = f.from_native('')
+ self.assertEqual(result, None)
+
class EmailFieldTests(TestCase):
"""
@@ -896,3 +904,12 @@ class CustomIntegerField(TestCase):
self.assertFalse(serializer.is_valid())
+class BooleanField(TestCase):
+ """
+ Tests for BooleanField
+ """
+ def test_boolean_required(self):
+ class BooleanRequiredSerializer(serializers.Serializer):
+ bool_field = serializers.BooleanField(required=True)
+
+ self.assertFalse(BooleanRequiredSerializer(data={}).is_valid())
diff --git a/rest_framework/tests/test_files.py b/rest_framework/tests/test_files.py
index 487046ac..c13c38b8 100644
--- a/rest_framework/tests/test_files.py
+++ b/rest_framework/tests/test_files.py
@@ -7,13 +7,13 @@ import datetime
class UploadedFile(object):
- def __init__(self, file, created=None):
+ def __init__(self, file=None, created=None):
self.file = file
self.created = created or datetime.datetime.now()
class UploadedFileSerializer(serializers.Serializer):
- file = serializers.FileField()
+ file = serializers.FileField(required=False)
created = serializers.DateTimeField()
def restore_object(self, attrs, instance=None):
@@ -47,5 +47,36 @@ class FileSerializerTests(TestCase):
now = datetime.datetime.now()
serializer = UploadedFileSerializer(data={'created': now})
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.object.created, now)
+ self.assertIsNone(serializer.object.file)
+
+ def test_remove_with_empty_string(self):
+ """
+ Passing empty string as data should cause file to be removed
+
+ Test for:
+ https://github.com/tomchristie/django-rest-framework/issues/937
+ """
+ now = datetime.datetime.now()
+ file = BytesIO(six.b('stuff'))
+ file.name = 'stuff.txt'
+ file.size = len(file.getvalue())
+
+ uploaded_file = UploadedFile(file=file, created=now)
+
+ serializer = UploadedFileSerializer(instance=uploaded_file, data={'created': now, 'file': ''})
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.object.created, uploaded_file.created)
+ self.assertIsNone(serializer.object.file)
+
+ def test_validation_error_with_non_file(self):
+ """
+ Passing non-files should raise a validation error.
+ """
+ now = datetime.datetime.now()
+ errmsg = 'No file was submitted. Check the encoding type on the form.'
+
+ serializer = UploadedFileSerializer(data={'created': now, 'file': 'abc'})
self.assertFalse(serializer.is_valid())
- self.assertIn('file', serializer.errors)
+ self.assertEqual(serializer.errors, {'file': [errmsg]})
diff --git a/rest_framework/tests/test_generics.py b/rest_framework/tests/test_generics.py
index 1550880b..79cd99ac 100644
--- a/rest_framework/tests/test_generics.py
+++ b/rest_framework/tests/test_generics.py
@@ -272,6 +272,48 @@ class TestInstanceView(TestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, expected)
+ def test_options_before_instance_create(self):
+ """
+ OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata
+ before the instance has been created
+ """
+ request = factory.options('/999')
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=999).render()
+ expected = {
+ 'parses': [
+ 'application/json',
+ 'application/x-www-form-urlencoded',
+ 'multipart/form-data'
+ ],
+ 'renders': [
+ 'application/json',
+ 'text/html'
+ ],
+ 'name': 'Instance',
+ 'description': 'Example description for OPTIONS.',
+ 'actions': {
+ 'PUT': {
+ 'text': {
+ 'max_length': 100,
+ 'read_only': False,
+ 'required': True,
+ 'type': 'string',
+ 'label': 'Text comes here',
+ 'help_text': 'Text description.'
+ },
+ 'id': {
+ 'read_only': True,
+ 'required': False,
+ 'type': 'integer',
+ 'label': 'ID',
+ },
+ }
+ }
+ }
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, expected)
+
def test_get_instance_view_incorrect_arg(self):
"""
GET requests with an incorrect pk type, should raise 404, not 500.
@@ -338,6 +380,17 @@ class TestInstanceView(TestCase):
new_obj = SlugBasedModel.objects.get(slug='test_slug')
self.assertEqual(new_obj.text, 'foobar')
+ def test_patch_cannot_create_an_object(self):
+ """
+ PATCH requests should not be able to create objects.
+ """
+ data = {'text': 'foobar'}
+ request = factory.patch('/999', data, format='json')
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=999).render()
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+ self.assertFalse(self.objects.filter(id=999).exists())
+
class TestOverriddenGetObject(TestCase):
"""
diff --git a/rest_framework/tests/test_pagination.py b/rest_framework/tests/test_pagination.py
index 85d4640e..4170d4b6 100644
--- a/rest_framework/tests/test_pagination.py
+++ b/rest_framework/tests/test_pagination.py
@@ -42,6 +42,16 @@ class PaginateByParamView(generics.ListAPIView):
paginate_by_param = 'page_size'
+class MaxPaginateByView(generics.ListAPIView):
+ """
+ View for testing custom max_paginate_by usage
+ """
+ model = BasicModel
+ paginate_by = 3
+ max_paginate_by = 5
+ paginate_by_param = 'page_size'
+
+
class IntegrationTestPagination(TestCase):
"""
Integration tests for paginated list views.
@@ -313,6 +323,43 @@ class TestCustomPaginateByParam(TestCase):
self.assertEqual(response.data['results'], self.data[:5])
+class TestMaxPaginateByParam(TestCase):
+ """
+ Tests for list views with max_paginate_by kwarg
+ """
+
+ def setUp(self):
+ """
+ Create 13 BasicModel instances.
+ """
+ for i in range(13):
+ BasicModel(text=i).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.view = MaxPaginateByView.as_view()
+
+ def test_max_paginate_by(self):
+ """
+ If max_paginate_by is set, it should limit page size for the view.
+ """
+ request = factory.get('/?page_size=10')
+ response = self.view(request).render()
+ self.assertEqual(response.data['count'], 13)
+ self.assertEqual(response.data['results'], self.data[:5])
+
+ def test_max_paginate_by_without_page_size_param(self):
+ """
+ If max_paginate_by is set, but client does not specifiy page_size,
+ standard `paginate_by` behavior should be used.
+ """
+ request = factory.get('/')
+ response = self.view(request).render()
+ self.assertEqual(response.data['results'], self.data[:3])
+
+
### Tests for context in pagination serializers
class CustomField(serializers.Field):
diff --git a/rest_framework/tests/test_relations_nested.py b/rest_framework/tests/test_relations_nested.py
index f6d006b3..d393b0c3 100644
--- a/rest_framework/tests/test_relations_nested.py
+++ b/rest_framework/tests/test_relations_nested.py
@@ -1,107 +1,328 @@
from __future__ import unicode_literals
+from django.db import models
from django.test import TestCase
from rest_framework import serializers
-from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
-class ForeignKeySourceSerializer(serializers.ModelSerializer):
- class Meta:
- model = ForeignKeySource
- fields = ('id', 'name', 'target')
- depth = 1
+class OneToOneTarget(models.Model):
+ name = models.CharField(max_length=100)
-class ForeignKeyTargetSerializer(serializers.ModelSerializer):
- class Meta:
- model = ForeignKeyTarget
- fields = ('id', 'name', 'sources')
- depth = 1
+class OneToOneSource(models.Model):
+ name = models.CharField(max_length=100)
+ target = models.OneToOneField(OneToOneTarget, related_name='source',
+ null=True, blank=True)
-class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
- class Meta:
- model = NullableForeignKeySource
- fields = ('id', 'name', 'target')
- depth = 1
+class OneToManyTarget(models.Model):
+ name = models.CharField(max_length=100)
-class NullableOneToOneTargetSerializer(serializers.ModelSerializer):
- class Meta:
- model = OneToOneTarget
- fields = ('id', 'name', 'nullable_source')
- depth = 1
+class OneToManySource(models.Model):
+ name = models.CharField(max_length=100)
+ target = models.ForeignKey(OneToManyTarget, related_name='sources')
-class ReverseForeignKeyTests(TestCase):
+class ReverseNestedOneToOneTests(TestCase):
def setUp(self):
- target = ForeignKeyTarget(name='target-1')
- target.save()
- new_target = ForeignKeyTarget(name='target-2')
- new_target.save()
+ class OneToOneSourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = OneToOneSource
+ fields = ('id', 'name')
+
+ class OneToOneTargetSerializer(serializers.ModelSerializer):
+ source = OneToOneSourceSerializer()
+
+ class Meta:
+ model = OneToOneTarget
+ fields = ('id', 'name', 'source')
+
+ self.Serializer = OneToOneTargetSerializer
+
for idx in range(1, 4):
- source = ForeignKeySource(name='source-%d' % idx, target=target)
+ target = OneToOneTarget(name='target-%d' % idx)
+ target.save()
+ source = OneToOneSource(name='source-%d' % idx, target=target)
source.save()
- def test_foreign_key_retrieve(self):
- queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset, many=True)
+ def test_one_to_one_retrieve(self):
+ queryset = OneToOneTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
- {'id': 2, 'name': 'source-2', 'target': {'id': 1, 'name': 'target-1'}},
- {'id': 3, 'name': 'source-3', 'target': {'id': 1, 'name': 'target-1'}},
+ {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
+ {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
+ {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}}
]
self.assertEqual(serializer.data, expected)
- def test_reverse_foreign_key_retrieve(self):
- queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ def test_one_to_one_create(self):
+ data = {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}}
+ serializer = self.Serializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-4')
+
+ # Ensure (target 4, target_source 4, source 4) are added, and
+ # everything else is as expected.
+ queryset = OneToOneTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'target-1', 'sources': [
- {'id': 1, 'name': 'source-1', 'target': 1},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': 1},
- ]},
- {'id': 2, 'name': 'target-2', 'sources': [
- ]}
+ {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
+ {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
+ {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}},
+ {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}}
]
self.assertEqual(serializer.data, expected)
+ def test_one_to_one_create_with_invalid_data(self):
+ data = {'id': 4, 'name': 'target-4', 'source': {'id': 4}}
+ serializer = self.Serializer(data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'source': [{'name': ['This field is required.']}]})
-class NestedNullableForeignKeyTests(TestCase):
+ def test_one_to_one_update(self):
+ data = {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}
+ instance = OneToOneTarget.objects.get(pk=3)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-3-updated')
+
+ # Ensure (target 3, target_source 3, source 3) are updated,
+ # and everything else is as expected.
+ queryset = OneToOneTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
+ {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
+ {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+
+class ForwardNestedOneToOneTests(TestCase):
def setUp(self):
- target = ForeignKeyTarget(name='target-1')
- target.save()
+ class OneToOneTargetSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = OneToOneTarget
+ fields = ('id', 'name')
+
+ class OneToOneSourceSerializer(serializers.ModelSerializer):
+ target = OneToOneTargetSerializer()
+
+ class Meta:
+ model = OneToOneSource
+ fields = ('id', 'name', 'target')
+
+ self.Serializer = OneToOneSourceSerializer
+
for idx in range(1, 4):
- if idx == 3:
- target = None
- source = NullableForeignKeySource(name='source-%d' % idx, target=target)
+ target = OneToOneTarget(name='target-%d' % idx)
+ target.save()
+ source = OneToOneSource(name='source-%d' % idx, target=target)
source.save()
- def test_foreign_key_retrieve_with_null(self):
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ def test_one_to_one_retrieve(self):
+ queryset = OneToOneSource.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
+ {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
+ {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_one_create(self):
+ data = {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}}
+ serializer = self.Serializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure (target 4, target_source 4, source 4) are added, and
+ # everything else is as expected.
+ queryset = OneToOneSource.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
+ {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
+ {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}},
+ {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_one_create_with_invalid_data(self):
+ data = {'id': 4, 'name': 'source-4', 'target': {'id': 4}}
+ serializer = self.Serializer(data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'target': [{'name': ['This field is required.']}]})
+
+ def test_one_to_one_update(self):
+ data = {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}}
+ instance = OneToOneSource.objects.get(pk=3)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-3-updated')
+
+ # Ensure (target 3, target_source 3, source 3) are updated,
+ # and everything else is as expected.
+ queryset = OneToOneSource.objects.all()
+ serializer = self.Serializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
- {'id': 2, 'name': 'source-2', 'target': {'id': 1, 'name': 'target-1'}},
- {'id': 3, 'name': 'source-3', 'target': None},
+ {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
+ {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}}
]
self.assertEqual(serializer.data, expected)
+ def test_one_to_one_update_to_null(self):
+ data = {'id': 3, 'name': 'source-3-updated', 'target': None}
+ instance = OneToOneSource.objects.get(pk=3)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
-class NestedNullableOneToOneTests(TestCase):
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-3-updated')
+ self.assertEqual(obj.target, None)
+
+ queryset = OneToOneSource.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
+ {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
+ {'id': 3, 'name': 'source-3-updated', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ # TODO: Nullable 1-1 tests
+ # def test_one_to_one_delete(self):
+ # data = {'id': 3, 'name': 'target-3', 'target_source': None}
+ # instance = OneToOneTarget.objects.get(pk=3)
+ # serializer = self.Serializer(instance, data=data)
+ # self.assertTrue(serializer.is_valid())
+ # serializer.save()
+
+ # # Ensure (target_source 3, source 3) are deleted,
+ # # and everything else is as expected.
+ # queryset = OneToOneTarget.objects.all()
+ # serializer = self.Serializer(queryset)
+ # expected = [
+ # {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
+ # {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
+ # {'id': 3, 'name': 'target-3', 'source': None}
+ # ]
+ # self.assertEqual(serializer.data, expected)
+
+
+class ReverseNestedOneToManyTests(TestCase):
def setUp(self):
- target = OneToOneTarget(name='target-1')
+ class OneToManySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = OneToManySource
+ fields = ('id', 'name')
+
+ class OneToManyTargetSerializer(serializers.ModelSerializer):
+ sources = OneToManySourceSerializer(many=True, allow_add_remove=True)
+
+ class Meta:
+ model = OneToManyTarget
+ fields = ('id', 'name', 'sources')
+
+ self.Serializer = OneToManyTargetSerializer
+
+ target = OneToManyTarget(name='target-1')
target.save()
- new_target = OneToOneTarget(name='target-2')
- new_target.save()
- source = NullableOneToOneSource(name='source-1', target=target)
- source.save()
+ for idx in range(1, 4):
+ source = OneToManySource(name='source-%d' % idx, target=target)
+ source.save()
- def test_reverse_foreign_key_retrieve_with_null(self):
- queryset = OneToOneTarget.objects.all()
- serializer = NullableOneToOneTargetSerializer(queryset, many=True)
+ def test_one_to_many_retrieve(self):
+ queryset = OneToManyTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'}]},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_many_create(self):
+ data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'},
+ {'id': 4, 'name': 'source-4'}]}
+ instance = OneToManyTarget.objects.get(pk=1)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-1')
+
+ # Ensure source 4 is added, and everything else is as
+ # expected.
+ queryset = OneToManyTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'target-1', 'nullable_source': {'id': 1, 'name': 'source-1', 'target': 1}},
- {'id': 2, 'name': 'target-2', 'nullable_source': None},
+ {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'},
+ {'id': 4, 'name': 'source-4'}]}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_many_create_with_invalid_data(self):
+ data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'},
+ {'id': 4}]}
+ serializer = self.Serializer(data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'sources': [{}, {}, {}, {'name': ['This field is required.']}]})
+
+ def test_one_to_many_update(self):
+ data = {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'}]}
+ instance = OneToManyTarget.objects.get(pk=1)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-1-updated')
+
+ # Ensure (target 1, source 1) are updated,
+ # and everything else is as expected.
+ queryset = OneToManyTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'}]}
+
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_many_delete(self):
+ data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 3, 'name': 'source-3'}]}
+ instance = OneToManyTarget.objects.get(pk=1)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+
+ # Ensure source 2 is deleted, and everything else is as
+ # expected.
+ queryset = OneToManyTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 3, 'name': 'source-3'}]}
+
]
self.assertEqual(serializer.data, expected)
diff --git a/rest_framework/tests/test_relations_pk.py b/rest_framework/tests/test_relations_pk.py
index e2a1b815..3815afdd 100644
--- a/rest_framework/tests/test_relations_pk.py
+++ b/rest_framework/tests/test_relations_pk.py
@@ -283,6 +283,15 @@ class PKForeignKeyTests(TestCase):
self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'target': ['This field is required.']})
+ def test_foreign_key_with_empty(self):
+ """
+ Regression test for #1072
+
+ https://github.com/tomchristie/django-rest-framework/issues/1072
+ """
+ serializer = NullableForeignKeySourceSerializer()
+ self.assertEqual(serializer.data['target'], None)
+
class PKNullableForeignKeyTests(TestCase):
def setUp(self):
diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py
index 5fcccb74..e723f7d4 100644
--- a/rest_framework/tests/test_routers.py
+++ b/rest_framework/tests/test_routers.py
@@ -146,7 +146,7 @@ class TestTrailingSlashRemoved(TestCase):
self.urls = self.router.urls
def test_urls_can_have_trailing_slash_removed(self):
- expected = ['^notes$', '^notes/(?P<pk>[^/]+)$']
+ expected = ['^notes$', '^notes/(?P<pk>[^/.]+)$']
for idx in range(len(expected)):
self.assertEqual(expected[idx], self.urls[idx].regex.pattern)
diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py
index 49d45fc2..48b8956b 100644
--- a/rest_framework/tests/test_testing.py
+++ b/rest_framework/tests/test_testing.py
@@ -17,8 +17,18 @@ def view(request):
})
+@api_view(['GET', 'POST'])
+def session_view(request):
+ active_session = request.session.get('active_session', False)
+ request.session['active_session'] = True
+ return Response({
+ 'active_session': active_session
+ })
+
+
urlpatterns = patterns('',
url(r'^view/$', view),
+ url(r'^session-view/$', session_view),
)
@@ -46,6 +56,26 @@ class TestAPITestClient(TestCase):
response = self.client.get('/view/')
self.assertEqual(response.data['user'], 'example')
+ def test_force_authenticate_with_sessions(self):
+ """
+ Setting `.force_authenticate()` forcibly authenticates each request.
+ """
+ user = User.objects.create_user('example', 'example@example.com')
+ self.client.force_authenticate(user)
+
+ # First request does not yet have an active session
+ response = self.client.get('/session-view/')
+ self.assertEqual(response.data['active_session'], False)
+
+ # Subsequant requests have an active session
+ response = self.client.get('/session-view/')
+ self.assertEqual(response.data['active_session'], True)
+
+ # Force authenticating as `None` should also logout the user session.
+ self.client.force_authenticate(None)
+ response = self.client.get('/session-view/')
+ self.assertEqual(response.data['active_session'], False)
+
def test_csrf_exempt_by_default(self):
"""
By default, the test client is CSRF exempt.
diff --git a/rest_framework/tests/test_views.py b/rest_framework/tests/test_views.py
index c0bec5ae..65c7e50e 100644
--- a/rest_framework/tests/test_views.py
+++ b/rest_framework/tests/test_views.py
@@ -32,6 +32,16 @@ def basic_view(request):
return {'method': 'PATCH', 'data': request.DATA}
+class ErrorView(APIView):
+ def get(self, request, *args, **kwargs):
+ raise Exception
+
+
+@api_view(['GET'])
+def error_view(request):
+ raise Exception
+
+
def sanitise_json_error(error_dict):
"""
Exact contents of JSON error messages depend on the installed version
@@ -99,3 +109,34 @@ class FunctionBasedViewIntegrationTests(TestCase):
}
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(sanitise_json_error(response.data), expected)
+
+
+class TestCustomExceptionHandler(TestCase):
+ def setUp(self):
+ self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER
+
+ def exception_handler(exc):
+ return Response('Error!', status=status.HTTP_400_BAD_REQUEST)
+
+ api_settings.EXCEPTION_HANDLER = exception_handler
+
+ def tearDown(self):
+ api_settings.EXCEPTION_HANDLER = self.DEFAULT_HANDLER
+
+ def test_class_based_view_exception_handler(self):
+ view = ErrorView.as_view()
+
+ request = factory.get('/', content_type='application/json')
+ response = view(request)
+ expected = 'Error!'
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEqual(response.data, expected)
+
+ def test_function_based_view_exception_handler(self):
+ view = error_view
+
+ request = factory.get('/', content_type='application/json')
+ response = view(request)
+ expected = 'Error!'
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEqual(response.data, expected)
diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py
index 65b45593..a946d837 100644
--- a/rest_framework/throttling.py
+++ b/rest_framework/throttling.py
@@ -2,7 +2,7 @@
Provides various throttling policies.
"""
from __future__ import unicode_literals
-from django.core.cache import cache
+from django.core.cache import cache as default_cache
from django.core.exceptions import ImproperlyConfigured
from rest_framework.settings import api_settings
import time
@@ -39,6 +39,7 @@ class SimpleRateThrottle(BaseThrottle):
Previous request information used for throttling is stored in the cache.
"""
+ cache = default_cache
timer = time.time
cache_format = 'throtte_%(scope)s_%(ident)s'
scope = None
@@ -99,7 +100,7 @@ class SimpleRateThrottle(BaseThrottle):
if self.key is None:
return True
- self.history = cache.get(self.key, [])
+ self.history = self.cache.get(self.key, [])
self.now = self.timer()
# Drop any requests from the history which have now passed the
@@ -116,7 +117,7 @@ class SimpleRateThrottle(BaseThrottle):
into the cache.
"""
self.history.insert(0, self.now)
- cache.set(self.key, self.history, self.duration)
+ self.cache.set(self.key, self.history, self.duration)
return True
def throttle_failure(self):
@@ -151,7 +152,9 @@ class AnonRateThrottle(SimpleRateThrottle):
if request.user.is_authenticated():
return None # Only throttle unauthenticated requests.
- ident = request.META.get('REMOTE_ADDR', None)
+ ident = request.META.get('HTTP_X_FORWARDED_FOR')
+ if ident is None:
+ ident = request.META.get('REMOTE_ADDR')
return self.cache_format % {
'scope': self.scope,
diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py
index d51374b0..e6690d17 100644
--- a/rest_framework/utils/breadcrumbs.py
+++ b/rest_framework/utils/breadcrumbs.py
@@ -1,6 +1,5 @@
from __future__ import unicode_literals
from django.core.urlresolvers import resolve, get_script_prefix
-from rest_framework.utils.formatting import get_view_name
def get_breadcrumbs(url):
@@ -9,8 +8,11 @@ def get_breadcrumbs(url):
tuple of (name, url).
"""
+ from rest_framework.settings import api_settings
from rest_framework.views import APIView
+ view_name_func = api_settings.VIEW_NAME_FUNCTION
+
def breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen):
"""
Add tuples of (name, url) to the breadcrumbs list,
@@ -30,7 +32,7 @@ def get_breadcrumbs(url):
# Probably an optional trailing slash.
if not seen or seen[-1] != view:
suffix = getattr(view, 'suffix', None)
- name = get_view_name(view.cls, suffix)
+ name = view_name_func(cls, suffix)
breadcrumbs_list.insert(0, (name, prefix + url))
seen.append(view)
diff --git a/rest_framework/utils/formatting.py b/rest_framework/utils/formatting.py
index 4bec8387..4b59ba84 100644
--- a/rest_framework/utils/formatting.py
+++ b/rest_framework/utils/formatting.py
@@ -5,11 +5,13 @@ from __future__ import unicode_literals
from django.utils.html import escape
from django.utils.safestring import mark_safe
-from rest_framework.compat import apply_markdown, smart_text
+from rest_framework.compat import apply_markdown
+from rest_framework.settings import api_settings
+from textwrap import dedent
import re
-def _remove_trailing_string(content, trailing):
+def remove_trailing_string(content, trailing):
"""
Strip trailing component `trailing` from `content` if it exists.
Used when generating names from view classes.
@@ -19,10 +21,14 @@ def _remove_trailing_string(content, trailing):
return content
-def _remove_leading_indent(content):
+def dedent(content):
"""
Remove leading indent from a block of text.
Used when generating descriptions from docstrings.
+
+ Note that python's `textwrap.dedent` doesn't quite cut it,
+ as it fails to dedent multiline docstrings that include
+ unindented text on the initial line.
"""
whitespace_counts = [len(line) - len(line.lstrip(' '))
for line in content.splitlines()[1:] if line.lstrip()]
@@ -31,11 +37,10 @@ def _remove_leading_indent(content):
if whitespace_counts:
whitespace_pattern = '^' + (' ' * min(whitespace_counts))
content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content)
- content = content.strip('\n')
- return content
+ return content.strip()
-def _camelcase_to_spaces(content):
+def camelcase_to_spaces(content):
"""
Translate 'CamelCaseNames' to 'Camel Case Names'.
Used when generating names from view classes.
@@ -44,31 +49,6 @@ def _camelcase_to_spaces(content):
content = re.sub(camelcase_boundry, ' \\1', content).strip()
return ' '.join(content.split('_')).title()
-
-def get_view_name(cls, suffix=None):
- """
- Return a formatted name for an `APIView` class or `@api_view` function.
- """
- name = cls.__name__
- name = _remove_trailing_string(name, 'View')
- name = _remove_trailing_string(name, 'ViewSet')
- name = _camelcase_to_spaces(name)
- if suffix:
- name += ' ' + suffix
- return name
-
-
-def get_view_description(cls, html=False):
- """
- Return a description for an `APIView` class or `@api_view` function.
- """
- description = cls.__doc__ or ''
- description = _remove_leading_indent(smart_text(description))
- if html:
- return markup_description(description)
- return description
-
-
def markup_description(description):
"""
Apply HTML markup to the given description.
diff --git a/rest_framework/views.py b/rest_framework/views.py
index d51233a9..853e6461 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -8,16 +8,79 @@ from django.http import Http404
from django.utils.datastructures import SortedDict
from django.views.decorators.csrf import csrf_exempt
from rest_framework import status, exceptions
-from rest_framework.compat import View, HttpResponseBase
+from rest_framework.compat import smart_text, HttpResponseBase, View
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.settings import api_settings
-from rest_framework.utils.formatting import get_view_name, get_view_description
+from rest_framework.utils import formatting
+
+
+def get_view_name(view_cls, suffix=None):
+ """
+ Given a view class, return a textual name to represent the view.
+ This name is used in the browsable API, and in OPTIONS responses.
+
+ This function is the default for the `VIEW_NAME_FUNCTION` setting.
+ """
+ name = view_cls.__name__
+ name = formatting.remove_trailing_string(name, 'View')
+ name = formatting.remove_trailing_string(name, 'ViewSet')
+ name = formatting.camelcase_to_spaces(name)
+ if suffix:
+ name += ' ' + suffix
+
+ return name
+
+def get_view_description(view_cls, html=False):
+ """
+ Given a view class, return a textual description to represent the view.
+ This name is used in the browsable API, and in OPTIONS responses.
+
+ This function is the default for the `VIEW_DESCRIPTION_FUNCTION` setting.
+ """
+ description = view_cls.__doc__ or ''
+ description = formatting.dedent(smart_text(description))
+ if html:
+ return formatting.markup_description(description)
+ return description
+
+
+def exception_handler(exc):
+ """
+ Returns the response that should be used for any given exception.
+
+ By default we handle the REST framework `APIException`, and also
+ Django's builtin `Http404` and `PermissionDenied` exceptions.
+
+ Any unhandled exceptions may return `None`, which will cause a 500 error
+ to be raised.
+ """
+ if isinstance(exc, exceptions.APIException):
+ headers = {}
+ if getattr(exc, 'auth_header', None):
+ headers['WWW-Authenticate'] = exc.auth_header
+ if getattr(exc, 'wait', None):
+ headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait
+
+ return Response({'detail': exc.detail},
+ status=exc.status_code,
+ headers=headers)
+
+ elif isinstance(exc, Http404):
+ return Response({'detail': 'Not found'},
+ status=status.HTTP_404_NOT_FOUND)
+
+ elif isinstance(exc, PermissionDenied):
+ return Response({'detail': 'Permission denied'},
+ status=status.HTTP_403_FORBIDDEN)
+
+ # Note: Unhandled exceptions will raise a 500 error.
+ return None
class APIView(View):
- settings = api_settings
+ # The following policies may be set at either globally, or per-view.
renderer_classes = api_settings.DEFAULT_RENDERER_CLASSES
parser_classes = api_settings.DEFAULT_PARSER_CLASSES
authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES
@@ -25,6 +88,9 @@ class APIView(View):
permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES
content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS
+ # Allow dependancy injection of other settings to make testing easier.
+ settings = api_settings
+
@classmethod
def as_view(cls, **initkwargs):
"""
@@ -110,6 +176,22 @@ class APIView(View):
'request': getattr(self, 'request', None)
}
+ def get_view_name(self):
+ """
+ Return the view name, as used in OPTIONS responses and in the
+ browsable API.
+ """
+ func = self.settings.VIEW_NAME_FUNCTION
+ return func(self.__class__, getattr(self, 'suffix', None))
+
+ def get_view_description(self, html=False):
+ """
+ Return some descriptive text for the view, as used in OPTIONS responses
+ and in the browsable API.
+ """
+ func = self.settings.VIEW_DESCRIPTION_FUNCTION
+ return func(self.__class__, html)
+
# API policy instantiation methods
def get_format_suffix(self, **kwargs):
@@ -269,33 +351,23 @@ class APIView(View):
Handle any exception that occurs, by returning an appropriate response,
or re-raising the error.
"""
- if isinstance(exc, exceptions.Throttled) and exc.wait is not None:
- # Throttle wait header
- self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait
-
if isinstance(exc, (exceptions.NotAuthenticated,
exceptions.AuthenticationFailed)):
# WWW-Authenticate header for 401 responses, else coerce to 403
auth_header = self.get_authenticate_header(self.request)
if auth_header:
- self.headers['WWW-Authenticate'] = auth_header
+ exc.auth_header = auth_header
else:
exc.status_code = status.HTTP_403_FORBIDDEN
- if isinstance(exc, exceptions.APIException):
- return Response({'detail': exc.detail},
- status=exc.status_code,
- exception=True)
- elif isinstance(exc, Http404):
- return Response({'detail': 'Not found'},
- status=status.HTTP_404_NOT_FOUND,
- exception=True)
- elif isinstance(exc, PermissionDenied):
- return Response({'detail': 'Permission denied'},
- status=status.HTTP_403_FORBIDDEN,
- exception=True)
- raise
+ response = self.settings.EXCEPTION_HANDLER(exc)
+
+ if response is None:
+ raise
+
+ response.exception = True
+ return response
# Note: session based authentication is explicitly CSRF validated,
# all other authentication is CSRF exempt.
@@ -342,16 +414,12 @@ class APIView(View):
Return a dictionary of metadata about the view.
Used to return responses for OPTIONS requests.
"""
-
- # This is used by ViewSets to disambiguate instance vs list views
- view_name_suffix = getattr(self, 'suffix', None)
-
# By default we can't provide any form-like information, however the
# generic views override this implementation and add additional
# information for POST and PUT methods, based on the serializer.
ret = SortedDict()
- ret['name'] = get_view_name(self.__class__, view_name_suffix)
- ret['description'] = get_view_description(self.__class__)
+ ret['name'] = self.get_view_name()
+ ret['description'] = self.get_view_description()
ret['renders'] = [renderer.media_type for renderer in self.renderer_classes]
ret['parses'] = [parser.media_type for parser in self.parser_classes]
return ret