aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/__init__.py18
-rw-r--r--rest_framework/authentication.py6
-rw-r--r--rest_framework/authtoken/models.py8
-rw-r--r--rest_framework/compat.py29
-rw-r--r--rest_framework/exceptions.py33
-rw-r--r--rest_framework/fields.py29
-rw-r--r--rest_framework/filters.py28
-rw-r--r--rest_framework/generics.py24
-rw-r--r--rest_framework/mixins.py13
-rw-r--r--rest_framework/parsers.py4
-rw-r--r--rest_framework/permissions.py22
-rw-r--r--rest_framework/relations.py15
-rw-r--r--rest_framework/renderers.py9
-rw-r--r--rest_framework/request.py15
-rw-r--r--rest_framework/response.py4
-rw-r--r--rest_framework/routers.py14
-rw-r--r--rest_framework/runtests/settings.py3
-rw-r--r--rest_framework/serializers.py122
-rw-r--r--rest_framework/settings.py1
-rw-r--r--rest_framework/status.py17
-rw-r--r--rest_framework/templates/rest_framework/base.html5
-rw-r--r--rest_framework/templatetags/rest_framework.py5
-rw-r--r--rest_framework/tests/accounts/__init__.py0
-rw-r--r--rest_framework/tests/accounts/models.py8
-rw-r--r--rest_framework/tests/accounts/serializers.py11
-rw-r--r--rest_framework/tests/models.py1
-rw-r--r--rest_framework/tests/records/__init__.py0
-rw-r--r--rest_framework/tests/records/models.py6
-rw-r--r--rest_framework/tests/test_authentication.py26
-rw-r--r--rest_framework/tests/test_fields.py92
-rw-r--r--rest_framework/tests/test_filters.py141
-rw-r--r--rest_framework/tests/test_genericrelations.py29
-rw-r--r--rest_framework/tests/test_generics.py60
-rw-r--r--rest_framework/tests/test_hyperlinkedserializers.py46
-rw-r--r--rest_framework/tests/test_relations.py20
-rw-r--r--rest_framework/tests/test_renderers.py91
-rw-r--r--rest_framework/tests/test_request.py32
-rw-r--r--rest_framework/tests/test_serializer.py120
-rw-r--r--rest_framework/tests/test_serializer_import.py19
-rw-r--r--rest_framework/tests/test_serializer_nested.py34
-rw-r--r--rest_framework/tests/test_serializers.py28
-rw-r--r--rest_framework/tests/test_status.py33
-rw-r--r--rest_framework/tests/test_templatetags.py19
-rw-r--r--rest_framework/tests/test_validation.py21
-rw-r--r--rest_framework/tests/test_write_only_fields.py42
-rw-r--r--rest_framework/tests/users/__init__.py0
-rw-r--r--rest_framework/tests/users/models.py6
-rw-r--r--rest_framework/tests/users/serializers.py8
-rw-r--r--rest_framework/urlpatterns.py2
-rw-r--r--rest_framework/utils/encoders.py8
-rw-r--r--rest_framework/views.py4
-rw-r--r--rest_framework/viewsets.py2
52 files changed, 1160 insertions, 173 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py
index 2bd2991b..6759680b 100644
--- a/rest_framework/__init__.py
+++ b/rest_framework/__init__.py
@@ -1,6 +1,20 @@
-__version__ = '2.3.8'
+"""
+______ _____ _____ _____ __ _
+| ___ \ ___/ ___|_ _| / _| | |
+| |_/ / |__ \ `--. | | | |_ _ __ __ _ _ __ ___ _____ _____ _ __| | __
+| /| __| `--. \ | | | _| '__/ _` | '_ ` _ \ / _ \ \ /\ / / _ \| '__| |/ /
+| |\ \| |___/\__/ / | | | | | | | (_| | | | | | | __/\ V V / (_) | | | <
+\_| \_\____/\____/ \_/ |_| |_| \__,_|_| |_| |_|\___| \_/\_/ \___/|_| |_|\_|
+"""
-VERSION = __version__ # synonym
+__title__ = 'Django REST framework'
+__version__ = '2.3.12'
+__author__ = 'Tom Christie'
+__license__ = 'BSD 2-Clause'
+__copyright__ = 'Copyright 2011-2013 Tom Christie'
+
+# Version synonym
+VERSION = __version__
# Header encoding (see RFC5987)
HTTP_HEADER_ENCODING = 'iso-8859-1'
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py
index cf001a24..e491ce5f 100644
--- a/rest_framework/authentication.py
+++ b/rest_framework/authentication.py
@@ -9,7 +9,7 @@ from django.core.exceptions import ImproperlyConfigured
from rest_framework import exceptions, HTTP_HEADER_ENCODING
from rest_framework.compat import CsrfViewMiddleware
from rest_framework.compat import oauth, oauth_provider, oauth_provider_store
-from rest_framework.compat import oauth2_provider, provider_now
+from rest_framework.compat import oauth2_provider, provider_now, check_nonce
from rest_framework.authtoken.models import Token
@@ -281,7 +281,9 @@ class OAuthAuthentication(BaseAuthentication):
"""
Checks nonce of request, and return True if valid.
"""
- return oauth_provider_store.check_nonce(request, oauth_request, oauth_request['oauth_nonce'])
+ oauth_nonce = oauth_request['oauth_nonce']
+ oauth_timestamp = oauth_request['oauth_timestamp']
+ return check_nonce(request, oauth_request, oauth_nonce, oauth_timestamp)
class OAuth2Authentication(BaseAuthentication):
diff --git a/rest_framework/authtoken/models.py b/rest_framework/authtoken/models.py
index 7601f5b7..024f62bf 100644
--- a/rest_framework/authtoken/models.py
+++ b/rest_framework/authtoken/models.py
@@ -1,11 +1,17 @@
import uuid
import hmac
from hashlib import sha1
-from rest_framework.compat import AUTH_USER_MODEL
from django.conf import settings
from django.db import models
+# Prior to Django 1.5, the AUTH_USER_MODEL setting does not exist.
+# Note that we don't perform this code in the compat module due to
+# bug report #1297
+# See: https://github.com/tomchristie/django-rest-framework/issues/1297
+AUTH_USER_MODEL = getattr(settings, 'AUTH_USER_MODEL', 'auth.User')
+
+
class Token(models.Model):
"""
The default authorization token model.
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index 581e29fc..b69749fe 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -7,6 +7,7 @@ versions of django/python, and compatibility wrappers around optional packages.
from __future__ import unicode_literals
import django
+import inspect
from django.core.exceptions import ImproperlyConfigured
from django.conf import settings
@@ -69,6 +70,13 @@ try:
except ImportError:
import urlparse
+# UserDict moves in Python 3
+try:
+ from UserDict import UserDict
+ from UserDict import DictMixin
+except ImportError:
+ from collections import UserDict
+ from collections import MutableMapping as DictMixin
# Try to import PIL in either of the two ways it can end up installed.
try:
@@ -96,13 +104,6 @@ def get_concrete_model(model_cls):
return model_cls
-# Django 1.5 add support for custom auth user model
-if django.VERSION >= (1, 5):
- AUTH_USER_MODEL = settings.AUTH_USER_MODEL
-else:
- AUTH_USER_MODEL = 'auth.User'
-
-
if django.VERSION >= (1, 5):
from django.views.generic import View
else:
@@ -529,9 +530,23 @@ except ImportError:
try:
import oauth_provider
from oauth_provider.store import store as oauth_provider_store
+
+ # check_nonce's calling signature in django-oauth-plus changes sometime
+ # between versions 2.0 and 2.2.1
+ def check_nonce(request, oauth_request, oauth_nonce, oauth_timestamp):
+ check_nonce_args = inspect.getargspec(oauth_provider_store.check_nonce).args
+ if 'timestamp' in check_nonce_args:
+ return oauth_provider_store.check_nonce(
+ request, oauth_request, oauth_nonce, oauth_timestamp
+ )
+ return oauth_provider_store.check_nonce(
+ request, oauth_request, oauth_nonce
+ )
+
except (ImportError, ImproperlyConfigured):
oauth_provider = None
oauth_provider_store = None
+ check_nonce = None
# OAuth 2 support is optional
try:
diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py
index 425a7214..4276625a 100644
--- a/rest_framework/exceptions.py
+++ b/rest_framework/exceptions.py
@@ -6,6 +6,7 @@ In addition Django's built in 403 and 404 exceptions are handled.
"""
from __future__ import unicode_literals
from rest_framework import status
+import math
class APIException(Exception):
@@ -13,40 +14,32 @@ class APIException(Exception):
Base class for REST framework exceptions.
Subclasses should provide `.status_code` and `.detail` properties.
"""
- pass
+ status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
+ default_detail = ''
+
+ def __init__(self, detail=None):
+ self.detail = detail or self.default_detail
class ParseError(APIException):
status_code = status.HTTP_400_BAD_REQUEST
default_detail = 'Malformed request.'
- def __init__(self, detail=None):
- self.detail = detail or self.default_detail
-
class AuthenticationFailed(APIException):
status_code = status.HTTP_401_UNAUTHORIZED
default_detail = 'Incorrect authentication credentials.'
- def __init__(self, detail=None):
- self.detail = detail or self.default_detail
-
class NotAuthenticated(APIException):
status_code = status.HTTP_401_UNAUTHORIZED
default_detail = 'Authentication credentials were not provided.'
- def __init__(self, detail=None):
- self.detail = detail or self.default_detail
-
class PermissionDenied(APIException):
status_code = status.HTTP_403_FORBIDDEN
default_detail = 'You do not have permission to perform this action.'
- def __init__(self, detail=None):
- self.detail = detail or self.default_detail
-
class MethodNotAllowed(APIException):
status_code = status.HTTP_405_METHOD_NOT_ALLOWED
@@ -75,14 +68,14 @@ class UnsupportedMediaType(APIException):
class Throttled(APIException):
status_code = status.HTTP_429_TOO_MANY_REQUESTS
- default_detail = "Request was throttled."
+ default_detail = 'Request was throttled.'
extra_detail = "Expected available in %d second%s."
def __init__(self, wait=None, detail=None):
- import math
- self.wait = wait and math.ceil(wait) or None
- if wait is not None:
- format = detail or self.default_detail + self.extra_detail
- self.detail = format % (self.wait, self.wait != 1 and 's' or '')
- else:
+ if wait is None:
self.detail = detail or self.default_detail
+ self.wait = None
+ else:
+ format = (detail or self.default_detail) + self.extra_detail
+ self.detail = format % (wait, wait != 1 and 's' or '')
+ self.wait = math.ceil(wait)
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index e23fc001..2f475d6e 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -246,6 +246,7 @@ class WritableField(Field):
"""
Base for read/write fields.
"""
+ write_only = False
default_validators = []
default_error_messages = {
'required': _('This field is required.'),
@@ -255,7 +256,7 @@ class WritableField(Field):
default = None
def __init__(self, source=None, label=None, help_text=None,
- read_only=False, required=None,
+ read_only=False, write_only=False, required=None,
validators=[], error_messages=None, widget=None,
default=None, blank=None):
@@ -269,6 +270,10 @@ class WritableField(Field):
super(WritableField, self).__init__(source=source, label=label, help_text=help_text)
self.read_only = read_only
+ self.write_only = write_only
+
+ assert not (read_only and write_only), "Cannot set read_only=True and write_only=True"
+
if required is None:
self.required = not(read_only)
else:
@@ -318,6 +323,11 @@ class WritableField(Field):
if errors:
raise ValidationError(errors)
+ def field_to_native(self, obj, field_name):
+ if self.write_only:
+ return None
+ return super(WritableField, self).field_to_native(obj, field_name)
+
def field_from_native(self, data, files, field_name, into):
"""
Given a dictionary and a field name, updates the dictionary `into`,
@@ -428,7 +438,7 @@ class BooleanField(WritableField):
def field_from_native(self, data, files, field_name, into):
# HTML checkboxes do not explicitly represent unchecked as `False`
# we deal with that here...
- if isinstance(data, QueryDict):
+ if isinstance(data, QueryDict) and self.default is None:
self.default = False
return super(BooleanField, self).field_from_native(
@@ -497,6 +507,7 @@ class ChoiceField(WritableField):
}
def __init__(self, choices=(), *args, **kwargs):
+ self.empty = kwargs.pop('empty', '')
super(ChoiceField, self).__init__(*args, **kwargs)
self.choices = choices
if not self.required:
@@ -513,6 +524,11 @@ class ChoiceField(WritableField):
choices = property(_get_choices, _set_choices)
+ def metadata(self):
+ data = super(ChoiceField, self).metadata()
+ data['choices'] = [{'value': v, 'display_name': n} for v, n in self.choices]
+ return data
+
def validate(self, value):
"""
Validates that the input is in self.choices.
@@ -537,9 +553,10 @@ class ChoiceField(WritableField):
return False
def from_native(self, value):
- if value in validators.EMPTY_VALUES:
- return None
- return super(ChoiceField, self).from_native(value)
+ value = super(ChoiceField, self).from_native(value)
+ if value == self.empty or value in validators.EMPTY_VALUES:
+ return self.empty
+ return value
class EmailField(CharField):
@@ -964,7 +981,7 @@ class ImageField(FileField):
return None
from rest_framework.compat import Image
- assert Image is not None, 'PIL must be installed for ImageField support'
+ assert Image is not None, 'Either Pillow or PIL must be installed for ImageField support.'
# We need to get a file object for PIL. We might have a path or we might
# have to read the data into memory.
diff --git a/rest_framework/filters.py b/rest_framework/filters.py
index e287a168..de91caed 100644
--- a/rest_framework/filters.py
+++ b/rest_framework/filters.py
@@ -3,6 +3,7 @@ Provides generic filtering backends that can be used to filter the results
returned by list views.
"""
from __future__ import unicode_literals
+from django.core.exceptions import ImproperlyConfigured
from django.db import models
from rest_framework.compat import django_filters, six, guardian, get_model_name
from functools import reduce
@@ -107,6 +108,7 @@ class SearchFilter(BaseFilterBackend):
class OrderingFilter(BaseFilterBackend):
ordering_param = 'ordering' # The URL query parameter used for the ordering.
+ ordering_fields = None
def get_ordering(self, request):
"""
@@ -122,16 +124,34 @@ class OrderingFilter(BaseFilterBackend):
return (ordering,)
return ordering
- def remove_invalid_fields(self, queryset, ordering):
- field_names = [field.name for field in queryset.model._meta.fields]
- return [term for term in ordering if term.lstrip('-') in field_names]
+ def remove_invalid_fields(self, queryset, ordering, view):
+ valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)
+
+ if valid_fields is None:
+ # Default to allowing filtering on serializer fields
+ serializer_class = getattr(view, 'serializer_class')
+ if serializer_class is None:
+ msg = ("Cannot use %s on a view which does not have either a "
+ "'serializer_class' or 'ordering_fields' attribute.")
+ raise ImproperlyConfigured(msg % self.__class__.__name__)
+ valid_fields = [
+ field.source or field_name
+ for field_name, field in serializer_class().fields.items()
+ if not getattr(field, 'write_only', False)
+ ]
+ elif valid_fields == '__all__':
+ # View explictly allows filtering on any model field
+ valid_fields = [field.name for field in queryset.model._meta.fields]
+ valid_fields += queryset.query.aggregates.keys()
+
+ return [term for term in ordering if term.lstrip('-') in valid_fields]
def filter_queryset(self, request, queryset, view):
ordering = self.get_ordering(request)
if ordering:
# Skip any incorrect parameters
- ordering = self.remove_invalid_fields(queryset, ordering)
+ ordering = self.remove_invalid_fields(queryset, ordering, view)
if not ordering:
# Use 'ordering' attribute by default
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index 6d204cf5..fd411ad3 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -175,6 +175,14 @@ class GenericAPIView(views.APIView):
method if you want to apply the configured filtering backend to the
default queryset.
"""
+ for backend in self.get_filter_backends():
+ queryset = backend().filter_queryset(self.request, queryset, self)
+ return queryset
+
+ def get_filter_backends(self):
+ """
+ Returns the list of filter backends that this view requires.
+ """
filter_backends = self.filter_backends or []
if not filter_backends and self.filter_backend:
warnings.warn(
@@ -185,10 +193,8 @@ class GenericAPIView(views.APIView):
PendingDeprecationWarning, stacklevel=2
)
filter_backends = [self.filter_backend]
+ return filter_backends
- for backend in filter_backends:
- queryset = backend().filter_queryset(self.request, queryset, self)
- return queryset
########################
### The following methods provide default implementations
@@ -338,6 +344,18 @@ class GenericAPIView(views.APIView):
"""
pass
+ def pre_delete(self, obj):
+ """
+ Placeholder method for calling before deleting an object.
+ """
+ pass
+
+ def post_delete(self, obj):
+ """
+ Placeholder method for calling after saving an object.
+ """
+ pass
+
def metadata(self, request):
"""
Return a dictionary of metadata about the view.
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index 4606c78b..5fbcf700 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -6,10 +6,12 @@ which allows mixin classes to be composed in interesting ways.
"""
from __future__ import unicode_literals
+from django.core.exceptions import ValidationError
from django.http import Http404
from rest_framework import status
from rest_framework.response import Response
from rest_framework.request import clone_request
+from rest_framework.settings import api_settings
import warnings
@@ -59,7 +61,7 @@ class CreateModelMixin(object):
def get_success_headers(self, data):
try:
- return {'Location': data['url']}
+ return {'Location': data[api_settings.URL_FIELD_NAME]}
except (TypeError, KeyError):
return {}
@@ -127,7 +129,12 @@ class UpdateModelMixin(object):
files=request.FILES, partial=partial)
if serializer.is_valid():
- self.pre_save(serializer.object)
+ try:
+ self.pre_save(serializer.object)
+ except ValidationError as err:
+ # full_clean on model instance may be called in pre_save, so we
+ # have to handle eventual errors.
+ return Response(err.message_dict, status=status.HTTP_400_BAD_REQUEST)
self.object = serializer.save(**save_kwargs)
self.post_save(self.object, created=created)
return Response(serializer.data, status=success_status_code)
@@ -186,5 +193,7 @@ class DestroyModelMixin(object):
"""
def destroy(self, request, *args, **kwargs):
obj = self.get_object()
+ self.pre_delete(obj)
obj.delete()
+ self.post_delete(obj)
return Response(status=status.HTTP_204_NO_CONTENT)
diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py
index 98fc0341..f1b3e38d 100644
--- a/rest_framework/parsers.py
+++ b/rest_framework/parsers.py
@@ -83,7 +83,7 @@ class YAMLParser(BaseParser):
data = stream.read().decode(encoding)
return yaml.safe_load(data)
except (ValueError, yaml.parser.ParserError) as exc:
- raise ParseError('YAML parse error - %s' % six.u(exc))
+ raise ParseError('YAML parse error - %s' % six.text_type(exc))
class FormParser(BaseParser):
@@ -153,7 +153,7 @@ class XMLParser(BaseParser):
try:
tree = etree.parse(stream, parser=parser, forbid_dtd=True)
except (etree.ParseError, ValueError) as exc:
- raise ParseError('XML parse error - %s' % six.u(exc))
+ raise ParseError('XML parse error - %s' % six.text_type(exc))
data = self._xml_convert(tree.getroot())
return data
diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py
index ab6655e7..f24a5123 100644
--- a/rest_framework/permissions.py
+++ b/rest_framework/permissions.py
@@ -54,9 +54,7 @@ class IsAuthenticated(BasePermission):
"""
def has_permission(self, request, view):
- if request.user and request.user.is_authenticated():
- return True
- return False
+ return request.user and request.user.is_authenticated()
class IsAdminUser(BasePermission):
@@ -65,9 +63,7 @@ class IsAdminUser(BasePermission):
"""
def has_permission(self, request, view):
- if request.user and request.user.is_staff:
- return True
- return False
+ return request.user and request.user.is_staff
class IsAuthenticatedOrReadOnly(BasePermission):
@@ -76,11 +72,9 @@ class IsAuthenticatedOrReadOnly(BasePermission):
"""
def has_permission(self, request, view):
- if (request.method in SAFE_METHODS or
- request.user and
- request.user.is_authenticated()):
- return True
- return False
+ return (request.method in SAFE_METHODS or
+ request.user and
+ request.user.is_authenticated())
class DjangoModelPermissions(BasePermission):
@@ -138,11 +132,9 @@ class DjangoModelPermissions(BasePermission):
perms = self.get_required_permissions(request.method, model_cls)
- if (request.user and
+ return (request.user and
(request.user.is_authenticated() or not self.authenticated_users_only) and
- request.user.has_perms(perms)):
- return True
- return False
+ request.user.has_perms(perms))
class DjangoModelPermissionsOrAnonReadOnly(DjangoModelPermissions):
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index 35c00bf1..02185c2f 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -65,16 +65,11 @@ class RelatedField(WritableField):
def initialize(self, parent, field_name):
super(RelatedField, self).initialize(parent, field_name)
if self.queryset is None and not self.read_only:
- try:
- manager = getattr(self.parent.opts.model, self.source or field_name)
- if hasattr(manager, 'related'): # Forward
- self.queryset = manager.related.model._default_manager.all()
- else: # Reverse
- self.queryset = manager.field.rel.to._default_manager.all()
- except Exception:
- msg = ('Serializer related fields must include a `queryset`' +
- ' argument or set `read_only=True')
- raise Exception(msg)
+ manager = getattr(self.parent.opts.model, self.source or field_name)
+ if hasattr(manager, 'related'): # Forward
+ self.queryset = manager.related.model._default_manager.all()
+ else: # Reverse
+ self.queryset = manager.field.rel.to._default_manager.all()
### We need this stuff to make form choices work...
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index fe4f43d4..2fdd3337 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -20,6 +20,7 @@ from rest_framework.compat import StringIO
from rest_framework.compat import six
from rest_framework.compat import smart_text
from rest_framework.compat import yaml
+from rest_framework.exceptions import ParseError
from rest_framework.settings import api_settings
from rest_framework.request import is_form_media_type, override_method
from rest_framework.utils import encoders
@@ -420,8 +421,12 @@ class BrowsableAPIRenderer(BaseRenderer):
In the absence of the View having an associated form then return None.
"""
if request.method == method:
- data = request.DATA
- files = request.FILES
+ try:
+ data = request.DATA
+ files = request.FILES
+ except ParseError:
+ data = None
+ files = None
else:
data = None
files = None
diff --git a/rest_framework/request.py b/rest_framework/request.py
index a570dcae..ca70b49e 100644
--- a/rest_framework/request.py
+++ b/rest_framework/request.py
@@ -223,7 +223,7 @@ class Request(object):
def user(self, value):
"""
Sets the user on the current request. This is necessary to maintain
- compatilbility with django.contrib.auth where the user proprety is
+ compatibility with django.contrib.auth where the user property is
set in the login and logout functions.
"""
self._user = value
@@ -333,7 +333,7 @@ class Request(object):
self._CONTENT_PARAM in self._data and
self._CONTENTTYPE_PARAM in self._data):
self._content_type = self._data[self._CONTENTTYPE_PARAM]
- self._stream = BytesIO(self._data[self._CONTENT_PARAM].encode(HTTP_HEADER_ENCODING))
+ self._stream = BytesIO(self._data[self._CONTENT_PARAM].encode(self.parser_context['encoding']))
self._data, self._files = (Empty, Empty)
def _parse(self):
@@ -355,7 +355,16 @@ class Request(object):
if not parser:
raise exceptions.UnsupportedMediaType(media_type)
- parsed = parser.parse(stream, media_type, self.parser_context)
+ try:
+ parsed = parser.parse(stream, media_type, self.parser_context)
+ except:
+ # If we get an exception during parsing, fill in empty data and
+ # re-raise. Ensures we don't simply repeat the error when
+ # attempting to render the browsable renderer response, or when
+ # logging the request or similar.
+ self._data = QueryDict('', self._request._encoding)
+ self._files = MultiValueDict()
+ raise
# Parser classes may return the raw data, or a
# DataAndFiles object. Unpack the result as required.
diff --git a/rest_framework/response.py b/rest_framework/response.py
index 5877c8a3..1dc6abcf 100644
--- a/rest_framework/response.py
+++ b/rest_framework/response.py
@@ -61,6 +61,10 @@ class Response(SimpleTemplateResponse):
assert charset, 'renderer returned unicode, and did not specify ' \
'a charset value.'
return bytes(ret.encode(charset))
+
+ if not ret:
+ del self['Content-Type']
+
return ret
@property
diff --git a/rest_framework/routers.py b/rest_framework/routers.py
index 3fee1e49..97b35c10 100644
--- a/rest_framework/routers.py
+++ b/rest_framework/routers.py
@@ -184,18 +184,24 @@ class SimpleRouter(BaseRouter):
bound_methods[method] = action
return bound_methods
- def get_lookup_regex(self, viewset):
+ def get_lookup_regex(self, viewset, lookup_prefix=''):
"""
Given a viewset, return the portion of URL regex that is used
to match against a single instance.
+
+ Note that lookup_prefix is not used directly inside REST rest_framework
+ itself, but is required in order to nicely support nested router
+ implementations, such as drf-nested-routers.
+
+ https://github.com/alanjds/drf-nested-routers
"""
if self.trailing_slash:
- base_regex = '(?P<{lookup_field}>[^/]+)'
+ base_regex = '(?P<{lookup_prefix}{lookup_field}>[^/]+)'
else:
# Don't consume `.json` style suffixes
- base_regex = '(?P<{lookup_field}>[^/.]+)'
+ base_regex = '(?P<{lookup_prefix}{lookup_field}>[^/.]+)'
lookup_field = getattr(viewset, 'lookup_field', 'pk')
- return base_regex.format(lookup_field=lookup_field)
+ return base_regex.format(lookup_field=lookup_field, lookup_prefix=lookup_prefix)
def get_urls(self):
"""
diff --git a/rest_framework/runtests/settings.py b/rest_framework/runtests/settings.py
index be721658..3fc0eb2f 100644
--- a/rest_framework/runtests/settings.py
+++ b/rest_framework/runtests/settings.py
@@ -100,6 +100,9 @@ INSTALLED_APPS = (
'rest_framework',
'rest_framework.authtoken',
'rest_framework.tests',
+ 'rest_framework.tests.accounts',
+ 'rest_framework.tests.records',
+ 'rest_framework.tests.users',
)
# OAuth is optional and won't work if there is no oauth_provider & oauth2
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 4210d058..536b040b 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -6,13 +6,14 @@ form encoded input.
Serialization in REST framework is a two-phase process:
1. Serializers marshal between complex types like model instances, and
-python primatives.
-2. The process of marshalling between python primatives and request and
+python primitives.
+2. The process of marshalling between python primitives and request and
response content is handled by parsers and renderers.
"""
from __future__ import unicode_literals
import copy
import datetime
+import inspect
import types
from decimal import Decimal
from django.core.paginator import Page
@@ -20,6 +21,8 @@ from django.db import models
from django.forms import widgets
from django.utils.datastructures import SortedDict
from rest_framework.compat import get_concrete_model, six
+from rest_framework.settings import api_settings
+
# Note: We do the following so that users of the framework can use this style:
#
@@ -32,6 +35,27 @@ from rest_framework.relations import *
from rest_framework.fields import *
+def _resolve_model(obj):
+ """
+ Resolve supplied `obj` to a Django model class.
+
+ `obj` must be a Django model class itself, or a string
+ representation of one. Useful in situtations like GH #1225 where
+ Django may not have resolved a string-based reference to a model in
+ another model's foreign key definition.
+
+ String representations should have the format:
+ 'appname.ModelName'
+ """
+ if type(obj) == str and len(obj.split('.')) == 2:
+ app_name, model_name = obj.split('.')
+ return models.get_model(app_name, model_name)
+ elif inspect.isclass(obj) and issubclass(obj, models.Model):
+ return obj
+ else:
+ raise ValueError("{0} is not a Django model".format(obj))
+
+
def pretty_name(name):
"""Converts 'first_name' to 'First name'"""
if not name:
@@ -42,6 +66,7 @@ def pretty_name(name):
class RelationsList(list):
_deleted = []
+
class NestedValidationError(ValidationError):
"""
The default ValidationError behavior is to stringify each item in the list
@@ -56,9 +81,13 @@ class NestedValidationError(ValidationError):
def __init__(self, message):
if isinstance(message, dict):
- self.messages = [message]
+ self._messages = [message]
else:
- self.messages = message
+ self._messages = message
+
+ @property
+ def messages(self):
+ return self._messages
class DictWithMetadata(dict):
@@ -321,12 +350,13 @@ class BaseSerializer(WritableField):
method = getattr(self, 'transform_%s' % field_name, None)
if callable(method):
value = method(obj, value)
- ret[key] = value
+ if not getattr(field, 'write_only', False):
+ ret[key] = value
ret.fields[key] = self.augment_field(field, field_name, key, value)
return ret
- def from_native(self, data, files):
+ def from_native(self, data, files=None):
"""
Deserialize primitives -> objects.
"""
@@ -356,6 +386,9 @@ class BaseSerializer(WritableField):
Override default so that the serializer can be used as a nested field
across relationships.
"""
+ if self.write_only:
+ return None
+
if self.source == '*':
return self.to_native(obj)
@@ -407,11 +440,19 @@ class BaseSerializer(WritableField):
# Set the serializer object if it exists
obj = get_component(self.parent.object, self.source or field_name) if self.parent.object else None
- obj = obj.all() if is_simple_callable(getattr(obj, 'all', None)) else obj
+
+ # If we have a model manager or similar object then we need
+ # to iterate through each instance.
+ if (self.many and
+ not hasattr(obj, '__iter__') and
+ is_simple_callable(getattr(obj, 'all', None))):
+ obj = obj.all()
if self.source == '*':
if value:
- into.update(value)
+ reverted_data = self.restore_fields(value, {})
+ if not self._errors:
+ into.update(reverted_data)
else:
if value in (None, ''):
into[(self.source or field_name)] = None
@@ -580,6 +621,7 @@ class ModelSerializerOptions(SerializerOptions):
super(ModelSerializerOptions, self).__init__(meta)
self.model = getattr(meta, 'model', None)
self.read_only_fields = getattr(meta, 'read_only_fields', ())
+ self.write_only_fields = getattr(meta, 'write_only_fields', ())
class ModelSerializer(Serializer):
@@ -606,6 +648,7 @@ class ModelSerializer(Serializer):
models.TextField: CharField,
models.CommaSeparatedIntegerField: CharField,
models.BooleanField: BooleanField,
+ models.NullBooleanField: BooleanField,
models.FileField: FileField,
models.ImageField: ImageField,
}
@@ -642,7 +685,7 @@ class ModelSerializer(Serializer):
if model_field.rel:
to_many = isinstance(model_field,
models.fields.related.ManyToManyField)
- related_model = model_field.rel.to
+ related_model = _resolve_model(model_field.rel.to)
if to_many and not model_field.rel.through._meta.auto_created:
has_through_model = True
@@ -699,7 +742,9 @@ class ModelSerializer(Serializer):
is_m2m = isinstance(relation.field,
models.fields.related.ManyToManyField)
- if is_m2m and not relation.field.rel.through._meta.auto_created:
+ if (is_m2m and
+ hasattr(relation.field.rel, 'through') and
+ not relation.field.rel.through._meta.auto_created):
has_through_model = True
if nested:
@@ -716,17 +761,29 @@ class ModelSerializer(Serializer):
# Add the `read_only` flag to any fields that have bee specified
# in the `read_only_fields` option
for field_name in self.opts.read_only_fields:
- assert field_name not in self.base_fields.keys(), \
- "field '%s' on serializer '%s' specified in " \
- "`read_only_fields`, but also added " \
- "as an explicit field. Remove it from `read_only_fields`." % \
- (field_name, self.__class__.__name__)
- assert field_name in ret, \
- "Non-existant field '%s' specified in `read_only_fields` " \
- "on serializer '%s'." % \
- (field_name, self.__class__.__name__)
+ assert field_name not in self.base_fields.keys(), (
+ "field '%s' on serializer '%s' specified in "
+ "`read_only_fields`, but also added "
+ "as an explicit field. Remove it from `read_only_fields`." %
+ (field_name, self.__class__.__name__))
+ assert field_name in ret, (
+ "Non-existant field '%s' specified in `read_only_fields` "
+ "on serializer '%s'." %
+ (field_name, self.__class__.__name__))
ret[field_name].read_only = True
+ for field_name in self.opts.write_only_fields:
+ assert field_name not in self.base_fields.keys(), (
+ "field '%s' on serializer '%s' specified in "
+ "`write_only_fields`, but also added "
+ "as an explicit field. Remove it from `write_only_fields`." %
+ (field_name, self.__class__.__name__))
+ assert field_name in ret, (
+ "Non-existant field '%s' specified in `write_only_fields` "
+ "on serializer '%s'." %
+ (field_name, self.__class__.__name__))
+ ret[field_name].write_only = True
+
return ret
def get_pk_field(self, model_field):
@@ -794,6 +851,8 @@ class ModelSerializer(Serializer):
# TODO: TypedChoiceField?
if model_field.flatchoices: # This ModelField contains choices
kwargs['choices'] = model_field.flatchoices
+ if model_field.null:
+ kwargs['empty'] = None
return ChoiceField(**kwargs)
# put this below the ChoiceField because min_value isn't a valid initializer
@@ -865,18 +924,18 @@ class ModelSerializer(Serializer):
# Reverse fk or one-to-one relations
for (obj, model) in meta.get_all_related_objects_with_model():
- field_name = obj.field.related_query_name()
+ field_name = obj.get_accessor_name()
if field_name in attrs:
related_data[field_name] = attrs.pop(field_name)
# Reverse m2m relations
for (obj, model) in meta.get_all_related_m2m_objects_with_model():
- field_name = obj.field.related_query_name()
+ field_name = obj.get_accessor_name()
if field_name in attrs:
m2m_data[field_name] = attrs.pop(field_name)
# Forward m2m relations
- for field in meta.many_to_many:
+ for field in meta.many_to_many + meta.virtual_fields:
if field.name in attrs:
m2m_data[field.name] = attrs.pop(field.name)
@@ -889,7 +948,10 @@ class ModelSerializer(Serializer):
# Update an existing instance...
if instance is not None:
for key, val in attrs.items():
- setattr(instance, key, val)
+ try:
+ setattr(instance, key, val)
+ except ValueError:
+ self._errors[key] = self.error_messages['required']
# ...or create a new instance
else:
@@ -933,11 +995,16 @@ class ModelSerializer(Serializer):
del(obj._m2m_data)
if getattr(obj, '_related_data', None):
+ related_fields = dict([
+ (field.get_accessor_name(), field)
+ for field, model
+ in obj._meta.get_all_related_objects_with_model()
+ ])
for accessor_name, related in obj._related_data.items():
if isinstance(related, RelationsList):
# Nested reverse fk relationship
for related_item in related:
- fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name
+ fk_field = related_fields[accessor_name].field.name
setattr(related_item, fk_field, obj)
self.save_object(related_item)
@@ -964,6 +1031,7 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
super(HyperlinkedModelSerializerOptions, self).__init__(meta)
self.view_name = getattr(meta, 'view_name', None)
self.lookup_field = getattr(meta, 'lookup_field', None)
+ self.url_field_name = getattr(meta, 'url_field_name', api_settings.URL_FIELD_NAME)
class HyperlinkedModelSerializer(ModelSerializer):
@@ -982,13 +1050,13 @@ class HyperlinkedModelSerializer(ModelSerializer):
if self.opts.view_name is None:
self.opts.view_name = self._get_default_view_name(self.opts.model)
- if 'url' not in fields:
+ if self.opts.url_field_name not in fields:
url_field = self._hyperlink_identify_field_class(
view_name=self.opts.view_name,
lookup_field=self.opts.lookup_field
)
ret = self._dict_class()
- ret['url'] = url_field
+ ret[self.opts.url_field_name] = url_field
ret.update(fields)
fields = ret
@@ -1024,7 +1092,7 @@ class HyperlinkedModelSerializer(ModelSerializer):
We need to override the default, to use the url as the identity.
"""
try:
- return data.get('url', None)
+ return data.get(self.opts.url_field_name, None)
except AttributeError:
return None
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index 8abaf140..ce171d6d 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -95,6 +95,7 @@ DEFAULTS = {
'URL_FORMAT_OVERRIDE': 'format',
'FORMAT_SUFFIX_KWARG': 'format',
+ 'URL_FIELD_NAME': 'url',
# Input and output formats
'DATE_INPUT_FORMATS': (
diff --git a/rest_framework/status.py b/rest_framework/status.py
index b9f249f9..76435371 100644
--- a/rest_framework/status.py
+++ b/rest_framework/status.py
@@ -6,6 +6,23 @@ And RFC 6585 - http://tools.ietf.org/html/rfc6585
"""
from __future__ import unicode_literals
+
+def is_informational(code):
+ return code >= 100 and code <= 199
+
+def is_success(code):
+ return code >= 200 and code <= 299
+
+def is_redirect(code):
+ return code >= 300 and code <= 399
+
+def is_client_error(code):
+ return code >= 400 and code <= 499
+
+def is_server_error(code):
+ return code >= 500 and code <= 599
+
+
HTTP_100_CONTINUE = 100
HTTP_101_SWITCHING_PROTOCOLS = 101
HTTP_200_OK = 200
diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html
index 7ab17dff..d19d5a2b 100644
--- a/rest_framework/templates/rest_framework/base.html
+++ b/rest_framework/templates/rest_framework/base.html
@@ -33,7 +33,7 @@
<div class="navbar-inner">
<div class="container-fluid">
<span href="/">
- {% block branding %}<a class='brand' href='http://django-rest-framework.org'>Django REST framework <span class="version">{{ version }}</span></a>{% endblock %}
+ {% block branding %}<a class='brand' rel="nofollow" href='http://www.django-rest-framework.org'>Django REST framework <span class="version">{{ version }}</span></a>{% endblock %}
</span>
<ul class="nav pull-right">
{% block userlinks %}
@@ -221,9 +221,6 @@
</div><!-- ./wrapper -->
{% block footer %}
- <!--<div id="footer">
- <a class="powered-by" href='http://django-rest-framework.org'>Django REST framework</a>
- </div>-->
{% endblock %}
{% block script %}
diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py
index e9c1cdd5..83c046f9 100644
--- a/rest_framework/templatetags/rest_framework.py
+++ b/rest_framework/templatetags/rest_framework.py
@@ -2,6 +2,7 @@ from __future__ import unicode_literals, absolute_import
from django import template
from django.core.urlresolvers import reverse, NoReverseMatch
from django.http import QueryDict
+from django.utils.encoding import iri_to_uri
from django.utils.html import escape
from django.utils.safestring import SafeData, mark_safe
from rest_framework.compat import urlparse, force_text, six, smart_urlquote
@@ -144,7 +145,9 @@ def add_query_param(request, key, val):
"""
Add a query parameter to the current request url, and return the new url.
"""
- return replace_query_param(request.get_full_path(), key, val)
+ iri = request.get_full_path()
+ uri = iri_to_uri(iri)
+ return replace_query_param(uri, key, val)
@register.filter
diff --git a/rest_framework/tests/accounts/__init__.py b/rest_framework/tests/accounts/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/rest_framework/tests/accounts/__init__.py
diff --git a/rest_framework/tests/accounts/models.py b/rest_framework/tests/accounts/models.py
new file mode 100644
index 00000000..525e601b
--- /dev/null
+++ b/rest_framework/tests/accounts/models.py
@@ -0,0 +1,8 @@
+from django.db import models
+
+from rest_framework.tests.users.models import User
+
+
+class Account(models.Model):
+ owner = models.ForeignKey(User, related_name='accounts_owned')
+ admins = models.ManyToManyField(User, blank=True, null=True, related_name='accounts_administered')
diff --git a/rest_framework/tests/accounts/serializers.py b/rest_framework/tests/accounts/serializers.py
new file mode 100644
index 00000000..a27b9ca6
--- /dev/null
+++ b/rest_framework/tests/accounts/serializers.py
@@ -0,0 +1,11 @@
+from rest_framework import serializers
+
+from rest_framework.tests.accounts.models import Account
+from rest_framework.tests.users.serializers import UserSerializer
+
+
+class AccountSerializer(serializers.ModelSerializer):
+ admins = UserSerializer(many=True)
+
+ class Meta:
+ model = Account
diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py
index 1598ecd9..32a726c0 100644
--- a/rest_framework/tests/models.py
+++ b/rest_framework/tests/models.py
@@ -70,6 +70,7 @@ class Comment(RESTFrameworkModel):
class ActionItem(RESTFrameworkModel):
title = models.CharField(max_length=200)
+ started = models.NullBooleanField(default=False)
done = models.BooleanField(default=False)
info = CustomField(default='---', max_length=12)
diff --git a/rest_framework/tests/records/__init__.py b/rest_framework/tests/records/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/rest_framework/tests/records/__init__.py
diff --git a/rest_framework/tests/records/models.py b/rest_framework/tests/records/models.py
new file mode 100644
index 00000000..76954807
--- /dev/null
+++ b/rest_framework/tests/records/models.py
@@ -0,0 +1,6 @@
+from django.db import models
+
+
+class Record(models.Model):
+ account = models.ForeignKey('accounts.Account', blank=True, null=True)
+ owner = models.ForeignKey('users.User', blank=True, null=True)
diff --git a/rest_framework/tests/test_authentication.py b/rest_framework/tests/test_authentication.py
index a44813b6..f072b81b 100644
--- a/rest_framework/tests/test_authentication.py
+++ b/rest_framework/tests/test_authentication.py
@@ -249,7 +249,7 @@ class OAuthTests(TestCase):
def setUp(self):
# these imports are here because oauth is optional and hiding them in try..except block or compat
# could obscure problems if something breaks
- from oauth_provider.models import Consumer, Resource
+ from oauth_provider.models import Consumer, Scope
from oauth_provider.models import Token as OAuthToken
from oauth_provider import consts
@@ -269,8 +269,8 @@ class OAuthTests(TestCase):
self.consumer = Consumer.objects.create(key=self.CONSUMER_KEY, secret=self.CONSUMER_SECRET,
name='example', user=self.user, status=self.consts.ACCEPTED)
- self.resource = Resource.objects.create(name="resource name", url="api/")
- self.token = OAuthToken.objects.create(user=self.user, consumer=self.consumer, resource=self.resource,
+ self.scope = Scope.objects.create(name="resource name", url="api/")
+ self.token = OAuthToken.objects.create(user=self.user, consumer=self.consumer, scope=self.scope,
token_type=OAuthToken.ACCESS, key=self.TOKEN_KEY, secret=self.TOKEN_SECRET, is_approved=True
)
@@ -362,7 +362,8 @@ class OAuthTests(TestCase):
def test_post_form_with_urlencoded_parameters(self):
"""Ensure POSTing with x-www-form-urlencoded auth parameters passes"""
params = self._create_authorization_url_parameters()
- response = self.csrf_client.post('/oauth/', params)
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', params, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
@unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
@@ -397,10 +398,10 @@ class OAuthTests(TestCase):
@unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
@unittest.skipUnless(oauth, 'oauth2 not installed')
def test_get_form_with_readonly_resource_passing_auth(self):
- """Ensure POSTing with a readonly resource instead of a write scope fails"""
+ """Ensure POSTing with a readonly scope instead of a write scope fails"""
read_only_access_token = self.token
- read_only_access_token.resource.is_readonly = True
- read_only_access_token.resource.save()
+ read_only_access_token.scope.is_readonly = True
+ read_only_access_token.scope.save()
params = self._create_authorization_url_parameters()
response = self.csrf_client.get('/oauth-with-scope/', params)
self.assertEqual(response.status_code, 200)
@@ -410,8 +411,8 @@ class OAuthTests(TestCase):
def test_post_form_with_readonly_resource_failing_auth(self):
"""Ensure POSTing with a readonly resource instead of a write scope fails"""
read_only_access_token = self.token
- read_only_access_token.resource.is_readonly = True
- read_only_access_token.resource.save()
+ read_only_access_token.scope.is_readonly = True
+ read_only_access_token.scope.save()
params = self._create_authorization_url_parameters()
response = self.csrf_client.post('/oauth-with-scope/', params)
self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
@@ -421,10 +422,11 @@ class OAuthTests(TestCase):
def test_post_form_with_write_resource_passing_auth(self):
"""Ensure POSTing with a write resource succeed"""
read_write_access_token = self.token
- read_write_access_token.resource.is_readonly = False
- read_write_access_token.resource.save()
+ read_write_access_token.scope.is_readonly = False
+ read_write_access_token.scope.save()
params = self._create_authorization_url_parameters()
- response = self.csrf_client.post('/oauth-with-scope/', params)
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth-with-scope/', params, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
@unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
diff --git a/rest_framework/tests/test_fields.py b/rest_framework/tests/test_fields.py
index 34fbab9c..5c96bce9 100644
--- a/rest_framework/tests/test_fields.py
+++ b/rest_framework/tests/test_fields.py
@@ -42,6 +42,31 @@ class TimeFieldModelSerializer(serializers.ModelSerializer):
model = TimeFieldModel
+SAMPLE_CHOICES = [
+ ('red', 'Red'),
+ ('green', 'Green'),
+ ('blue', 'Blue'),
+]
+
+
+class ChoiceFieldModel(models.Model):
+ choice = models.CharField(choices=SAMPLE_CHOICES, blank=True, max_length=255)
+
+
+class ChoiceFieldModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ChoiceFieldModel
+
+
+class ChoiceFieldModelWithNull(models.Model):
+ choice = models.CharField(choices=SAMPLE_CHOICES, blank=True, null=True, max_length=255)
+
+
+class ChoiceFieldModelWithNullSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ChoiceFieldModelWithNull
+
+
class BasicFieldTests(TestCase):
def test_auto_now_fields_read_only(self):
"""
@@ -667,34 +692,71 @@ class ChoiceFieldTests(TestCase):
"""
Tests for the ChoiceField options generator
"""
-
- SAMPLE_CHOICES = [
- ('red', 'Red'),
- ('green', 'Green'),
- ('blue', 'Blue'),
- ]
-
def test_choices_required(self):
"""
Make sure proper choices are rendered if field is required
"""
- f = serializers.ChoiceField(required=True, choices=self.SAMPLE_CHOICES)
- self.assertEqual(f.choices, self.SAMPLE_CHOICES)
+ f = serializers.ChoiceField(required=True, choices=SAMPLE_CHOICES)
+ self.assertEqual(f.choices, SAMPLE_CHOICES)
def test_choices_not_required(self):
"""
Make sure proper choices (plus blank) are rendered if the field isn't required
"""
- f = serializers.ChoiceField(required=False, choices=self.SAMPLE_CHOICES)
- self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + self.SAMPLE_CHOICES)
+ f = serializers.ChoiceField(required=False, choices=SAMPLE_CHOICES)
+ self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + SAMPLE_CHOICES)
+
+ def test_invalid_choice_model(self):
+ s = ChoiceFieldModelSerializer(data={'choice': 'wrong_value'})
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'choice': ['Select a valid choice. wrong_value is not one of the available choices.']})
+ self.assertEqual(s.data['choice'], '')
+
+ def test_empty_choice_model(self):
+ """
+ Test that the 'empty' value is correctly passed and used depending on
+ the 'null' property on the model field.
+ """
+ s = ChoiceFieldModelSerializer(data={'choice': ''})
+ self.assertTrue(s.is_valid())
+ self.assertEqual(s.data['choice'], '')
+
+ s = ChoiceFieldModelWithNullSerializer(data={'choice': ''})
+ self.assertTrue(s.is_valid())
+ self.assertEqual(s.data['choice'], None)
def test_from_native_empty(self):
"""
- Make sure from_native() returns None on empty param.
+ Make sure from_native() returns an empty string on empty param by default.
"""
- f = serializers.ChoiceField(choices=self.SAMPLE_CHOICES)
- result = f.from_native('')
- self.assertEqual(result, None)
+ f = serializers.ChoiceField(choices=SAMPLE_CHOICES)
+ self.assertEqual(f.from_native(''), '')
+ self.assertEqual(f.from_native(None), '')
+
+ def test_from_native_empty_override(self):
+ """
+ Make sure you can override from_native() behavior regarding empty values.
+ """
+ f = serializers.ChoiceField(choices=SAMPLE_CHOICES, empty=None)
+ self.assertEqual(f.from_native(''), None)
+ self.assertEqual(f.from_native(None), None)
+
+ def test_metadata_choices(self):
+ """
+ Make sure proper choices are included in the field's metadata.
+ """
+ choices = [{'value': v, 'display_name': n} for v, n in SAMPLE_CHOICES]
+ f = serializers.ChoiceField(choices=SAMPLE_CHOICES)
+ self.assertEqual(f.metadata()['choices'], choices)
+
+ def test_metadata_choices_not_required(self):
+ """
+ Make sure proper choices are included in the field's metadata.
+ """
+ choices = [{'value': v, 'display_name': n}
+ for v, n in models.fields.BLANK_CHOICE_DASH + SAMPLE_CHOICES]
+ f = serializers.ChoiceField(required=False, choices=SAMPLE_CHOICES)
+ self.assertEqual(f.metadata()['choices'], choices)
class EmailFieldTests(TestCase):
diff --git a/rest_framework/tests/test_filters.py b/rest_framework/tests/test_filters.py
index 379db29d..18188186 100644
--- a/rest_framework/tests/test_filters.py
+++ b/rest_framework/tests/test_filters.py
@@ -363,6 +363,11 @@ class OrdringFilterModel(models.Model):
text = models.CharField(max_length=100)
+class OrderingFilterRelatedModel(models.Model):
+ related_object = models.ForeignKey(OrdringFilterModel,
+ related_name="relateds")
+
+
class OrderingFilterTests(TestCase):
def setUp(self):
# Sequence of title/text is:
@@ -388,6 +393,7 @@ class OrderingFilterTests(TestCase):
model = OrdringFilterModel
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
+ ordering_fields = ('text',)
view = OrderingListView.as_view()
request = factory.get('?ordering=text')
@@ -406,6 +412,7 @@ class OrderingFilterTests(TestCase):
model = OrdringFilterModel
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
+ ordering_fields = ('text',)
view = OrderingListView.as_view()
request = factory.get('?ordering=-text')
@@ -424,6 +431,7 @@ class OrderingFilterTests(TestCase):
model = OrdringFilterModel
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
+ ordering_fields = ('text',)
view = OrderingListView.as_view()
request = factory.get('?ordering=foobar')
@@ -442,6 +450,7 @@ class OrderingFilterTests(TestCase):
model = OrdringFilterModel
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
+ oredering_fields = ('text',)
view = OrderingListView.as_view()
request = factory.get('')
@@ -460,6 +469,7 @@ class OrderingFilterTests(TestCase):
model = OrdringFilterModel
filter_backends = (filters.OrderingFilter,)
ordering = 'title'
+ ordering_fields = ('text',)
view = OrderingListView.as_view()
request = factory.get('')
@@ -472,3 +482,134 @@ class OrderingFilterTests(TestCase):
{'id': 1, 'title': 'zyx', 'text': 'abc'},
]
)
+
+ def test_ordering_by_aggregate_field(self):
+ # create some related models to aggregate order by
+ num_objs = [2, 5, 3]
+ for obj, num_relateds in zip(OrdringFilterModel.objects.all(),
+ num_objs):
+ for _ in range(num_relateds):
+ new_related = OrderingFilterRelatedModel(
+ related_object=obj
+ )
+ new_related.save()
+
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = 'title'
+ ordering_fields = '__all__'
+ queryset = OrdringFilterModel.objects.all().annotate(
+ models.Count("relateds"))
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=relateds__count')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ ]
+ )
+
+
+class SensitiveOrderingFilterModel(models.Model):
+ username = models.CharField(max_length=20)
+ password = models.CharField(max_length=100)
+
+
+# Three different styles of serializer.
+# All should allow ordering by username, but not by password.
+class SensitiveDataSerializer1(serializers.ModelSerializer):
+ username = serializers.CharField()
+
+ class Meta:
+ model = SensitiveOrderingFilterModel
+ fields = ('id', 'username')
+
+
+class SensitiveDataSerializer2(serializers.ModelSerializer):
+ username = serializers.CharField()
+ password = serializers.CharField(write_only=True)
+
+ class Meta:
+ model = SensitiveOrderingFilterModel
+ fields = ('id', 'username', 'password')
+
+
+class SensitiveDataSerializer3(serializers.ModelSerializer):
+ user = serializers.CharField(source='username')
+
+ class Meta:
+ model = SensitiveOrderingFilterModel
+ fields = ('id', 'user')
+
+
+class SensitiveOrderingFilterTests(TestCase):
+ def setUp(self):
+ for idx in range(3):
+ username = {0: 'userA', 1: 'userB', 2: 'userC'}[idx]
+ password = {0: 'passA', 1: 'passC', 2: 'passB'}[idx]
+ SensitiveOrderingFilterModel(username=username, password=password).save()
+
+ def test_order_by_serializer_fields(self):
+ for serializer_cls in [
+ SensitiveDataSerializer1,
+ SensitiveDataSerializer2,
+ SensitiveDataSerializer3
+ ]:
+ class OrderingListView(generics.ListAPIView):
+ queryset = SensitiveOrderingFilterModel.objects.all().order_by('username')
+ filter_backends = (filters.OrderingFilter,)
+ serializer_class = serializer_cls
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=-username')
+ response = view(request)
+
+ if serializer_cls == SensitiveDataSerializer3:
+ username_field = 'user'
+ else:
+ username_field = 'username'
+
+ # Note: Inverse username ordering correctly applied.
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, username_field: 'userC'},
+ {'id': 2, username_field: 'userB'},
+ {'id': 1, username_field: 'userA'},
+ ]
+ )
+
+ def test_cannot_order_by_non_serializer_fields(self):
+ for serializer_cls in [
+ SensitiveDataSerializer1,
+ SensitiveDataSerializer2,
+ SensitiveDataSerializer3
+ ]:
+ class OrderingListView(generics.ListAPIView):
+ queryset = SensitiveOrderingFilterModel.objects.all().order_by('username')
+ filter_backends = (filters.OrderingFilter,)
+ serializer_class = serializer_cls
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=password')
+ response = view(request)
+
+ if serializer_cls == SensitiveDataSerializer3:
+ username_field = 'user'
+ else:
+ username_field = 'username'
+
+ # Note: The passwords are not in order. Default ordering is used.
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 1, username_field: 'userA'}, # PassB
+ {'id': 2, username_field: 'userB'}, # PassC
+ {'id': 3, username_field: 'userC'}, # PassA
+ ]
+ ) \ No newline at end of file
diff --git a/rest_framework/tests/test_genericrelations.py b/rest_framework/tests/test_genericrelations.py
index c38bfb9f..2d341344 100644
--- a/rest_framework/tests/test_genericrelations.py
+++ b/rest_framework/tests/test_genericrelations.py
@@ -69,6 +69,35 @@ class TestGenericRelations(TestCase):
}
self.assertEqual(serializer.data, expected)
+ def test_generic_nested_relation(self):
+ """
+ Test saving a GenericRelation field via a nested serializer.
+ """
+
+ class TagSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Tag
+ exclude = ('content_type', 'object_id')
+
+ class BookmarkSerializer(serializers.ModelSerializer):
+ tags = TagSerializer()
+
+ class Meta:
+ model = Bookmark
+ exclude = ('id',)
+
+ data = {
+ 'url': 'https://docs.djangoproject.com/',
+ 'tags': [
+ {'tag': 'contenttypes'},
+ {'tag': 'genericrelations'},
+ ]
+ }
+ serializer = BookmarkSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEqual(serializer.object.tags.count(), 2)
+
def test_generic_fk(self):
"""
Test a relationship that spans a GenericForeignKey field.
diff --git a/rest_framework/tests/test_generics.py b/rest_framework/tests/test_generics.py
index 79cd99ac..996bd5b0 100644
--- a/rest_framework/tests/test_generics.py
+++ b/rest_framework/tests/test_generics.py
@@ -23,6 +23,10 @@ class InstanceView(generics.RetrieveUpdateDestroyAPIView):
"""
model = BasicModel
+ def get_queryset(self):
+ queryset = super(InstanceView, self).get_queryset()
+ return queryset.exclude(text='filtered out')
+
class SlugSerializer(serializers.ModelSerializer):
slug = serializers.Field() # read only
@@ -160,10 +164,10 @@ class TestInstanceView(TestCase):
"""
Create 3 BasicModel intances.
"""
- items = ['foo', 'bar', 'baz']
+ items = ['foo', 'bar', 'baz', 'filtered out']
for item in items:
BasicModel(text=item).save()
- self.objects = BasicModel.objects
+ self.objects = BasicModel.objects.exclude(text='filtered out')
self.data = [
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
@@ -352,6 +356,17 @@ class TestInstanceView(TestCase):
updated = self.objects.get(id=1)
self.assertEqual(updated.text, 'foobar')
+ def test_put_to_filtered_out_instance(self):
+ """
+ PUT requests to an URL of instance which is filtered out should not be
+ able to create new objects.
+ """
+ data = {'text': 'foo'}
+ filtered_out_pk = BasicModel.objects.filter(text='filtered out')[0].pk
+ request = factory.put('/{0}'.format(filtered_out_pk), data, format='json')
+ response = self.view(request, pk=filtered_out_pk).render()
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+
def test_put_as_create_on_id_based_url(self):
"""
PUT requests to RetrieveUpdateDestroyAPIView should create an object
@@ -508,6 +523,25 @@ class ExclusiveFilterBackend(object):
return queryset.filter(text='other')
+class TwoFieldModel(models.Model):
+ field_a = models.CharField(max_length=100)
+ field_b = models.CharField(max_length=100)
+
+
+class DynamicSerializerView(generics.ListCreateAPIView):
+ model = TwoFieldModel
+ renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer)
+
+ def get_serializer_class(self):
+ if self.request.method == 'POST':
+ class DynamicSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = TwoFieldModel
+ fields = ('field_b',)
+ return DynamicSerializer
+ return super(DynamicSerializerView, self).get_serializer_class()
+
+
class TestFilterBackendAppliedToViews(TestCase):
def setUp(self):
@@ -564,28 +598,6 @@ class TestFilterBackendAppliedToViews(TestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, {'id': 1, 'text': 'foo'})
-
-class TwoFieldModel(models.Model):
- field_a = models.CharField(max_length=100)
- field_b = models.CharField(max_length=100)
-
-
-class DynamicSerializerView(generics.ListCreateAPIView):
- model = TwoFieldModel
- renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer)
-
- def get_serializer_class(self):
- if self.request.method == 'POST':
- class DynamicSerializer(serializers.ModelSerializer):
- class Meta:
- model = TwoFieldModel
- fields = ('field_b',)
- return DynamicSerializer
- return super(DynamicSerializerView, self).get_serializer_class()
-
-
-class TestFilterBackendAppliedToViews(TestCase):
-
def test_dynamic_serializer_form_in_browsable_api(self):
"""
GET requests to ListCreateAPIView should return filtered list.
diff --git a/rest_framework/tests/test_hyperlinkedserializers.py b/rest_framework/tests/test_hyperlinkedserializers.py
index 61e613d7..83d46043 100644
--- a/rest_framework/tests/test_hyperlinkedserializers.py
+++ b/rest_framework/tests/test_hyperlinkedserializers.py
@@ -3,6 +3,7 @@ import json
from django.test import TestCase
from rest_framework import generics, status, serializers
from rest_framework.compat import patterns, url
+from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory
from rest_framework.tests.models import (
Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment,
@@ -331,3 +332,48 @@ class TestOverriddenURLField(TestCase):
serializer.data,
{'title': 'New blog post', 'url': 'foo bar'}
)
+
+
+class TestURLFieldNameBySettings(TestCase):
+ urls = 'rest_framework.tests.test_hyperlinkedserializers'
+
+ def setUp(self):
+ self.saved_url_field_name = api_settings.URL_FIELD_NAME
+ api_settings.URL_FIELD_NAME = 'global_url_field'
+
+ class Serializer(serializers.HyperlinkedModelSerializer):
+
+ class Meta:
+ model = BlogPost
+ fields = ('title', api_settings.URL_FIELD_NAME)
+
+ self.Serializer = Serializer
+ self.obj = BlogPost.objects.create(title="New blog post")
+
+ def tearDown(self):
+ api_settings.URL_FIELD_NAME = self.saved_url_field_name
+
+ def test_overridden_url_field_name(self):
+ request = factory.get('/posts/')
+ serializer = self.Serializer(self.obj, context={'request': request})
+ self.assertIn(api_settings.URL_FIELD_NAME, serializer.data)
+
+
+class TestURLFieldNameByOptions(TestCase):
+ urls = 'rest_framework.tests.test_hyperlinkedserializers'
+
+ def setUp(self):
+ class Serializer(serializers.HyperlinkedModelSerializer):
+
+ class Meta:
+ model = BlogPost
+ fields = ('title', 'serializer_url_field')
+ url_field_name = 'serializer_url_field'
+
+ self.Serializer = Serializer
+ self.obj = BlogPost.objects.create(title="New blog post")
+
+ def test_overridden_url_field_name(self):
+ request = factory.get('/posts/')
+ serializer = self.Serializer(self.obj, context={'request': request})
+ self.assertIn(self.Serializer.Meta.url_field_name, serializer.data)
diff --git a/rest_framework/tests/test_relations.py b/rest_framework/tests/test_relations.py
index d19219c9..f52e0e1e 100644
--- a/rest_framework/tests/test_relations.py
+++ b/rest_framework/tests/test_relations.py
@@ -98,3 +98,23 @@ class RelatedFieldSourceTests(TestCase):
obj = ClassWithQuerysetMethod()
value = field.field_to_native(obj, 'field_name')
self.assertEqual(value, ['BlogPost object'])
+
+ # Regression for #1129
+ def test_exception_for_incorect_fk(self):
+ """
+ Check that the exception message are correct if the source field
+ doesn't exist.
+ """
+ from rest_framework.tests.models import ManyToManySource
+ class Meta:
+ model = ManyToManySource
+ attrs = {
+ 'name': serializers.SlugRelatedField(
+ slug_field='name', source='banzai'),
+ 'Meta': Meta,
+ }
+
+ TestSerializer = type(str('TestSerializer'),
+ (serializers.ModelSerializer,), attrs)
+ with self.assertRaises(AttributeError):
+ TestSerializer(data={'name': 'foo'})
diff --git a/rest_framework/tests/test_renderers.py b/rest_framework/tests/test_renderers.py
index 76299a89..fb33df2c 100644
--- a/rest_framework/tests/test_renderers.py
+++ b/rest_framework/tests/test_renderers.py
@@ -3,6 +3,7 @@ from __future__ import unicode_literals
from decimal import Decimal
from django.core.cache import cache
+from django.db import models
from django.test import TestCase
from django.utils import unittest
from django.utils.translation import ugettext_lazy as _
@@ -15,7 +16,9 @@ from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \
from rest_framework.parsers import YAMLParser, XMLParser
from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory
+from collections import MutableMapping
import datetime
+import json
import pickle
import re
@@ -32,6 +35,10 @@ expected_results = [
]
+class DummyTestModel(models.Model):
+ name = models.CharField(max_length=42, default='')
+
+
class BasicRendererTests(TestCase):
def test_expected_results(self):
for value, renderer_cls, expected in expected_results:
@@ -64,11 +71,23 @@ class MockView(APIView):
class MockGETView(APIView):
-
def get(self, request, **kwargs):
return Response({'foo': ['bar', 'baz']})
+
+class MockPOSTView(APIView):
+ def post(self, request, **kwargs):
+ return Response({'foo': request.DATA})
+
+
+class EmptyGETView(APIView):
+ renderer_classes = (JSONRenderer,)
+
+ def get(self, request, **kwargs):
+ return Response(status=status.HTTP_204_NO_CONTENT)
+
+
class HTMLView(APIView):
renderer_classes = (BrowsableAPIRenderer, )
@@ -88,8 +107,10 @@ urlpatterns = patterns('',
url(r'^cache$', MockGETView.as_view()),
url(r'^jsonp/jsonrenderer$', MockGETView.as_view(renderer_classes=[JSONRenderer, JSONPRenderer])),
url(r'^jsonp/nojsonrenderer$', MockGETView.as_view(renderer_classes=[JSONPRenderer])),
+ url(r'^parseerror$', MockPOSTView.as_view(renderer_classes=[JSONRenderer, BrowsableAPIRenderer])),
url(r'^html$', HTMLView.as_view()),
url(r'^html1$', HTMLView1.as_view()),
+ url(r'^empty$', EmptyGETView.as_view()),
url(r'^api', include('rest_framework.urls', namespace='rest_framework'))
)
@@ -219,6 +240,22 @@ class RendererEndToEndTests(TestCase):
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
+ def test_parse_error_renderers_browsable_api(self):
+ """Invalid data should still render the browsable API correctly."""
+ resp = self.client.post('/parseerror', data='foobar', content_type='application/json', HTTP_ACCEPT='text/html')
+ self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
+ self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST)
+
+ def test_204_no_content_responses_have_no_content_type_set(self):
+ """
+ Regression test for #1196
+
+ https://github.com/tomchristie/django-rest-framework/issues/1196
+ """
+ resp = self.client.get('/empty')
+ self.assertEqual(resp.get('Content-Type', None), None)
+ self.assertEqual(resp.status_code, status.HTTP_204_NO_CONTENT)
+
_flat_repr = '{"foo": ["bar", "baz"]}'
_indented_repr = '{\n "foo": [\n "bar",\n "baz"\n ]\n}'
@@ -244,6 +281,58 @@ class JSONRendererTests(TestCase):
ret = JSONRenderer().render(_('test'))
self.assertEqual(ret, b'"test"')
+ def test_render_queryset_values(self):
+ o = DummyTestModel.objects.create(name='dummy')
+ qs = DummyTestModel.objects.values('id', 'name')
+ ret = JSONRenderer().render(qs)
+ data = json.loads(ret.decode('utf-8'))
+ self.assertEquals(data, [{'id': o.id, 'name': o.name}])
+
+ def test_render_queryset_values_list(self):
+ o = DummyTestModel.objects.create(name='dummy')
+ qs = DummyTestModel.objects.values_list('id', 'name')
+ ret = JSONRenderer().render(qs)
+ data = json.loads(ret.decode('utf-8'))
+ self.assertEquals(data, [[o.id, o.name]])
+
+ def test_render_dict_abc_obj(self):
+ class Dict(MutableMapping):
+ def __init__(self):
+ self._dict = dict()
+ def __getitem__(self, key):
+ return self._dict.__getitem__(key)
+ def __setitem__(self, key, value):
+ return self._dict.__setitem__(key, value)
+ def __delitem__(self, key):
+ return self._dict.__delitem__(key)
+ def __iter__(self):
+ return self._dict.__iter__()
+ def __len__(self):
+ return self._dict.__len__()
+ def keys(self):
+ return self._dict.keys()
+
+ x = Dict()
+ x['key'] = 'string value'
+ x[2] = 3
+ ret = JSONRenderer().render(x)
+ data = json.loads(ret.decode('utf-8'))
+ self.assertEquals(data, {'key': 'string value', '2': 3})
+
+ def test_render_obj_with_getitem(self):
+ class DictLike(object):
+ def __init__(self):
+ self._dict = {}
+ def set(self, value):
+ self._dict = dict(value)
+ def __getitem__(self, key):
+ return self._dict[key]
+
+ x = DictLike()
+ x.set({'a': 1, 'b': 'string'})
+ with self.assertRaises(TypeError):
+ JSONRenderer().render(x)
+
def test_without_content_type_args(self):
"""
Test basic JSON rendering.
diff --git a/rest_framework/tests/test_request.py b/rest_framework/tests/test_request.py
index 3929bcc2..c0b50f33 100644
--- a/rest_framework/tests/test_request.py
+++ b/rest_framework/tests/test_request.py
@@ -5,6 +5,7 @@ from __future__ import unicode_literals
from django.contrib.auth.models import User
from django.contrib.auth import authenticate, login, logout
from django.contrib.sessions.middleware import SessionMiddleware
+from django.core.handlers.wsgi import WSGIRequest
from django.test import TestCase
from rest_framework import status
from rest_framework.authentication import SessionAuthentication
@@ -15,12 +16,13 @@ from rest_framework.parsers import (
MultiPartParser,
JSONParser
)
-from rest_framework.request import Request
+from rest_framework.request import Request, Empty
from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory, APIClient
from rest_framework.views import APIView
from rest_framework.compat import six
+from io import BytesIO
import json
@@ -149,6 +151,34 @@ class TestContentParsing(TestCase):
request.parsers = (JSONParser(), )
self.assertEqual(request.DATA, json_data)
+ def test_form_POST_unicode(self):
+ """
+ JSON POST via default web interface with unicode data
+ """
+ # Note: environ and other variables here have simplified content compared to real Request
+ CONTENT = b'_content_type=application%2Fjson&_content=%7B%22request%22%3A+4%2C+%22firm%22%3A+1%2C+%22text%22%3A+%22%D0%9F%D1%80%D0%B8%D0%B2%D0%B5%D1%82%21%22%7D'
+ environ = {
+ 'REQUEST_METHOD': 'POST',
+ 'CONTENT_TYPE': 'application/x-www-form-urlencoded',
+ 'CONTENT_LENGTH': len(CONTENT),
+ 'wsgi.input': BytesIO(CONTENT),
+ }
+ wsgi_request = WSGIRequest(environ=environ)
+ wsgi_request._load_post_and_files()
+ parsers = (JSONParser(), FormParser(), MultiPartParser())
+ parser_context = {
+ 'encoding': 'utf-8',
+ 'kwargs': {},
+ 'args': (),
+ }
+ request = Request(wsgi_request, parsers=parsers, parser_context=parser_context)
+ method = request.method
+ self.assertEqual(method, 'POST')
+ self.assertEqual(request._content_type, 'application/json')
+ self.assertEqual(request._stream.getvalue(), b'{"request": 4, "firm": 1, "text": "\xd0\x9f\xd1\x80\xd0\xb8\xd0\xb2\xd0\xb5\xd1\x82!"}')
+ self.assertEqual(request._data, Empty)
+ self.assertEqual(request._files, Empty)
+
# def test_accessing_post_after_data_form(self):
# """
# Ensures request.POST can be accessed after request.DATA in
diff --git a/rest_framework/tests/test_serializer.py b/rest_framework/tests/test_serializer.py
index 1f85a474..75d6e785 100644
--- a/rest_framework/tests/test_serializer.py
+++ b/rest_framework/tests/test_serializer.py
@@ -105,6 +105,17 @@ class ModelSerializerWithNestedSerializer(serializers.ModelSerializer):
model = Person
+class NestedSerializerWithRenamedField(serializers.Serializer):
+ renamed_info = serializers.Field(source='info')
+
+
+class ModelSerializerWithNestedSerializerWithRenamedField(serializers.ModelSerializer):
+ nested = NestedSerializerWithRenamedField(source='*')
+
+ class Meta:
+ model = Person
+
+
class PersonSerializerInvalidReadOnly(serializers.ModelSerializer):
"""
Testing for #652.
@@ -456,6 +467,20 @@ class ValidationTests(TestCase):
)
self.assertEqual(serializer.is_valid(), True)
+ def test_writable_star_source_with_inner_source_fields(self):
+ """
+ Tests that a serializer with source="*" correctly expands the
+ it's fields into the outer serializer even if they have their
+ own 'source' parameters.
+ """
+
+ serializer = ModelSerializerWithNestedSerializerWithRenamedField(data={
+ 'name': 'marko',
+ 'nested': {'renamed_info': 'hi'}},
+ )
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.errors, {})
+
class CustomValidationTests(TestCase):
class CommentSerializerWithFieldValidator(CommentSerializer):
@@ -558,6 +583,29 @@ class ModelValidationTests(TestCase):
self.assertFalse(second_serializer.is_valid())
self.assertEqual(second_serializer.errors, {'title': ['Album with this Title already exists.']})
+ def test_foreign_key_is_null_with_partial(self):
+ """
+ Test ModelSerializer validation with partial=True
+
+ Specifically test that a null foreign key does not pass validation
+ """
+ album = Album(title='test')
+ album.save()
+
+ class PhotoSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Photo
+
+ photo_serializer = PhotoSerializer(data={'description': 'test', 'album': album.pk})
+ self.assertTrue(photo_serializer.is_valid())
+ photo = photo_serializer.save()
+
+ # Updating only the album (foreign key)
+ photo_serializer = PhotoSerializer(instance=photo, data={'album': ''}, partial=True)
+ self.assertFalse(photo_serializer.is_valid())
+ self.assertTrue('album' in photo_serializer.errors)
+ self.assertEqual(photo_serializer.errors['album'], photo_serializer.error_messages['required'])
+
def test_foreign_key_with_partial(self):
"""
Test ModelSerializer validation with partial=True
@@ -1720,3 +1768,75 @@ class TestSerializerTransformMethods(TestCase):
'b_renamed': None,
}
)
+
+
+class DefaultTrueBooleanModel(models.Model):
+ cat = models.BooleanField(default=True)
+ dog = models.BooleanField(default=False)
+
+
+class SerializerDefaultTrueBoolean(TestCase):
+
+ def setUp(self):
+ super(SerializerDefaultTrueBoolean, self).setUp()
+
+ class DefaultTrueBooleanSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = DefaultTrueBooleanModel
+ fields = ('cat', 'dog')
+
+ self.default_true_boolean_serializer = DefaultTrueBooleanSerializer
+
+ def test_enabled_as_false(self):
+ serializer = self.default_true_boolean_serializer(data={'cat': False,
+ 'dog': False})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.data['cat'], False)
+ self.assertEqual(serializer.data['dog'], False)
+
+ def test_enabled_as_true(self):
+ serializer = self.default_true_boolean_serializer(data={'cat': True,
+ 'dog': True})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.data['cat'], True)
+ self.assertEqual(serializer.data['dog'], True)
+
+ def test_enabled_partial(self):
+ serializer = self.default_true_boolean_serializer(data={'cat': False},
+ partial=True)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.data['cat'], False)
+ self.assertEqual(serializer.data['dog'], False)
+
+
+class BoolenFieldTypeTest(TestCase):
+ '''
+ Ensure the various Boolean based model fields are rendered as the proper
+ field type
+
+ '''
+
+ def setUp(self):
+ '''
+ Setup an ActionItemSerializer for BooleanTesting
+ '''
+ data = {
+ 'title': 'b' * 201,
+ }
+ self.serializer = ActionItemSerializer(data=data)
+
+ def test_booleanfield_type(self):
+ '''
+ Test that BooleanField is infered from models.BooleanField
+ '''
+ bfield = self.serializer.get_fields()['done']
+ self.assertEqual(type(bfield), fields.BooleanField)
+
+ def test_nullbooleanfield_type(self):
+ '''
+ Test that BooleanField is infered from models.NullBooleanField
+
+ https://groups.google.com/forum/#!topic/django-rest-framework/D9mXEftpuQ8
+ '''
+ bfield = self.serializer.get_fields()['started']
+ self.assertEqual(type(bfield), fields.BooleanField)
diff --git a/rest_framework/tests/test_serializer_import.py b/rest_framework/tests/test_serializer_import.py
new file mode 100644
index 00000000..9f30a7ff
--- /dev/null
+++ b/rest_framework/tests/test_serializer_import.py
@@ -0,0 +1,19 @@
+from django.test import TestCase
+
+from rest_framework import serializers
+from rest_framework.tests.accounts.serializers import AccountSerializer
+
+
+class ImportingModelSerializerTests(TestCase):
+ """
+ In some situations like, GH #1225, it is possible, especially in
+ testing, to import a serializer who's related models have not yet
+ been resolved by Django. `AccountSerializer` is an example of such
+ a serializer (imported at the top of this file).
+ """
+ def test_import_model_serializer(self):
+ """
+ The serializer at the top of this file should have been
+ imported successfully, and we should be able to instantiate it.
+ """
+ self.assertIsInstance(AccountSerializer(), serializers.ModelSerializer)
diff --git a/rest_framework/tests/test_serializer_nested.py b/rest_framework/tests/test_serializer_nested.py
index 029f8bff..6d69ffbd 100644
--- a/rest_framework/tests/test_serializer_nested.py
+++ b/rest_framework/tests/test_serializer_nested.py
@@ -6,6 +6,7 @@ Doesn't cover model serializers.
from __future__ import unicode_literals
from django.test import TestCase
from rest_framework import serializers
+from . import models
class WritableNestedSerializerBasicTests(TestCase):
@@ -311,3 +312,36 @@ class ForeignKeyNestedSerializerUpdateTests(TestCase):
serializer = self.AlbumSerializer(instance=original, data=data)
self.assertEqual(serializer.is_valid(), True)
self.assertEqual(serializer.object, expected)
+
+
+class NestedModelSerializerUpdateTests(TestCase):
+ def test_second_nested_level(self):
+ john = models.Person.objects.create(name="john")
+
+ post = john.blogpost_set.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I hate this blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ class BlogPostCommentSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = models.BlogPostComment
+
+ class BlogPostSerializer(serializers.ModelSerializer):
+ comments = BlogPostCommentSerializer(many=True, source='blogpostcomment_set')
+ class Meta:
+ model = models.BlogPost
+ fields = ('id', 'title', 'comments')
+
+ class PersonSerializer(serializers.ModelSerializer):
+ posts = BlogPostSerializer(many=True, source='blogpost_set')
+ class Meta:
+ model = models.Person
+ fields = ('id', 'name', 'age', 'posts')
+
+ serialize = PersonSerializer(instance=john)
+ deserialize = PersonSerializer(data=serialize.data, instance=john)
+ self.assertTrue(deserialize.is_valid())
+
+ result = deserialize.object
+ result.save()
+ self.assertEqual(result.id, john.id)
diff --git a/rest_framework/tests/test_serializers.py b/rest_framework/tests/test_serializers.py
new file mode 100644
index 00000000..082a400c
--- /dev/null
+++ b/rest_framework/tests/test_serializers.py
@@ -0,0 +1,28 @@
+from django.db import models
+from django.test import TestCase
+
+from rest_framework.serializers import _resolve_model
+from rest_framework.tests.models import BasicModel
+
+
+class ResolveModelTests(TestCase):
+ """
+ `_resolve_model` should return a Django model class given the
+ provided argument is a Django model class itself, or a properly
+ formatted string representation of one.
+ """
+ def test_resolve_django_model(self):
+ resolved_model = _resolve_model(BasicModel)
+ self.assertEqual(resolved_model, BasicModel)
+
+ def test_resolve_string_representation(self):
+ resolved_model = _resolve_model('tests.BasicModel')
+ self.assertEqual(resolved_model, BasicModel)
+
+ def test_resolve_non_django_model(self):
+ with self.assertRaises(ValueError):
+ _resolve_model(TestCase)
+
+ def test_resolve_improper_string_representation(self):
+ with self.assertRaises(ValueError):
+ _resolve_model('BasicModel')
diff --git a/rest_framework/tests/test_status.py b/rest_framework/tests/test_status.py
new file mode 100644
index 00000000..7b1bdae3
--- /dev/null
+++ b/rest_framework/tests/test_status.py
@@ -0,0 +1,33 @@
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework.status import (
+ is_informational, is_success, is_redirect, is_client_error, is_server_error
+)
+
+
+class TestStatus(TestCase):
+ def test_status_categories(self):
+ self.assertFalse(is_informational(99))
+ self.assertTrue(is_informational(100))
+ self.assertTrue(is_informational(199))
+ self.assertFalse(is_informational(200))
+
+ self.assertFalse(is_success(199))
+ self.assertTrue(is_success(200))
+ self.assertTrue(is_success(299))
+ self.assertFalse(is_success(300))
+
+ self.assertFalse(is_redirect(299))
+ self.assertTrue(is_redirect(300))
+ self.assertTrue(is_redirect(399))
+ self.assertFalse(is_redirect(400))
+
+ self.assertFalse(is_client_error(399))
+ self.assertTrue(is_client_error(400))
+ self.assertTrue(is_client_error(499))
+ self.assertFalse(is_client_error(500))
+
+ self.assertFalse(is_server_error(499))
+ self.assertTrue(is_server_error(500))
+ self.assertTrue(is_server_error(599))
+ self.assertFalse(is_server_error(600)) \ No newline at end of file
diff --git a/rest_framework/tests/test_templatetags.py b/rest_framework/tests/test_templatetags.py
new file mode 100644
index 00000000..609a9e08
--- /dev/null
+++ b/rest_framework/tests/test_templatetags.py
@@ -0,0 +1,19 @@
+# encoding: utf-8
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework.test import APIRequestFactory
+from rest_framework.templatetags.rest_framework import add_query_param
+
+factory = APIRequestFactory()
+
+
+class TemplateTagTests(TestCase):
+
+ def test_add_query_param_with_non_latin_charactor(self):
+ # Ensure we don't double-escape non-latin characters
+ # that are present in the querystring.
+ # See #1314.
+ request = factory.get("/", {'q': '查询'})
+ json_url = add_query_param(request, "format", "json")
+ self.assertIn("q=%E6%9F%A5%E8%AF%A2", json_url)
+ self.assertIn("format=json", json_url)
diff --git a/rest_framework/tests/test_validation.py b/rest_framework/tests/test_validation.py
index ebfdff9c..124c874d 100644
--- a/rest_framework/tests/test_validation.py
+++ b/rest_framework/tests/test_validation.py
@@ -47,12 +47,18 @@ class ShouldValidateModel(models.Model):
class ShouldValidateModelSerializer(serializers.ModelSerializer):
renamed = serializers.CharField(source='should_validate_field', required=False)
+ def validate_renamed(self, attrs, source):
+ value = attrs[source]
+ if len(value) < 3:
+ raise serializers.ValidationError('Minimum 3 characters.')
+ return attrs
+
class Meta:
model = ShouldValidateModel
fields = ('renamed',)
-class TestPreSaveValidationExclusions(TestCase):
+class TestPreSaveValidationExclusionsSerializer(TestCase):
def test_renamed_fields_are_model_validated(self):
"""
Ensure fields with 'source' applied do get still get model validation.
@@ -61,6 +67,19 @@ class TestPreSaveValidationExclusions(TestCase):
# does not have `blank=True`, so this serializer should not validate.
serializer = ShouldValidateModelSerializer(data={'renamed': ''})
self.assertEqual(serializer.is_valid(), False)
+ self.assertIn('renamed', serializer.errors)
+ self.assertNotIn('should_validate_field', serializer.errors)
+
+
+class TestCustomValidationMethods(TestCase):
+ def test_custom_validation_method_is_executed(self):
+ serializer = ShouldValidateModelSerializer(data={'renamed': 'fo'})
+ self.assertFalse(serializer.is_valid())
+ self.assertIn('renamed', serializer.errors)
+
+ def test_custom_validation_method_passing(self):
+ serializer = ShouldValidateModelSerializer(data={'renamed': 'foo'})
+ self.assertTrue(serializer.is_valid())
class ValidationSerializer(serializers.Serializer):
diff --git a/rest_framework/tests/test_write_only_fields.py b/rest_framework/tests/test_write_only_fields.py
new file mode 100644
index 00000000..aabb18d6
--- /dev/null
+++ b/rest_framework/tests/test_write_only_fields.py
@@ -0,0 +1,42 @@
+from django.db import models
+from django.test import TestCase
+from rest_framework import serializers
+
+
+class ExampleModel(models.Model):
+ email = models.EmailField(max_length=100)
+ password = models.CharField(max_length=100)
+
+
+class WriteOnlyFieldTests(TestCase):
+ def test_write_only_fields(self):
+ class ExampleSerializer(serializers.Serializer):
+ email = serializers.EmailField()
+ password = serializers.CharField(write_only=True)
+
+ data = {
+ 'email': 'foo@example.com',
+ 'password': '123'
+ }
+ serializer = ExampleSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEquals(serializer.object, data)
+ self.assertEquals(serializer.data, {'email': 'foo@example.com'})
+
+ def test_write_only_fields_meta(self):
+ class ExampleSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ExampleModel
+ fields = ('email', 'password')
+ write_only_fields = ('password',)
+
+ data = {
+ 'email': 'foo@example.com',
+ 'password': '123'
+ }
+ serializer = ExampleSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertTrue(isinstance(serializer.object, ExampleModel))
+ self.assertEquals(serializer.object.email, data['email'])
+ self.assertEquals(serializer.object.password, data['password'])
+ self.assertEquals(serializer.data, {'email': 'foo@example.com'})
diff --git a/rest_framework/tests/users/__init__.py b/rest_framework/tests/users/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/rest_framework/tests/users/__init__.py
diff --git a/rest_framework/tests/users/models.py b/rest_framework/tests/users/models.py
new file mode 100644
index 00000000..128bac90
--- /dev/null
+++ b/rest_framework/tests/users/models.py
@@ -0,0 +1,6 @@
+from django.db import models
+
+
+class User(models.Model):
+ account = models.ForeignKey('accounts.Account', blank=True, null=True, related_name='users')
+ active_record = models.ForeignKey('records.Record', blank=True, null=True)
diff --git a/rest_framework/tests/users/serializers.py b/rest_framework/tests/users/serializers.py
new file mode 100644
index 00000000..da496554
--- /dev/null
+++ b/rest_framework/tests/users/serializers.py
@@ -0,0 +1,8 @@
+from rest_framework import serializers
+
+from rest_framework.tests.users.models import User
+
+
+class UserSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = User
diff --git a/rest_framework/urlpatterns.py b/rest_framework/urlpatterns.py
index d9143bb4..0ff137b0 100644
--- a/rest_framework/urlpatterns.py
+++ b/rest_framework/urlpatterns.py
@@ -57,6 +57,6 @@ def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None):
allowed_pattern = '(%s)' % '|'.join(allowed)
suffix_pattern = r'\.(?P<%s>%s)$' % (suffix_kwarg, allowed_pattern)
else:
- suffix_pattern = r'\.(?P<%s>[a-z]+)$' % suffix_kwarg
+ suffix_pattern = r'\.(?P<%s>[a-z0-9]+)$' % suffix_kwarg
return apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required)
diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py
index 35ad206b..e5fa4194 100644
--- a/rest_framework/utils/encoders.py
+++ b/rest_framework/utils/encoders.py
@@ -2,6 +2,7 @@
Helper classes for parsers.
"""
from __future__ import unicode_literals
+from django.db.models.query import QuerySet
from django.utils.datastructures import SortedDict
from django.utils.functional import Promise
from rest_framework.compat import timezone, force_text
@@ -42,8 +43,15 @@ class JSONEncoder(json.JSONEncoder):
return str(o.total_seconds())
elif isinstance(o, decimal.Decimal):
return str(o)
+ elif isinstance(o, QuerySet):
+ return list(o)
elif hasattr(o, 'tolist'):
return o.tolist()
+ elif hasattr(o, '__getitem__'):
+ try:
+ return dict(o)
+ except:
+ pass
elif hasattr(o, '__iter__'):
return [i for i in o]
return super(JSONEncoder, self).default(o)
diff --git a/rest_framework/views.py b/rest_framework/views.py
index 853e6461..e863af6d 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -154,8 +154,8 @@ class APIView(View):
Returns a dict that is passed through to Parser.parse(),
as the `parser_context` keyword argument.
"""
- # Note: Additionally `request` will also be added to the context
- # by the Request object.
+ # Note: Additionally `request` and `encoding` will also be added
+ # to the context by the Request object.
return {
'view': self,
'args': getattr(self, 'args', ()),
diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py
index d91323f2..7eb29f99 100644
--- a/rest_framework/viewsets.py
+++ b/rest_framework/viewsets.py
@@ -9,7 +9,7 @@ Actions are only bound to methods at the point of instantiating the views.
user_detail = UserViewSet.as_view({'get': 'retrieve'})
Typically, rather than instantiate views from viewsets directly, you'll
-regsiter the viewset with a router and let the URL conf be determined
+register the viewset with a router and let the URL conf be determined
automatically.
router = DefaultRouter()