aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/__init__.py18
-rw-r--r--rest_framework/compat.py23
-rw-r--r--rest_framework/fields.py40
-rw-r--r--rest_framework/filters.py5
-rw-r--r--rest_framework/generics.py37
-rw-r--r--rest_framework/mixins.py13
-rw-r--r--rest_framework/parsers.py4
-rw-r--r--rest_framework/permissions.py7
-rw-r--r--rest_framework/renderers.py90
-rw-r--r--rest_framework/request.py13
-rw-r--r--rest_framework/response.py4
-rw-r--r--rest_framework/serializers.py77
-rw-r--r--rest_framework/status.py17
-rw-r--r--rest_framework/templates/rest_framework/base.html9
-rw-r--r--rest_framework/templates/rest_framework/form.html12
-rw-r--r--rest_framework/templates/rest_framework/raw_data_form.html12
-rw-r--r--rest_framework/tests/test_fields.py92
-rw-r--r--rest_framework/tests/test_files.py13
-rw-r--r--rest_framework/tests/test_filters.py39
-rw-r--r--rest_framework/tests/test_generics.py60
-rw-r--r--rest_framework/tests/test_pagination.py85
-rw-r--r--rest_framework/tests/test_permissions.py4
-rw-r--r--rest_framework/tests/test_renderers.py85
-rw-r--r--rest_framework/tests/test_request.py32
-rw-r--r--rest_framework/tests/test_serializer.py104
-rw-r--r--rest_framework/tests/test_serializer_empty.py15
-rw-r--r--rest_framework/tests/test_serializer_nested.py102
-rw-r--r--rest_framework/tests/test_status.py33
-rw-r--r--rest_framework/urlpatterns.py2
-rw-r--r--rest_framework/utils/encoders.py8
-rw-r--r--rest_framework/views.py4
-rw-r--r--rest_framework/viewsets.py2
32 files changed, 899 insertions, 162 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py
index 2bd2991b..f5483b9d 100644
--- a/rest_framework/__init__.py
+++ b/rest_framework/__init__.py
@@ -1,6 +1,20 @@
-__version__ = '2.3.8'
+"""
+______ _____ _____ _____ __ _
+| ___ \ ___/ ___|_ _| / _| | |
+| |_/ / |__ \ `--. | | | |_ _ __ __ _ _ __ ___ _____ _____ _ __| | __
+| /| __| `--. \ | | | _| '__/ _` | '_ ` _ \ / _ \ \ /\ / / _ \| '__| |/ /
+| |\ \| |___/\__/ / | | | | | | | (_| | | | | | | __/\ V V / (_) | | | <
+\_| \_\____/\____/ \_/ |_| |_| \__,_|_| |_| |_|\___| \_/\_/ \___/|_| |_|\_|
+"""
-VERSION = __version__ # synonym
+__title__ = 'Django REST framework'
+__version__ = '2.3.10'
+__author__ = 'Tom Christie'
+__license__ = 'BSD 2-Clause'
+__copyright__ = 'Copyright 2011-2013 Tom Christie'
+
+# Version synonym
+VERSION = __version__
# Header encoding (see RFC5987)
HTTP_HEADER_ENCODING = 'iso-8859-1'
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index efd2581f..b4d37ab8 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -65,6 +65,13 @@ try:
except ImportError:
import urlparse
+# UserDict moves in Python 3
+try:
+ from UserDict import UserDict
+ from UserDict import DictMixin
+except ImportError:
+ from collections import UserDict
+ from collections import MutableMapping as DictMixin
# Try to import PIL in either of the two ways it can end up installed.
try:
@@ -76,6 +83,22 @@ except ImportError:
Image = None
+def get_model_name(model_cls):
+ try:
+ return model_cls._meta.model_name
+ except AttributeError:
+ # < 1.6 used module_name instead of model_name
+ return model_cls._meta.module_name
+
+
+def get_concrete_model(model_cls):
+ try:
+ return model_cls._meta.concrete_model
+ except AttributeError:
+ # 1.3 does not include concrete model
+ return model_cls
+
+
# Django 1.5 add support for custom auth user model
if django.VERSION >= (1, 5):
AUTH_USER_MODEL = settings.AUTH_USER_MODEL
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index f340510d..65edd0d6 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -125,6 +125,7 @@ class Field(object):
use_files = False
form_field_class = forms.CharField
type_label = 'field'
+ widget = None
def __init__(self, source=None, label=None, help_text=None):
self.parent = None
@@ -136,9 +137,29 @@ class Field(object):
if label is not None:
self.label = smart_text(label)
+ else:
+ self.label = None
if help_text is not None:
self.help_text = strip_multiple_choice_msg(smart_text(help_text))
+ else:
+ self.help_text = None
+
+ self._errors = []
+ self._value = None
+ self._name = None
+
+ @property
+ def errors(self):
+ return self._errors
+
+ def widget_html(self):
+ if not self.widget:
+ return ''
+ return self.widget.render(self._name, self._value)
+
+ def label_tag(self):
+ return '<label for="%s">%s:</label>' % (self._name, self.label)
def initialize(self, parent, field_name):
"""
@@ -301,6 +322,7 @@ class WritableField(Field):
return
try:
+ data = data or {}
if self.use_files:
files = files or {}
try:
@@ -470,6 +492,7 @@ class ChoiceField(WritableField):
}
def __init__(self, choices=(), *args, **kwargs):
+ self.empty = kwargs.pop('empty', '')
super(ChoiceField, self).__init__(*args, **kwargs)
self.choices = choices
if not self.required:
@@ -486,6 +509,11 @@ class ChoiceField(WritableField):
choices = property(_get_choices, _set_choices)
+ def metadata(self):
+ data = super(ChoiceField, self).metadata()
+ data['choices'] = [{'value': v, 'display_name': n} for v, n in self.choices]
+ return data
+
def validate(self, value):
"""
Validates that the input is in self.choices.
@@ -510,9 +538,10 @@ class ChoiceField(WritableField):
return False
def from_native(self, value):
- if value in validators.EMPTY_VALUES:
- return None
- return super(ChoiceField, self).from_native(value)
+ value = super(ChoiceField, self).from_native(value)
+ if value == self.empty or value in validators.EMPTY_VALUES:
+ return self.empty
+ return value
class EmailField(CharField):
@@ -751,6 +780,7 @@ class IntegerField(WritableField):
type_name = 'IntegerField'
type_label = 'integer'
form_field_class = forms.IntegerField
+ empty = 0
default_error_messages = {
'invalid': _('Enter a whole number.'),
@@ -782,6 +812,7 @@ class FloatField(WritableField):
type_name = 'FloatField'
type_label = 'float'
form_field_class = forms.FloatField
+ empty = 0
default_error_messages = {
'invalid': _("'%s' value must be a float."),
@@ -802,6 +833,7 @@ class DecimalField(WritableField):
type_name = 'DecimalField'
type_label = 'decimal'
form_field_class = forms.DecimalField
+ empty = Decimal('0')
default_error_messages = {
'invalid': _('Enter a number.'),
@@ -934,7 +966,7 @@ class ImageField(FileField):
return None
from rest_framework.compat import Image
- assert Image is not None, 'PIL must be installed for ImageField support'
+ assert Image is not None, 'Either Pillow or PIL must be installed for ImageField support.'
# We need to get a file object for PIL. We might have a path or we might
# have to read the data into memory.
diff --git a/rest_framework/filters.py b/rest_framework/filters.py
index b8fe7f77..5c6a187c 100644
--- a/rest_framework/filters.py
+++ b/rest_framework/filters.py
@@ -4,7 +4,7 @@ returned by list views.
"""
from __future__ import unicode_literals
from django.db import models
-from rest_framework.compat import django_filters, six, guardian
+from rest_framework.compat import django_filters, six, guardian, get_model_name
from functools import reduce
import operator
@@ -124,6 +124,7 @@ class OrderingFilter(BaseFilterBackend):
def remove_invalid_fields(self, queryset, ordering):
field_names = [field.name for field in queryset.model._meta.fields]
+ field_names += queryset.query.aggregates.keys()
return [term for term in ordering if term.lstrip('-') in field_names]
def filter_queryset(self, request, queryset, view):
@@ -158,7 +159,7 @@ class DjangoObjectPermissionsFilter(BaseFilterBackend):
model_cls = queryset.model
kwargs = {
'app_label': model_cls._meta.app_label,
- 'model_name': model_cls._meta.module_name
+ 'model_name': get_model_name(model_cls)
}
permission = self.perm_format % kwargs
return guardian.shortcuts.get_objects_for_user(user, permission, queryset)
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index 5fb37db7..bd33c01a 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -25,13 +25,13 @@ def strict_positive_int(integer_string, cutoff=None):
ret = min(ret, cutoff)
return ret
-def get_object_or_404(queryset, **filter_kwargs):
+def get_object_or_404(queryset, *filter_args, **filter_kwargs):
"""
Same as Django's standard shortcut, but make sure to raise 404
if the filter_kwargs don't match the required types.
"""
try:
- return _get_object_or_404(queryset, **filter_kwargs)
+ return _get_object_or_404(queryset, *filter_args, **filter_kwargs)
except (TypeError, ValueError):
raise Http404
@@ -54,6 +54,7 @@ class GenericAPIView(views.APIView):
# If you want to use object lookups other than pk, set this attribute.
# For more complex lookup requirements override `get_object()`.
lookup_field = 'pk'
+ lookup_url_kwarg = None
# Pagination settings
paginate_by = api_settings.PAGINATE_BY
@@ -147,8 +148,8 @@ class GenericAPIView(views.APIView):
page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg)
page = page_kwarg or page_query_param or 1
try:
- page_number = strict_positive_int(page)
- except ValueError:
+ page_number = paginator.validate_number(page)
+ except InvalidPage:
if page == 'last':
page_number = paginator.num_pages
else:
@@ -174,6 +175,14 @@ class GenericAPIView(views.APIView):
method if you want to apply the configured filtering backend to the
default queryset.
"""
+ for backend in self.get_filter_backends():
+ queryset = backend().filter_queryset(self.request, queryset, self)
+ return queryset
+
+ def get_filter_backends(self):
+ """
+ Returns the list of filter backends that this view requires.
+ """
filter_backends = self.filter_backends or []
if not filter_backends and self.filter_backend:
warnings.warn(
@@ -184,10 +193,8 @@ class GenericAPIView(views.APIView):
DeprecationWarning, stacklevel=2
)
filter_backends = [self.filter_backend]
+ return filter_backends
- for backend in filter_backends:
- queryset = backend().filter_queryset(self.request, queryset, self)
- return queryset
########################
### The following methods provide default implementations
@@ -278,9 +285,11 @@ class GenericAPIView(views.APIView):
pass # Deprecation warning
# Perform the lookup filtering.
+ # Note that `pk` and `slug` are deprecated styles of lookup filtering.
+ lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
+ lookup = self.kwargs.get(lookup_url_kwarg, None)
pk = self.kwargs.get(self.pk_url_kwarg, None)
slug = self.kwargs.get(self.slug_url_kwarg, None)
- lookup = self.kwargs.get(self.lookup_field, None)
if lookup is not None:
filter_kwargs = {self.lookup_field: lookup}
@@ -335,6 +344,18 @@ class GenericAPIView(views.APIView):
"""
pass
+ def pre_delete(self, obj):
+ """
+ Placeholder method for calling before deleting an object.
+ """
+ pass
+
+ def post_delete(self, obj):
+ """
+ Placeholder method for calling after saving an object.
+ """
+ pass
+
def metadata(self, request):
"""
Return a dictionary of metadata about the view.
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index 2c85d157..b62a4cc1 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -6,6 +6,7 @@ which allows mixin classes to be composed in interesting ways.
"""
from __future__ import unicode_literals
+from django.core.exceptions import ValidationError
from django.http import Http404
from rest_framework import status
from rest_framework.response import Response
@@ -127,7 +128,12 @@ class UpdateModelMixin(object):
files=request.FILES, partial=partial)
if serializer.is_valid():
- self.pre_save(serializer.object)
+ try:
+ self.pre_save(serializer.object)
+ except ValidationError as err:
+ # full_clean on model instance may be called in pre_save, so we
+ # have to handle eventual errors.
+ return Response(err.message_dict, status=status.HTTP_400_BAD_REQUEST)
self.object = serializer.save(**save_kwargs)
self.post_save(self.object, created=created)
return Response(serializer.data, status=success_status_code)
@@ -158,7 +164,8 @@ class UpdateModelMixin(object):
Set any attributes on the object that are implicit in the request.
"""
# pk and/or slug attributes are implicit in the URL.
- lookup = self.kwargs.get(self.lookup_field, None)
+ lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
+ lookup = self.kwargs.get(lookup_url_kwarg, None)
pk = self.kwargs.get(self.pk_url_kwarg, None)
slug = self.kwargs.get(self.slug_url_kwarg, None)
slug_field = slug and self.slug_field or None
@@ -185,5 +192,7 @@ class DestroyModelMixin(object):
"""
def destroy(self, request, *args, **kwargs):
obj = self.get_object()
+ self.pre_delete(obj)
obj.delete()
+ self.post_delete(obj)
return Response(status=status.HTTP_204_NO_CONTENT)
diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py
index 98fc0341..f1b3e38d 100644
--- a/rest_framework/parsers.py
+++ b/rest_framework/parsers.py
@@ -83,7 +83,7 @@ class YAMLParser(BaseParser):
data = stream.read().decode(encoding)
return yaml.safe_load(data)
except (ValueError, yaml.parser.ParserError) as exc:
- raise ParseError('YAML parse error - %s' % six.u(exc))
+ raise ParseError('YAML parse error - %s' % six.text_type(exc))
class FormParser(BaseParser):
@@ -153,7 +153,7 @@ class XMLParser(BaseParser):
try:
tree = etree.parse(stream, parser=parser, forbid_dtd=True)
except (etree.ParseError, ValueError) as exc:
- raise ParseError('XML parse error - %s' % six.u(exc))
+ raise ParseError('XML parse error - %s' % six.text_type(exc))
data = self._xml_convert(tree.getroot())
return data
diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py
index 14bec42c..d93dba19 100644
--- a/rest_framework/permissions.py
+++ b/rest_framework/permissions.py
@@ -3,7 +3,8 @@ Provides a set of pluggable permission policies.
"""
from __future__ import unicode_literals
from django.http import Http404
-from rest_framework.compat import oauth2_provider_scope, oauth2_constants
+from rest_framework.compat import (get_model_name, oauth2_provider_scope,
+ oauth2_constants)
SAFE_METHODS = ['GET', 'HEAD', 'OPTIONS']
@@ -106,7 +107,7 @@ class DjangoModelPermissions(BasePermission):
"""
kwargs = {
'app_label': model_cls._meta.app_label,
- 'model_name': model_cls._meta.module_name
+ 'model_name': get_model_name(model_cls)
}
return [perm % kwargs for perm in self.perms_map[method]]
@@ -167,7 +168,7 @@ class DjangoObjectPermissions(DjangoModelPermissions):
def get_required_object_permissions(self, method, model_cls):
kwargs = {
'app_label': model_cls._meta.app_label,
- 'model_name': model_cls._meta.module_name
+ 'model_name': get_model_name(model_cls)
}
return [perm % kwargs for perm in self.perms_map[method]]
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index 2ce51e97..2fdd3337 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -20,6 +20,7 @@ from rest_framework.compat import StringIO
from rest_framework.compat import six
from rest_framework.compat import smart_text
from rest_framework.compat import yaml
+from rest_framework.exceptions import ParseError
from rest_framework.settings import api_settings
from rest_framework.request import is_form_media_type, override_method
from rest_framework.utils import encoders
@@ -272,7 +273,9 @@ class TemplateHTMLRenderer(BaseRenderer):
return [self.template_name]
elif hasattr(view, 'get_template_names'):
return view.get_template_names()
- raise ImproperlyConfigured('Returned a template response with no template_name')
+ elif hasattr(view, 'template_name'):
+ return [view.template_name]
+ raise ImproperlyConfigured('Returned a template response with no `template_name` attribute set on either the view or response')
def get_exception_template(self, response):
template_names = [name % {'status_code': response.status_code}
@@ -334,71 +337,15 @@ class HTMLFormRenderer(BaseRenderer):
template = 'rest_framework/form.html'
charset = 'utf-8'
- def data_to_form_fields(self, data):
- fields = {}
- for key, val in data.fields.items():
- if getattr(val, 'read_only', True):
- # Don't include read-only fields.
- continue
-
- if getattr(val, 'fields', None):
- # Nested data not supported by HTML forms.
- continue
-
- kwargs = {}
- kwargs['required'] = val.required
-
- #if getattr(v, 'queryset', None):
- # kwargs['queryset'] = v.queryset
-
- if getattr(val, 'choices', None) is not None:
- kwargs['choices'] = val.choices
-
- if getattr(val, 'regex', None) is not None:
- kwargs['regex'] = val.regex
-
- if getattr(val, 'widget', None):
- widget = copy.deepcopy(val.widget)
- kwargs['widget'] = widget
-
- if getattr(val, 'default', None) is not None:
- kwargs['initial'] = val.default
-
- if getattr(val, 'label', None) is not None:
- kwargs['label'] = val.label
-
- if getattr(val, 'help_text', None) is not None:
- kwargs['help_text'] = val.help_text
-
- fields[key] = val.form_field_class(**kwargs)
-
- return fields
-
def render(self, data, accepted_media_type=None, renderer_context=None):
"""
Render serializer data and return an HTML form, as a string.
"""
- # The HTMLFormRenderer currently uses something of a hack to render
- # the content, by translating each of the serializer fields into
- # an html form field, creating a dynamic form using those fields,
- # and then rendering that form.
-
- # This isn't strictly neccessary, as we could render the serilizer
- # fields to HTML directly. The implementation is historical and will
- # likely change at some point.
-
- self.renderer_context = renderer_context or {}
+ renderer_context = renderer_context or {}
request = renderer_context['request']
- # Creating an on the fly form see:
- # http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python
- fields = self.data_to_form_fields(data)
- DynamicForm = type(str('DynamicForm'), (forms.Form,), fields)
- data = None if data.empty else data
-
template = loader.get_template(self.template)
- context = RequestContext(request, {'form': DynamicForm(data)})
-
+ context = RequestContext(request, {'form': data})
return template.render(context)
@@ -419,8 +366,13 @@ class BrowsableAPIRenderer(BaseRenderer):
"""
renderers = [renderer for renderer in view.renderer_classes
if not issubclass(renderer, BrowsableAPIRenderer)]
+ non_template_renderers = [renderer for renderer in renderers
+ if not hasattr(renderer, 'get_template_names')]
+
if not renderers:
return None
+ elif non_template_renderers:
+ return non_template_renderers[0]()
return renderers[0]()
def get_content(self, renderer, data,
@@ -468,6 +420,17 @@ class BrowsableAPIRenderer(BaseRenderer):
In the absence of the View having an associated form then return None.
"""
+ if request.method == method:
+ try:
+ data = request.DATA
+ files = request.FILES
+ except ParseError:
+ data = None
+ files = None
+ else:
+ data = None
+ files = None
+
with override_method(view, request, method) as request:
obj = getattr(view, 'object', None)
if not self.show_form_for_method(view, method, request, obj):
@@ -480,9 +443,10 @@ class BrowsableAPIRenderer(BaseRenderer):
or not any(is_form_media_type(parser.media_type) for parser in view.parser_classes)):
return
- serializer = view.get_serializer(instance=obj)
-
+ serializer = view.get_serializer(instance=obj, data=data, files=files)
+ serializer.is_valid()
data = serializer.data
+
form_renderer = self.form_renderer_class()
return form_renderer.render(data, self.accepted_media_type, self.renderer_context)
@@ -574,6 +538,7 @@ class BrowsableAPIRenderer(BaseRenderer):
renderer = self.get_default_renderer(view)
+ raw_data_post_form = self.get_raw_data_form(view, 'POST', request)
raw_data_put_form = self.get_raw_data_form(view, 'PUT', request)
raw_data_patch_form = self.get_raw_data_form(view, 'PATCH', request)
raw_data_put_or_patch_form = raw_data_put_form or raw_data_patch_form
@@ -592,12 +557,11 @@ class BrowsableAPIRenderer(BaseRenderer):
'put_form': self.get_rendered_html_form(view, 'PUT', request),
'post_form': self.get_rendered_html_form(view, 'POST', request),
- 'patch_form': self.get_rendered_html_form(view, 'PATCH', request),
'delete_form': self.get_rendered_html_form(view, 'DELETE', request),
'options_form': self.get_rendered_html_form(view, 'OPTIONS', request),
'raw_data_put_form': raw_data_put_form,
- 'raw_data_post_form': self.get_raw_data_form(view, 'POST', request),
+ 'raw_data_post_form': raw_data_post_form,
'raw_data_patch_form': raw_data_patch_form,
'raw_data_put_or_patch_form': raw_data_put_or_patch_form,
diff --git a/rest_framework/request.py b/rest_framework/request.py
index 977d4d96..fcea2508 100644
--- a/rest_framework/request.py
+++ b/rest_framework/request.py
@@ -334,7 +334,7 @@ class Request(object):
self._CONTENT_PARAM in self._data and
self._CONTENTTYPE_PARAM in self._data):
self._content_type = self._data[self._CONTENTTYPE_PARAM]
- self._stream = BytesIO(self._data[self._CONTENT_PARAM].encode(HTTP_HEADER_ENCODING))
+ self._stream = BytesIO(self._data[self._CONTENT_PARAM].encode(self.parser_context['encoding']))
self._data, self._files = (Empty, Empty)
def _parse(self):
@@ -356,7 +356,16 @@ class Request(object):
if not parser:
raise exceptions.UnsupportedMediaType(media_type)
- parsed = parser.parse(stream, media_type, self.parser_context)
+ try:
+ parsed = parser.parse(stream, media_type, self.parser_context)
+ except:
+ # If we get an exception during parsing, fill in empty data and
+ # re-raise. Ensures we don't simply repeat the error when
+ # attempting to render the browsable renderer response, or when
+ # logging the request or similar.
+ self._data = QueryDict('', self._request._encoding)
+ self._files = MultiValueDict()
+ raise
# Parser classes may return the raw data, or a
# DataAndFiles object. Unpack the result as required.
diff --git a/rest_framework/response.py b/rest_framework/response.py
index 5877c8a3..1dc6abcf 100644
--- a/rest_framework/response.py
+++ b/rest_framework/response.py
@@ -61,6 +61,10 @@ class Response(SimpleTemplateResponse):
assert charset, 'renderer returned unicode, and did not specify ' \
'a charset value.'
return bytes(ret.encode(charset))
+
+ if not ret:
+ del self['Content-Type']
+
return ret
@property
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 9e3881a2..9c27717f 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -6,8 +6,8 @@ form encoded input.
Serialization in REST framework is a two-phase process:
1. Serializers marshal between complex types like model instances, and
-python primatives.
-2. The process of marshalling between python primatives and request and
+python primitives.
+2. The process of marshalling between python primitives and request and
response content is handled by parsers and renderers.
"""
from __future__ import unicode_literals
@@ -31,9 +31,17 @@ from rest_framework.relations import *
from rest_framework.fields import *
+def pretty_name(name):
+ """Converts 'first_name' to 'First name'"""
+ if not name:
+ return ''
+ return name.replace('_', ' ').capitalize()
+
+
class RelationsList(list):
_deleted = []
+
class NestedValidationError(ValidationError):
"""
The default ValidationError behavior is to stringify each item in the list
@@ -48,9 +56,13 @@ class NestedValidationError(ValidationError):
def __init__(self, message):
if isinstance(message, dict):
- self.messages = [message]
+ self._messages = [message]
else:
- self.messages = message
+ self._messages = message
+
+ @property
+ def messages(self):
+ return self._messages
class DictWithMetadata(dict):
@@ -254,10 +266,13 @@ class BaseSerializer(WritableField):
for field_name, field in self.fields.items():
if field_name in self._errors:
continue
+
+ source = field.source or field_name
+ if self.partial and source not in attrs:
+ continue
try:
validate_method = getattr(self, 'validate_%s' % field_name, None)
if validate_method:
- source = field.source or field_name
attrs = validate_method(attrs, source)
except ValidationError as err:
self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages)
@@ -300,14 +315,19 @@ class BaseSerializer(WritableField):
"""
ret = self._dict_class()
ret.fields = self._dict_class()
- ret.empty = obj is None
for field_name, field in self.fields.items():
+ if field.read_only and obj is None:
+ continue
field.initialize(parent=self, field_name=field_name)
key = self.get_field_key(field_name)
value = field.field_to_native(obj, field_name)
+ method = getattr(self, 'transform_%s' % field_name, None)
+ if callable(method):
+ value = method(obj, value)
ret[key] = value
- ret.fields[key] = field
+ ret.fields[key] = self.augment_field(field, field_name, key, value)
+
return ret
def from_native(self, data, files):
@@ -315,6 +335,7 @@ class BaseSerializer(WritableField):
Deserialize primitives -> objects.
"""
self._errors = {}
+
if data is not None or files is not None:
attrs = self.restore_fields(data, files)
if attrs is not None:
@@ -325,6 +346,15 @@ class BaseSerializer(WritableField):
if not self._errors:
return self.restore_object(attrs, instance=getattr(self, 'object', None))
+ def augment_field(self, field, field_name, key, value):
+ # This horrible stuff is to manage serializers rendering to HTML
+ field._errors = self._errors.get(key) if self._errors else None
+ field._name = field_name
+ field._value = self.init_data.get(key) if self._errors and self.init_data else value
+ if not field.label:
+ field.label = pretty_name(key)
+ return field
+
def field_to_native(self, obj, field_name):
"""
Override default so that the serializer can be used as a nested field
@@ -375,8 +405,14 @@ class BaseSerializer(WritableField):
return
# Set the serializer object if it exists
- obj = getattr(self.parent.object, field_name) if self.parent.object else None
- obj = obj.all() if is_simple_callable(getattr(obj, 'all', None)) else obj
+ obj = get_component(self.parent.object, self.source or field_name) if self.parent.object else None
+
+ # If we have a model manager or similar object then we need
+ # to iterate through each instance.
+ if (self.many and
+ not hasattr(obj, '__iter__') and
+ is_simple_callable(getattr(obj, 'all', None))):
+ obj = obj.all()
if self.source == '*':
if value:
@@ -503,6 +539,9 @@ class BaseSerializer(WritableField):
"""
Save the deserialized object and return it.
"""
+ # Clear cached _data, which may be invalidated by `save()`
+ self._data = None
+
if isinstance(self.object, list):
[self.save_object(item, **kwargs) for item in self.object]
@@ -751,6 +790,8 @@ class ModelSerializer(Serializer):
# TODO: TypedChoiceField?
if model_field.flatchoices: # This ModelField contains choices
kwargs['choices'] = model_field.flatchoices
+ if model_field.null:
+ kwargs['empty'] = None
return ChoiceField(**kwargs)
# put this below the ChoiceField because min_value isn't a valid initializer
@@ -822,13 +863,13 @@ class ModelSerializer(Serializer):
# Reverse fk or one-to-one relations
for (obj, model) in meta.get_all_related_objects_with_model():
- field_name = obj.field.related_query_name()
+ field_name = obj.get_accessor_name()
if field_name in attrs:
related_data[field_name] = attrs.pop(field_name)
# Reverse m2m relations
for (obj, model) in meta.get_all_related_m2m_objects_with_model():
- field_name = obj.field.related_query_name()
+ field_name = obj.get_accessor_name()
if field_name in attrs:
m2m_data[field_name] = attrs.pop(field_name)
@@ -846,7 +887,10 @@ class ModelSerializer(Serializer):
# Update an existing instance...
if instance is not None:
for key, val in attrs.items():
- setattr(instance, key, val)
+ try:
+ setattr(instance, key, val)
+ except ValueError:
+ self._errors[key] = self.error_messages['required']
# ...or create a new instance
else:
@@ -872,7 +916,7 @@ class ModelSerializer(Serializer):
def save_object(self, obj, **kwargs):
"""
- Save the deserialized object and return it.
+ Save the deserialized object.
"""
if getattr(obj, '_nested_forward_relations', None):
# Nested relationships need to be saved before we can save the
@@ -890,11 +934,16 @@ class ModelSerializer(Serializer):
del(obj._m2m_data)
if getattr(obj, '_related_data', None):
+ related_fields = dict([
+ (field.get_accessor_name(), field)
+ for field, model
+ in obj._meta.get_all_related_objects_with_model()
+ ])
for accessor_name, related in obj._related_data.items():
if isinstance(related, RelationsList):
# Nested reverse fk relationship
for related_item in related:
- fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name
+ fk_field = related_fields[accessor_name].field.name
setattr(related_item, fk_field, obj)
self.save_object(related_item)
diff --git a/rest_framework/status.py b/rest_framework/status.py
index b9f249f9..76435371 100644
--- a/rest_framework/status.py
+++ b/rest_framework/status.py
@@ -6,6 +6,23 @@ And RFC 6585 - http://tools.ietf.org/html/rfc6585
"""
from __future__ import unicode_literals
+
+def is_informational(code):
+ return code >= 100 and code <= 199
+
+def is_success(code):
+ return code >= 200 and code <= 299
+
+def is_redirect(code):
+ return code >= 300 and code <= 399
+
+def is_client_error(code):
+ return code >= 400 and code <= 499
+
+def is_server_error(code):
+ return code >= 500 and code <= 599
+
+
HTTP_100_CONTINUE = 100
HTTP_101_SWITCHING_PROTOCOLS = 101
HTTP_200_OK = 200
diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html
index 47377d51..42ede968 100644
--- a/rest_framework/templates/rest_framework/base.html
+++ b/rest_framework/templates/rest_framework/base.html
@@ -111,7 +111,9 @@
<div class="content-main">
<div class="page-header"><h1>{{ name }}</h1></div>
+ {% block description %}
{{ description }}
+ {% endblock %}
<div class="request-info" style="clear: both" >
<pre class="prettyprint"><b>{{ request.method }}</b> {{ request.get_full_path }}</pre>
</div>
@@ -152,7 +154,7 @@
{% with form=raw_data_post_form %}
<form action="{{ request.get_full_path }}" method="POST" class="form-horizontal">
<fieldset>
- {% include "rest_framework/form.html" %}
+ {% include "rest_framework/raw_data_form.html" %}
<div class="form-actions">
<button class="btn btn-primary" title="Make a POST request on the {{ name }} resource">POST</button>
</div>
@@ -189,7 +191,7 @@
{% with form=raw_data_put_or_patch_form %}
<form action="{{ request.get_full_path }}" method="POST" class="form-horizontal">
<fieldset>
- {% include "rest_framework/form.html" %}
+ {% include "rest_framework/raw_data_form.html" %}
<div class="form-actions">
{% if raw_data_put_form %}
<button class="btn btn-primary js-tooltip" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="PUT" title="Make a PUT request on the {{ name }} resource">PUT</button>
@@ -220,9 +222,6 @@
</div><!-- ./wrapper -->
{% block footer %}
- <!--<div id="footer">
- <a class="powered-by" href='http://django-rest-framework.org'>Django REST framework</a>
- </div>-->
{% endblock %}
{% block script %}
diff --git a/rest_framework/templates/rest_framework/form.html b/rest_framework/templates/rest_framework/form.html
index b27f652e..b1e148df 100644
--- a/rest_framework/templates/rest_framework/form.html
+++ b/rest_framework/templates/rest_framework/form.html
@@ -1,13 +1,15 @@
{% load rest_framework %}
{% csrf_token %}
{{ form.non_field_errors }}
-{% for field in form %}
- <div class="control-group"> <!--{% if field.errors %}error{% endif %}-->
+{% for field in form.fields.values %}
+ {% if not field.read_only %}
+ <div class="control-group {% if field.errors %}error{% endif %}">
{{ field.label_tag|add_class:"control-label" }}
<div class="controls">
- {{ field }}
- <span class="help-block">{{ field.help_text }}</span>
- <!--{{ field.errors|add_class:"help-block" }}-->
+ {{ field.widget_html }}
+ {% if field.help_text %}<span class="help-block">{{ field.help_text }}</span>{% endif %}
+ {% for error in field.errors %}<span class="help-block">{{ error }}</span>{% endfor %}
</div>
</div>
+ {% endif %}
{% endfor %}
diff --git a/rest_framework/templates/rest_framework/raw_data_form.html b/rest_framework/templates/rest_framework/raw_data_form.html
new file mode 100644
index 00000000..075279f7
--- /dev/null
+++ b/rest_framework/templates/rest_framework/raw_data_form.html
@@ -0,0 +1,12 @@
+{% load rest_framework %}
+{% csrf_token %}
+{{ form.non_field_errors }}
+{% for field in form %}
+ <div class="control-group">
+ {{ field.label_tag|add_class:"control-label" }}
+ <div class="controls">
+ {{ field }}
+ <span class="help-block">{{ field.help_text }}</span>
+ </div>
+ </div>
+{% endfor %}
diff --git a/rest_framework/tests/test_fields.py b/rest_framework/tests/test_fields.py
index 34fbab9c..5c96bce9 100644
--- a/rest_framework/tests/test_fields.py
+++ b/rest_framework/tests/test_fields.py
@@ -42,6 +42,31 @@ class TimeFieldModelSerializer(serializers.ModelSerializer):
model = TimeFieldModel
+SAMPLE_CHOICES = [
+ ('red', 'Red'),
+ ('green', 'Green'),
+ ('blue', 'Blue'),
+]
+
+
+class ChoiceFieldModel(models.Model):
+ choice = models.CharField(choices=SAMPLE_CHOICES, blank=True, max_length=255)
+
+
+class ChoiceFieldModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ChoiceFieldModel
+
+
+class ChoiceFieldModelWithNull(models.Model):
+ choice = models.CharField(choices=SAMPLE_CHOICES, blank=True, null=True, max_length=255)
+
+
+class ChoiceFieldModelWithNullSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ChoiceFieldModelWithNull
+
+
class BasicFieldTests(TestCase):
def test_auto_now_fields_read_only(self):
"""
@@ -667,34 +692,71 @@ class ChoiceFieldTests(TestCase):
"""
Tests for the ChoiceField options generator
"""
-
- SAMPLE_CHOICES = [
- ('red', 'Red'),
- ('green', 'Green'),
- ('blue', 'Blue'),
- ]
-
def test_choices_required(self):
"""
Make sure proper choices are rendered if field is required
"""
- f = serializers.ChoiceField(required=True, choices=self.SAMPLE_CHOICES)
- self.assertEqual(f.choices, self.SAMPLE_CHOICES)
+ f = serializers.ChoiceField(required=True, choices=SAMPLE_CHOICES)
+ self.assertEqual(f.choices, SAMPLE_CHOICES)
def test_choices_not_required(self):
"""
Make sure proper choices (plus blank) are rendered if the field isn't required
"""
- f = serializers.ChoiceField(required=False, choices=self.SAMPLE_CHOICES)
- self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + self.SAMPLE_CHOICES)
+ f = serializers.ChoiceField(required=False, choices=SAMPLE_CHOICES)
+ self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + SAMPLE_CHOICES)
+
+ def test_invalid_choice_model(self):
+ s = ChoiceFieldModelSerializer(data={'choice': 'wrong_value'})
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'choice': ['Select a valid choice. wrong_value is not one of the available choices.']})
+ self.assertEqual(s.data['choice'], '')
+
+ def test_empty_choice_model(self):
+ """
+ Test that the 'empty' value is correctly passed and used depending on
+ the 'null' property on the model field.
+ """
+ s = ChoiceFieldModelSerializer(data={'choice': ''})
+ self.assertTrue(s.is_valid())
+ self.assertEqual(s.data['choice'], '')
+
+ s = ChoiceFieldModelWithNullSerializer(data={'choice': ''})
+ self.assertTrue(s.is_valid())
+ self.assertEqual(s.data['choice'], None)
def test_from_native_empty(self):
"""
- Make sure from_native() returns None on empty param.
+ Make sure from_native() returns an empty string on empty param by default.
"""
- f = serializers.ChoiceField(choices=self.SAMPLE_CHOICES)
- result = f.from_native('')
- self.assertEqual(result, None)
+ f = serializers.ChoiceField(choices=SAMPLE_CHOICES)
+ self.assertEqual(f.from_native(''), '')
+ self.assertEqual(f.from_native(None), '')
+
+ def test_from_native_empty_override(self):
+ """
+ Make sure you can override from_native() behavior regarding empty values.
+ """
+ f = serializers.ChoiceField(choices=SAMPLE_CHOICES, empty=None)
+ self.assertEqual(f.from_native(''), None)
+ self.assertEqual(f.from_native(None), None)
+
+ def test_metadata_choices(self):
+ """
+ Make sure proper choices are included in the field's metadata.
+ """
+ choices = [{'value': v, 'display_name': n} for v, n in SAMPLE_CHOICES]
+ f = serializers.ChoiceField(choices=SAMPLE_CHOICES)
+ self.assertEqual(f.metadata()['choices'], choices)
+
+ def test_metadata_choices_not_required(self):
+ """
+ Make sure proper choices are included in the field's metadata.
+ """
+ choices = [{'value': v, 'display_name': n}
+ for v, n in models.fields.BLANK_CHOICE_DASH + SAMPLE_CHOICES]
+ f = serializers.ChoiceField(required=False, choices=SAMPLE_CHOICES)
+ self.assertEqual(f.metadata()['choices'], choices)
class EmailFieldTests(TestCase):
diff --git a/rest_framework/tests/test_files.py b/rest_framework/tests/test_files.py
index c13c38b8..78f4cf42 100644
--- a/rest_framework/tests/test_files.py
+++ b/rest_framework/tests/test_files.py
@@ -80,3 +80,16 @@ class FileSerializerTests(TestCase):
serializer = UploadedFileSerializer(data={'created': now, 'file': 'abc'})
self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'file': [errmsg]})
+
+ def test_validation_with_no_data(self):
+ """
+ Validation should still function when no data dictionary is provided.
+ """
+ now = datetime.datetime.now()
+ file = BytesIO(six.b('stuff'))
+ file.name = 'stuff.txt'
+ file.size = len(file.getvalue())
+ uploaded_file = UploadedFile(file=file, created=now)
+
+ serializer = UploadedFileSerializer(files={'file': file})
+ self.assertFalse(serializer.is_valid())
diff --git a/rest_framework/tests/test_filters.py b/rest_framework/tests/test_filters.py
index 9697c5ee..8a03a077 100644
--- a/rest_framework/tests/test_filters.py
+++ b/rest_framework/tests/test_filters.py
@@ -364,6 +364,12 @@ class OrdringFilterModel(models.Model):
text = models.CharField(max_length=100)
+class OrderingFilterRelatedModel(models.Model):
+ related_object = models.ForeignKey(OrdringFilterModel,
+ related_name="relateds")
+
+
+
class OrderingFilterTests(TestCase):
def setUp(self):
# Sequence of title/text is:
@@ -473,3 +479,36 @@ class OrderingFilterTests(TestCase):
{'id': 1, 'title': 'zyx', 'text': 'abc'},
]
)
+
+ def test_ordering_by_aggregate_field(self):
+ # create some related models to aggregate order by
+ num_objs = [2, 5, 3]
+ for obj, num_relateds in zip(OrdringFilterModel.objects.all(),
+ num_objs):
+ for _ in range(num_relateds):
+ new_related = OrderingFilterRelatedModel(
+ related_object=obj
+ )
+ new_related.save()
+
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = 'title'
+ queryset = OrdringFilterModel.objects.all().annotate(
+ models.Count("relateds"))
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=relateds__count')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ ]
+ )
+
+
+
diff --git a/rest_framework/tests/test_generics.py b/rest_framework/tests/test_generics.py
index 79cd99ac..996bd5b0 100644
--- a/rest_framework/tests/test_generics.py
+++ b/rest_framework/tests/test_generics.py
@@ -23,6 +23,10 @@ class InstanceView(generics.RetrieveUpdateDestroyAPIView):
"""
model = BasicModel
+ def get_queryset(self):
+ queryset = super(InstanceView, self).get_queryset()
+ return queryset.exclude(text='filtered out')
+
class SlugSerializer(serializers.ModelSerializer):
slug = serializers.Field() # read only
@@ -160,10 +164,10 @@ class TestInstanceView(TestCase):
"""
Create 3 BasicModel intances.
"""
- items = ['foo', 'bar', 'baz']
+ items = ['foo', 'bar', 'baz', 'filtered out']
for item in items:
BasicModel(text=item).save()
- self.objects = BasicModel.objects
+ self.objects = BasicModel.objects.exclude(text='filtered out')
self.data = [
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
@@ -352,6 +356,17 @@ class TestInstanceView(TestCase):
updated = self.objects.get(id=1)
self.assertEqual(updated.text, 'foobar')
+ def test_put_to_filtered_out_instance(self):
+ """
+ PUT requests to an URL of instance which is filtered out should not be
+ able to create new objects.
+ """
+ data = {'text': 'foo'}
+ filtered_out_pk = BasicModel.objects.filter(text='filtered out')[0].pk
+ request = factory.put('/{0}'.format(filtered_out_pk), data, format='json')
+ response = self.view(request, pk=filtered_out_pk).render()
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+
def test_put_as_create_on_id_based_url(self):
"""
PUT requests to RetrieveUpdateDestroyAPIView should create an object
@@ -508,6 +523,25 @@ class ExclusiveFilterBackend(object):
return queryset.filter(text='other')
+class TwoFieldModel(models.Model):
+ field_a = models.CharField(max_length=100)
+ field_b = models.CharField(max_length=100)
+
+
+class DynamicSerializerView(generics.ListCreateAPIView):
+ model = TwoFieldModel
+ renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer)
+
+ def get_serializer_class(self):
+ if self.request.method == 'POST':
+ class DynamicSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = TwoFieldModel
+ fields = ('field_b',)
+ return DynamicSerializer
+ return super(DynamicSerializerView, self).get_serializer_class()
+
+
class TestFilterBackendAppliedToViews(TestCase):
def setUp(self):
@@ -564,28 +598,6 @@ class TestFilterBackendAppliedToViews(TestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, {'id': 1, 'text': 'foo'})
-
-class TwoFieldModel(models.Model):
- field_a = models.CharField(max_length=100)
- field_b = models.CharField(max_length=100)
-
-
-class DynamicSerializerView(generics.ListCreateAPIView):
- model = TwoFieldModel
- renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer)
-
- def get_serializer_class(self):
- if self.request.method == 'POST':
- class DynamicSerializer(serializers.ModelSerializer):
- class Meta:
- model = TwoFieldModel
- fields = ('field_b',)
- return DynamicSerializer
- return super(DynamicSerializerView, self).get_serializer_class()
-
-
-class TestFilterBackendAppliedToViews(TestCase):
-
def test_dynamic_serializer_form_in_browsable_api(self):
"""
GET requests to ListCreateAPIView should return filtered list.
diff --git a/rest_framework/tests/test_pagination.py b/rest_framework/tests/test_pagination.py
index d6bc7895..cadb515f 100644
--- a/rest_framework/tests/test_pagination.py
+++ b/rest_framework/tests/test_pagination.py
@@ -430,3 +430,88 @@ class TestCustomPaginationSerializer(TestCase):
'objects': ['john', 'paul']
}
self.assertEqual(serializer.data, expected)
+
+
+class NonIntegerPage(object):
+
+ def __init__(self, paginator, object_list, prev_token, token, next_token):
+ self.paginator = paginator
+ self.object_list = object_list
+ self.prev_token = prev_token
+ self.token = token
+ self.next_token = next_token
+
+ def has_next(self):
+ return not not self.next_token
+
+ def next_page_number(self):
+ return self.next_token
+
+ def has_previous(self):
+ return not not self.prev_token
+
+ def previous_page_number(self):
+ return self.prev_token
+
+
+class NonIntegerPaginator(object):
+
+ def __init__(self, object_list, per_page):
+ self.object_list = object_list
+ self.per_page = per_page
+
+ def count(self):
+ # pretend like we don't know how many pages we have
+ return None
+
+ def page(self, token=None):
+ if token:
+ try:
+ first = self.object_list.index(token)
+ except ValueError:
+ first = 0
+ else:
+ first = 0
+ n = len(self.object_list)
+ last = min(first + self.per_page, n)
+ prev_token = self.object_list[last - (2 * self.per_page)] if first else None
+ next_token = self.object_list[last] if last < n else None
+ return NonIntegerPage(self, self.object_list[first:last], prev_token, token, next_token)
+
+
+class TestNonIntegerPagination(TestCase):
+
+
+ def test_custom_pagination_serializer(self):
+ objects = ['john', 'paul', 'george', 'ringo']
+ paginator = NonIntegerPaginator(objects, 2)
+
+ request = APIRequestFactory().get('/foobar')
+ serializer = CustomPaginationSerializer(
+ instance=paginator.page(),
+ context={'request': request}
+ )
+ expected = {
+ 'links': {
+ 'next': 'http://testserver/foobar?page={0}'.format(objects[2]),
+ 'prev': None
+ },
+ 'total_results': None,
+ 'objects': objects[:2]
+ }
+ self.assertEqual(serializer.data, expected)
+
+ request = APIRequestFactory().get('/foobar')
+ serializer = CustomPaginationSerializer(
+ instance=paginator.page('george'),
+ context={'request': request}
+ )
+ expected = {
+ 'links': {
+ 'next': None,
+ 'prev': 'http://testserver/foobar?page={0}'.format(objects[0]),
+ },
+ 'total_results': None,
+ 'objects': objects[2:]
+ }
+ self.assertEqual(serializer.data, expected)
diff --git a/rest_framework/tests/test_permissions.py b/rest_framework/tests/test_permissions.py
index d08124f4..6e3a6303 100644
--- a/rest_framework/tests/test_permissions.py
+++ b/rest_framework/tests/test_permissions.py
@@ -4,7 +4,7 @@ from django.db import models
from django.test import TestCase
from django.utils import unittest
from rest_framework import generics, status, permissions, authentication, HTTP_HEADER_ENCODING
-from rest_framework.compat import guardian
+from rest_framework.compat import guardian, get_model_name
from rest_framework.filters import DjangoObjectPermissionsFilter
from rest_framework.test import APIRequestFactory
from rest_framework.tests.models import BasicModel
@@ -202,7 +202,7 @@ class ObjectPermissionsIntegrationTests(TestCase):
# give everyone model level permissions, as we are not testing those
everyone = Group.objects.create(name='everyone')
- model_name = BasicPermModel._meta.module_name
+ model_name = get_model_name(BasicPermModel)
app_label = BasicPermModel._meta.app_label
f = '{0}_{1}'.format
perms = {
diff --git a/rest_framework/tests/test_renderers.py b/rest_framework/tests/test_renderers.py
index 9d1dd77e..9cb68233 100644
--- a/rest_framework/tests/test_renderers.py
+++ b/rest_framework/tests/test_renderers.py
@@ -16,7 +16,9 @@ from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \
from rest_framework.parsers import YAMLParser, XMLParser
from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory
+from collections import MutableMapping
import datetime
+import json
import pickle
import re
@@ -65,11 +67,23 @@ class MockView(APIView):
class MockGETView(APIView):
-
def get(self, request, **kwargs):
return Response({'foo': ['bar', 'baz']})
+
+class MockPOSTView(APIView):
+ def post(self, request, **kwargs):
+ return Response({'foo': request.DATA})
+
+
+class EmptyGETView(APIView):
+ renderer_classes = (JSONRenderer,)
+
+ def get(self, request, **kwargs):
+ return Response(status=status.HTTP_204_NO_CONTENT)
+
+
class HTMLView(APIView):
renderer_classes = (BrowsableAPIRenderer, )
@@ -89,8 +103,10 @@ urlpatterns = patterns('',
url(r'^cache$', MockGETView.as_view()),
url(r'^jsonp/jsonrenderer$', MockGETView.as_view(renderer_classes=[JSONRenderer, JSONPRenderer])),
url(r'^jsonp/nojsonrenderer$', MockGETView.as_view(renderer_classes=[JSONPRenderer])),
+ url(r'^parseerror$', MockPOSTView.as_view(renderer_classes=[JSONRenderer, BrowsableAPIRenderer])),
url(r'^html$', HTMLView.as_view()),
url(r'^html1$', HTMLView1.as_view()),
+ url(r'^empty$', EmptyGETView.as_view()),
url(r'^api', include('rest_framework.urls', namespace='rest_framework'))
)
@@ -220,6 +236,22 @@ class RendererEndToEndTests(TestCase):
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
+ def test_parse_error_renderers_browsable_api(self):
+ """Invalid data should still render the browsable API correctly."""
+ resp = self.client.post('/parseerror', data='foobar', content_type='application/json', HTTP_ACCEPT='text/html')
+ self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
+ self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST)
+
+ def test_204_no_content_responses_have_no_content_type_set(self):
+ """
+ Regression test for #1196
+
+ https://github.com/tomchristie/django-rest-framework/issues/1196
+ """
+ resp = self.client.get('/empty')
+ self.assertEqual(resp.get('Content-Type', None), None)
+ self.assertEqual(resp.status_code, status.HTTP_204_NO_CONTENT)
+
_flat_repr = '{"foo": ["bar", "baz"]}'
_indented_repr = '{\n "foo": [\n "bar",\n "baz"\n ]\n}'
@@ -245,6 +277,44 @@ class JSONRendererTests(TestCase):
ret = JSONRenderer().render(_('test'))
self.assertEqual(ret, b'"test"')
+ def test_render_dict_abc_obj(self):
+ class Dict(MutableMapping):
+ def __init__(self):
+ self._dict = dict()
+ def __getitem__(self, key):
+ return self._dict.__getitem__(key)
+ def __setitem__(self, key, value):
+ return self._dict.__setitem__(key, value)
+ def __delitem__(self, key):
+ return self._dict.__delitem__(key)
+ def __iter__(self):
+ return self._dict.__iter__()
+ def __len__(self):
+ return self._dict.__len__()
+ def keys(self):
+ return self._dict.keys()
+
+ x = Dict()
+ x['key'] = 'string value'
+ x[2] = 3
+ ret = JSONRenderer().render(x)
+ data = json.loads(ret.decode('utf-8'))
+ self.assertEquals(data, {'key': 'string value', '2': 3})
+
+ def test_render_obj_with_getitem(self):
+ class DictLike(object):
+ def __init__(self):
+ self._dict = {}
+ def set(self, value):
+ self._dict = dict(value)
+ def __getitem__(self, key):
+ return self._dict[key]
+
+ x = DictLike()
+ x.set({'a': 1, 'b': 'string'})
+ with self.assertRaises(TypeError):
+ JSONRenderer().render(x)
+
def test_without_content_type_args(self):
"""
Test basic JSON rendering.
@@ -329,7 +399,7 @@ if yaml:
class YAMLRendererTests(TestCase):
"""
- Tests specific to the JSON Renderer
+ Tests specific to the YAML Renderer
"""
def test_render(self):
@@ -355,6 +425,17 @@ if yaml:
data = parser.parse(StringIO(content))
self.assertEqual(obj, data)
+ def test_render_decimal(self):
+ """
+ Test YAML decimal rendering.
+ """
+ renderer = YAMLRenderer()
+ content = renderer.render({'field': Decimal('111.2')}, 'application/yaml')
+ self.assertYAMLContains(content, "field: '111.2'")
+
+ def assertYAMLContains(self, content, string):
+ self.assertTrue(string in content, '%r not in %r' % (string, content))
+
class XMLRendererTestCase(TestCase):
"""
diff --git a/rest_framework/tests/test_request.py b/rest_framework/tests/test_request.py
index d6363425..f07c31a3 100644
--- a/rest_framework/tests/test_request.py
+++ b/rest_framework/tests/test_request.py
@@ -6,6 +6,7 @@ from django.conf.urls import patterns
from django.contrib.auth.models import User
from django.contrib.auth import authenticate, login, logout
from django.contrib.sessions.middleware import SessionMiddleware
+from django.core.handlers.wsgi import WSGIRequest
from django.test import TestCase
from rest_framework import status
from rest_framework.authentication import SessionAuthentication
@@ -15,12 +16,13 @@ from rest_framework.parsers import (
MultiPartParser,
JSONParser
)
-from rest_framework.request import Request
+from rest_framework.request import Request, Empty
from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory, APIClient
from rest_framework.views import APIView
from rest_framework.compat import six
+from io import BytesIO
import json
@@ -146,6 +148,34 @@ class TestContentParsing(TestCase):
request.parsers = (JSONParser(), )
self.assertEqual(request.DATA, json_data)
+ def test_form_POST_unicode(self):
+ """
+ JSON POST via default web interface with unicode data
+ """
+ # Note: environ and other variables here have simplified content compared to real Request
+ CONTENT = b'_content_type=application%2Fjson&_content=%7B%22request%22%3A+4%2C+%22firm%22%3A+1%2C+%22text%22%3A+%22%D0%9F%D1%80%D0%B8%D0%B2%D0%B5%D1%82%21%22%7D'
+ environ = {
+ 'REQUEST_METHOD': 'POST',
+ 'CONTENT_TYPE': 'application/x-www-form-urlencoded',
+ 'CONTENT_LENGTH': len(CONTENT),
+ 'wsgi.input': BytesIO(CONTENT),
+ }
+ wsgi_request = WSGIRequest(environ=environ)
+ wsgi_request._load_post_and_files()
+ parsers = (JSONParser(), FormParser(), MultiPartParser())
+ parser_context = {
+ 'encoding': 'utf-8',
+ 'kwargs': {},
+ 'args': (),
+ }
+ request = Request(wsgi_request, parsers=parsers, parser_context=parser_context)
+ method = request.method
+ self.assertEqual(method, 'POST')
+ self.assertEqual(request._content_type, 'application/json')
+ self.assertEqual(request._stream.getvalue(), b'{"request": 4, "firm": 1, "text": "\xd0\x9f\xd1\x80\xd0\xb8\xd0\xb2\xd0\xb5\xd1\x82!"}')
+ self.assertEqual(request._data, Empty)
+ self.assertEqual(request._files, Empty)
+
# def test_accessing_post_after_data_form(self):
# """
# Ensures request.POST can be accessed after request.DATA in
diff --git a/rest_framework/tests/test_serializer.py b/rest_framework/tests/test_serializer.py
index 739bb70a..e80276e9 100644
--- a/rest_framework/tests/test_serializer.py
+++ b/rest_framework/tests/test_serializer.py
@@ -1,3 +1,4 @@
+# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from django.db import models
from django.db.models.fields import BLANK_CHOICE_DASH
@@ -136,6 +137,7 @@ class BasicTests(TestCase):
'Happy new year!',
datetime.datetime(2012, 1, 1)
)
+ self.actionitem = ActionItem(title='Some to do item',)
self.data = {
'email': 'tom@example.com',
'content': 'Happy new year!',
@@ -157,8 +159,7 @@ class BasicTests(TestCase):
expected = {
'email': '',
'content': '',
- 'created': None,
- 'sub_comment': ''
+ 'created': None
}
self.assertEqual(serializer.data, expected)
@@ -264,6 +265,20 @@ class BasicTests(TestCase):
"""
self.assertRaises(AssertionError, PersonSerializerInvalidReadOnly, [])
+ def test_serializer_data_is_cleared_on_save(self):
+ """
+ Check _data attribute is cleared on `save()`
+
+ Regression test for #1116
+ — id field is not populated if `data` is accessed prior to `save()`
+ """
+ serializer = ActionItemSerializer(self.actionitem)
+ self.assertIsNone(serializer.data.get('id',None), 'New instance. `id` should not be set.')
+ serializer.save()
+ self.assertIsNotNone(serializer.data.get('id',None), 'Model is saved. `id` should be set.')
+
+
+
class DictStyleSerializer(serializers.Serializer):
"""
@@ -496,6 +511,33 @@ class CustomValidationTests(TestCase):
self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'email': ['Enter a valid email address.']})
+ def test_partial_update(self):
+ """
+ Make sure that validate_email isn't called when partial=True and email
+ isn't found in data.
+ """
+ initial_data = {
+ 'email': 'tom@example.com',
+ 'content': 'A test comment',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+
+ serializer = self.CommentSerializerWithFieldValidator(data=initial_data)
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.object
+
+ new_content = 'An *updated* test comment'
+ partial_data = {
+ 'content': new_content
+ }
+
+ serializer = self.CommentSerializerWithFieldValidator(instance=instance,
+ data=partial_data,
+ partial=True)
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.object
+ self.assertEqual(instance.content, new_content)
+
class PositiveIntegerAsChoiceTests(TestCase):
def test_positive_integer_in_json_is_correctly_parsed(self):
@@ -516,6 +558,29 @@ class ModelValidationTests(TestCase):
self.assertFalse(second_serializer.is_valid())
self.assertEqual(second_serializer.errors, {'title': ['Album with this Title already exists.']})
+ def test_foreign_key_is_null_with_partial(self):
+ """
+ Test ModelSerializer validation with partial=True
+
+ Specifically test that a null foreign key does not pass validation
+ """
+ album = Album(title='test')
+ album.save()
+
+ class PhotoSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Photo
+
+ photo_serializer = PhotoSerializer(data={'description': 'test', 'album': album.pk})
+ self.assertTrue(photo_serializer.is_valid())
+ photo = photo_serializer.save()
+
+ # Updating only the album (foreign key)
+ photo_serializer = PhotoSerializer(instance=photo, data={'album': ''}, partial=True)
+ self.assertFalse(photo_serializer.is_valid())
+ self.assertTrue('album' in photo_serializer.errors)
+ self.assertEqual(photo_serializer.errors['album'], photo_serializer.error_messages['required'])
+
def test_foreign_key_with_partial(self):
"""
Test ModelSerializer validation with partial=True
@@ -1643,3 +1708,38 @@ class SerializerSupportsManyRelationships(TestCase):
serializer = SimpleSlugSourceModelSerializer(data={'text': 'foo', 'targets': [1, 2]})
self.assertTrue(serializer.is_valid())
self.assertEqual(serializer.data, {'text': 'foo', 'targets': [1, 2]})
+
+
+class TransformMethodsSerializer(serializers.Serializer):
+ a = serializers.CharField()
+ b_renamed = serializers.CharField(source='b')
+
+ def transform_a(self, obj, value):
+ return value.lower()
+
+ def transform_b_renamed(self, obj, value):
+ if value is not None:
+ return 'and ' + value
+
+
+class TestSerializerTransformMethods(TestCase):
+ def setUp(self):
+ self.s = TransformMethodsSerializer()
+
+ def test_transform_methods(self):
+ self.assertEqual(
+ self.s.to_native({'a': 'GREEN EGGS', 'b': 'HAM'}),
+ {
+ 'a': 'green eggs',
+ 'b_renamed': 'and HAM',
+ }
+ )
+
+ def test_missing_fields(self):
+ self.assertEqual(
+ self.s.to_native({'a': 'GREEN EGGS'}),
+ {
+ 'a': 'green eggs',
+ 'b_renamed': None,
+ }
+ )
diff --git a/rest_framework/tests/test_serializer_empty.py b/rest_framework/tests/test_serializer_empty.py
new file mode 100644
index 00000000..30cff361
--- /dev/null
+++ b/rest_framework/tests/test_serializer_empty.py
@@ -0,0 +1,15 @@
+from django.test import TestCase
+from rest_framework import serializers
+
+
+class EmptySerializerTestCase(TestCase):
+ def test_empty_serializer(self):
+ class FooBarSerializer(serializers.Serializer):
+ foo = serializers.IntegerField()
+ bar = serializers.SerializerMethodField('get_bar')
+
+ def get_bar(self, obj):
+ return 'bar'
+
+ serializer = FooBarSerializer()
+ self.assertEquals(serializer.data, {'foo': 0})
diff --git a/rest_framework/tests/test_serializer_nested.py b/rest_framework/tests/test_serializer_nested.py
index 71d0e24b..7114a060 100644
--- a/rest_framework/tests/test_serializer_nested.py
+++ b/rest_framework/tests/test_serializer_nested.py
@@ -6,6 +6,7 @@ Doesn't cover model serializers.
from __future__ import unicode_literals
from django.test import TestCase
from rest_framework import serializers
+from . import models
class WritableNestedSerializerBasicTests(TestCase):
@@ -244,3 +245,104 @@ class WritableNestedSerializerObjectTests(TestCase):
serializer = self.AlbumSerializer(data=data, many=True)
self.assertEqual(serializer.is_valid(), True)
self.assertEqual(serializer.object, expected_object)
+
+
+class ForeignKeyNestedSerializerUpdateTests(TestCase):
+ def setUp(self):
+ class Artist(object):
+ def __init__(self, name):
+ self.name = name
+
+ def __eq__(self, other):
+ return self.name == other.name
+
+ class Album(object):
+ def __init__(self, name, artist):
+ self.name, self.artist = name, artist
+
+ def __eq__(self, other):
+ return self.name == other.name and self.artist == other.artist
+
+ class ArtistSerializer(serializers.Serializer):
+ name = serializers.CharField()
+
+ def restore_object(self, attrs, instance=None):
+ if instance:
+ instance.name = attrs['name']
+ else:
+ instance = Artist(attrs['name'])
+ return instance
+
+ class AlbumSerializer(serializers.Serializer):
+ name = serializers.CharField()
+ by = ArtistSerializer(source='artist')
+
+ def restore_object(self, attrs, instance=None):
+ if instance:
+ instance.name = attrs['name']
+ instance.artist = attrs['artist']
+ else:
+ instance = Album(attrs['name'], attrs['artist'])
+ return instance
+
+ self.Artist = Artist
+ self.Album = Album
+ self.AlbumSerializer = AlbumSerializer
+
+ def test_create_via_foreign_key_with_source(self):
+ """
+ Check that we can both *create* and *update* into objects across
+ ForeignKeys that have a `source` specified.
+ Regression test for #1170
+ """
+ data = {
+ 'name': 'Discovery',
+ 'by': {'name': 'Daft Punk'},
+ }
+
+ expected = self.Album(artist=self.Artist('Daft Punk'), name='Discovery')
+
+ # create
+ serializer = self.AlbumSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, expected)
+
+ # update
+ original = self.Album(artist=self.Artist('The Bats'), name='Free All the Monsters')
+ serializer = self.AlbumSerializer(instance=original, data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, expected)
+
+
+class NestedModelSerializerUpdateTests(TestCase):
+ def test_second_nested_level(self):
+ john = models.Person.objects.create(name="john")
+
+ post = john.blogpost_set.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I hate this blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ class BlogPostCommentSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = models.BlogPostComment
+
+ class BlogPostSerializer(serializers.ModelSerializer):
+ comments = BlogPostCommentSerializer(many=True, source='blogpostcomment_set')
+ class Meta:
+ model = models.BlogPost
+ fields = ('id', 'title', 'comments')
+
+ class PersonSerializer(serializers.ModelSerializer):
+ posts = BlogPostSerializer(many=True, source='blogpost_set')
+ class Meta:
+ model = models.Person
+ fields = ('id', 'name', 'age', 'posts')
+
+ serialize = PersonSerializer(instance=john)
+ deserialize = PersonSerializer(data=serialize.data, instance=john)
+ self.assertTrue(deserialize.is_valid())
+
+ result = deserialize.object
+ result.save()
+ self.assertEqual(result.id, john.id)
+
diff --git a/rest_framework/tests/test_status.py b/rest_framework/tests/test_status.py
new file mode 100644
index 00000000..7b1bdae3
--- /dev/null
+++ b/rest_framework/tests/test_status.py
@@ -0,0 +1,33 @@
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework.status import (
+ is_informational, is_success, is_redirect, is_client_error, is_server_error
+)
+
+
+class TestStatus(TestCase):
+ def test_status_categories(self):
+ self.assertFalse(is_informational(99))
+ self.assertTrue(is_informational(100))
+ self.assertTrue(is_informational(199))
+ self.assertFalse(is_informational(200))
+
+ self.assertFalse(is_success(199))
+ self.assertTrue(is_success(200))
+ self.assertTrue(is_success(299))
+ self.assertFalse(is_success(300))
+
+ self.assertFalse(is_redirect(299))
+ self.assertTrue(is_redirect(300))
+ self.assertTrue(is_redirect(399))
+ self.assertFalse(is_redirect(400))
+
+ self.assertFalse(is_client_error(399))
+ self.assertTrue(is_client_error(400))
+ self.assertTrue(is_client_error(499))
+ self.assertFalse(is_client_error(500))
+
+ self.assertFalse(is_server_error(499))
+ self.assertTrue(is_server_error(500))
+ self.assertTrue(is_server_error(599))
+ self.assertFalse(is_server_error(600)) \ No newline at end of file
diff --git a/rest_framework/urlpatterns.py b/rest_framework/urlpatterns.py
index a62530c7..038e9ee3 100644
--- a/rest_framework/urlpatterns.py
+++ b/rest_framework/urlpatterns.py
@@ -57,6 +57,6 @@ def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None):
allowed_pattern = '(%s)' % '|'.join(allowed)
suffix_pattern = r'\.(?P<%s>%s)$' % (suffix_kwarg, allowed_pattern)
else:
- suffix_pattern = r'\.(?P<%s>[a-z]+)$' % suffix_kwarg
+ suffix_pattern = r'\.(?P<%s>[a-z0-9]+)$' % suffix_kwarg
return apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required)
diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py
index 13a85550..229b0b28 100644
--- a/rest_framework/utils/encoders.py
+++ b/rest_framework/utils/encoders.py
@@ -45,6 +45,11 @@ class JSONEncoder(json.JSONEncoder):
return str(o)
elif hasattr(o, 'tolist'):
return o.tolist()
+ elif hasattr(o, '__getitem__'):
+ try:
+ return dict(o)
+ except:
+ pass
elif hasattr(o, '__iter__'):
return [i for i in o]
return super(JSONEncoder, self).default(o)
@@ -90,6 +95,9 @@ else:
node.flow_style = best_style
return node
+ SafeDumper.add_representer(decimal.Decimal,
+ SafeDumper.represent_decimal)
+
SafeDumper.add_representer(SortedDict,
yaml.representer.SafeRepresenter.represent_dict)
SafeDumper.add_representer(DictWithMetadata,
diff --git a/rest_framework/views.py b/rest_framework/views.py
index 853e6461..e863af6d 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -154,8 +154,8 @@ class APIView(View):
Returns a dict that is passed through to Parser.parse(),
as the `parser_context` keyword argument.
"""
- # Note: Additionally `request` will also be added to the context
- # by the Request object.
+ # Note: Additionally `request` and `encoding` will also be added
+ # to the context by the Request object.
return {
'view': self,
'args': getattr(self, 'args', ()),
diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py
index d91323f2..7eb29f99 100644
--- a/rest_framework/viewsets.py
+++ b/rest_framework/viewsets.py
@@ -9,7 +9,7 @@ Actions are only bound to methods at the point of instantiating the views.
user_detail = UserViewSet.as_view({'get': 'retrieve'})
Typically, rather than instantiate views from viewsets directly, you'll
-regsiter the viewset with a router and let the URL conf be determined
+register the viewset with a router and let the URL conf be determined
automatically.
router = DefaultRouter()