aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/fields.py5
-rw-r--r--rest_framework/mixins.py15
-rw-r--r--rest_framework/relations.py11
-rw-r--r--rest_framework/renderers.py192
-rw-r--r--rest_framework/routers.py6
-rw-r--r--rest_framework/serializers.py71
-rw-r--r--rest_framework/settings.py9
-rw-r--r--rest_framework/templates/rest_framework/base.html10
-rw-r--r--rest_framework/test.py2
-rw-r--r--rest_framework/tests/test_files.py37
-rw-r--r--rest_framework/tests/test_generics.py11
-rw-r--r--rest_framework/tests/test_relations_nested.py351
-rw-r--r--rest_framework/tests/test_routers.py2
-rw-r--r--rest_framework/tests/test_testing.py30
-rw-r--r--rest_framework/views.py88
15 files changed, 639 insertions, 201 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 07779c47..3e0ca1a1 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -308,7 +308,10 @@ class WritableField(Field):
try:
if self.use_files:
files = files or {}
- native = files[field_name]
+ try:
+ native = files[field_name]
+ except KeyError:
+ native = data[field_name]
else:
native = data[field_name]
except KeyError:
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index f11def6d..426865ff 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -142,11 +142,16 @@ class UpdateModelMixin(object):
try:
return self.get_object()
except Http404:
- # If this is a PUT-as-create operation, we need to ensure that
- # we have relevant permissions, as if this was a POST request.
- # This will either raise a PermissionDenied exception,
- # or simply return None
- self.check_permissions(clone_request(self.request, 'POST'))
+ if self.request.method == 'PUT':
+ # For PUT-as-create operation, we need to ensure that we have
+ # relevant permissions, as if this was a POST request. This
+ # will either raise a PermissionDenied exception, or simply
+ # return None.
+ self.check_permissions(clone_request(self.request, 'POST'))
+ else:
+ # PATCH requests where the object does not exist should still
+ # return a 404 response.
+ raise
def pre_save(self, obj):
"""
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index edaf76d6..3ad16ee5 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -134,9 +134,9 @@ class RelatedField(WritableField):
value = obj
for component in source.split('.'):
- value = get_component(value, component)
if value is None:
break
+ value = get_component(value, component)
except ObjectDoesNotExist:
return None
@@ -244,6 +244,8 @@ class PrimaryKeyRelatedField(RelatedField):
source = self.source or field_name
queryset = obj
for component in source.split('.'):
+ if queryset is None:
+ return []
queryset = get_component(queryset, component)
# Forward relationship
@@ -567,8 +569,13 @@ class HyperlinkedIdentityField(Field):
May raise a `NoReverseMatch` if the `view_name` and `lookup_field`
attributes are not configured to correctly match the URL conf.
"""
- lookup_field = getattr(obj, self.lookup_field)
+ lookup_field = getattr(obj, self.lookup_field, None)
kwargs = {self.lookup_field: lookup_field}
+
+ # Handle unsaved object case
+ if lookup_field is None:
+ return None
+
try:
return reverse(view_name, kwargs=kwargs, request=request, format=format)
except NoReverseMatch:
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index 1006e26c..b30f2ea9 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -36,6 +36,7 @@ class BaseRenderer(object):
media_type = None
format = None
charset = 'utf-8'
+ render_style = 'text'
def render(self, data, accepted_media_type=None, renderer_context=None):
raise NotImplemented('Renderer class requires .render() to be implemented')
@@ -51,16 +52,17 @@ class JSONRenderer(BaseRenderer):
format = 'json'
encoder_class = encoders.JSONEncoder
ensure_ascii = True
- charset = 'utf-8'
- # Note that JSON encodings must be utf-8, utf-16 or utf-32.
+ charset = None
+ # JSON is a binary encoding, that can be encoded as utf-8, utf-16 or utf-32.
# See: http://www.ietf.org/rfc/rfc4627.txt
+ # Also: http://lucumr.pocoo.org/2013/7/19/application-mimetypes-and-encodings/
def render(self, data, accepted_media_type=None, renderer_context=None):
"""
Render `data` into JSON.
"""
if data is None:
- return ''
+ return bytes()
# If 'indent' is provided in the context, then pretty print the result.
# E.g. If we're being called by the BrowsableAPIRenderer.
@@ -85,13 +87,12 @@ class JSONRenderer(BaseRenderer):
# and may (or may not) be unicode.
# On python 3.x json.dumps() returns unicode strings.
if isinstance(ret, six.text_type):
- return bytes(ret.encode(self.charset))
+ return bytes(ret.encode('utf-8'))
return ret
class UnicodeJSONRenderer(JSONRenderer):
ensure_ascii = False
- charset = 'utf-8'
"""
Renderer which serializes to JSON.
Does *not* apply JSON's character escaping for non-ascii characters.
@@ -108,6 +109,7 @@ class JSONPRenderer(JSONRenderer):
format = 'jsonp'
callback_parameter = 'callback'
default_callback = 'callback'
+ charset = 'utf-8'
def get_callback(self, renderer_context):
"""
@@ -316,6 +318,85 @@ class StaticHTMLRenderer(TemplateHTMLRenderer):
return data
+class HTMLFormRenderer(BaseRenderer):
+ """
+ Renderers serializer data into an HTML form.
+
+ If the serializer was instantiated without an object then this will
+ return an HTML form not bound to any object,
+ otherwise it will return an HTML form with the appropriate initial data
+ populated from the object.
+
+ Note that rendering of field and form errors is not currently supported.
+ """
+ media_type = 'text/html'
+ format = 'form'
+ template = 'rest_framework/form.html'
+ charset = 'utf-8'
+
+ def data_to_form_fields(self, data):
+ fields = {}
+ for key, val in data.fields.items():
+ if getattr(val, 'read_only', True):
+ continue
+
+ kwargs = {}
+ kwargs['required'] = val.required
+
+ #if getattr(v, 'queryset', None):
+ # kwargs['queryset'] = v.queryset
+
+ if getattr(val, 'choices', None) is not None:
+ kwargs['choices'] = val.choices
+
+ if getattr(val, 'regex', None) is not None:
+ kwargs['regex'] = val.regex
+
+ if getattr(val, 'widget', None):
+ widget = copy.deepcopy(val.widget)
+ kwargs['widget'] = widget
+
+ if getattr(val, 'default', None) is not None:
+ kwargs['initial'] = val.default
+
+ if getattr(val, 'label', None) is not None:
+ kwargs['label'] = val.label
+
+ if getattr(val, 'help_text', None) is not None:
+ kwargs['help_text'] = val.help_text
+
+ fields[key] = val.form_field_class(**kwargs)
+
+ return fields
+
+ def render(self, data, accepted_media_type=None, renderer_context=None):
+ """
+ Render serializer data and return an HTML form, as a string.
+ """
+ # The HTMLFormRenderer currently uses something of a hack to render
+ # the content, by translating each of the serializer fields into
+ # an html form field, creating a dynamic form using those fields,
+ # and then rendering that form.
+
+ # This isn't strictly neccessary, as we could render the serilizer
+ # fields to HTML directly. The implementation is historical and will
+ # likely change at some point.
+
+ self.renderer_context = renderer_context or {}
+ request = renderer_context['request']
+
+ # Creating an on the fly form see:
+ # http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python
+ fields = self.data_to_form_fields(data)
+ DynamicForm = type(str('DynamicForm'), (forms.Form,), fields)
+ data = None if data.empty else data
+
+ template = loader.get_template(self.template)
+ context = RequestContext(request, {'form': DynamicForm(data)})
+
+ return template.render(context)
+
+
class BrowsableAPIRenderer(BaseRenderer):
"""
HTML renderer used to self-document the API.
@@ -324,6 +405,7 @@ class BrowsableAPIRenderer(BaseRenderer):
format = 'api'
template = 'rest_framework/api.html'
charset = 'utf-8'
+ form_renderer_class = HTMLFormRenderer
def get_default_renderer(self, view):
"""
@@ -348,7 +430,10 @@ class BrowsableAPIRenderer(BaseRenderer):
renderer_context['indent'] = 4
content = renderer.render(data, accepted_media_type, renderer_context)
- if renderer.charset is None:
+ render_style = getattr(renderer, 'render_style', 'text')
+ assert render_style in ['text', 'binary'], 'Expected .render_style ' \
+ '"text" or "binary", but got "%s"' % render_style
+ if render_style == 'binary':
return '[%d bytes of binary content]' % len(content)
return content
@@ -371,54 +456,7 @@ class BrowsableAPIRenderer(BaseRenderer):
return False # Doesn't have permissions
return True
- def serializer_to_form_fields(self, serializer):
- fields = {}
- for k, v in serializer.get_fields().items():
- if getattr(v, 'read_only', True):
- continue
-
- kwargs = {}
- kwargs['required'] = v.required
-
- #if getattr(v, 'queryset', None):
- # kwargs['queryset'] = v.queryset
-
- if getattr(v, 'choices', None) is not None:
- kwargs['choices'] = v.choices
-
- if getattr(v, 'regex', None) is not None:
- kwargs['regex'] = v.regex
-
- if getattr(v, 'widget', None):
- widget = copy.deepcopy(v.widget)
- kwargs['widget'] = widget
-
- if getattr(v, 'default', None) is not None:
- kwargs['initial'] = v.default
-
- if getattr(v, 'label', None) is not None:
- kwargs['label'] = v.label
-
- if getattr(v, 'help_text', None) is not None:
- kwargs['help_text'] = v.help_text
-
- fields[k] = v.form_field_class(**kwargs)
-
- return fields
-
- def _get_form(self, view, method, request):
- # We need to impersonate a request with the correct method,
- # so that eg. any dynamic get_serializer_class methods return the
- # correct form for each method.
- restore = view.request
- request = clone_request(request, method)
- view.request = request
- try:
- return self.get_form(view, method, request)
- finally:
- view.request = restore
-
- def _get_raw_data_form(self, view, method, request, media_types):
+ def _get_rendered_html_form(self, view, method, request):
# We need to impersonate a request with the correct method,
# so that eg. any dynamic get_serializer_class methods return the
# correct form for each method.
@@ -426,15 +464,16 @@ class BrowsableAPIRenderer(BaseRenderer):
request = clone_request(request, method)
view.request = request
try:
- return self.get_raw_data_form(view, method, request, media_types)
+ return self.get_rendered_html_form(view, method, request)
finally:
view.request = restore
- def get_form(self, view, method, request):
+ def get_rendered_html_form(self, view, method, request):
"""
- Get a form, possibly bound to either the input or output data.
- In the absence on of the Resource having an associated form then
- provide a form that can be used to submit arbitrary content.
+ Return a string representing a rendered HTML form, possibly bound to
+ either the input or output data.
+
+ In the absence of the View having an associated form then return None.
"""
obj = getattr(view, 'object', None)
if not self.show_form_for_method(view, method, request, obj):
@@ -447,14 +486,21 @@ class BrowsableAPIRenderer(BaseRenderer):
return
serializer = view.get_serializer(instance=obj)
- fields = self.serializer_to_form_fields(serializer)
+ data = serializer.data
+ form_renderer = self.form_renderer_class()
+ return form_renderer.render(data, self.accepted_media_type, self.renderer_context)
- # Creating an on the fly form see:
- # http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python
- OnTheFlyForm = type(str("OnTheFlyForm"), (forms.Form,), fields)
- data = (obj is not None) and serializer.data or None
- form_instance = OnTheFlyForm(data)
- return form_instance
+ def _get_raw_data_form(self, view, method, request, media_types):
+ # We need to impersonate a request with the correct method,
+ # so that eg. any dynamic get_serializer_class methods return the
+ # correct form for each method.
+ restore = view.request
+ request = clone_request(request, method)
+ view.request = request
+ try:
+ return self.get_raw_data_form(view, method, request, media_types)
+ finally:
+ view.request = restore
def get_raw_data_form(self, view, method, request, media_types):
"""
@@ -509,8 +555,8 @@ class BrowsableAPIRenderer(BaseRenderer):
"""
Render the HTML for the browsable API representation.
"""
- accepted_media_type = accepted_media_type or ''
- renderer_context = renderer_context or {}
+ self.accepted_media_type = accepted_media_type or ''
+ self.renderer_context = renderer_context or {}
view = renderer_context['view']
request = renderer_context['request']
@@ -520,11 +566,11 @@ class BrowsableAPIRenderer(BaseRenderer):
renderer = self.get_default_renderer(view)
content = self.get_content(renderer, data, accepted_media_type, renderer_context)
- put_form = self._get_form(view, 'PUT', request)
- post_form = self._get_form(view, 'POST', request)
- patch_form = self._get_form(view, 'PATCH', request)
- delete_form = self._get_form(view, 'DELETE', request)
- options_form = self._get_form(view, 'OPTIONS', request)
+ put_form = self._get_rendered_html_form(view, 'PUT', request)
+ post_form = self._get_rendered_html_form(view, 'POST', request)
+ patch_form = self._get_rendered_html_form(view, 'PATCH', request)
+ delete_form = self._get_rendered_html_form(view, 'DELETE', request)
+ options_form = self._get_rendered_html_form(view, 'OPTIONS', request)
raw_data_put_form = self._get_raw_data_form(view, 'PUT', request, media_types)
raw_data_post_form = self._get_raw_data_form(view, 'POST', request, media_types)
diff --git a/rest_framework/routers.py b/rest_framework/routers.py
index 930011d3..3fee1e49 100644
--- a/rest_framework/routers.py
+++ b/rest_framework/routers.py
@@ -189,7 +189,11 @@ class SimpleRouter(BaseRouter):
Given a viewset, return the portion of URL regex that is used
to match against a single instance.
"""
- base_regex = '(?P<{lookup_field}>[^/]+)'
+ if self.trailing_slash:
+ base_regex = '(?P<{lookup_field}>[^/]+)'
+ else:
+ # Don't consume `.json` style suffixes
+ base_regex = '(?P<{lookup_field}>[^/.]+)'
lookup_field = getattr(viewset, 'lookup_field', 'pk')
return base_regex.format(lookup_field=lookup_field)
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 31cfa344..abff6898 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -32,6 +32,9 @@ from rest_framework.relations import *
from rest_framework.fields import *
+class RelationsList(list):
+ _deleted = []
+
class NestedValidationError(ValidationError):
"""
The default ValidationError behavior is to stringify each item in the list
@@ -161,7 +164,6 @@ class BaseSerializer(WritableField):
self._data = None
self._files = None
self._errors = None
- self._deleted = None
if many and instance is not None and not hasattr(instance, '__iter__'):
raise ValueError('instance should be a queryset or other iterable with many=True')
@@ -298,7 +300,8 @@ class BaseSerializer(WritableField):
Serialize objects -> primitives.
"""
ret = self._dict_class()
- ret.fields = {}
+ ret.fields = self._dict_class()
+ ret.empty = obj is None
for field_name, field in self.fields.items():
field.initialize(parent=self, field_name=field_name)
@@ -336,9 +339,9 @@ class BaseSerializer(WritableField):
value = obj
for component in source.split('.'):
- value = get_component(value, component)
if value is None:
- break
+ return self.to_native(None)
+ value = get_component(value, component)
except ObjectDoesNotExist:
return None
@@ -378,6 +381,7 @@ class BaseSerializer(WritableField):
# Set the serializer object if it exists
obj = getattr(self.parent.object, field_name) if self.parent.object else None
+ obj = obj.all() if is_simple_callable(getattr(obj, 'all', None)) else obj
if self.source == '*':
if value:
@@ -391,7 +395,8 @@ class BaseSerializer(WritableField):
'data': value,
'context': self.context,
'partial': self.partial,
- 'many': self.many
+ 'many': self.many,
+ 'allow_add_remove': self.allow_add_remove
}
serializer = self.__class__(**kwargs)
@@ -434,7 +439,7 @@ class BaseSerializer(WritableField):
DeprecationWarning, stacklevel=3)
if many:
- ret = []
+ ret = RelationsList()
errors = []
update = self.object is not None
@@ -461,8 +466,8 @@ class BaseSerializer(WritableField):
ret.append(self.from_native(item, None))
errors.append(self._errors)
- if update:
- self._deleted = identity_to_objects.values()
+ if update and self.allow_add_remove:
+ ret._deleted = identity_to_objects.values()
self._errors = any(errors) and errors or []
else:
@@ -514,12 +519,12 @@ class BaseSerializer(WritableField):
"""
if isinstance(self.object, list):
[self.save_object(item, **kwargs) for item in self.object]
+
+ if self.object._deleted:
+ [self.delete_object(item) for item in self.object._deleted]
else:
self.save_object(self.object, **kwargs)
- if self.allow_add_remove and self._deleted:
- [self.delete_object(item) for item in self._deleted]
-
return self.object
def metadata(self):
@@ -795,9 +800,12 @@ class ModelSerializer(Serializer):
cls = self.opts.model
opts = get_concrete_model(cls)._meta
exclusions = [field.name for field in opts.fields + opts.many_to_many]
+
for field_name, field in self.fields.items():
field_name = field.source or field_name
- if field_name in exclusions and not field.read_only:
+ if field_name in exclusions \
+ and not field.read_only \
+ and not isinstance(field, Serializer):
exclusions.remove(field_name)
return exclusions
@@ -823,6 +831,7 @@ class ModelSerializer(Serializer):
"""
m2m_data = {}
related_data = {}
+ nested_forward_relations = {}
meta = self.opts.model._meta
# Reverse fk or one-to-one relations
@@ -842,6 +851,12 @@ class ModelSerializer(Serializer):
if field.name in attrs:
m2m_data[field.name] = attrs.pop(field.name)
+ # Nested forward relations - These need to be marked so we can save
+ # them before saving the parent model instance.
+ for field_name in attrs.keys():
+ if isinstance(self.fields.get(field_name, None), Serializer):
+ nested_forward_relations[field_name] = attrs[field_name]
+
# Update an existing instance...
if instance is not None:
for key, val in attrs.items():
@@ -857,6 +872,7 @@ class ModelSerializer(Serializer):
# at the point of save.
instance._related_data = related_data
instance._m2m_data = m2m_data
+ instance._nested_forward_relations = nested_forward_relations
return instance
@@ -872,6 +888,14 @@ class ModelSerializer(Serializer):
"""
Save the deserialized object and return it.
"""
+ if getattr(obj, '_nested_forward_relations', None):
+ # Nested relationships need to be saved before we can save the
+ # parent instance.
+ for field_name, sub_object in obj._nested_forward_relations.items():
+ if sub_object:
+ self.save_object(sub_object)
+ setattr(obj, field_name, sub_object)
+
obj.save(**kwargs)
if getattr(obj, '_m2m_data', None):
@@ -881,7 +905,25 @@ class ModelSerializer(Serializer):
if getattr(obj, '_related_data', None):
for accessor_name, related in obj._related_data.items():
- setattr(obj, accessor_name, related)
+ if isinstance(related, RelationsList):
+ # Nested reverse fk relationship
+ for related_item in related:
+ fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name
+ setattr(related_item, fk_field, obj)
+ self.save_object(related_item)
+
+ # Delete any removed objects
+ if related._deleted:
+ [self.delete_object(item) for item in related._deleted]
+
+ elif isinstance(related, models.Model):
+ # Nested reverse one-one relationship
+ fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name
+ setattr(related, fk_field, obj)
+ self.save_object(related)
+ else:
+ # Reverse FK or reverse one-one
+ setattr(obj, accessor_name, related)
del(obj._related_data)
@@ -903,6 +945,7 @@ class HyperlinkedModelSerializer(ModelSerializer):
_options_class = HyperlinkedModelSerializerOptions
_default_view_name = '%(model_name)s-detail'
_hyperlink_field_class = HyperlinkedRelatedField
+ _hyperlink_identify_field_class = HyperlinkedIdentityField
def get_default_fields(self):
fields = super(HyperlinkedModelSerializer, self).get_default_fields()
@@ -911,7 +954,7 @@ class HyperlinkedModelSerializer(ModelSerializer):
self.opts.view_name = self._get_default_view_name(self.opts.model)
if 'url' not in fields:
- url_field = HyperlinkedIdentityField(
+ url_field = self._hyperlink_identify_field_class(
view_name=self.opts.view_name,
lookup_field=self.opts.lookup_field
)
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index b8e40bfa..8c084751 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -48,7 +48,6 @@ DEFAULTS = {
),
'DEFAULT_THROTTLE_CLASSES': (
),
-
'DEFAULT_CONTENT_NEGOTIATION_CLASS':
'rest_framework.negotiation.DefaultContentNegotiation',
@@ -70,14 +69,14 @@ DEFAULTS = {
'PAGINATE_BY_PARAM': None,
'MAX_PAGINATE_BY': None,
- # View configuration
- 'VIEW_NAME_FUNCTION': 'rest_framework.views.get_view_name',
- 'VIEW_DESCRIPTION_FUNCTION': 'rest_framework.views.get_view_description',
-
# Authentication
'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
'UNAUTHENTICATED_TOKEN': None,
+ # View configuration
+ 'VIEW_NAME_FUNCTION': 'rest_framework.views.get_view_name',
+ 'VIEW_DESCRIPTION_FUNCTION': 'rest_framework.views.get_view_description',
+
# Testing
'TEST_REQUEST_RENDERER_CLASSES': (
'rest_framework.renderers.MultiPartRenderer',
diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html
index 51f9c291..6ae47563 100644
--- a/rest_framework/templates/rest_framework/base.html
+++ b/rest_framework/templates/rest_framework/base.html
@@ -136,9 +136,9 @@
{% if post_form %}
<div class="tab-pane" id="object-form">
{% with form=post_form %}
- <form action="{{ request.get_full_path }}" method="POST" {% if form.is_multipart %}enctype="multipart/form-data"{% endif %} class="form-horizontal">
+ <form action="{{ request.get_full_path }}" method="POST" enctype="multipart/form-data" class="form-horizontal">
<fieldset>
- {% include "rest_framework/form.html" %}
+ {{ post_form }}
<div class="form-actions">
<button class="btn btn-primary" title="Make a POST request on the {{ name }} resource">POST</button>
</div>
@@ -174,16 +174,14 @@
<div class="well tab-content">
{% if put_form %}
<div class="tab-pane" id="object-form">
- {% with form=put_form %}
- <form action="{{ request.get_full_path }}" method="POST" {% if form.is_multipart %}enctype="multipart/form-data"{% endif %} class="form-horizontal">
+ <form action="{{ request.get_full_path }}" method="POST" enctype="multipart/form-data" class="form-horizontal">
<fieldset>
- {% include "rest_framework/form.html" %}
+ {{ put_form }}
<div class="form-actions">
<button class="btn btn-primary js-tooltip" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="PUT" title="Make a PUT request on the {{ name }} resource">PUT</button>
</div>
</fieldset>
</form>
- {% endwith %}
</div>
{% endif %}
<div {% if put_form %}class="tab-pane"{% endif %} id="generic-content-form">
diff --git a/rest_framework/test.py b/rest_framework/test.py
index a18f5a29..234d10a4 100644
--- a/rest_framework/test.py
+++ b/rest_framework/test.py
@@ -134,6 +134,8 @@ class APIClient(APIRequestFactory, DjangoClient):
"""
self.handler._force_user = user
self.handler._force_token = token
+ if user is None:
+ self.logout() # Also clear any possible session info if required
def request(self, **kwargs):
# Ensure that any credentials set get added to every request.
diff --git a/rest_framework/tests/test_files.py b/rest_framework/tests/test_files.py
index 487046ac..c13c38b8 100644
--- a/rest_framework/tests/test_files.py
+++ b/rest_framework/tests/test_files.py
@@ -7,13 +7,13 @@ import datetime
class UploadedFile(object):
- def __init__(self, file, created=None):
+ def __init__(self, file=None, created=None):
self.file = file
self.created = created or datetime.datetime.now()
class UploadedFileSerializer(serializers.Serializer):
- file = serializers.FileField()
+ file = serializers.FileField(required=False)
created = serializers.DateTimeField()
def restore_object(self, attrs, instance=None):
@@ -47,5 +47,36 @@ class FileSerializerTests(TestCase):
now = datetime.datetime.now()
serializer = UploadedFileSerializer(data={'created': now})
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.object.created, now)
+ self.assertIsNone(serializer.object.file)
+
+ def test_remove_with_empty_string(self):
+ """
+ Passing empty string as data should cause file to be removed
+
+ Test for:
+ https://github.com/tomchristie/django-rest-framework/issues/937
+ """
+ now = datetime.datetime.now()
+ file = BytesIO(six.b('stuff'))
+ file.name = 'stuff.txt'
+ file.size = len(file.getvalue())
+
+ uploaded_file = UploadedFile(file=file, created=now)
+
+ serializer = UploadedFileSerializer(instance=uploaded_file, data={'created': now, 'file': ''})
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.object.created, uploaded_file.created)
+ self.assertIsNone(serializer.object.file)
+
+ def test_validation_error_with_non_file(self):
+ """
+ Passing non-files should raise a validation error.
+ """
+ now = datetime.datetime.now()
+ errmsg = 'No file was submitted. Check the encoding type on the form.'
+
+ serializer = UploadedFileSerializer(data={'created': now, 'file': 'abc'})
self.assertFalse(serializer.is_valid())
- self.assertIn('file', serializer.errors)
+ self.assertEqual(serializer.errors, {'file': [errmsg]})
diff --git a/rest_framework/tests/test_generics.py b/rest_framework/tests/test_generics.py
index 1550880b..7a87d389 100644
--- a/rest_framework/tests/test_generics.py
+++ b/rest_framework/tests/test_generics.py
@@ -338,6 +338,17 @@ class TestInstanceView(TestCase):
new_obj = SlugBasedModel.objects.get(slug='test_slug')
self.assertEqual(new_obj.text, 'foobar')
+ def test_patch_cannot_create_an_object(self):
+ """
+ PATCH requests should not be able to create objects.
+ """
+ data = {'text': 'foobar'}
+ request = factory.patch('/999', data, format='json')
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=999).render()
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+ self.assertFalse(self.objects.filter(id=999).exists())
+
class TestOverriddenGetObject(TestCase):
"""
diff --git a/rest_framework/tests/test_relations_nested.py b/rest_framework/tests/test_relations_nested.py
index f6d006b3..d393b0c3 100644
--- a/rest_framework/tests/test_relations_nested.py
+++ b/rest_framework/tests/test_relations_nested.py
@@ -1,107 +1,328 @@
from __future__ import unicode_literals
+from django.db import models
from django.test import TestCase
from rest_framework import serializers
-from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
-class ForeignKeySourceSerializer(serializers.ModelSerializer):
- class Meta:
- model = ForeignKeySource
- fields = ('id', 'name', 'target')
- depth = 1
+class OneToOneTarget(models.Model):
+ name = models.CharField(max_length=100)
-class ForeignKeyTargetSerializer(serializers.ModelSerializer):
- class Meta:
- model = ForeignKeyTarget
- fields = ('id', 'name', 'sources')
- depth = 1
+class OneToOneSource(models.Model):
+ name = models.CharField(max_length=100)
+ target = models.OneToOneField(OneToOneTarget, related_name='source',
+ null=True, blank=True)
-class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
- class Meta:
- model = NullableForeignKeySource
- fields = ('id', 'name', 'target')
- depth = 1
+class OneToManyTarget(models.Model):
+ name = models.CharField(max_length=100)
-class NullableOneToOneTargetSerializer(serializers.ModelSerializer):
- class Meta:
- model = OneToOneTarget
- fields = ('id', 'name', 'nullable_source')
- depth = 1
+class OneToManySource(models.Model):
+ name = models.CharField(max_length=100)
+ target = models.ForeignKey(OneToManyTarget, related_name='sources')
-class ReverseForeignKeyTests(TestCase):
+class ReverseNestedOneToOneTests(TestCase):
def setUp(self):
- target = ForeignKeyTarget(name='target-1')
- target.save()
- new_target = ForeignKeyTarget(name='target-2')
- new_target.save()
+ class OneToOneSourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = OneToOneSource
+ fields = ('id', 'name')
+
+ class OneToOneTargetSerializer(serializers.ModelSerializer):
+ source = OneToOneSourceSerializer()
+
+ class Meta:
+ model = OneToOneTarget
+ fields = ('id', 'name', 'source')
+
+ self.Serializer = OneToOneTargetSerializer
+
for idx in range(1, 4):
- source = ForeignKeySource(name='source-%d' % idx, target=target)
+ target = OneToOneTarget(name='target-%d' % idx)
+ target.save()
+ source = OneToOneSource(name='source-%d' % idx, target=target)
source.save()
- def test_foreign_key_retrieve(self):
- queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset, many=True)
+ def test_one_to_one_retrieve(self):
+ queryset = OneToOneTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
- {'id': 2, 'name': 'source-2', 'target': {'id': 1, 'name': 'target-1'}},
- {'id': 3, 'name': 'source-3', 'target': {'id': 1, 'name': 'target-1'}},
+ {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
+ {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
+ {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}}
]
self.assertEqual(serializer.data, expected)
- def test_reverse_foreign_key_retrieve(self):
- queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ def test_one_to_one_create(self):
+ data = {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}}
+ serializer = self.Serializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-4')
+
+ # Ensure (target 4, target_source 4, source 4) are added, and
+ # everything else is as expected.
+ queryset = OneToOneTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'target-1', 'sources': [
- {'id': 1, 'name': 'source-1', 'target': 1},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': 1},
- ]},
- {'id': 2, 'name': 'target-2', 'sources': [
- ]}
+ {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
+ {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
+ {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}},
+ {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}}
]
self.assertEqual(serializer.data, expected)
+ def test_one_to_one_create_with_invalid_data(self):
+ data = {'id': 4, 'name': 'target-4', 'source': {'id': 4}}
+ serializer = self.Serializer(data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'source': [{'name': ['This field is required.']}]})
-class NestedNullableForeignKeyTests(TestCase):
+ def test_one_to_one_update(self):
+ data = {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}
+ instance = OneToOneTarget.objects.get(pk=3)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-3-updated')
+
+ # Ensure (target 3, target_source 3, source 3) are updated,
+ # and everything else is as expected.
+ queryset = OneToOneTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
+ {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
+ {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+
+class ForwardNestedOneToOneTests(TestCase):
def setUp(self):
- target = ForeignKeyTarget(name='target-1')
- target.save()
+ class OneToOneTargetSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = OneToOneTarget
+ fields = ('id', 'name')
+
+ class OneToOneSourceSerializer(serializers.ModelSerializer):
+ target = OneToOneTargetSerializer()
+
+ class Meta:
+ model = OneToOneSource
+ fields = ('id', 'name', 'target')
+
+ self.Serializer = OneToOneSourceSerializer
+
for idx in range(1, 4):
- if idx == 3:
- target = None
- source = NullableForeignKeySource(name='source-%d' % idx, target=target)
+ target = OneToOneTarget(name='target-%d' % idx)
+ target.save()
+ source = OneToOneSource(name='source-%d' % idx, target=target)
source.save()
- def test_foreign_key_retrieve_with_null(self):
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ def test_one_to_one_retrieve(self):
+ queryset = OneToOneSource.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
+ {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
+ {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_one_create(self):
+ data = {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}}
+ serializer = self.Serializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure (target 4, target_source 4, source 4) are added, and
+ # everything else is as expected.
+ queryset = OneToOneSource.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
+ {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
+ {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}},
+ {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_one_create_with_invalid_data(self):
+ data = {'id': 4, 'name': 'source-4', 'target': {'id': 4}}
+ serializer = self.Serializer(data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'target': [{'name': ['This field is required.']}]})
+
+ def test_one_to_one_update(self):
+ data = {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}}
+ instance = OneToOneSource.objects.get(pk=3)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-3-updated')
+
+ # Ensure (target 3, target_source 3, source 3) are updated,
+ # and everything else is as expected.
+ queryset = OneToOneSource.objects.all()
+ serializer = self.Serializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
- {'id': 2, 'name': 'source-2', 'target': {'id': 1, 'name': 'target-1'}},
- {'id': 3, 'name': 'source-3', 'target': None},
+ {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
+ {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}}
]
self.assertEqual(serializer.data, expected)
+ def test_one_to_one_update_to_null(self):
+ data = {'id': 3, 'name': 'source-3-updated', 'target': None}
+ instance = OneToOneSource.objects.get(pk=3)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
-class NestedNullableOneToOneTests(TestCase):
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-3-updated')
+ self.assertEqual(obj.target, None)
+
+ queryset = OneToOneSource.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
+ {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
+ {'id': 3, 'name': 'source-3-updated', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ # TODO: Nullable 1-1 tests
+ # def test_one_to_one_delete(self):
+ # data = {'id': 3, 'name': 'target-3', 'target_source': None}
+ # instance = OneToOneTarget.objects.get(pk=3)
+ # serializer = self.Serializer(instance, data=data)
+ # self.assertTrue(serializer.is_valid())
+ # serializer.save()
+
+ # # Ensure (target_source 3, source 3) are deleted,
+ # # and everything else is as expected.
+ # queryset = OneToOneTarget.objects.all()
+ # serializer = self.Serializer(queryset)
+ # expected = [
+ # {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
+ # {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
+ # {'id': 3, 'name': 'target-3', 'source': None}
+ # ]
+ # self.assertEqual(serializer.data, expected)
+
+
+class ReverseNestedOneToManyTests(TestCase):
def setUp(self):
- target = OneToOneTarget(name='target-1')
+ class OneToManySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = OneToManySource
+ fields = ('id', 'name')
+
+ class OneToManyTargetSerializer(serializers.ModelSerializer):
+ sources = OneToManySourceSerializer(many=True, allow_add_remove=True)
+
+ class Meta:
+ model = OneToManyTarget
+ fields = ('id', 'name', 'sources')
+
+ self.Serializer = OneToManyTargetSerializer
+
+ target = OneToManyTarget(name='target-1')
target.save()
- new_target = OneToOneTarget(name='target-2')
- new_target.save()
- source = NullableOneToOneSource(name='source-1', target=target)
- source.save()
+ for idx in range(1, 4):
+ source = OneToManySource(name='source-%d' % idx, target=target)
+ source.save()
- def test_reverse_foreign_key_retrieve_with_null(self):
- queryset = OneToOneTarget.objects.all()
- serializer = NullableOneToOneTargetSerializer(queryset, many=True)
+ def test_one_to_many_retrieve(self):
+ queryset = OneToManyTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'}]},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_many_create(self):
+ data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'},
+ {'id': 4, 'name': 'source-4'}]}
+ instance = OneToManyTarget.objects.get(pk=1)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-1')
+
+ # Ensure source 4 is added, and everything else is as
+ # expected.
+ queryset = OneToManyTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'target-1', 'nullable_source': {'id': 1, 'name': 'source-1', 'target': 1}},
- {'id': 2, 'name': 'target-2', 'nullable_source': None},
+ {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'},
+ {'id': 4, 'name': 'source-4'}]}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_many_create_with_invalid_data(self):
+ data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'},
+ {'id': 4}]}
+ serializer = self.Serializer(data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'sources': [{}, {}, {}, {'name': ['This field is required.']}]})
+
+ def test_one_to_many_update(self):
+ data = {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'}]}
+ instance = OneToManyTarget.objects.get(pk=1)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-1-updated')
+
+ # Ensure (target 1, source 1) are updated,
+ # and everything else is as expected.
+ queryset = OneToManyTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'}]}
+
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_many_delete(self):
+ data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 3, 'name': 'source-3'}]}
+ instance = OneToManyTarget.objects.get(pk=1)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+
+ # Ensure source 2 is deleted, and everything else is as
+ # expected.
+ queryset = OneToManyTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 3, 'name': 'source-3'}]}
+
]
self.assertEqual(serializer.data, expected)
diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py
index 5fcccb74..e723f7d4 100644
--- a/rest_framework/tests/test_routers.py
+++ b/rest_framework/tests/test_routers.py
@@ -146,7 +146,7 @@ class TestTrailingSlashRemoved(TestCase):
self.urls = self.router.urls
def test_urls_can_have_trailing_slash_removed(self):
- expected = ['^notes$', '^notes/(?P<pk>[^/]+)$']
+ expected = ['^notes$', '^notes/(?P<pk>[^/.]+)$']
for idx in range(len(expected)):
self.assertEqual(expected[idx], self.urls[idx].regex.pattern)
diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py
index 49d45fc2..48b8956b 100644
--- a/rest_framework/tests/test_testing.py
+++ b/rest_framework/tests/test_testing.py
@@ -17,8 +17,18 @@ def view(request):
})
+@api_view(['GET', 'POST'])
+def session_view(request):
+ active_session = request.session.get('active_session', False)
+ request.session['active_session'] = True
+ return Response({
+ 'active_session': active_session
+ })
+
+
urlpatterns = patterns('',
url(r'^view/$', view),
+ url(r'^session-view/$', session_view),
)
@@ -46,6 +56,26 @@ class TestAPITestClient(TestCase):
response = self.client.get('/view/')
self.assertEqual(response.data['user'], 'example')
+ def test_force_authenticate_with_sessions(self):
+ """
+ Setting `.force_authenticate()` forcibly authenticates each request.
+ """
+ user = User.objects.create_user('example', 'example@example.com')
+ self.client.force_authenticate(user)
+
+ # First request does not yet have an active session
+ response = self.client.get('/session-view/')
+ self.assertEqual(response.data['active_session'], False)
+
+ # Subsequant requests have an active session
+ response = self.client.get('/session-view/')
+ self.assertEqual(response.data['active_session'], True)
+
+ # Force authenticating as `None` should also logout the user session.
+ self.client.force_authenticate(None)
+ response = self.client.get('/session-view/')
+ self.assertEqual(response.data['active_session'], False)
+
def test_csrf_exempt_by_default(self):
"""
By default, the test client is CSRF exempt.
diff --git a/rest_framework/views.py b/rest_framework/views.py
index 727a9f95..4cff0422 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -15,8 +15,14 @@ from rest_framework.settings import api_settings
from rest_framework.utils import formatting
-def get_view_name(cls, suffix=None):
- name = cls.__name__
+def get_view_name(view_cls, suffix=None):
+ """
+ Given a view class, return a textual name to represent the view.
+ This name is used in the browsable API, and in OPTIONS responses.
+
+ This function is the default for the `VIEW_NAME_FUNCTION` setting.
+ """
+ name = view_cls.__name__
name = formatting.remove_trailing_string(name, 'View')
name = formatting.remove_trailing_string(name, 'ViewSet')
name = formatting.camelcase_to_spaces(name)
@@ -25,17 +31,56 @@ def get_view_name(cls, suffix=None):
return name
-def get_view_description(cls, html=False):
- description = cls.__doc__ or ''
+def get_view_description(view_cls, html=False):
+ """
+ Given a view class, return a textual description to represent the view.
+ This name is used in the browsable API, and in OPTIONS responses.
+
+ This function is the default for the `VIEW_DESCRIPTION_FUNCTION` setting.
+ """
+ description = view_cls.__doc__ or ''
description = formatting.dedent(smart_text(description))
if html:
return formatting.markup_description(description)
return description
+def exception_handler(exc):
+ """
+ Returns the response that should be used for any given exception.
+
+ By default we handle the REST framework `APIException`, and also
+ Django's builtin `Http404` and `PermissionDenied` exceptions.
+
+ Any unhandled exceptions may return `None`, which will cause a 500 error
+ to be raised.
+ """
+ if isinstance(exc, exceptions.APIException):
+ headers = {}
+ if getattr(exc, 'auth_header', None):
+ headers['WWW-Authenticate'] = exc.auth_header
+ if getattr(exc, 'wait', None):
+ headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait
+
+ return Response({'detail': exc.detail},
+ status=exc.status_code,
+ headers=headers)
+
+ elif isinstance(exc, Http404):
+ return Response({'detail': 'Not found'},
+ status=status.HTTP_404_NOT_FOUND)
+
+ elif isinstance(exc, PermissionDenied):
+ return Response({'detail': 'Permission denied'},
+ status=status.HTTP_403_FORBIDDEN)
+
+ # Note: Unhandled exceptions will raise a 500 error.
+ return None
+
+
class APIView(View):
- settings = api_settings
+ # The following policies may be set at either globally, or per-view.
renderer_classes = api_settings.DEFAULT_RENDERER_CLASSES
parser_classes = api_settings.DEFAULT_PARSER_CLASSES
authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES
@@ -43,6 +88,9 @@ class APIView(View):
permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES
content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS
+ # Allow dependancy injection of other settings to make testing easier.
+ settings = api_settings
+
@classmethod
def as_view(cls, **initkwargs):
"""
@@ -133,7 +181,7 @@ class APIView(View):
Return the view name, as used in OPTIONS responses and in the
browsable API.
"""
- func = api_settings.VIEW_NAME_FUNCTION
+ func = self.settings.VIEW_NAME_FUNCTION
return func(self.__class__, getattr(self, 'suffix', None))
def get_view_description(self, html=False):
@@ -141,7 +189,7 @@ class APIView(View):
Return some descriptive text for the view, as used in OPTIONS responses
and in the browsable API.
"""
- func = api_settings.VIEW_DESCRIPTION_FUNCTION
+ func = self.settings.VIEW_DESCRIPTION_FUNCTION
return func(self.__class__, html)
# API policy instantiation methods
@@ -303,33 +351,23 @@ class APIView(View):
Handle any exception that occurs, by returning an appropriate response,
or re-raising the error.
"""
- if isinstance(exc, exceptions.Throttled) and exc.wait is not None:
- # Throttle wait header
- self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait
-
if isinstance(exc, (exceptions.NotAuthenticated,
exceptions.AuthenticationFailed)):
# WWW-Authenticate header for 401 responses, else coerce to 403
auth_header = self.get_authenticate_header(self.request)
if auth_header:
- self.headers['WWW-Authenticate'] = auth_header
+ exc.auth_header = auth_header
else:
exc.status_code = status.HTTP_403_FORBIDDEN
- if isinstance(exc, exceptions.APIException):
- return Response({'detail': exc.detail},
- status=exc.status_code,
- exception=True)
- elif isinstance(exc, Http404):
- return Response({'detail': 'Not found'},
- status=status.HTTP_404_NOT_FOUND,
- exception=True)
- elif isinstance(exc, PermissionDenied):
- return Response({'detail': 'Permission denied'},
- status=status.HTTP_403_FORBIDDEN,
- exception=True)
- raise
+ response = exception_handler(exc)
+
+ if response is None:
+ raise
+
+ response.exception = True
+ return response
# Note: session based authentication is explicitly CSRF validated,
# all other authentication is CSRF exempt.