aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
authorTom Christie2013-08-30 09:28:33 +0100
committerTom Christie2013-08-30 09:28:33 +0100
commit9a5b2eefa92dede844ab94d049093e91ac98af5b (patch)
treefaf389e2f8c8296aeaa486ab97ed0be9113cc2ba /rest_framework
parentbf07b8e616bd92e4ae3c2c09b198181d7075e6bd (diff)
parentf3ab0b2b1d5734314dbe3cdd13cd7c4f0531bf7d (diff)
downloaddjango-rest-framework-9a5b2eefa92dede844ab94d049093e91ac98af5b.tar.bz2
Merge master
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/fields.py10
-rw-r--r--rest_framework/generics.py11
-rw-r--r--rest_framework/mixins.py15
-rw-r--r--rest_framework/parsers.py10
-rw-r--r--rest_framework/relations.py11
-rw-r--r--rest_framework/renderers.py321
-rw-r--r--rest_framework/request.py23
-rw-r--r--rest_framework/routers.py6
-rw-r--r--rest_framework/serializers.py81
-rw-r--r--rest_framework/settings.py10
-rw-r--r--rest_framework/static/rest_framework/js/default.js45
-rw-r--r--rest_framework/templates/rest_framework/base.html18
-rw-r--r--rest_framework/test.py2
-rw-r--r--rest_framework/tests/test_fields.py8
-rw-r--r--rest_framework/tests/test_files.py37
-rw-r--r--rest_framework/tests/test_generics.py11
-rw-r--r--rest_framework/tests/test_pagination.py47
-rw-r--r--rest_framework/tests/test_relations_nested.py351
-rw-r--r--rest_framework/tests/test_routers.py2
-rw-r--r--rest_framework/tests/test_testing.py30
-rw-r--r--rest_framework/throttling.py11
-rw-r--r--rest_framework/utils/breadcrumbs.py7
-rw-r--r--rest_framework/views.py88
23 files changed, 878 insertions, 277 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 518ba41a..b3a9b0df 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -301,7 +301,10 @@ class WritableField(Field):
try:
if self.use_files:
files = files or {}
- native = files[field_name]
+ try:
+ native = files[field_name]
+ except KeyError:
+ native = data[field_name]
else:
native = data[field_name]
except KeyError:
@@ -504,6 +507,11 @@ class ChoiceField(WritableField):
return True
return False
+ def from_native(self, value):
+ if value in validators.EMPTY_VALUES:
+ return None
+ return super(ChoiceField, self).from_native(value)
+
class EmailField(CharField):
type_name = 'EmailField'
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index 8e6b8e26..851f8474 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -14,13 +14,15 @@ from rest_framework.settings import api_settings
import warnings
-def strict_positive_int(integer_string):
+def strict_positive_int(integer_string, cutoff=None):
"""
Cast a string to a strictly positive integer.
"""
ret = int(integer_string)
if ret <= 0:
raise ValueError()
+ if cutoff:
+ ret = min(ret, cutoff)
return ret
def get_object_or_404(queryset, **filter_kwargs):
@@ -56,6 +58,7 @@ class GenericAPIView(views.APIView):
# Pagination settings
paginate_by = api_settings.PAGINATE_BY
paginate_by_param = api_settings.PAGINATE_BY_PARAM
+ max_paginate_by = api_settings.MAX_PAGINATE_BY
pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS
page_kwarg = 'page'
@@ -205,9 +208,11 @@ class GenericAPIView(views.APIView):
DeprecationWarning, stacklevel=2)
if self.paginate_by_param:
- query_params = self.request.QUERY_PARAMS
try:
- return strict_positive_int(query_params[self.paginate_by_param])
+ return strict_positive_int(
+ self.request.QUERY_PARAMS[self.paginate_by_param],
+ cutoff=self.max_paginate_by
+ )
except (KeyError, ValueError):
pass
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index 679dfa6c..2c85d157 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -142,11 +142,16 @@ class UpdateModelMixin(object):
try:
return self.get_object()
except Http404:
- # If this is a PUT-as-create operation, we need to ensure that
- # we have relevant permissions, as if this was a POST request.
- # This will either raise a PermissionDenied exception,
- # or simply return None
- self.check_permissions(clone_request(self.request, 'POST'))
+ if self.request.method == 'PUT':
+ # For PUT-as-create operation, we need to ensure that we have
+ # relevant permissions, as if this was a POST request. This
+ # will either raise a PermissionDenied exception, or simply
+ # return None.
+ self.check_permissions(clone_request(self.request, 'POST'))
+ else:
+ # PATCH requests where the object does not exist should still
+ # return a 404 response.
+ raise
def pre_save(self, obj):
"""
diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py
index 96bfac84..98fc0341 100644
--- a/rest_framework/parsers.py
+++ b/rest_framework/parsers.py
@@ -10,9 +10,9 @@ from django.core.files.uploadhandler import StopFutureHandlers
from django.http import QueryDict
from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser
from django.http.multipartparser import MultiPartParserError, parse_header, ChunkIter
-from rest_framework.compat import yaml, etree
+from rest_framework.compat import etree, six, yaml
from rest_framework.exceptions import ParseError
-from rest_framework.compat import six
+from rest_framework import renderers
import json
import datetime
import decimal
@@ -47,6 +47,7 @@ class JSONParser(BaseParser):
"""
media_type = 'application/json'
+ renderer_class = renderers.UnicodeJSONRenderer
def parse(self, stream, media_type=None, parser_context=None):
"""
@@ -121,7 +122,8 @@ class MultiPartParser(BaseParser):
parser_context = parser_context or {}
request = parser_context['request']
encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
- meta = request.META
+ meta = request.META.copy()
+ meta['CONTENT_TYPE'] = media_type
upload_handlers = request.upload_handlers
try:
@@ -129,7 +131,7 @@ class MultiPartParser(BaseParser):
data, files = parser.parse()
return DataAndFiles(data, files)
except MultiPartParserError as exc:
- raise ParseError('Multipart form parse error - %s' % six.u(exc))
+ raise ParseError('Multipart form parse error - %s' % str(exc))
class XMLParser(BaseParser):
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index f1f7dea7..417925b5 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -126,9 +126,9 @@ class RelatedField(WritableField):
value = obj
for component in source.split('.'):
- value = get_component(value, component)
if value is None:
break
+ value = get_component(value, component)
except ObjectDoesNotExist:
return None
@@ -236,6 +236,8 @@ class PrimaryKeyRelatedField(RelatedField):
source = self.source or field_name
queryset = obj
for component in source.split('.'):
+ if queryset is None:
+ return []
queryset = get_component(queryset, component)
# Forward relationship
@@ -556,8 +558,13 @@ class HyperlinkedIdentityField(Field):
May raise a `NoReverseMatch` if the `view_name` and `lookup_field`
attributes are not configured to correctly match the URL conf.
"""
- lookup_field = getattr(obj, self.lookup_field)
+ lookup_field = getattr(obj, self.lookup_field, None)
kwargs = {self.lookup_field: lookup_field}
+
+ # Handle unsaved object case
+ if lookup_field is None:
+ return None
+
try:
return reverse(view_name, kwargs=kwargs, request=request, format=format)
except NoReverseMatch:
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index 1006e26c..fca67eee 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -21,10 +21,10 @@ from rest_framework.compat import six
from rest_framework.compat import smart_text
from rest_framework.compat import yaml
from rest_framework.settings import api_settings
-from rest_framework.request import clone_request
+from rest_framework.request import is_form_media_type, override_method
from rest_framework.utils import encoders
from rest_framework.utils.breadcrumbs import get_breadcrumbs
-from rest_framework import exceptions, parsers, status, VERSION
+from rest_framework import exceptions, status, VERSION
class BaseRenderer(object):
@@ -36,6 +36,7 @@ class BaseRenderer(object):
media_type = None
format = None
charset = 'utf-8'
+ render_style = 'text'
def render(self, data, accepted_media_type=None, renderer_context=None):
raise NotImplemented('Renderer class requires .render() to be implemented')
@@ -51,16 +52,17 @@ class JSONRenderer(BaseRenderer):
format = 'json'
encoder_class = encoders.JSONEncoder
ensure_ascii = True
- charset = 'utf-8'
- # Note that JSON encodings must be utf-8, utf-16 or utf-32.
+ charset = None
+ # JSON is a binary encoding, that can be encoded as utf-8, utf-16 or utf-32.
# See: http://www.ietf.org/rfc/rfc4627.txt
+ # Also: http://lucumr.pocoo.org/2013/7/19/application-mimetypes-and-encodings/
def render(self, data, accepted_media_type=None, renderer_context=None):
"""
Render `data` into JSON.
"""
if data is None:
- return ''
+ return bytes()
# If 'indent' is provided in the context, then pretty print the result.
# E.g. If we're being called by the BrowsableAPIRenderer.
@@ -85,13 +87,12 @@ class JSONRenderer(BaseRenderer):
# and may (or may not) be unicode.
# On python 3.x json.dumps() returns unicode strings.
if isinstance(ret, six.text_type):
- return bytes(ret.encode(self.charset))
+ return bytes(ret.encode('utf-8'))
return ret
class UnicodeJSONRenderer(JSONRenderer):
ensure_ascii = False
- charset = 'utf-8'
"""
Renderer which serializes to JSON.
Does *not* apply JSON's character escaping for non-ascii characters.
@@ -108,6 +109,7 @@ class JSONPRenderer(JSONRenderer):
format = 'jsonp'
callback_parameter = 'callback'
default_callback = 'callback'
+ charset = 'utf-8'
def get_callback(self, renderer_context):
"""
@@ -316,6 +318,90 @@ class StaticHTMLRenderer(TemplateHTMLRenderer):
return data
+class HTMLFormRenderer(BaseRenderer):
+ """
+ Renderers serializer data into an HTML form.
+
+ If the serializer was instantiated without an object then this will
+ return an HTML form not bound to any object,
+ otherwise it will return an HTML form with the appropriate initial data
+ populated from the object.
+
+ Note that rendering of field and form errors is not currently supported.
+ """
+ media_type = 'text/html'
+ format = 'form'
+ template = 'rest_framework/form.html'
+ charset = 'utf-8'
+
+ def data_to_form_fields(self, data):
+ fields = {}
+ for key, val in data.fields.items():
+ if getattr(val, 'read_only', True):
+ # Don't include read-only fields.
+ continue
+
+ if getattr(val, 'fields', None):
+ # Nested data not supported by HTML forms.
+ continue
+
+ kwargs = {}
+ kwargs['required'] = val.required
+
+ #if getattr(v, 'queryset', None):
+ # kwargs['queryset'] = v.queryset
+
+ if getattr(val, 'choices', None) is not None:
+ kwargs['choices'] = val.choices
+
+ if getattr(val, 'regex', None) is not None:
+ kwargs['regex'] = val.regex
+
+ if getattr(val, 'widget', None):
+ widget = copy.deepcopy(val.widget)
+ kwargs['widget'] = widget
+
+ if getattr(val, 'default', None) is not None:
+ kwargs['initial'] = val.default
+
+ if getattr(val, 'label', None) is not None:
+ kwargs['label'] = val.label
+
+ if getattr(val, 'help_text', None) is not None:
+ kwargs['help_text'] = val.help_text
+
+ fields[key] = val.form_field_class(**kwargs)
+
+ return fields
+
+ def render(self, data, accepted_media_type=None, renderer_context=None):
+ """
+ Render serializer data and return an HTML form, as a string.
+ """
+ # The HTMLFormRenderer currently uses something of a hack to render
+ # the content, by translating each of the serializer fields into
+ # an html form field, creating a dynamic form using those fields,
+ # and then rendering that form.
+
+ # This isn't strictly neccessary, as we could render the serilizer
+ # fields to HTML directly. The implementation is historical and will
+ # likely change at some point.
+
+ self.renderer_context = renderer_context or {}
+ request = renderer_context['request']
+
+ # Creating an on the fly form see:
+ # http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python
+ fields = self.data_to_form_fields(data)
+ DynamicForm = type(str('DynamicForm'), (forms.Form,), fields)
+ data = None if data.empty else data
+
+ template = loader.get_template(self.template)
+ context = RequestContext(request, {'form': DynamicForm(data)})
+
+ return template.render(context)
+
+
class BrowsableAPIRenderer(BaseRenderer):
"""
HTML renderer used to self-document the API.
@@ -324,6 +410,7 @@ class BrowsableAPIRenderer(BaseRenderer):
format = 'api'
template = 'rest_framework/api.html'
charset = 'utf-8'
+ form_renderer_class = HTMLFormRenderer
def get_default_renderer(self, view):
"""
@@ -348,7 +435,10 @@ class BrowsableAPIRenderer(BaseRenderer):
renderer_context['indent'] = 4
content = renderer.render(data, accepted_media_type, renderer_context)
- if renderer.charset is None:
+ render_style = getattr(renderer, 'render_style', 'text')
+ assert render_style in ['text', 'binary'], 'Expected .render_style ' \
+ '"text" or "binary", but got "%s"' % render_style
+ if render_style == 'binary':
return '[%d bytes of binary content]' % len(content)
return content
@@ -371,130 +461,99 @@ class BrowsableAPIRenderer(BaseRenderer):
return False # Doesn't have permissions
return True
- def serializer_to_form_fields(self, serializer):
- fields = {}
- for k, v in serializer.get_fields().items():
- if getattr(v, 'read_only', True):
- continue
-
- kwargs = {}
- kwargs['required'] = v.required
-
- #if getattr(v, 'queryset', None):
- # kwargs['queryset'] = v.queryset
-
- if getattr(v, 'choices', None) is not None:
- kwargs['choices'] = v.choices
-
- if getattr(v, 'regex', None) is not None:
- kwargs['regex'] = v.regex
-
- if getattr(v, 'widget', None):
- widget = copy.deepcopy(v.widget)
- kwargs['widget'] = widget
-
- if getattr(v, 'default', None) is not None:
- kwargs['initial'] = v.default
-
- if getattr(v, 'label', None) is not None:
- kwargs['label'] = v.label
-
- if getattr(v, 'help_text', None) is not None:
- kwargs['help_text'] = v.help_text
-
- fields[k] = v.form_field_class(**kwargs)
-
- return fields
-
- def _get_form(self, view, method, request):
- # We need to impersonate a request with the correct method,
- # so that eg. any dynamic get_serializer_class methods return the
- # correct form for each method.
- restore = view.request
- request = clone_request(request, method)
- view.request = request
- try:
- return self.get_form(view, method, request)
- finally:
- view.request = restore
-
- def _get_raw_data_form(self, view, method, request, media_types):
- # We need to impersonate a request with the correct method,
- # so that eg. any dynamic get_serializer_class methods return the
- # correct form for each method.
- restore = view.request
- request = clone_request(request, method)
- view.request = request
- try:
- return self.get_raw_data_form(view, method, request, media_types)
- finally:
- view.request = restore
-
- def get_form(self, view, method, request):
+ def get_rendered_html_form(self, view, method, request):
"""
- Get a form, possibly bound to either the input or output data.
- In the absence on of the Resource having an associated form then
- provide a form that can be used to submit arbitrary content.
+ Return a string representing a rendered HTML form, possibly bound to
+ either the input or output data.
+
+ In the absence of the View having an associated form then return None.
"""
- obj = getattr(view, 'object', None)
- if not self.show_form_for_method(view, method, request, obj):
- return
+ with override_method(view, request, method) as request:
+ obj = getattr(view, 'object', None)
+ if not self.show_form_for_method(view, method, request, obj):
+ return
- if method in ('DELETE', 'OPTIONS'):
- return True # Don't actually need to return a form
+ if method in ('DELETE', 'OPTIONS'):
+ return True # Don't actually need to return a form
- if not getattr(view, 'get_serializer', None) or not parsers.FormParser in view.parser_classes:
- return
+ if (not getattr(view, 'get_serializer', None)
+ or not any(is_form_media_type(parser.media_type) for parser in view.parser_classes)):
+ return
- serializer = view.get_serializer(instance=obj)
- fields = self.serializer_to_form_fields(serializer)
+ serializer = view.get_serializer(instance=obj)
- # Creating an on the fly form see:
- # http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python
- OnTheFlyForm = type(str("OnTheFlyForm"), (forms.Form,), fields)
- data = (obj is not None) and serializer.data or None
- form_instance = OnTheFlyForm(data)
- return form_instance
+ data = serializer.data
+ form_renderer = self.form_renderer_class()
+ return form_renderer.render(data, self.accepted_media_type, self.renderer_context)
- def get_raw_data_form(self, view, method, request, media_types):
+ def get_raw_data_form(self, view, method, request):
"""
Returns a form that allows for arbitrary content types to be tunneled
via standard HTML forms.
(Which are typically application/x-www-form-urlencoded)
"""
-
- # If we're not using content overloading there's no point in supplying a generic form,
- # as the view won't treat the form's value as the content of the request.
- if not (api_settings.FORM_CONTENT_OVERRIDE
- and api_settings.FORM_CONTENTTYPE_OVERRIDE):
- return None
-
- # Check permissions
- obj = getattr(view, 'object', None)
- if not self.show_form_for_method(view, method, request, obj):
- return
-
- content_type_field = api_settings.FORM_CONTENTTYPE_OVERRIDE
- content_field = api_settings.FORM_CONTENT_OVERRIDE
- choices = [(media_type, media_type) for media_type in media_types]
- initial = media_types[0]
-
- # NB. http://jacobian.org/writing/dynamic-form-generation/
- class GenericContentForm(forms.Form):
- def __init__(self):
- super(GenericContentForm, self).__init__()
-
- self.fields[content_type_field] = forms.ChoiceField(
- label='Media type',
- choices=choices,
- initial=initial
- )
- self.fields[content_field] = forms.CharField(
- label='Content',
- widget=forms.Textarea
- )
-
- return GenericContentForm()
+ with override_method(view, request, method) as request:
+ # If we're not using content overloading there's no point in
+ # supplying a generic form, as the view won't treat the form's
+ # value as the content of the request.
+ if not (api_settings.FORM_CONTENT_OVERRIDE
+ and api_settings.FORM_CONTENTTYPE_OVERRIDE):
+ return None
+
+ # Check permissions
+ obj = getattr(view, 'object', None)
+ if not self.show_form_for_method(view, method, request, obj):
+ return
+
+ # If possible, serialize the initial content for the generic form
+ default_parser = view.parser_classes[0]
+ renderer_class = getattr(default_parser, 'renderer_class', None)
+ if (hasattr(view, 'get_serializer') and renderer_class):
+ # View has a serializer defined and parser class has a
+ # corresponding renderer that can be used to render the data.
+
+ # Get a read-only version of the serializer
+ serializer = view.get_serializer(instance=obj)
+ if obj is None:
+ for name, field in serializer.fields.items():
+ if getattr(field, 'read_only', None):
+ del serializer.fields[name]
+
+ # Render the raw data content
+ renderer = renderer_class()
+ accepted = self.accepted_media_type
+ context = self.renderer_context.copy()
+ context['indent'] = 4
+ content = renderer.render(serializer.data, accepted, context)
+ else:
+ content = None
+
+ # Generate a generic form that includes a content type field,
+ # and a content field.
+ content_type_field = api_settings.FORM_CONTENTTYPE_OVERRIDE
+ content_field = api_settings.FORM_CONTENT_OVERRIDE
+
+ media_types = [parser.media_type for parser in view.parser_classes]
+ choices = [(media_type, media_type) for media_type in media_types]
+ initial = media_types[0]
+
+ # NB. http://jacobian.org/writing/dynamic-form-generation/
+ class GenericContentForm(forms.Form):
+ def __init__(self):
+ super(GenericContentForm, self).__init__()
+
+ self.fields[content_type_field] = forms.ChoiceField(
+ label='Media type',
+ choices=choices,
+ initial=initial
+ )
+ self.fields[content_field] = forms.CharField(
+ label='Content',
+ widget=forms.Textarea,
+ initial=content
+ )
+
+ return GenericContentForm()
def get_name(self, view):
return view.get_view_name()
@@ -509,26 +568,25 @@ class BrowsableAPIRenderer(BaseRenderer):
"""
Render the HTML for the browsable API representation.
"""
- accepted_media_type = accepted_media_type or ''
- renderer_context = renderer_context or {}
+ self.accepted_media_type = accepted_media_type or ''
+ self.renderer_context = renderer_context or {}
view = renderer_context['view']
request = renderer_context['request']
response = renderer_context['response']
- media_types = [parser.media_type for parser in view.parser_classes]
renderer = self.get_default_renderer(view)
content = self.get_content(renderer, data, accepted_media_type, renderer_context)
- put_form = self._get_form(view, 'PUT', request)
- post_form = self._get_form(view, 'POST', request)
- patch_form = self._get_form(view, 'PATCH', request)
- delete_form = self._get_form(view, 'DELETE', request)
- options_form = self._get_form(view, 'OPTIONS', request)
+ put_form = self.get_rendered_html_form(view, 'PUT', request)
+ post_form = self.get_rendered_html_form(view, 'POST', request)
+ patch_form = self.get_rendered_html_form(view, 'PATCH', request)
+ delete_form = self.get_rendered_html_form(view, 'DELETE', request)
+ options_form = self.get_rendered_html_form(view, 'OPTIONS', request)
- raw_data_put_form = self._get_raw_data_form(view, 'PUT', request, media_types)
- raw_data_post_form = self._get_raw_data_form(view, 'POST', request, media_types)
- raw_data_patch_form = self._get_raw_data_form(view, 'PATCH', request, media_types)
+ raw_data_put_form = self.get_raw_data_form(view, 'PUT', request)
+ raw_data_post_form = self.get_raw_data_form(view, 'POST', request)
+ raw_data_patch_form = self.get_raw_data_form(view, 'PATCH', request)
raw_data_put_or_patch_form = raw_data_put_form or raw_data_patch_form
name = self.get_name(view)
@@ -581,3 +639,4 @@ class MultiPartRenderer(BaseRenderer):
def render(self, data, accepted_media_type=None, renderer_context=None):
return encode_multipart(self.BOUNDARY, data)
+
diff --git a/rest_framework/request.py b/rest_framework/request.py
index 919716f4..977d4d96 100644
--- a/rest_framework/request.py
+++ b/rest_framework/request.py
@@ -28,6 +28,29 @@ def is_form_media_type(media_type):
base_media_type == 'multipart/form-data')
+class override_method(object):
+ """
+ A context manager that temporarily overrides the method on a request,
+ additionally setting the `view.request` attribute.
+
+ Usage:
+
+ with override_method(view, request, 'POST') as request:
+ ... # Do stuff with `view` and `request`
+ """
+ def __init__(self, view, request, method):
+ self.view = view
+ self.request = request
+ self.method = method
+
+ def __enter__(self):
+ self.view.request = clone_request(self.request, self.method)
+ return self.view.request
+
+ def __exit__(self, *args, **kwarg):
+ self.view.request = self.request
+
+
class Empty(object):
"""
Placeholder for unset attributes.
diff --git a/rest_framework/routers.py b/rest_framework/routers.py
index b761ba9a..1c7a8158 100644
--- a/rest_framework/routers.py
+++ b/rest_framework/routers.py
@@ -213,7 +213,11 @@ class SimpleRouter(BaseRouter):
Given a viewset, return the portion of URL regex that is used
to match against a single instance.
"""
- base_regex = '(?P<{lookup_field}>[^/]+)'
+ if self.trailing_slash:
+ base_regex = '(?P<{lookup_field}>[^/]+)'
+ else:
+ # Don't consume `.json` style suffixes
+ base_regex = '(?P<{lookup_field}>[^/.]+)'
lookup_field = getattr(viewset, 'lookup_field', 'pk')
return base_regex.format(lookup_field=lookup_field)
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index b3850157..f1775762 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -31,6 +31,9 @@ from rest_framework.relations import *
from rest_framework.fields import *
+class RelationsList(list):
+ _deleted = []
+
class NestedValidationError(ValidationError):
"""
The default ValidationError behavior is to stringify each item in the list
@@ -160,7 +163,6 @@ class BaseSerializer(WritableField):
self._data = None
self._files = None
self._errors = None
- self._deleted = None
if many and instance is not None and not hasattr(instance, '__iter__'):
raise ValueError('instance should be a queryset or other iterable with many=True')
@@ -297,7 +299,8 @@ class BaseSerializer(WritableField):
Serialize objects -> primitives.
"""
ret = self._dict_class()
- ret.fields = {}
+ ret.fields = self._dict_class()
+ ret.empty = obj is None
for field_name, field in self.fields.items():
field.initialize(parent=self, field_name=field_name)
@@ -330,14 +333,15 @@ class BaseSerializer(WritableField):
if self.source == '*':
return self.to_native(obj)
+ # Get the raw field value
try:
source = self.source or field_name
value = obj
for component in source.split('.'):
- value = get_component(value, component)
if value is None:
break
+ value = get_component(value, component)
except ObjectDoesNotExist:
return None
@@ -372,6 +376,7 @@ class BaseSerializer(WritableField):
# Set the serializer object if it exists
obj = getattr(self.parent.object, field_name) if self.parent.object else None
+ obj = obj.all() if is_simple_callable(getattr(obj, 'all', None)) else obj
if self.source == '*':
if value:
@@ -385,7 +390,8 @@ class BaseSerializer(WritableField):
'data': value,
'context': self.context,
'partial': self.partial,
- 'many': self.many
+ 'many': self.many,
+ 'allow_add_remove': self.allow_add_remove
}
serializer = self.__class__(**kwargs)
@@ -418,8 +424,17 @@ class BaseSerializer(WritableField):
if self._errors is None:
data, files = self.init_data, self.init_files
- if self.many:
- ret = []
+ if self.many is not None:
+ many = self.many
+ else:
+ many = hasattr(data, '__iter__') and not isinstance(data, (Page, dict, six.text_type))
+ if many:
+ warnings.warn('Implict list/queryset serialization is deprecated. '
+ 'Use the `many=True` flag when instantiating the serializer.',
+ DeprecationWarning, stacklevel=3)
+
+ if many:
+ ret = RelationsList()
errors = []
update = self.object is not None
@@ -446,8 +461,8 @@ class BaseSerializer(WritableField):
ret.append(self.from_native(item, None))
errors.append(self._errors)
- if update:
- self._deleted = identity_to_objects.values()
+ if update and self.allow_add_remove:
+ ret._deleted = identity_to_objects.values()
self._errors = any(errors) and errors or []
else:
@@ -490,12 +505,12 @@ class BaseSerializer(WritableField):
"""
if isinstance(self.object, list):
[self.save_object(item, **kwargs) for item in self.object]
+
+ if self.object._deleted:
+ [self.delete_object(item) for item in self.object._deleted]
else:
self.save_object(self.object, **kwargs)
- if self.allow_add_remove and self._deleted:
- [self.delete_object(item) for item in self._deleted]
-
return self.object
def metadata(self):
@@ -771,9 +786,12 @@ class ModelSerializer(Serializer):
cls = self.opts.model
opts = get_concrete_model(cls)._meta
exclusions = [field.name for field in opts.fields + opts.many_to_many]
+
for field_name, field in self.fields.items():
field_name = field.source or field_name
- if field_name in exclusions and not field.read_only:
+ if field_name in exclusions \
+ and not field.read_only \
+ and not isinstance(field, Serializer):
exclusions.remove(field_name)
return exclusions
@@ -799,6 +817,7 @@ class ModelSerializer(Serializer):
"""
m2m_data = {}
related_data = {}
+ nested_forward_relations = {}
meta = self.opts.model._meta
# Reverse fk or one-to-one relations
@@ -818,6 +837,12 @@ class ModelSerializer(Serializer):
if field.name in attrs:
m2m_data[field.name] = attrs.pop(field.name)
+ # Nested forward relations - These need to be marked so we can save
+ # them before saving the parent model instance.
+ for field_name in attrs.keys():
+ if isinstance(self.fields.get(field_name, None), Serializer):
+ nested_forward_relations[field_name] = attrs[field_name]
+
# Update an existing instance...
if instance is not None:
for key, val in attrs.items():
@@ -833,6 +858,7 @@ class ModelSerializer(Serializer):
# at the point of save.
instance._related_data = related_data
instance._m2m_data = m2m_data
+ instance._nested_forward_relations = nested_forward_relations
return instance
@@ -848,6 +874,14 @@ class ModelSerializer(Serializer):
"""
Save the deserialized object and return it.
"""
+ if getattr(obj, '_nested_forward_relations', None):
+ # Nested relationships need to be saved before we can save the
+ # parent instance.
+ for field_name, sub_object in obj._nested_forward_relations.items():
+ if sub_object:
+ self.save_object(sub_object)
+ setattr(obj, field_name, sub_object)
+
obj.save(**kwargs)
if getattr(obj, '_m2m_data', None):
@@ -857,7 +891,25 @@ class ModelSerializer(Serializer):
if getattr(obj, '_related_data', None):
for accessor_name, related in obj._related_data.items():
- setattr(obj, accessor_name, related)
+ if isinstance(related, RelationsList):
+ # Nested reverse fk relationship
+ for related_item in related:
+ fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name
+ setattr(related_item, fk_field, obj)
+ self.save_object(related_item)
+
+ # Delete any removed objects
+ if related._deleted:
+ [self.delete_object(item) for item in related._deleted]
+
+ elif isinstance(related, models.Model):
+ # Nested reverse one-one relationship
+ fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name
+ setattr(related, fk_field, obj)
+ self.save_object(related)
+ else:
+ # Reverse FK or reverse one-one
+ setattr(obj, accessor_name, related)
del(obj._related_data)
@@ -879,6 +931,7 @@ class HyperlinkedModelSerializer(ModelSerializer):
_options_class = HyperlinkedModelSerializerOptions
_default_view_name = '%(model_name)s-detail'
_hyperlink_field_class = HyperlinkedRelatedField
+ _hyperlink_identify_field_class = HyperlinkedIdentityField
def get_default_fields(self):
fields = super(HyperlinkedModelSerializer, self).get_default_fields()
@@ -887,7 +940,7 @@ class HyperlinkedModelSerializer(ModelSerializer):
self.opts.view_name = self._get_default_view_name(self.opts.model)
if 'url' not in fields:
- url_field = HyperlinkedIdentityField(
+ url_field = self._hyperlink_identify_field_class(
view_name=self.opts.view_name,
lookup_field=self.opts.lookup_field
)
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index 7d25e513..8c084751 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -48,7 +48,6 @@ DEFAULTS = {
),
'DEFAULT_THROTTLE_CLASSES': (
),
-
'DEFAULT_CONTENT_NEGOTIATION_CLASS':
'rest_framework.negotiation.DefaultContentNegotiation',
@@ -68,15 +67,16 @@ DEFAULTS = {
# Pagination
'PAGINATE_BY': None,
'PAGINATE_BY_PARAM': None,
-
- # View configuration
- 'VIEW_NAME_FUNCTION': 'rest_framework.views.get_view_name',
- 'VIEW_DESCRIPTION_FUNCTION': 'rest_framework.views.get_view_description',
+ 'MAX_PAGINATE_BY': None,
# Authentication
'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
'UNAUTHENTICATED_TOKEN': None,
+ # View configuration
+ 'VIEW_NAME_FUNCTION': 'rest_framework.views.get_view_name',
+ 'VIEW_DESCRIPTION_FUNCTION': 'rest_framework.views.get_view_description',
+
# Testing
'TEST_REQUEST_RENDERER_CLASSES': (
'rest_framework.renderers.MultiPartRenderer',
diff --git a/rest_framework/static/rest_framework/js/default.js b/rest_framework/static/rest_framework/js/default.js
index c74829d7..bcb1964d 100644
--- a/rest_framework/static/rest_framework/js/default.js
+++ b/rest_framework/static/rest_framework/js/default.js
@@ -1,13 +1,56 @@
+function getCookie(c_name)
+{
+ // From http://www.w3schools.com/js/js_cookies.asp
+ var c_value = document.cookie;
+ var c_start = c_value.indexOf(" " + c_name + "=");
+ if (c_start == -1) {
+ c_start = c_value.indexOf(c_name + "=");
+ }
+ if (c_start == -1) {
+ c_value = null;
+ } else {
+ c_start = c_value.indexOf("=", c_start) + 1;
+ var c_end = c_value.indexOf(";", c_start);
+ if (c_end == -1) {
+ c_end = c_value.length;
+ }
+ c_value = unescape(c_value.substring(c_start,c_end));
+ }
+ return c_value;
+}
+
+// JSON highlighting.
prettyPrint();
+// Bootstrap tooltips.
$('.js-tooltip').tooltip({
delay: 1000
});
+// Deal with rounded tab styling after tab clicks.
$('a[data-toggle="tab"]:first').on('shown', function (e) {
$(e.target).parents('.tabbable').addClass('first-tab-active');
});
$('a[data-toggle="tab"]:not(:first)').on('shown', function (e) {
$(e.target).parents('.tabbable').removeClass('first-tab-active');
});
-$('.form-switcher a:first').tab('show');
+
+$('a[data-toggle="tab"]').click(function(){
+ document.cookie="tabstyle=" + this.name + "; path=/";
+});
+
+// Store tab preference in cookies & display appropriate tab on load.
+var selectedTab = null;
+var selectedTabName = getCookie('tabstyle');
+
+if (selectedTabName) {
+ selectedTab = $('.form-switcher a[name=' + selectedTabName + ']');
+}
+
+if (selectedTab && selectedTab.length > 0) {
+ // Display whichever tab is selected.
+ selectedTab.tab('show');
+} else {
+ // If no tab selected, display rightmost tab.
+ $('.form-switcher a:first').tab('show');
+}
diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html
index 51f9c291..aa90e90c 100644
--- a/rest_framework/templates/rest_framework/base.html
+++ b/rest_framework/templates/rest_framework/base.html
@@ -128,17 +128,17 @@
<div {% if post_form %}class="tabbable"{% endif %}>
{% if post_form %}
<ul class="nav nav-tabs form-switcher">
- <li><a href="#object-form" data-toggle="tab">HTML form</a></li>
- <li><a href="#generic-content-form" data-toggle="tab">Raw data</a></li>
+ <li><a name='html-tab' href="#object-form" data-toggle="tab">HTML form</a></li>
+ <li><a name='raw-tab' href="#generic-content-form" data-toggle="tab">Raw data</a></li>
</ul>
{% endif %}
<div class="well tab-content">
{% if post_form %}
<div class="tab-pane" id="object-form">
{% with form=post_form %}
- <form action="{{ request.get_full_path }}" method="POST" {% if form.is_multipart %}enctype="multipart/form-data"{% endif %} class="form-horizontal">
+ <form action="{{ request.get_full_path }}" method="POST" enctype="multipart/form-data" class="form-horizontal">
<fieldset>
- {% include "rest_framework/form.html" %}
+ {{ post_form }}
<div class="form-actions">
<button class="btn btn-primary" title="Make a POST request on the {{ name }} resource">POST</button>
</div>
@@ -167,23 +167,21 @@
<div {% if put_form %}class="tabbable"{% endif %}>
{% if put_form %}
<ul class="nav nav-tabs form-switcher">
- <li><a href="#object-form" data-toggle="tab">HTML form</a></li>
- <li><a href="#generic-content-form" data-toggle="tab">Raw data</a></li>
+ <li><a name='html-tab' href="#object-form" data-toggle="tab">HTML form</a></li>
+ <li><a name='raw-tab' href="#generic-content-form" data-toggle="tab">Raw data</a></li>
</ul>
{% endif %}
<div class="well tab-content">
{% if put_form %}
<div class="tab-pane" id="object-form">
- {% with form=put_form %}
- <form action="{{ request.get_full_path }}" method="POST" {% if form.is_multipart %}enctype="multipart/form-data"{% endif %} class="form-horizontal">
+ <form action="{{ request.get_full_path }}" method="POST" enctype="multipart/form-data" class="form-horizontal">
<fieldset>
- {% include "rest_framework/form.html" %}
+ {{ put_form }}
<div class="form-actions">
<button class="btn btn-primary js-tooltip" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="PUT" title="Make a PUT request on the {{ name }} resource">PUT</button>
</div>
</fieldset>
</form>
- {% endwith %}
</div>
{% endif %}
<div {% if put_form %}class="tab-pane"{% endif %} id="generic-content-form">
diff --git a/rest_framework/test.py b/rest_framework/test.py
index a18f5a29..234d10a4 100644
--- a/rest_framework/test.py
+++ b/rest_framework/test.py
@@ -134,6 +134,8 @@ class APIClient(APIRequestFactory, DjangoClient):
"""
self.handler._force_user = user
self.handler._force_token = token
+ if user is None:
+ self.logout() # Also clear any possible session info if required
def request(self, **kwargs):
# Ensure that any credentials set get added to every request.
diff --git a/rest_framework/tests/test_fields.py b/rest_framework/tests/test_fields.py
index ebccba7d..34fbab9c 100644
--- a/rest_framework/tests/test_fields.py
+++ b/rest_framework/tests/test_fields.py
@@ -688,6 +688,14 @@ class ChoiceFieldTests(TestCase):
f = serializers.ChoiceField(required=False, choices=self.SAMPLE_CHOICES)
self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + self.SAMPLE_CHOICES)
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns None on empty param.
+ """
+ f = serializers.ChoiceField(choices=self.SAMPLE_CHOICES)
+ result = f.from_native('')
+ self.assertEqual(result, None)
+
class EmailFieldTests(TestCase):
"""
diff --git a/rest_framework/tests/test_files.py b/rest_framework/tests/test_files.py
index 487046ac..c13c38b8 100644
--- a/rest_framework/tests/test_files.py
+++ b/rest_framework/tests/test_files.py
@@ -7,13 +7,13 @@ import datetime
class UploadedFile(object):
- def __init__(self, file, created=None):
+ def __init__(self, file=None, created=None):
self.file = file
self.created = created or datetime.datetime.now()
class UploadedFileSerializer(serializers.Serializer):
- file = serializers.FileField()
+ file = serializers.FileField(required=False)
created = serializers.DateTimeField()
def restore_object(self, attrs, instance=None):
@@ -47,5 +47,36 @@ class FileSerializerTests(TestCase):
now = datetime.datetime.now()
serializer = UploadedFileSerializer(data={'created': now})
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.object.created, now)
+ self.assertIsNone(serializer.object.file)
+
+ def test_remove_with_empty_string(self):
+ """
+ Passing empty string as data should cause file to be removed
+
+ Test for:
+ https://github.com/tomchristie/django-rest-framework/issues/937
+ """
+ now = datetime.datetime.now()
+ file = BytesIO(six.b('stuff'))
+ file.name = 'stuff.txt'
+ file.size = len(file.getvalue())
+
+ uploaded_file = UploadedFile(file=file, created=now)
+
+ serializer = UploadedFileSerializer(instance=uploaded_file, data={'created': now, 'file': ''})
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.object.created, uploaded_file.created)
+ self.assertIsNone(serializer.object.file)
+
+ def test_validation_error_with_non_file(self):
+ """
+ Passing non-files should raise a validation error.
+ """
+ now = datetime.datetime.now()
+ errmsg = 'No file was submitted. Check the encoding type on the form.'
+
+ serializer = UploadedFileSerializer(data={'created': now, 'file': 'abc'})
self.assertFalse(serializer.is_valid())
- self.assertIn('file', serializer.errors)
+ self.assertEqual(serializer.errors, {'file': [errmsg]})
diff --git a/rest_framework/tests/test_generics.py b/rest_framework/tests/test_generics.py
index 1550880b..7a87d389 100644
--- a/rest_framework/tests/test_generics.py
+++ b/rest_framework/tests/test_generics.py
@@ -338,6 +338,17 @@ class TestInstanceView(TestCase):
new_obj = SlugBasedModel.objects.get(slug='test_slug')
self.assertEqual(new_obj.text, 'foobar')
+ def test_patch_cannot_create_an_object(self):
+ """
+ PATCH requests should not be able to create objects.
+ """
+ data = {'text': 'foobar'}
+ request = factory.patch('/999', data, format='json')
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=999).render()
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+ self.assertFalse(self.objects.filter(id=999).exists())
+
class TestOverriddenGetObject(TestCase):
"""
diff --git a/rest_framework/tests/test_pagination.py b/rest_framework/tests/test_pagination.py
index 85d4640e..4170d4b6 100644
--- a/rest_framework/tests/test_pagination.py
+++ b/rest_framework/tests/test_pagination.py
@@ -42,6 +42,16 @@ class PaginateByParamView(generics.ListAPIView):
paginate_by_param = 'page_size'
+class MaxPaginateByView(generics.ListAPIView):
+ """
+ View for testing custom max_paginate_by usage
+ """
+ model = BasicModel
+ paginate_by = 3
+ max_paginate_by = 5
+ paginate_by_param = 'page_size'
+
+
class IntegrationTestPagination(TestCase):
"""
Integration tests for paginated list views.
@@ -313,6 +323,43 @@ class TestCustomPaginateByParam(TestCase):
self.assertEqual(response.data['results'], self.data[:5])
+class TestMaxPaginateByParam(TestCase):
+ """
+ Tests for list views with max_paginate_by kwarg
+ """
+
+ def setUp(self):
+ """
+ Create 13 BasicModel instances.
+ """
+ for i in range(13):
+ BasicModel(text=i).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.view = MaxPaginateByView.as_view()
+
+ def test_max_paginate_by(self):
+ """
+ If max_paginate_by is set, it should limit page size for the view.
+ """
+ request = factory.get('/?page_size=10')
+ response = self.view(request).render()
+ self.assertEqual(response.data['count'], 13)
+ self.assertEqual(response.data['results'], self.data[:5])
+
+ def test_max_paginate_by_without_page_size_param(self):
+ """
+ If max_paginate_by is set, but client does not specifiy page_size,
+ standard `paginate_by` behavior should be used.
+ """
+ request = factory.get('/')
+ response = self.view(request).render()
+ self.assertEqual(response.data['results'], self.data[:3])
+
+
### Tests for context in pagination serializers
class CustomField(serializers.Field):
diff --git a/rest_framework/tests/test_relations_nested.py b/rest_framework/tests/test_relations_nested.py
index f6d006b3..d393b0c3 100644
--- a/rest_framework/tests/test_relations_nested.py
+++ b/rest_framework/tests/test_relations_nested.py
@@ -1,107 +1,328 @@
from __future__ import unicode_literals
+from django.db import models
from django.test import TestCase
from rest_framework import serializers
-from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
-class ForeignKeySourceSerializer(serializers.ModelSerializer):
- class Meta:
- model = ForeignKeySource
- fields = ('id', 'name', 'target')
- depth = 1
+class OneToOneTarget(models.Model):
+ name = models.CharField(max_length=100)
-class ForeignKeyTargetSerializer(serializers.ModelSerializer):
- class Meta:
- model = ForeignKeyTarget
- fields = ('id', 'name', 'sources')
- depth = 1
+class OneToOneSource(models.Model):
+ name = models.CharField(max_length=100)
+ target = models.OneToOneField(OneToOneTarget, related_name='source',
+ null=True, blank=True)
-class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
- class Meta:
- model = NullableForeignKeySource
- fields = ('id', 'name', 'target')
- depth = 1
+class OneToManyTarget(models.Model):
+ name = models.CharField(max_length=100)
-class NullableOneToOneTargetSerializer(serializers.ModelSerializer):
- class Meta:
- model = OneToOneTarget
- fields = ('id', 'name', 'nullable_source')
- depth = 1
+class OneToManySource(models.Model):
+ name = models.CharField(max_length=100)
+ target = models.ForeignKey(OneToManyTarget, related_name='sources')
-class ReverseForeignKeyTests(TestCase):
+class ReverseNestedOneToOneTests(TestCase):
def setUp(self):
- target = ForeignKeyTarget(name='target-1')
- target.save()
- new_target = ForeignKeyTarget(name='target-2')
- new_target.save()
+ class OneToOneSourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = OneToOneSource
+ fields = ('id', 'name')
+
+ class OneToOneTargetSerializer(serializers.ModelSerializer):
+ source = OneToOneSourceSerializer()
+
+ class Meta:
+ model = OneToOneTarget
+ fields = ('id', 'name', 'source')
+
+ self.Serializer = OneToOneTargetSerializer
+
for idx in range(1, 4):
- source = ForeignKeySource(name='source-%d' % idx, target=target)
+ target = OneToOneTarget(name='target-%d' % idx)
+ target.save()
+ source = OneToOneSource(name='source-%d' % idx, target=target)
source.save()
- def test_foreign_key_retrieve(self):
- queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset, many=True)
+ def test_one_to_one_retrieve(self):
+ queryset = OneToOneTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
- {'id': 2, 'name': 'source-2', 'target': {'id': 1, 'name': 'target-1'}},
- {'id': 3, 'name': 'source-3', 'target': {'id': 1, 'name': 'target-1'}},
+ {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
+ {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
+ {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}}
]
self.assertEqual(serializer.data, expected)
- def test_reverse_foreign_key_retrieve(self):
- queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ def test_one_to_one_create(self):
+ data = {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}}
+ serializer = self.Serializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-4')
+
+ # Ensure (target 4, target_source 4, source 4) are added, and
+ # everything else is as expected.
+ queryset = OneToOneTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'target-1', 'sources': [
- {'id': 1, 'name': 'source-1', 'target': 1},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': 1},
- ]},
- {'id': 2, 'name': 'target-2', 'sources': [
- ]}
+ {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
+ {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
+ {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}},
+ {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}}
]
self.assertEqual(serializer.data, expected)
+ def test_one_to_one_create_with_invalid_data(self):
+ data = {'id': 4, 'name': 'target-4', 'source': {'id': 4}}
+ serializer = self.Serializer(data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'source': [{'name': ['This field is required.']}]})
-class NestedNullableForeignKeyTests(TestCase):
+ def test_one_to_one_update(self):
+ data = {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}
+ instance = OneToOneTarget.objects.get(pk=3)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-3-updated')
+
+ # Ensure (target 3, target_source 3, source 3) are updated,
+ # and everything else is as expected.
+ queryset = OneToOneTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
+ {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
+ {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+
+class ForwardNestedOneToOneTests(TestCase):
def setUp(self):
- target = ForeignKeyTarget(name='target-1')
- target.save()
+ class OneToOneTargetSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = OneToOneTarget
+ fields = ('id', 'name')
+
+ class OneToOneSourceSerializer(serializers.ModelSerializer):
+ target = OneToOneTargetSerializer()
+
+ class Meta:
+ model = OneToOneSource
+ fields = ('id', 'name', 'target')
+
+ self.Serializer = OneToOneSourceSerializer
+
for idx in range(1, 4):
- if idx == 3:
- target = None
- source = NullableForeignKeySource(name='source-%d' % idx, target=target)
+ target = OneToOneTarget(name='target-%d' % idx)
+ target.save()
+ source = OneToOneSource(name='source-%d' % idx, target=target)
source.save()
- def test_foreign_key_retrieve_with_null(self):
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ def test_one_to_one_retrieve(self):
+ queryset = OneToOneSource.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
+ {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
+ {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_one_create(self):
+ data = {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}}
+ serializer = self.Serializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure (target 4, target_source 4, source 4) are added, and
+ # everything else is as expected.
+ queryset = OneToOneSource.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
+ {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
+ {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}},
+ {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_one_create_with_invalid_data(self):
+ data = {'id': 4, 'name': 'source-4', 'target': {'id': 4}}
+ serializer = self.Serializer(data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'target': [{'name': ['This field is required.']}]})
+
+ def test_one_to_one_update(self):
+ data = {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}}
+ instance = OneToOneSource.objects.get(pk=3)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-3-updated')
+
+ # Ensure (target 3, target_source 3, source 3) are updated,
+ # and everything else is as expected.
+ queryset = OneToOneSource.objects.all()
+ serializer = self.Serializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
- {'id': 2, 'name': 'source-2', 'target': {'id': 1, 'name': 'target-1'}},
- {'id': 3, 'name': 'source-3', 'target': None},
+ {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
+ {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}}
]
self.assertEqual(serializer.data, expected)
+ def test_one_to_one_update_to_null(self):
+ data = {'id': 3, 'name': 'source-3-updated', 'target': None}
+ instance = OneToOneSource.objects.get(pk=3)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
-class NestedNullableOneToOneTests(TestCase):
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-3-updated')
+ self.assertEqual(obj.target, None)
+
+ queryset = OneToOneSource.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
+ {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
+ {'id': 3, 'name': 'source-3-updated', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ # TODO: Nullable 1-1 tests
+ # def test_one_to_one_delete(self):
+ # data = {'id': 3, 'name': 'target-3', 'target_source': None}
+ # instance = OneToOneTarget.objects.get(pk=3)
+ # serializer = self.Serializer(instance, data=data)
+ # self.assertTrue(serializer.is_valid())
+ # serializer.save()
+
+ # # Ensure (target_source 3, source 3) are deleted,
+ # # and everything else is as expected.
+ # queryset = OneToOneTarget.objects.all()
+ # serializer = self.Serializer(queryset)
+ # expected = [
+ # {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
+ # {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
+ # {'id': 3, 'name': 'target-3', 'source': None}
+ # ]
+ # self.assertEqual(serializer.data, expected)
+
+
+class ReverseNestedOneToManyTests(TestCase):
def setUp(self):
- target = OneToOneTarget(name='target-1')
+ class OneToManySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = OneToManySource
+ fields = ('id', 'name')
+
+ class OneToManyTargetSerializer(serializers.ModelSerializer):
+ sources = OneToManySourceSerializer(many=True, allow_add_remove=True)
+
+ class Meta:
+ model = OneToManyTarget
+ fields = ('id', 'name', 'sources')
+
+ self.Serializer = OneToManyTargetSerializer
+
+ target = OneToManyTarget(name='target-1')
target.save()
- new_target = OneToOneTarget(name='target-2')
- new_target.save()
- source = NullableOneToOneSource(name='source-1', target=target)
- source.save()
+ for idx in range(1, 4):
+ source = OneToManySource(name='source-%d' % idx, target=target)
+ source.save()
- def test_reverse_foreign_key_retrieve_with_null(self):
- queryset = OneToOneTarget.objects.all()
- serializer = NullableOneToOneTargetSerializer(queryset, many=True)
+ def test_one_to_many_retrieve(self):
+ queryset = OneToManyTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'}]},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_many_create(self):
+ data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'},
+ {'id': 4, 'name': 'source-4'}]}
+ instance = OneToManyTarget.objects.get(pk=1)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-1')
+
+ # Ensure source 4 is added, and everything else is as
+ # expected.
+ queryset = OneToManyTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'target-1', 'nullable_source': {'id': 1, 'name': 'source-1', 'target': 1}},
- {'id': 2, 'name': 'target-2', 'nullable_source': None},
+ {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'},
+ {'id': 4, 'name': 'source-4'}]}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_many_create_with_invalid_data(self):
+ data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'},
+ {'id': 4}]}
+ serializer = self.Serializer(data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'sources': [{}, {}, {}, {'name': ['This field is required.']}]})
+
+ def test_one_to_many_update(self):
+ data = {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'}]}
+ instance = OneToManyTarget.objects.get(pk=1)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-1-updated')
+
+ # Ensure (target 1, source 1) are updated,
+ # and everything else is as expected.
+ queryset = OneToManyTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'}]}
+
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_many_delete(self):
+ data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 3, 'name': 'source-3'}]}
+ instance = OneToManyTarget.objects.get(pk=1)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+
+ # Ensure source 2 is deleted, and everything else is as
+ # expected.
+ queryset = OneToManyTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 3, 'name': 'source-3'}]}
+
]
self.assertEqual(serializer.data, expected)
diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py
index c3597e38..3f456fef 100644
--- a/rest_framework/tests/test_routers.py
+++ b/rest_framework/tests/test_routers.py
@@ -146,7 +146,7 @@ class TestTrailingSlashRemoved(TestCase):
self.urls = self.router.urls
def test_urls_can_have_trailing_slash_removed(self):
- expected = ['^notes$', '^notes/(?P<pk>[^/]+)$']
+ expected = ['^notes$', '^notes/(?P<pk>[^/.]+)$']
for idx in range(len(expected)):
self.assertEqual(expected[idx], self.urls[idx].regex.pattern)
diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py
index 49d45fc2..48b8956b 100644
--- a/rest_framework/tests/test_testing.py
+++ b/rest_framework/tests/test_testing.py
@@ -17,8 +17,18 @@ def view(request):
})
+@api_view(['GET', 'POST'])
+def session_view(request):
+ active_session = request.session.get('active_session', False)
+ request.session['active_session'] = True
+ return Response({
+ 'active_session': active_session
+ })
+
+
urlpatterns = patterns('',
url(r'^view/$', view),
+ url(r'^session-view/$', session_view),
)
@@ -46,6 +56,26 @@ class TestAPITestClient(TestCase):
response = self.client.get('/view/')
self.assertEqual(response.data['user'], 'example')
+ def test_force_authenticate_with_sessions(self):
+ """
+ Setting `.force_authenticate()` forcibly authenticates each request.
+ """
+ user = User.objects.create_user('example', 'example@example.com')
+ self.client.force_authenticate(user)
+
+ # First request does not yet have an active session
+ response = self.client.get('/session-view/')
+ self.assertEqual(response.data['active_session'], False)
+
+ # Subsequant requests have an active session
+ response = self.client.get('/session-view/')
+ self.assertEqual(response.data['active_session'], True)
+
+ # Force authenticating as `None` should also logout the user session.
+ self.client.force_authenticate(None)
+ response = self.client.get('/session-view/')
+ self.assertEqual(response.data['active_session'], False)
+
def test_csrf_exempt_by_default(self):
"""
By default, the test client is CSRF exempt.
diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py
index 65b45593..a946d837 100644
--- a/rest_framework/throttling.py
+++ b/rest_framework/throttling.py
@@ -2,7 +2,7 @@
Provides various throttling policies.
"""
from __future__ import unicode_literals
-from django.core.cache import cache
+from django.core.cache import cache as default_cache
from django.core.exceptions import ImproperlyConfigured
from rest_framework.settings import api_settings
import time
@@ -39,6 +39,7 @@ class SimpleRateThrottle(BaseThrottle):
Previous request information used for throttling is stored in the cache.
"""
+ cache = default_cache
timer = time.time
cache_format = 'throtte_%(scope)s_%(ident)s'
scope = None
@@ -99,7 +100,7 @@ class SimpleRateThrottle(BaseThrottle):
if self.key is None:
return True
- self.history = cache.get(self.key, [])
+ self.history = self.cache.get(self.key, [])
self.now = self.timer()
# Drop any requests from the history which have now passed the
@@ -116,7 +117,7 @@ class SimpleRateThrottle(BaseThrottle):
into the cache.
"""
self.history.insert(0, self.now)
- cache.set(self.key, self.history, self.duration)
+ self.cache.set(self.key, self.history, self.duration)
return True
def throttle_failure(self):
@@ -151,7 +152,9 @@ class AnonRateThrottle(SimpleRateThrottle):
if request.user.is_authenticated():
return None # Only throttle unauthenticated requests.
- ident = request.META.get('REMOTE_ADDR', None)
+ ident = request.META.get('HTTP_X_FORWARDED_FOR')
+ if ident is None:
+ ident = request.META.get('REMOTE_ADDR')
return self.cache_format % {
'scope': self.scope,
diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py
index 0384faba..e6690d17 100644
--- a/rest_framework/utils/breadcrumbs.py
+++ b/rest_framework/utils/breadcrumbs.py
@@ -8,8 +8,11 @@ def get_breadcrumbs(url):
tuple of (name, url).
"""
+ from rest_framework.settings import api_settings
from rest_framework.views import APIView
+ view_name_func = api_settings.VIEW_NAME_FUNCTION
+
def breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen):
"""
Add tuples of (name, url) to the breadcrumbs list,
@@ -28,8 +31,8 @@ def get_breadcrumbs(url):
# Don't list the same view twice in a row.
# Probably an optional trailing slash.
if not seen or seen[-1] != view:
- instance = view.cls()
- name = instance.get_view_name()
+ suffix = getattr(view, 'suffix', None)
+ name = view_name_func(cls, suffix)
breadcrumbs_list.insert(0, (name, prefix + url))
seen.append(view)
diff --git a/rest_framework/views.py b/rest_framework/views.py
index 727a9f95..4cff0422 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -15,8 +15,14 @@ from rest_framework.settings import api_settings
from rest_framework.utils import formatting
-def get_view_name(cls, suffix=None):
- name = cls.__name__
+def get_view_name(view_cls, suffix=None):
+ """
+ Given a view class, return a textual name to represent the view.
+ This name is used in the browsable API, and in OPTIONS responses.
+
+ This function is the default for the `VIEW_NAME_FUNCTION` setting.
+ """
+ name = view_cls.__name__
name = formatting.remove_trailing_string(name, 'View')
name = formatting.remove_trailing_string(name, 'ViewSet')
name = formatting.camelcase_to_spaces(name)
@@ -25,17 +31,56 @@ def get_view_name(cls, suffix=None):
return name
-def get_view_description(cls, html=False):
- description = cls.__doc__ or ''
+def get_view_description(view_cls, html=False):
+ """
+ Given a view class, return a textual description to represent the view.
+ This name is used in the browsable API, and in OPTIONS responses.
+
+ This function is the default for the `VIEW_DESCRIPTION_FUNCTION` setting.
+ """
+ description = view_cls.__doc__ or ''
description = formatting.dedent(smart_text(description))
if html:
return formatting.markup_description(description)
return description
+def exception_handler(exc):
+ """
+ Returns the response that should be used for any given exception.
+
+ By default we handle the REST framework `APIException`, and also
+ Django's builtin `Http404` and `PermissionDenied` exceptions.
+
+ Any unhandled exceptions may return `None`, which will cause a 500 error
+ to be raised.
+ """
+ if isinstance(exc, exceptions.APIException):
+ headers = {}
+ if getattr(exc, 'auth_header', None):
+ headers['WWW-Authenticate'] = exc.auth_header
+ if getattr(exc, 'wait', None):
+ headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait
+
+ return Response({'detail': exc.detail},
+ status=exc.status_code,
+ headers=headers)
+
+ elif isinstance(exc, Http404):
+ return Response({'detail': 'Not found'},
+ status=status.HTTP_404_NOT_FOUND)
+
+ elif isinstance(exc, PermissionDenied):
+ return Response({'detail': 'Permission denied'},
+ status=status.HTTP_403_FORBIDDEN)
+
+ # Note: Unhandled exceptions will raise a 500 error.
+ return None
+
+
class APIView(View):
- settings = api_settings
+ # The following policies may be set at either globally, or per-view.
renderer_classes = api_settings.DEFAULT_RENDERER_CLASSES
parser_classes = api_settings.DEFAULT_PARSER_CLASSES
authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES
@@ -43,6 +88,9 @@ class APIView(View):
permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES
content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS
+ # Allow dependancy injection of other settings to make testing easier.
+ settings = api_settings
+
@classmethod
def as_view(cls, **initkwargs):
"""
@@ -133,7 +181,7 @@ class APIView(View):
Return the view name, as used in OPTIONS responses and in the
browsable API.
"""
- func = api_settings.VIEW_NAME_FUNCTION
+ func = self.settings.VIEW_NAME_FUNCTION
return func(self.__class__, getattr(self, 'suffix', None))
def get_view_description(self, html=False):
@@ -141,7 +189,7 @@ class APIView(View):
Return some descriptive text for the view, as used in OPTIONS responses
and in the browsable API.
"""
- func = api_settings.VIEW_DESCRIPTION_FUNCTION
+ func = self.settings.VIEW_DESCRIPTION_FUNCTION
return func(self.__class__, html)
# API policy instantiation methods
@@ -303,33 +351,23 @@ class APIView(View):
Handle any exception that occurs, by returning an appropriate response,
or re-raising the error.
"""
- if isinstance(exc, exceptions.Throttled) and exc.wait is not None:
- # Throttle wait header
- self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait
-
if isinstance(exc, (exceptions.NotAuthenticated,
exceptions.AuthenticationFailed)):
# WWW-Authenticate header for 401 responses, else coerce to 403
auth_header = self.get_authenticate_header(self.request)
if auth_header:
- self.headers['WWW-Authenticate'] = auth_header
+ exc.auth_header = auth_header
else:
exc.status_code = status.HTTP_403_FORBIDDEN
- if isinstance(exc, exceptions.APIException):
- return Response({'detail': exc.detail},
- status=exc.status_code,
- exception=True)
- elif isinstance(exc, Http404):
- return Response({'detail': 'Not found'},
- status=status.HTTP_404_NOT_FOUND,
- exception=True)
- elif isinstance(exc, PermissionDenied):
- return Response({'detail': 'Permission denied'},
- status=status.HTTP_403_FORBIDDEN,
- exception=True)
- raise
+ response = exception_handler(exc)
+
+ if response is None:
+ raise
+
+ response.exception = True
+ return response
# Note: session based authentication is explicitly CSRF validated,
# all other authentication is CSRF exempt.