aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/__init__.py2
-rw-r--r--rest_framework/authentication.py83
-rw-r--r--rest_framework/authtoken/views.py3
-rw-r--r--rest_framework/compat.py6
-rw-r--r--rest_framework/decorators.py9
-rw-r--r--rest_framework/exceptions.py16
-rw-r--r--rest_framework/fields.py5
-rw-r--r--rest_framework/generics.py31
-rw-r--r--rest_framework/mixins.py19
-rw-r--r--rest_framework/pagination.py26
-rw-r--r--rest_framework/parsers.py2
-rw-r--r--rest_framework/relations.py112
-rw-r--r--rest_framework/renderers.py2
-rw-r--r--rest_framework/request.py29
-rw-r--r--rest_framework/serializers.py74
-rw-r--r--rest_framework/settings.py5
-rw-r--r--rest_framework/templates/rest_framework/base.html2
-rw-r--r--rest_framework/templates/rest_framework/login.html8
-rw-r--r--rest_framework/templatetags/rest_framework.py2
-rw-r--r--rest_framework/tests/authentication.py48
-rw-r--r--rest_framework/tests/decorators.py39
-rw-r--r--rest_framework/tests/extras/__init__.py0
-rw-r--r--rest_framework/tests/extras/bad_import.py1
-rw-r--r--rest_framework/tests/fields.py49
-rw-r--r--rest_framework/tests/files.py14
-rw-r--r--rest_framework/tests/genericrelations.py96
-rw-r--r--rest_framework/tests/generics.py19
-rw-r--r--rest_framework/tests/hyperlinkedserializers.py2
-rw-r--r--rest_framework/tests/models.py61
-rw-r--r--rest_framework/tests/pagination.py47
-rw-r--r--rest_framework/tests/relations.py47
-rw-r--r--rest_framework/tests/relations_hyperlink.py69
-rw-r--r--rest_framework/tests/relations_nested.py53
-rw-r--r--rest_framework/tests/relations_pk.py65
-rw-r--r--rest_framework/tests/relations_slug.py257
-rw-r--r--rest_framework/tests/request.py2
-rw-r--r--rest_framework/tests/serializer.py145
-rw-r--r--rest_framework/tests/settings.py21
-rw-r--r--rest_framework/tests/urlpatterns.py78
-rw-r--r--rest_framework/tests/utils.py27
-rw-r--r--rest_framework/tests/views.py4
-rw-r--r--rest_framework/urlpatterns.py45
-rw-r--r--rest_framework/utils/encoders.py6
-rw-r--r--rest_framework/views.py21
44 files changed, 1351 insertions, 301 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py
index 151ba832..f9882c57 100644
--- a/rest_framework/__init__.py
+++ b/rest_framework/__init__.py
@@ -1,3 +1,3 @@
-__version__ = '2.1.14'
+__version__ = '2.1.17'
VERSION = __version__ # synonym
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py
index c50bf944..76ee4bd6 100644
--- a/rest_framework/authentication.py
+++ b/rest_framework/authentication.py
@@ -23,34 +23,47 @@ class BaseAuthentication(object):
"""
raise NotImplementedError(".authenticate() must be overridden.")
+ def authenticate_header(self, request):
+ """
+ Return a string to be used as the value of the `WWW-Authenticate`
+ header in a `401 Unauthenticated` response, or `None` if the
+ authentication scheme should return `403 Permission Denied` responses.
+ """
+ pass
+
class BasicAuthentication(BaseAuthentication):
"""
HTTP Basic authentication against username/password.
"""
+ www_authenticate_realm = 'api'
def authenticate(self, request):
"""
Returns a `User` if a correct username and password have been supplied
using HTTP Basic authentication. Otherwise returns `None`.
"""
- if 'HTTP_AUTHORIZATION' in request.META:
- auth = request.META['HTTP_AUTHORIZATION'].split()
- if len(auth) == 2 and auth[0].lower() == "basic":
- try:
- encoding = api_settings.HTTP_HEADER_ENCODING
- b = base64.b64decode(auth[1].encode(encoding))
- auth_parts = b.decode(encoding).partition(':')
- except TypeError:
- return None
-
- try:
- userid = smart_text(auth_parts[0])
- password = smart_text(auth_parts[2])
- except DjangoUnicodeDecodeError:
- return None
-
- return self.authenticate_credentials(userid, password)
+ auth = request.META.get('HTTP_AUTHORIZATION', '').split()
+
+ if not auth or auth[0].lower() != "basic":
+ return None
+
+ if len(auth) != 2:
+ raise exceptions.AuthenticationFailed('Invalid basic header')
+
+ encoding = api_settings.HTTP_HEADER_ENCODING
+ try:
+ auth_parts = base64.b64decode(auth[1].encode(encoding)).partition(':')
+ except TypeError:
+ raise exceptions.AuthenticationFailed('Invalid basic header')
+
+ try:
+ userid = smart_text(auth_parts[0])
+ password = smart_text(auth_parts[2])
+ except DjangoUnicodeDecodeError:
+ raise exceptions.AuthenticationFailed('Invalid basic header')
+
+ return self.authenticate_credentials(userid, password)
def authenticate_credentials(self, userid, password):
"""
@@ -59,6 +72,10 @@ class BasicAuthentication(BaseAuthentication):
user = authenticate(username=userid, password=password)
if user is not None and user.is_active:
return (user, None)
+ raise exceptions.AuthenticationFailed('Invalid username/password')
+
+ def authenticate_header(self, request):
+ return 'Basic realm="%s"' % self.www_authenticate_realm
class SessionAuthentication(BaseAuthentication):
@@ -78,7 +95,7 @@ class SessionAuthentication(BaseAuthentication):
# Unauthenticated, CSRF validation not required
if not user or not user.is_active:
- return
+ return None
# Enforce CSRF validation for session based authentication.
class CSRFCheck(CsrfViewMiddleware):
@@ -89,7 +106,7 @@ class SessionAuthentication(BaseAuthentication):
reason = CSRFCheck().process_view(http_request, None, (), {})
if reason:
# CSRF failed, bail with explicit error message
- raise exceptions.PermissionDenied('CSRF Failed: %s' % reason)
+ raise exceptions.AuthenticationFailed('CSRF Failed: %s' % reason)
# CSRF passed with authenticated user
return (user, None)
@@ -116,14 +133,26 @@ class TokenAuthentication(BaseAuthentication):
def authenticate(self, request):
auth = request.META.get('HTTP_AUTHORIZATION', '').split()
- if len(auth) == 2 and auth[0].lower() == "token":
- key = auth[1]
- try:
- token = self.model.objects.get(key=key)
- except self.model.DoesNotExist:
- return None
+ if not auth or auth[0].lower() != "token":
+ return None
+
+ if len(auth) != 2:
+ raise exceptions.AuthenticationFailed('Invalid token header')
+
+ return self.authenticate_credentials(auth[1])
+
+ def authenticate_credentials(self, key):
+ try:
+ token = self.model.objects.get(key=key)
+ except self.model.DoesNotExist:
+ raise exceptions.AuthenticationFailed('Invalid token')
+
+ if token.user.is_active:
+ return (token.user, token)
+ raise exceptions.AuthenticationFailed('User inactive or deleted')
+
+ def authenticate_header(self, request):
+ return 'Token'
- if token.user.is_active:
- return (token.user, token)
# TODO: OAuthAuthentication
diff --git a/rest_framework/authtoken/views.py b/rest_framework/authtoken/views.py
index d318c723..7c03cb76 100644
--- a/rest_framework/authtoken/views.py
+++ b/rest_framework/authtoken/views.py
@@ -12,10 +12,11 @@ class ObtainAuthToken(APIView):
permission_classes = ()
parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,)
renderer_classes = (renderers.JSONRenderer,)
+ serializer_class = AuthTokenSerializer
model = Token
def post(self, request):
- serializer = AuthTokenSerializer(data=request.DATA)
+ serializer = self.serializer_class(data=request.DATA)
if serializer.is_valid():
token, created = Token.objects.get_or_create(user=serializer.object['user'])
return Response({'token': token.key})
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index 5924cd6d..ef11b85b 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -126,6 +126,12 @@ else:
update_wrapper(view, cls.dispatch, assigned=())
return view
+# Taken from @markotibold's attempt at supporting PATCH.
+# https://github.com/markotibold/django-rest-framework/tree/patch
+http_method_names = set(View.http_method_names)
+http_method_names.add('patch')
+View.http_method_names = list(http_method_names) # PATCH method is not implemented by Django
+
# PUT, DELETE do not require CSRF until 1.4. They should. Make it better.
if django.VERSION >= (1, 4):
from django.middleware.csrf import CsrfViewMiddleware
diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py
index 1b710a03..7a4103e1 100644
--- a/rest_framework/decorators.py
+++ b/rest_framework/decorators.py
@@ -1,4 +1,5 @@
from rest_framework.views import APIView
+import types
def api_view(http_method_names):
@@ -23,6 +24,14 @@ def api_view(http_method_names):
# pass
# WrappedAPIView.__doc__ = func.doc <--- Not possible to do this
+ # api_view applied without (method_names)
+ assert not(isinstance(http_method_names, types.FunctionType)), \
+ '@api_view missing list of allowed HTTP methods'
+
+ # api_view applied with eg. string instead of list of strings
+ assert isinstance(http_method_names, (list, tuple)), \
+ '@api_view expected a list of strings, recieved %s' % type(http_method_names).__name__
+
allowed_methods = set(http_method_names) | set(('options',))
WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods]
diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py
index 89479deb..d635351c 100644
--- a/rest_framework/exceptions.py
+++ b/rest_framework/exceptions.py
@@ -23,6 +23,22 @@ class ParseError(APIException):
self.detail = detail or self.default_detail
+class AuthenticationFailed(APIException):
+ status_code = status.HTTP_401_UNAUTHORIZED
+ default_detail = 'Incorrect authentication credentials.'
+
+ def __init__(self, detail=None):
+ self.detail = detail or self.default_detail
+
+
+class NotAuthenticated(APIException):
+ status_code = status.HTTP_401_UNAUTHORIZED
+ default_detail = 'Authentication credentials were not provided.'
+
+ def __init__(self, detail=None):
+ self.detail = detail or self.default_detail
+
+
class PermissionDenied(APIException):
status_code = status.HTTP_403_FORBIDDEN
default_detail = 'You do not have permission to perform this action.'
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index adea5bf5..a66e1d7c 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -1,4 +1,3 @@
-
from __future__ import unicode_literals
import copy
@@ -185,11 +184,13 @@ class WritableField(Field):
try:
if self._use_files:
+ files = files or {}
native = files[field_name]
else:
native = data[field_name]
except KeyError:
- if self.default is not None:
+ if self.default is not None and not self.root.partial:
+ # Note: partial updates shouldn't set defaults
native = self.default
else:
if self.required:
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index dd8dfcf8..19f2b704 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -47,14 +47,16 @@ class GenericAPIView(views.APIView):
return serializer_class
- def get_serializer(self, instance=None, data=None, files=None):
+ def get_serializer(self, instance=None, data=None,
+ files=None, partial=False):
"""
Return the serializer instance that should be used for validating and
deserializing input, and for serializing output.
"""
serializer_class = self.get_serializer_class()
context = self.get_serializer_context()
- return serializer_class(instance, data=data, files=files, context=context)
+ return serializer_class(instance, data=data, files=files,
+ partial=partial, context=context)
class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView):
@@ -171,6 +173,10 @@ class UpdateAPIView(mixins.UpdateModelMixin,
def put(self, request, *args, **kwargs):
return self.update(request, *args, **kwargs)
+ def patch(self, request, *args, **kwargs):
+ kwargs['partial'] = True
+ return self.update(request, *args, **kwargs)
+
class ListCreateAPIView(mixins.ListModelMixin,
mixins.CreateModelMixin,
@@ -185,6 +191,23 @@ class ListCreateAPIView(mixins.ListModelMixin,
return self.create(request, *args, **kwargs)
+class RetrieveUpdateAPIView(mixins.RetrieveModelMixin,
+ mixins.UpdateModelMixin,
+ SingleObjectAPIView):
+ """
+ Concrete view for retrieving, updating a model instance.
+ """
+ def get(self, request, *args, **kwargs):
+ return self.retrieve(request, *args, **kwargs)
+
+ def put(self, request, *args, **kwargs):
+ return self.update(request, *args, **kwargs)
+
+ def patch(self, request, *args, **kwargs):
+ kwargs['partial'] = True
+ return self.update(request, *args, **kwargs)
+
+
class RetrieveDestroyAPIView(mixins.RetrieveModelMixin,
mixins.DestroyModelMixin,
SingleObjectAPIView):
@@ -211,5 +234,9 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,
def put(self, request, *args, **kwargs):
return self.update(request, *args, **kwargs)
+ def patch(self, request, *args, **kwargs):
+ kwargs['partial'] = True
+ return self.update(request, *args, **kwargs)
+
def delete(self, request, *args, **kwargs):
return self.destroy(request, *args, **kwargs)
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index 503376ce..acaf8a71 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -18,11 +18,14 @@ class CreateModelMixin(object):
"""
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.DATA, files=request.FILES)
+
if serializer.is_valid():
self.pre_save(serializer.object)
self.object = serializer.save()
headers = self.get_success_headers(serializer.data)
- return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
+ return Response(serializer.data, status=status.HTTP_201_CREATED,
+ headers=headers)
+
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
def get_success_headers(self, data):
@@ -84,20 +87,21 @@ class UpdateModelMixin(object):
Should be mixed in with `SingleObjectBaseView`.
"""
def update(self, request, *args, **kwargs):
+ partial = kwargs.pop('partial', False)
try:
self.object = self.get_object()
- created = False
+ success_status_code = status.HTTP_200_OK
except Http404:
self.object = None
- created = True
+ success_status_code = status.HTTP_201_CREATED
- serializer = self.get_serializer(self.object, data=request.DATA, files=request.FILES)
+ serializer = self.get_serializer(self.object, data=request.DATA,
+ files=request.FILES, partial=partial)
if serializer.is_valid():
self.pre_save(serializer.object)
self.object = serializer.save()
- status_code = created and status.HTTP_201_CREATED or status.HTTP_200_OK
- return Response(serializer.data, status=status_code)
+ return Response(serializer.data, status=success_status_code)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
@@ -117,7 +121,8 @@ class UpdateModelMixin(object):
# Ensure we clean the attributes so that we don't eg return integer
# pk using a string representation, as provided by the url conf kwarg.
- obj.full_clean()
+ if hasattr(obj, 'full_clean'):
+ obj.full_clean()
class DestroyModelMixin(object):
diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py
index d241ade7..92d41e0e 100644
--- a/rest_framework/pagination.py
+++ b/rest_framework/pagination.py
@@ -34,6 +34,17 @@ class PreviousPageField(serializers.Field):
return replace_query_param(url, self.page_field, page)
+class DefaultObjectSerializer(serializers.Field):
+ """
+ If no object serializer is specified, then this serializer will be applied
+ as the default.
+ """
+
+ def __init__(self, source=None, context=None):
+ # Note: Swallow context kwarg - only required for eg. ModelSerializer.
+ super(DefaultObjectSerializer, self).__init__(source=source)
+
+
class PaginationSerializerOptions(serializers.SerializerOptions):
"""
An object that stores the options that may be provided to a
@@ -44,7 +55,7 @@ class PaginationSerializerOptions(serializers.SerializerOptions):
def __init__(self, meta):
super(PaginationSerializerOptions, self).__init__(meta)
self.object_serializer_class = getattr(meta, 'object_serializer_class',
- serializers.Field)
+ DefaultObjectSerializer)
class BasePaginationSerializer(serializers.Serializer):
@@ -62,14 +73,13 @@ class BasePaginationSerializer(serializers.Serializer):
super(BasePaginationSerializer, self).__init__(*args, **kwargs)
results_field = self.results_field
object_serializer = self.opts.object_serializer_class
- self.fields[results_field] = object_serializer(source='object_list')
- def to_native(self, obj):
- """
- Prevent default behaviour of iterating over elements, and serializing
- each in turn.
- """
- return self.convert_object(obj)
+ if 'context' in kwargs:
+ context_kwarg = {'context': kwargs['context']}
+ else:
+ context_kwarg = {}
+
+ self.fields[results_field] = object_serializer(source='object_list', **context_kwarg)
class PaginationSerializer(BasePaginationSerializer):
diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py
index 7c01006a..4a2b34a5 100644
--- a/rest_framework/parsers.py
+++ b/rest_framework/parsers.py
@@ -8,12 +8,12 @@ on the request, such as form content or json encoded data.
from django.http import QueryDict
from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser
from django.http.multipartparser import MultiPartParserError
-from django.utils import simplejson as json
from rest_framework.compat import yaml, ETParseError
from rest_framework.exceptions import ParseError
from rest_framework.compat import six
from xml.etree import ElementTree as ET
from xml.parsers.expat import ExpatError
+import json
import datetime
import decimal
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index b7a6e0c1..c4f854ef 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -6,6 +6,7 @@ from django.core.urlresolvers import resolve, get_script_prefix
from django import forms
from django.forms import widgets
from django.forms.models import ModelChoiceIterator
+from django.utils.translation import ugettext_lazy as _
from rest_framework.fields import Field, WritableField
from rest_framework.reverse import reverse
from rest_framework.compat import urlparse
@@ -103,7 +104,13 @@ class RelatedField(WritableField):
### Regular serializer stuff...
def field_to_native(self, obj, field_name):
- value = getattr(obj, self.source or field_name)
+ try:
+ value = getattr(obj, self.source or field_name)
+ except ObjectDoesNotExist:
+ return None
+
+ if value is None:
+ return None
return self.to_native(value)
def field_from_native(self, data, files, field_name, into):
@@ -144,7 +151,7 @@ class ManyRelatedMixin(object):
value = data.getlist(self.source or field_name)
except:
# Non-form data
- value = data.get(self.source or field_name)
+ value = data.get(self.source or field_name, [])
else:
if value == ['']:
value = []
@@ -171,6 +178,11 @@ class PrimaryKeyRelatedField(RelatedField):
default_read_only = False
form_field_class = forms.ChoiceField
+ default_error_messages = {
+ 'does_not_exist': _("Invalid pk '%s' - object does not exist."),
+ 'incorrect_type': _('Incorrect type. Expected pk value, received %s.'),
+ }
+
# TODO: Remove these field hacks...
def prepare_value(self, obj):
return self.to_native(obj.pk)
@@ -196,7 +208,11 @@ class PrimaryKeyRelatedField(RelatedField):
try:
return self.queryset.get(pk=data)
except ObjectDoesNotExist:
- msg = "Invalid pk '%s' - object does not exist." % smart_text(data)
+ msg = self.error_messages['does_not_exist'] % smart_text(data)
+ raise ValidationError(msg)
+ except (TypeError, ValueError):
+ received = type(data).__name__
+ msg = self.error_messages['incorrect_type'] % received
raise ValidationError(msg)
def field_to_native(self, obj, field_name):
@@ -205,7 +221,10 @@ class PrimaryKeyRelatedField(RelatedField):
pk = obj.serializable_value(self.source or field_name)
except AttributeError:
# RelatedObject (reverse relationship)
- obj = getattr(obj, self.source or field_name)
+ try:
+ obj = getattr(obj, self.source or field_name)
+ except ObjectDoesNotExist:
+ return None
return self.to_native(obj.pk)
# Forward relationship
return self.to_native(pk)
@@ -218,6 +237,11 @@ class ManyPrimaryKeyRelatedField(ManyRelatedField):
default_read_only = False
form_field_class = forms.MultipleChoiceField
+ default_error_messages = {
+ 'does_not_exist': _("Invalid pk '%s' - object does not exist."),
+ 'incorrect_type': _('Incorrect type. Expected pk value, received %s.'),
+ }
+
def prepare_value(self, obj):
return self.to_native(obj.pk)
@@ -252,7 +276,11 @@ class ManyPrimaryKeyRelatedField(ManyRelatedField):
try:
return self.queryset.get(pk=data)
except ObjectDoesNotExist:
- msg = "Invalid pk '%s' - object does not exist." % smart_text(data)
+ msg = self.error_messages['does_not_exist'] % smart_text(data)
+ raise ValidationError(msg)
+ except (TypeError, ValueError):
+ received = type(data).__name__
+ msg = self.error_messages['incorrect_type'] % received
raise ValidationError(msg)
### Slug relationships
@@ -262,6 +290,11 @@ class SlugRelatedField(RelatedField):
default_read_only = False
form_field_class = forms.ChoiceField
+ default_error_messages = {
+ 'does_not_exist': _("Object with %s=%s does not exist."),
+ 'invalid': _('Invalid value.'),
+ }
+
def __init__(self, *args, **kwargs):
self.slug_field = kwargs.pop('slug_field', None)
assert self.slug_field, 'slug_field is required'
@@ -277,8 +310,11 @@ class SlugRelatedField(RelatedField):
try:
return self.queryset.get(**{self.slug_field: data})
except ObjectDoesNotExist:
- raise ValidationError('Object with %s=%s does not exist.' %
+ raise ValidationError(self.error_messages['does_not_exist'] %
(self.slug_field, unicode(data)))
+ except (TypeError, ValueError):
+ msg = self.error_messages['invalid']
+ raise ValidationError(msg)
class ManySlugRelatedField(ManyRelatedMixin, SlugRelatedField):
@@ -297,6 +333,14 @@ class HyperlinkedRelatedField(RelatedField):
default_read_only = False
form_field_class = forms.ChoiceField
+ default_error_messages = {
+ 'no_match': _('Invalid hyperlink - No URL match'),
+ 'incorrect_match': _('Invalid hyperlink - Incorrect URL match'),
+ 'configuration_error': _('Invalid hyperlink due to configuration error'),
+ 'does_not_exist': _("Invalid hyperlink - object does not exist."),
+ 'incorrect_type': _('Incorrect type. Expected url string, received %s.'),
+ }
+
def __init__(self, *args, **kwargs):
try:
self.view_name = kwargs.pop('view_name')
@@ -333,21 +377,21 @@ class HyperlinkedRelatedField(RelatedField):
slug = getattr(obj, self.slug_field, None)
if not slug:
- raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name)
+ raise Exception('Could not resolve URL for field using view name "%s"' % view_name)
kwargs = {self.slug_url_kwarg: slug}
try:
- return reverse(self.view_name, kwargs=kwargs, request=request, format=format)
+ return reverse(view_name, kwargs=kwargs, request=request, format=format)
except:
pass
kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug}
try:
- return reverse(self.view_name, kwargs=kwargs, request=request, format=format)
+ return reverse(view_name, kwargs=kwargs, request=request, format=format)
except:
pass
- raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name)
+ raise Exception('Could not resolve URL for field using view name "%s"' % view_name)
def from_native(self, value):
# Convert URL -> model instance pk
@@ -355,7 +399,13 @@ class HyperlinkedRelatedField(RelatedField):
if self.queryset is None:
raise Exception('Writable related fields must include a `queryset` argument')
- if value.startswith('http:') or value.startswith('https:'):
+ try:
+ http_prefix = value.startswith('http:') or value.startswith('https:')
+ except AttributeError:
+ msg = self.error_messages['incorrect_type']
+ raise ValidationError(msg % type(value).__name__)
+
+ if http_prefix:
# If needed convert absolute URLs to relative path
value = urlparse.urlparse(value).path
prefix = get_script_prefix()
@@ -365,10 +415,10 @@ class HyperlinkedRelatedField(RelatedField):
try:
match = resolve(value)
except:
- raise ValidationError('Invalid hyperlink - No URL match')
+ raise ValidationError(self.error_messages['no_match'])
- if match.url_name != self.view_name:
- raise ValidationError('Invalid hyperlink - Incorrect URL match')
+ if match.view_name != self.view_name:
+ raise ValidationError(self.error_messages['incorrect_match'])
pk = match.kwargs.get(self.pk_url_kwarg, None)
slug = match.kwargs.get(self.slug_url_kwarg, None)
@@ -380,14 +430,18 @@ class HyperlinkedRelatedField(RelatedField):
elif slug is not None:
slug_field = self.get_slug_field()
queryset = self.queryset.filter(**{slug_field: slug})
- # If none of those are defined, it's an error.
+ # If none of those are defined, it's probably a configuation error.
else:
- raise ValidationError('Invalid hyperlink')
+ raise ValidationError(self.error_messages['configuration_error'])
try:
obj = queryset.get()
except ObjectDoesNotExist:
- raise ValidationError('Invalid hyperlink - object does not exist.')
+ raise ValidationError(self.error_messages['does_not_exist'])
+ except (TypeError, ValueError):
+ msg = self.error_messages['incorrect_type']
+ raise ValidationError(msg % type(value).__name__)
+
return obj
@@ -410,6 +464,7 @@ class HyperlinkedIdentityField(Field):
# TODO: Make view_name mandatory, and have the
# HyperlinkedModelSerializer set it on-the-fly
self.view_name = kwargs.pop('view_name', None)
+ # Optionally the format of the target hyperlink may be specified
self.format = kwargs.pop('format', None)
self.slug_field = kwargs.pop('slug_field', self.slug_field)
@@ -421,9 +476,22 @@ class HyperlinkedIdentityField(Field):
def field_to_native(self, obj, field_name):
request = self.context.get('request', None)
- format = self.format or self.context.get('format', None)
+ format = self.context.get('format', None)
view_name = self.view_name or self.parent.opts.view_name
kwargs = {self.pk_url_kwarg: obj.pk}
+
+ # By default use whatever format is given for the current context
+ # unless the target is a different type to the source.
+ #
+ # Eg. Consider a HyperlinkedIdentityField pointing from a json
+ # representation to an html property of that representation...
+ #
+ # '/snippets/1/' should link to '/snippets/1/highlight/'
+ # ...but...
+ # '/snippets/1/.json' should link to '/snippets/1/highlight/.html'
+ if format and self.format and self.format != format:
+ format = self.format
+
try:
return reverse(view_name, kwargs=kwargs, request=request, format=format)
except:
@@ -432,18 +500,18 @@ class HyperlinkedIdentityField(Field):
slug = getattr(obj, self.slug_field, None)
if not slug:
- raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name)
+ raise Exception('Could not resolve URL for field using view name "%s"' % view_name)
kwargs = {self.slug_url_kwarg: slug}
try:
- return reverse(self.view_name, kwargs=kwargs, request=request, format=format)
+ return reverse(view_name, kwargs=kwargs, request=request, format=format)
except:
pass
kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug}
try:
- return reverse(self.view_name, kwargs=kwargs, request=request, format=format)
+ return reverse(view_name, kwargs=kwargs, request=request, format=format)
except:
pass
- raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name)
+ raise Exception('Could not resolve URL for field using view name "%s"' % view_name)
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index 54930167..b3ee0690 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -10,10 +10,10 @@ from __future__ import unicode_literals
import copy
import string
+import json
from django import forms
from django.http.multipartparser import parse_header
from django.template import RequestContext, loader, Template
-from django.utils import simplejson as json
from rest_framework.compat import yaml
from rest_framework.exceptions import ConfigurationError
from rest_framework.settings import api_settings
diff --git a/rest_framework/request.py b/rest_framework/request.py
index 048a1c41..23e1da87 100644
--- a/rest_framework/request.py
+++ b/rest_framework/request.py
@@ -86,6 +86,7 @@ class Request(object):
self._method = Empty
self._content_type = Empty
self._stream = Empty
+ self._authenticator = None
if self.parser_context is None:
self.parser_context = {}
@@ -166,7 +167,7 @@ class Request(object):
by the authentication classes provided to the request.
"""
if not hasattr(self, '_user'):
- self._user, self._auth = self._authenticate()
+ self._authenticator, self._user, self._auth = self._authenticate()
return self._user
@user.setter
@@ -185,7 +186,7 @@ class Request(object):
request, such as an authentication token.
"""
if not hasattr(self, '_auth'):
- self._user, self._auth = self._authenticate()
+ self._authenticator, self._user, self._auth = self._authenticate()
return self._auth
@auth.setter
@@ -196,6 +197,14 @@ class Request(object):
"""
self._auth = value
+ @property
+ def successful_authenticator(self):
+ """
+ Return the instance of the authentication instance class that was used
+ to authenticate the request, or `None`.
+ """
+ return self._authenticator
+
def _load_data_and_files(self):
"""
Parses the request content into self.DATA and self.FILES.
@@ -299,21 +308,23 @@ class Request(object):
def _authenticate(self):
"""
- Attempt to authenticate the request using each authentication instance in turn.
- Returns a two-tuple of (user, authtoken).
+ Attempt to authenticate the request using each authentication instance
+ in turn.
+ Returns a three-tuple of (authenticator, user, authtoken).
"""
for authenticator in self.authenticators:
user_auth_tuple = authenticator.authenticate(self)
if not user_auth_tuple is None:
- return user_auth_tuple
+ user, auth = user_auth_tuple
+ return (authenticator, user, auth)
return self._not_authenticated()
def _not_authenticated(self):
"""
- Return a two-tuple of (user, authtoken), representing an
- unauthenticated request.
+ Return a three-tuple of (authenticator, user, authtoken), representing
+ an unauthenticated request.
- By default this will be (AnonymousUser, None).
+ By default this will be (None, AnonymousUser, None).
"""
if api_settings.UNAUTHENTICATED_USER:
user = api_settings.UNAUTHENTICATED_USER()
@@ -325,7 +336,7 @@ class Request(object):
else:
auth = None
- return (user, auth)
+ return (None, user, auth)
def __getattr__(self, attr):
"""
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 663f166b..3d3bcb3c 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -2,6 +2,7 @@ import copy
import datetime
import types
from decimal import Decimal
+from django.core.paginator import Page
from django.db import models
from django.forms import widgets
from django.utils.datastructures import SortedDict
@@ -209,6 +210,11 @@ class BaseSerializer(Field):
Converts a dictionary of data into a dictionary of deserialized fields.
"""
reverted_data = {}
+
+ if data is not None and not isinstance(data, dict):
+ self._errors['non_field_errors'] = [u'Invalid data']
+ return None
+
for field_name, field in self.fields.items():
field.initialize(parent=self, field_name=field_name)
try:
@@ -223,6 +229,8 @@ class BaseSerializer(Field):
Run `validate_<fieldname>()` and `validate()` methods on the serializer
"""
for field_name, field in self.fields.items():
+ if field_name in self._errors:
+ continue
try:
validate_method = getattr(self, 'validate_%s' % field_name, None)
if validate_method:
@@ -267,7 +275,11 @@ class BaseSerializer(Field):
"""
Serialize objects -> primitives.
"""
- if hasattr(obj, '__iter__'):
+ # Note: At the moment we have an ugly hack to determine if we should
+ # walk over iterables. At some point, serializers will require an
+ # explicit `many=True` in order to iterate over a set, and this hack
+ # will disappear.
+ if hasattr(obj, '__iter__') and not isinstance(obj, Page):
return [self.convert_object(item) for item in obj]
return self.convert_object(obj)
@@ -277,7 +289,7 @@ class BaseSerializer(Field):
"""
if hasattr(data, '__iter__') and not isinstance(data, dict):
# TODO: error data when deserializing lists
- return (self.from_native(item) for item in data)
+ return [self.from_native(item, None) for item in data]
self._errors = {}
if data is not None or files is not None:
@@ -294,15 +306,21 @@ class BaseSerializer(Field):
Override default so that we can apply ModelSerializer as a nested
field to relationships.
"""
- if self.source:
- for component in self.source.split('.'):
- obj = getattr(obj, component)
+ if self.source == '*':
+ return self.to_native(obj)
+
+ try:
+ if self.source:
+ for component in self.source.split('.'):
+ obj = getattr(obj, component)
+ if is_simple_callable(obj):
+ obj = obj()
+ else:
+ obj = getattr(obj, field_name)
if is_simple_callable(obj):
obj = obj()
- else:
- obj = getattr(obj, field_name)
- if is_simple_callable(obj):
- obj = value()
+ except ObjectDoesNotExist:
+ return None
# If the object has an "all" method, assume it's a relationship
if is_simple_callable(getattr(obj, 'all', None)):
@@ -408,7 +426,7 @@ class ModelSerializer(Serializer):
"""
Returns a default instance of the pk field.
"""
- return Field()
+ return self.get_field(model_field)
def get_nested_field(self, model_field):
"""
@@ -426,7 +444,7 @@ class ModelSerializer(Serializer):
# TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to)
kwargs = {
- 'null': model_field.null,
+ 'null': model_field.null or model_field.blank,
'queryset': model_field.rel.to._default_manager
}
@@ -445,11 +463,14 @@ class ModelSerializer(Serializer):
if model_field.null or model_field.blank:
kwargs['required'] = False
+ if isinstance(model_field, models.AutoField) or not model_field.editable:
+ kwargs['read_only'] = True
+
if model_field.has_default():
kwargs['required'] = False
kwargs['default'] = model_field.get_default()
- if model_field.__class__ == models.TextField:
+ if issubclass(model_field.__class__, models.TextField):
kwargs['widget'] = widgets.Textarea
# TODO: TypedChoiceField?
@@ -458,6 +479,7 @@ class ModelSerializer(Serializer):
return ChoiceField(**kwargs)
field_mapping = {
+ models.AutoField: IntegerField,
models.FloatField: FloatField,
models.IntegerField: IntegerField,
models.PositiveIntegerField: IntegerField,
@@ -492,6 +514,22 @@ class ModelSerializer(Serializer):
exclusions.remove(field_name)
return exclusions
+ def full_clean(self, instance):
+ """
+ Perform Django's full_clean, and populate the `errors` dictionary
+ if any validation errors occur.
+
+ Note that we don't perform this inside the `.restore_object()` method,
+ so that subclasses can override `.restore_object()`, and still get
+ the full_clean validation checking.
+ """
+ try:
+ instance.full_clean(exclude=self.get_validation_exclusions())
+ except ValidationError, err:
+ self._errors = err.message_dict
+ return None
+ return instance
+
def restore_object(self, attrs, instance=None):
"""
Restore the model instance.
@@ -531,13 +569,21 @@ class ModelSerializer(Serializer):
return instance
- def save(self, save_m2m=True):
+ def from_native(self, data, files):
+ """
+ Override the default method to also include model field validation.
+ """
+ instance = super(ModelSerializer, self).from_native(data, files)
+ if instance:
+ return self.full_clean(instance)
+
+ def save(self):
"""
Save the deserialized object and return it.
"""
self.object.save()
- if getattr(self, 'm2m_data', None) and save_m2m:
+ if getattr(self, 'm2m_data', None):
for accessor_name, object_list in self.m2m_data.items():
setattr(self.object, accessor_name, object_list)
self.m2m_data = {}
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index 2358d188..13d03e62 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -119,9 +119,8 @@ def import_from_string(val, setting_name):
module_path, class_name = '.'.join(parts[:-1]), parts[-1]
module = importlib.import_module(module_path)
return getattr(module, class_name)
- except:
- raise
- msg = "Could not import '%s' for API setting '%s'" % (val, setting_name)
+ except ImportError as e:
+ msg = "Could not import '%s' for API setting '%s'. %s: %s." % (val, setting_name, e.__class__.__name__, e)
raise ImportError(msg)
diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html
index 42e49cb9..092bf2e4 100644
--- a/rest_framework/templates/rest_framework/base.html
+++ b/rest_framework/templates/rest_framework/base.html
@@ -112,7 +112,7 @@
<div class="request-info">
<pre class="prettyprint"><b>{{ request.method }}</b> {{ request.get_full_path }}</pre>
- <div>
+ </div>
<div class="response-info">
<pre class="prettyprint"><div class="meta nocode"><b>HTTP {{ response.status_code }} {{ response.status_text }}</b>{% autoescape off %}
{% for key, val in response.items %}<b>{{ key }}:</b> <span class="lit">{{ val|urlize_quoted_links }}</span>
diff --git a/rest_framework/templates/rest_framework/login.html b/rest_framework/templates/rest_framework/login.html
index 6e2bd8d4..e10ce20f 100644
--- a/rest_framework/templates/rest_framework/login.html
+++ b/rest_framework/templates/rest_framework/login.html
@@ -25,14 +25,14 @@
<form action="{% url 'rest_framework:login' %}" class=" form-inline" method="post">
{% csrf_token %}
<div id="div_id_username" class="clearfix control-group">
- <div class="controls" style="height: 30px">
- <Label class="span4" style="margin-top: 3px">Username:</label>
+ <div class="controls">
+ <Label class="span4">Username:</label>
<input style="height: 25px" type="text" name="username" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_username">
</div>
</div>
<div id="div_id_password" class="clearfix control-group">
- <div class="controls" style="height: 30px">
- <Label class="span4" style="margin-top: 3px">Password:</label>
+ <div class="controls">
+ <Label class="span4">Password:</label>
<input style="height: 25px" type="password" name="password" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_password">
</div>
</div>
diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py
index 4205e57c..cbafbe0e 100644
--- a/rest_framework/templatetags/rest_framework.py
+++ b/rest_framework/templatetags/rest_framework.py
@@ -27,7 +27,7 @@ register = template.Library()
# conflicts with this rest_framework template tag module.
try: # Django 1.5+
- from django.contrib.staticfiles.templatetags import StaticFilesNode
+ from django.contrib.staticfiles.templatetags.staticfiles import StaticFilesNode
@register.tag('static')
def do_static(parser, token):
diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py
index 8c0bfc47..ba2042cb 100644
--- a/rest_framework/tests/authentication.py
+++ b/rest_framework/tests/authentication.py
@@ -1,14 +1,13 @@
from django.contrib.auth.models import User
from django.http import HttpResponse
from django.test import Client, TestCase
-from django.utils import simplejson as json
-
from rest_framework import permissions
from rest_framework.authtoken.models import Token
-from rest_framework.authentication import TokenAuthentication
+from rest_framework.authentication import TokenAuthentication, BasicAuthentication, SessionAuthentication
from rest_framework.compat import patterns
from rest_framework.views import APIView
+import json
import base64
@@ -21,10 +20,10 @@ class MockView(APIView):
def put(self, request):
return HttpResponse({'a': 1, 'b': 2, 'c': 3})
-MockView.authentication_classes += (TokenAuthentication,)
-
urlpatterns = patterns('',
- (r'^$', MockView.as_view()),
+ (r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])),
+ (r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])),
+ (r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),
(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),
)
@@ -42,25 +41,26 @@ class BasicAuthTests(TestCase):
def test_post_form_passing_basic_auth(self):
"""Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF"""
- auth = 'Basic ' + base64.encodestring(('%s:%s' % (self.username, self.password)).encode('iso-8859-1')).strip().decode('iso-8859-1')
- response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).encode('iso-8859-1').strip().decode('iso-8859-1')
+ response = self.csrf_client.post('/basic/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
def test_post_json_passing_basic_auth(self):
"""Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF"""
- auth = 'Basic ' + base64.encodestring(('%s:%s' % (self.username, self.password)).encode('iso-8859-1')).strip().decode('iso-8859-1')
- response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)
+ auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).encode('iso-8859-1').strip().decode('iso-8859-1')
+ response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
def test_post_form_failing_basic_auth(self):
"""Ensure POSTing form over basic auth without correct credentials fails"""
- response = self.csrf_client.post('/', {'example': 'example'})
- self.assertEqual(response.status_code, 403)
+ response = self.csrf_client.post('/basic/', {'example': 'example'})
+ self.assertEqual(response.status_code, 401)
def test_post_json_failing_basic_auth(self):
"""Ensure POSTing json over basic auth without correct credentials fails"""
- response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json')
- self.assertEqual(response.status_code, 403)
+ response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json')
+ self.assertEqual(response.status_code, 401)
+ self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"')
class SessionAuthTests(TestCase):
@@ -83,7 +83,7 @@ class SessionAuthTests(TestCase):
Ensure POSTing form over session authentication without CSRF token fails.
"""
self.csrf_client.login(username=self.username, password=self.password)
- response = self.csrf_client.post('/', {'example': 'example'})
+ response = self.csrf_client.post('/session/', {'example': 'example'})
self.assertEqual(response.status_code, 403)
def test_post_form_session_auth_passing(self):
@@ -91,7 +91,7 @@ class SessionAuthTests(TestCase):
Ensure POSTing form over session authentication with logged in user and CSRF token passes.
"""
self.non_csrf_client.login(username=self.username, password=self.password)
- response = self.non_csrf_client.post('/', {'example': 'example'})
+ response = self.non_csrf_client.post('/session/', {'example': 'example'})
self.assertEqual(response.status_code, 200)
def test_put_form_session_auth_passing(self):
@@ -99,14 +99,14 @@ class SessionAuthTests(TestCase):
Ensure PUTting form over session authentication with logged in user and CSRF token passes.
"""
self.non_csrf_client.login(username=self.username, password=self.password)
- response = self.non_csrf_client.put('/', {'example': 'example'})
+ response = self.non_csrf_client.put('/session/', {'example': 'example'})
self.assertEqual(response.status_code, 200)
def test_post_form_session_auth_failing(self):
"""
Ensure POSTing form over session authentication without logged in user fails.
"""
- response = self.csrf_client.post('/', {'example': 'example'})
+ response = self.csrf_client.post('/session/', {'example': 'example'})
self.assertEqual(response.status_code, 403)
@@ -127,24 +127,24 @@ class TokenAuthTests(TestCase):
def test_post_form_passing_token_auth(self):
"""Ensure POSTing json over token auth with correct credentials passes and does not require CSRF"""
auth = "Token " + self.key
- response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
def test_post_json_passing_token_auth(self):
"""Ensure POSTing form over token auth with correct credentials passes and does not require CSRF"""
auth = "Token " + self.key
- response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
def test_post_form_failing_token_auth(self):
"""Ensure POSTing form over token auth without correct credentials fails"""
- response = self.csrf_client.post('/', {'example': 'example'})
- self.assertEqual(response.status_code, 403)
+ response = self.csrf_client.post('/token/', {'example': 'example'})
+ self.assertEqual(response.status_code, 401)
def test_post_json_failing_token_auth(self):
"""Ensure POSTing json over token auth without correct credentials fails"""
- response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json')
- self.assertEqual(response.status_code, 403)
+ response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json')
+ self.assertEqual(response.status_code, 401)
def test_token_has_auto_assigned_key_if_none_provided(self):
"""Ensure creating a token with no key will auto-assign a key"""
diff --git a/rest_framework/tests/decorators.py b/rest_framework/tests/decorators.py
index 8079c8cb..82f912e9 100644
--- a/rest_framework/tests/decorators.py
+++ b/rest_framework/tests/decorators.py
@@ -1,5 +1,4 @@
from django.test import TestCase
-from django.test.client import RequestFactory
from rest_framework import status
from rest_framework.response import Response
from rest_framework.renderers import JSONRenderer
@@ -17,6 +16,8 @@ from rest_framework.decorators import (
permission_classes,
)
+from rest_framework.tests.utils import RequestFactory
+
class DecoratorTestCase(TestCase):
@@ -27,13 +28,27 @@ class DecoratorTestCase(TestCase):
response.request = request
return APIView.finalize_response(self, request, response, *args, **kwargs)
- def test_wrap_view(self):
+ def test_api_view_incorrect(self):
+ """
+ If @api_view is not applied correct, we should raise an assertion.
+ """
- @api_view(['GET'])
+ @api_view
def view(request):
- return Response({})
+ return Response()
+
+ request = self.factory.get('/')
+ self.assertRaises(AssertionError, view, request)
- self.assertTrue(isinstance(view.cls_instance, APIView))
+ def test_api_view_incorrect_arguments(self):
+ """
+ If @api_view is missing arguments, we should raise an assertion.
+ """
+
+ with self.assertRaises(AssertionError):
+ @api_view('GET')
+ def view(request):
+ return Response()
def test_calling_method(self):
@@ -63,6 +78,20 @@ class DecoratorTestCase(TestCase):
response = view(request)
self.assertEqual(response.status_code, 405)
+ def test_calling_patch_method(self):
+
+ @api_view(['GET', 'PATCH'])
+ def view(request):
+ return Response({})
+
+ request = self.factory.patch('/')
+ response = view(request)
+ self.assertEqual(response.status_code, 200)
+
+ request = self.factory.post('/')
+ response = view(request)
+ self.assertEqual(response.status_code, 405)
+
def test_renderer_classes(self):
@api_view(['GET'])
diff --git a/rest_framework/tests/extras/__init__.py b/rest_framework/tests/extras/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/rest_framework/tests/extras/__init__.py
diff --git a/rest_framework/tests/extras/bad_import.py b/rest_framework/tests/extras/bad_import.py
new file mode 100644
index 00000000..68263d94
--- /dev/null
+++ b/rest_framework/tests/extras/bad_import.py
@@ -0,0 +1 @@
+raise ValueError
diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py
new file mode 100644
index 00000000..8068272d
--- /dev/null
+++ b/rest_framework/tests/fields.py
@@ -0,0 +1,49 @@
+"""
+General serializer field tests.
+"""
+
+from django.db import models
+from django.test import TestCase
+from rest_framework import serializers
+
+
+class TimestampedModel(models.Model):
+ added = models.DateTimeField(auto_now_add=True)
+ updated = models.DateTimeField(auto_now=True)
+
+
+class CharPrimaryKeyModel(models.Model):
+ id = models.CharField(max_length=20, primary_key=True)
+
+
+class TimestampedModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = TimestampedModel
+
+
+class CharPrimaryKeyModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = CharPrimaryKeyModel
+
+
+class ReadOnlyFieldTests(TestCase):
+ def test_auto_now_fields_read_only(self):
+ """
+ auto_now and auto_now_add fields should be read_only by default.
+ """
+ serializer = TimestampedModelSerializer()
+ self.assertEquals(serializer.fields['added'].read_only, True)
+
+ def test_auto_pk_fields_read_only(self):
+ """
+ AutoField fields should be read_only by default.
+ """
+ serializer = TimestampedModelSerializer()
+ self.assertEquals(serializer.fields['id'].read_only, True)
+
+ def test_non_auto_pk_fields_not_read_only(self):
+ """
+ PK fields other than AutoField fields should not be read_only by default.
+ """
+ serializer = CharPrimaryKeyModelSerializer()
+ self.assertEquals(serializer.fields['id'].read_only, False)
diff --git a/rest_framework/tests/files.py b/rest_framework/tests/files.py
index ca6bc905..0434f900 100644
--- a/rest_framework/tests/files.py
+++ b/rest_framework/tests/files.py
@@ -26,7 +26,6 @@ class UploadedFileSerializer(serializers.Serializer):
class FileSerializerTests(TestCase):
-
def test_create(self):
now = datetime.datetime.now()
file = BytesIO(six.b('stuff'))
@@ -38,3 +37,16 @@ class FileSerializerTests(TestCase):
self.assertEquals(serializer.object.created, uploaded_file.created)
self.assertEquals(serializer.object.file, uploaded_file.file)
self.assertFalse(serializer.object is uploaded_file)
+
+ def test_creation_failure(self):
+ """
+ Passing files=None should result in an ValidationError
+
+ Regression test for:
+ https://github.com/tomchristie/django-rest-framework/issues/542
+ """
+ now = datetime.datetime.now()
+
+ serializer = UploadedFileSerializer(data={'created': now})
+ self.assertFalse(serializer.is_valid())
+ self.assertIn('file', serializer.errors)
diff --git a/rest_framework/tests/genericrelations.py b/rest_framework/tests/genericrelations.py
index ba29dbed..72070a1a 100644
--- a/rest_framework/tests/genericrelations.py
+++ b/rest_framework/tests/genericrelations.py
@@ -1,27 +1,63 @@
from __future__ import unicode_literals
+from django.contrib.contenttypes.models import ContentType
+from django.contrib.contenttypes.generic import GenericRelation, GenericForeignKey
+from django.db import models
from django.test import TestCase
from rest_framework import serializers
-from rest_framework.tests.models import *
+
+
+class Tag(models.Model):
+ """
+ Tags have a descriptive slug, and are attached to an arbitrary object.
+ """
+ tag = models.SlugField()
+ content_type = models.ForeignKey(ContentType)
+ object_id = models.PositiveIntegerField()
+ tagged_item = GenericForeignKey('content_type', 'object_id')
+
+ def __unicode__(self):
+ return self.tag
+
+
+class Bookmark(models.Model):
+ """
+ A URL bookmark that may have multiple tags attached.
+ """
+ url = models.URLField()
+ tags = GenericRelation(Tag)
+
+ def __unicode__(self):
+ return 'Bookmark: %s' % self.url
+
+
+class Note(models.Model):
+ """
+ A textual note that may have multiple tags attached.
+ """
+ text = models.TextField()
+ tags = GenericRelation(Tag)
+
+ def __unicode__(self):
+ return 'Note: %s' % self.text
class TestGenericRelations(TestCase):
def setUp(self):
- bookmark = Bookmark(url='https://www.djangoproject.com/')
- bookmark.save()
- django = Tag(tag_name='django')
- django.save()
- python = Tag(tag_name='python')
- python.save()
- t1 = TaggedItem(content_object=bookmark, tag=django)
- t1.save()
- t2 = TaggedItem(content_object=bookmark, tag=python)
- t2.save()
- self.bookmark = bookmark
-
- def test_reverse_generic_relation(self):
+ self.bookmark = Bookmark.objects.create(url='https://www.djangoproject.com/')
+ Tag.objects.create(tagged_item=self.bookmark, tag='django')
+ Tag.objects.create(tagged_item=self.bookmark, tag='python')
+ self.note = Note.objects.create(text='Remember the milk')
+ Tag.objects.create(tagged_item=self.note, tag='reminder')
+
+ def test_generic_relation(self):
+ """
+ Test a relationship that spans a GenericRelation field.
+ IE. A reverse generic relationship.
+ """
+
class BookmarkSerializer(serializers.ModelSerializer):
- tags = serializers.ManyRelatedField(source='tags')
+ tags = serializers.ManyRelatedField()
class Meta:
model = Bookmark
@@ -33,3 +69,33 @@ class TestGenericRelations(TestCase):
'url': 'https://www.djangoproject.com/'
}
self.assertEquals(serializer.data, expected)
+
+ def test_generic_fk(self):
+ """
+ Test a relationship that spans a GenericForeignKey field.
+ IE. A forward generic relationship.
+ """
+
+ class TagSerializer(serializers.ModelSerializer):
+ tagged_item = serializers.RelatedField()
+
+ class Meta:
+ model = Tag
+ exclude = ('id', 'content_type', 'object_id')
+
+ serializer = TagSerializer(Tag.objects.all())
+ expected = [
+ {
+ 'tag': u'django',
+ 'tagged_item': u'Bookmark: https://www.djangoproject.com/'
+ },
+ {
+ 'tag': u'python',
+ 'tagged_item': u'Bookmark: https://www.djangoproject.com/'
+ },
+ {
+ 'tag': u'reminder',
+ 'tagged_item': u'Note: Remember the milk'
+ }
+ ]
+ self.assertEquals(serializer.data, expected)
diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py
index 215de0c4..fd01312a 100644
--- a/rest_framework/tests/generics.py
+++ b/rest_framework/tests/generics.py
@@ -1,12 +1,11 @@
from __future__ import unicode_literals
-
from django.db import models
from django.test import TestCase
-from django.test.client import RequestFactory
-from django.utils import simplejson as json
from rest_framework import generics, serializers, status
+from rest_framework.tests.utils import RequestFactory
from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel
from rest_framework.compat import six
+import json
factory = RequestFactory()
@@ -183,6 +182,20 @@ class TestInstanceView(TestCase):
updated = self.objects.get(id=1)
self.assertEquals(updated.text, 'foobar')
+ def test_patch_instance_view(self):
+ """
+ PATCH requests to RetrieveUpdateDestroyAPIView should update an object.
+ """
+ content = {'text': 'foobar'}
+ request = factory.patch('/1', json.dumps(content),
+ content_type='application/json')
+
+ response = self.view(request, pk=1).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data, {'id': 1, 'text': 'foobar'})
+ updated = self.objects.get(id=1)
+ self.assertEquals(updated.text, 'foobar')
+
def test_delete_instance_view(self):
"""
DELETE requests to RetrieveUpdateDestroyAPIView should delete an object.
diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py
index ee4d8e57..c6a8224b 100644
--- a/rest_framework/tests/hyperlinkedserializers.py
+++ b/rest_framework/tests/hyperlinkedserializers.py
@@ -1,6 +1,6 @@
+import json
from django.test import TestCase
from django.test.client import RequestFactory
-from django.utils import simplejson as json
from rest_framework import generics, status, serializers
from rest_framework.compat import patterns, url
from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, Album, Photo, OptionalRelationModel
diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py
index 0759650a..9ab15328 100644
--- a/rest_framework/tests/models.py
+++ b/rest_framework/tests/models.py
@@ -71,6 +71,7 @@ class SlugBasedModel(RESTFrameworkModel):
class DefaultValueModel(RESTFrameworkModel):
text = models.CharField(default='foobar', max_length=100)
+ extra = models.CharField(blank=True, null=True, max_length=100)
class CallableDefaultValueModel(RESTFrameworkModel):
@@ -85,27 +86,6 @@ class ReadOnlyManyToManyModel(RESTFrameworkModel):
text = models.CharField(max_length=100, default='anchor')
rel = models.ManyToManyField(Anchor)
-# Models to test generic relations
-
-
-class Tag(RESTFrameworkModel):
- tag_name = models.SlugField()
-
-
-class TaggedItem(RESTFrameworkModel):
- tag = models.ForeignKey(Tag, related_name='items')
- content_type = models.ForeignKey(ContentType)
- object_id = models.PositiveIntegerField()
- content_object = GenericForeignKey('content_type', 'object_id')
-
- def __unicode__(self):
- return self.tag.tag_name
-
-
-class Bookmark(RESTFrameworkModel):
- url = models.URLField()
- tags = GenericRelation(TaggedItem)
-
# Model to test filtering.
class FilterableItem(RESTFrameworkModel):
@@ -176,3 +156,42 @@ class OptionalRelationModel(RESTFrameworkModel):
# Model for RegexField
class Book(RESTFrameworkModel):
isbn = models.CharField(max_length=13)
+
+
+# Models for relations tests
+# ManyToMany
+class ManyToManyTarget(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+
+
+class ManyToManySource(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+ targets = models.ManyToManyField(ManyToManyTarget, related_name='sources')
+
+
+# ForeignKey
+class ForeignKeyTarget(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+
+
+class ForeignKeySource(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+ target = models.ForeignKey(ForeignKeyTarget, related_name='sources')
+
+
+# Nullable ForeignKey
+class NullableForeignKeySource(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+ target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True,
+ related_name='nullable_sources')
+
+
+# OneToOne
+class OneToOneTarget(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+
+
+class NullableOneToOneSource(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+ target = models.OneToOneField(OneToOneTarget, null=True, blank=True,
+ related_name='nullable_source')
diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py
index 81d297a1..697dfb5b 100644
--- a/rest_framework/tests/pagination.py
+++ b/rest_framework/tests/pagination.py
@@ -181,10 +181,10 @@ class UnitTestPagination(TestCase):
"""
Ensure context gets passed through to the object serializer.
"""
- serializer = PassOnContextPaginationSerializer(self.first_page)
+ serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'})
serializer.data
results = serializer.fields[serializer.results_field]
- self.assertTrue(serializer.context is results.context)
+ self.assertEquals(serializer.context, results.context)
class TestUnpaginated(TestCase):
@@ -252,6 +252,8 @@ class TestCustomPaginateByParam(TestCase):
self.assertEquals(response.data['results'], self.data[:5])
+### Tests for context in pagination serializers
+
class CustomField(serializers.Field):
def to_native(self, value):
if not 'view' in self.context:
@@ -262,6 +264,11 @@ class CustomField(serializers.Field):
class BasicModelSerializer(serializers.Serializer):
text = CustomField()
+ def __init__(self, *args, **kwargs):
+ super(BasicModelSerializer, self).__init__(*args, **kwargs)
+ if not 'view' in self.context:
+ raise RuntimeError("context isn't getting passed into serializer init")
+
class TestContextPassedToCustomField(TestCase):
def setUp(self):
@@ -279,3 +286,39 @@ class TestContextPassedToCustomField(TestCase):
self.assertEquals(response.status_code, status.HTTP_200_OK)
+
+### Tests for custom pagination serializers
+
+class LinksSerializer(serializers.Serializer):
+ next = pagination.NextPageField(source='*')
+ prev = pagination.PreviousPageField(source='*')
+
+
+class CustomPaginationSerializer(pagination.BasePaginationSerializer):
+ links = LinksSerializer(source='*') # Takes the page object as the source
+ total_results = serializers.Field(source='paginator.count')
+
+ results_field = 'objects'
+
+
+class TestCustomPaginationSerializer(TestCase):
+ def setUp(self):
+ objects = ['john', 'paul', 'george', 'ringo']
+ paginator = Paginator(objects, 2)
+ self.page = paginator.page(1)
+
+ def test_custom_pagination_serializer(self):
+ request = RequestFactory().get('/foobar')
+ serializer = CustomPaginationSerializer(
+ instance=self.page,
+ context={'request': request}
+ )
+ expected = {
+ 'links': {
+ 'next': 'http://testserver/foobar?page=2',
+ 'prev': None
+ },
+ 'total_results': 4,
+ 'objects': ['john', 'paul']
+ }
+ self.assertEquals(serializer.data, expected)
diff --git a/rest_framework/tests/relations.py b/rest_framework/tests/relations.py
new file mode 100644
index 00000000..edc85f9e
--- /dev/null
+++ b/rest_framework/tests/relations.py
@@ -0,0 +1,47 @@
+"""
+General tests for relational fields.
+"""
+
+from django.db import models
+from django.test import TestCase
+from rest_framework import serializers
+
+
+class NullModel(models.Model):
+ pass
+
+
+class FieldTests(TestCase):
+ def test_pk_related_field_with_empty_string(self):
+ """
+ Regression test for #446
+
+ https://github.com/tomchristie/django-rest-framework/issues/446
+ """
+ field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all())
+ self.assertRaises(serializers.ValidationError, field.from_native, '')
+ self.assertRaises(serializers.ValidationError, field.from_native, [])
+
+ def test_hyperlinked_related_field_with_empty_string(self):
+ field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='')
+ self.assertRaises(serializers.ValidationError, field.from_native, '')
+ self.assertRaises(serializers.ValidationError, field.from_native, [])
+
+ def test_slug_related_field_with_empty_string(self):
+ field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk')
+ self.assertRaises(serializers.ValidationError, field.from_native, '')
+ self.assertRaises(serializers.ValidationError, field.from_native, [])
+
+
+class TestManyRelateMixin(TestCase):
+ def test_missing_many_to_many_related_field(self):
+ '''
+ Regression test for #632
+
+ https://github.com/tomchristie/django-rest-framework/pull/632
+ '''
+ field = serializers.ManyRelatedField(read_only=False)
+
+ into = {}
+ field.field_from_native({}, None, 'field_name', into)
+ self.assertEqual(into['field_name'], [])
diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/relations_hyperlink.py
index 407c04e0..b4ad3166 100644
--- a/rest_framework/tests/relations_hyperlink.py
+++ b/rest_framework/tests/relations_hyperlink.py
@@ -1,9 +1,9 @@
from __future__ import unicode_literals
-from django.db import models
from django.test import TestCase
from rest_framework import serializers
from rest_framework.compat import patterns, url
+from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
def dummy_view(request, pk):
@@ -15,20 +15,11 @@ urlpatterns = patterns('',
url(r'^foreignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeysource-detail'),
url(r'^foreignkeytarget/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeytarget-detail'),
url(r'^nullableforeignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableforeignkeysource-detail'),
+ url(r'^onetoonetarget/(?P<pk>[0-9]+)/$', dummy_view, name='onetoonetarget-detail'),
+ url(r'^nullableonetoonesource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableonetoonesource-detail'),
)
-# ManyToMany
-
-class ManyToManyTarget(models.Model):
- name = models.CharField(max_length=100)
-
-
-class ManyToManySource(models.Model):
- name = models.CharField(max_length=100)
- targets = models.ManyToManyField(ManyToManyTarget, related_name='sources')
-
-
class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer):
sources = serializers.ManyHyperlinkedRelatedField(view_name='manytomanysource-detail')
@@ -41,17 +32,6 @@ class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer):
model = ManyToManySource
-# ForeignKey
-
-class ForeignKeyTarget(models.Model):
- name = models.CharField(max_length=100)
-
-
-class ForeignKeySource(models.Model):
- name = models.CharField(max_length=100)
- target = models.ForeignKey(ForeignKeyTarget, related_name='sources')
-
-
class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer):
sources = serializers.ManyHyperlinkedRelatedField(view_name='foreignkeysource-detail')
@@ -65,16 +45,17 @@ class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
# Nullable ForeignKey
+class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = NullableForeignKeySource
-class NullableForeignKeySource(models.Model):
- name = models.CharField(max_length=100)
- target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True,
- related_name='nullable_sources')
+# OneToOne
+class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer):
+ nullable_source = serializers.HyperlinkedRelatedField(view_name='nullableonetoonesource-detail')
-class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
- model = NullableForeignKeySource
+ model = OneToOneTarget
# TODO: Add test that .data cannot be accessed prior to .is_valid
@@ -236,6 +217,13 @@ class HyperlinkedForeignKeyTests(TestCase):
]
self.assertEquals(serializer.data, expected)
+ def test_foreign_key_update_incorrect_type(self):
+ data = {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': 2}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'target': [u'Incorrect type. Expected url string, received int.']})
+
def test_reverse_foreign_key_update(self):
data = {'url': '/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']}
instance = ForeignKeyTarget.objects.get(pk=2)
@@ -248,7 +236,7 @@ class HyperlinkedForeignKeyTests(TestCase):
expected = [
{'url': '/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/2/', '/foreignkeysource/3/']},
{'url': '/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
- ]
+ ]
self.assertEquals(new_serializer.data, expected)
serializer.save()
@@ -434,3 +422,24 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
# {'id': 2, 'name': 'target-2', 'sources': []},
# ]
# self.assertEquals(serializer.data, expected)
+
+
+class HyperlinkedNullableOneToOneTests(TestCase):
+ urls = 'rest_framework.tests.relations_hyperlink'
+
+ def setUp(self):
+ target = OneToOneTarget(name='target-1')
+ target.save()
+ new_target = OneToOneTarget(name='target-2')
+ new_target.save()
+ source = NullableOneToOneSource(name='source-1', target=target)
+ source.save()
+
+ def test_reverse_foreign_key_retrieve_with_null(self):
+ queryset = OneToOneTarget.objects.all()
+ serializer = NullableOneToOneTargetSerializer(queryset)
+ expected = [
+ {'url': '/onetoonetarget/1/', 'name': u'target-1', 'nullable_source': '/nullableonetoonesource/1/'},
+ {'url': '/onetoonetarget/2/', 'name': u'target-2', 'nullable_source': None},
+ ]
+ self.assertEquals(serializer.data, expected)
diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py
index 442cbebe..e81f0e42 100644
--- a/rest_framework/tests/relations_nested.py
+++ b/rest_framework/tests/relations_nested.py
@@ -1,19 +1,7 @@
from __future__ import unicode_literals
-
-from django.db import models
from django.test import TestCase
from rest_framework import serializers
-
-
-# ForeignKey
-
-class ForeignKeyTarget(models.Model):
- name = models.CharField(max_length=100)
-
-
-class ForeignKeySource(models.Model):
- name = models.CharField(max_length=100)
- target = models.ForeignKey(ForeignKeyTarget, related_name='sources')
+from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
class ForeignKeySourceSerializer(serializers.ModelSerializer):
@@ -34,20 +22,24 @@ class ForeignKeyTargetSerializer(serializers.ModelSerializer):
model = ForeignKeyTarget
-# Nullable ForeignKey
-
-class NullableForeignKeySource(models.Model):
- name = models.CharField(max_length=100)
- target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True,
- related_name='nullable_sources')
-
-
class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
depth = 1
model = NullableForeignKeySource
+class NullableOneToOneSourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = NullableOneToOneSource
+
+
+class NullableOneToOneTargetSerializer(serializers.ModelSerializer):
+ nullable_source = NullableOneToOneSourceSerializer()
+
+ class Meta:
+ model = OneToOneTarget
+
+
class ReverseForeignKeyTests(TestCase):
def setUp(self):
target = ForeignKeyTarget(name='target-1')
@@ -102,3 +94,22 @@ class NestedNullableForeignKeyTests(TestCase):
{'id': 3, 'name': 'source-3', 'target': None},
]
self.assertEquals(serializer.data, expected)
+
+
+class NestedNullableOneToOneTests(TestCase):
+ def setUp(self):
+ target = OneToOneTarget(name='target-1')
+ target.save()
+ new_target = OneToOneTarget(name='target-2')
+ new_target.save()
+ source = NullableOneToOneSource(name='source-1', target=target)
+ source.save()
+
+ def test_reverse_foreign_key_retrieve_with_null(self):
+ queryset = OneToOneTarget.objects.all()
+ serializer = NullableOneToOneTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'nullable_source': {'id': 1, 'name': u'source-1', 'target': 1}},
+ {'id': 2, 'name': u'target-2', 'nullable_source': None},
+ ]
+ self.assertEquals(serializer.data, expected)
diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/relations_pk.py
index a04c5c80..4d00795a 100644
--- a/rest_framework/tests/relations_pk.py
+++ b/rest_framework/tests/relations_pk.py
@@ -3,17 +3,7 @@ from __future__ import unicode_literals
from django.db import models
from django.test import TestCase
from rest_framework import serializers
-
-
-# ManyToMany
-
-class ManyToManyTarget(models.Model):
- name = models.CharField(max_length=100)
-
-
-class ManyToManySource(models.Model):
- name = models.CharField(max_length=100)
- targets = models.ManyToManyField(ManyToManyTarget, related_name='sources')
+from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
class ManyToManyTargetSerializer(serializers.ModelSerializer):
@@ -28,17 +18,6 @@ class ManyToManySourceSerializer(serializers.ModelSerializer):
model = ManyToManySource
-# ForeignKey
-
-class ForeignKeyTarget(models.Model):
- name = models.CharField(max_length=100)
-
-
-class ForeignKeySource(models.Model):
- name = models.CharField(max_length=100)
- target = models.ForeignKey(ForeignKeyTarget, related_name='sources')
-
-
class ForeignKeyTargetSerializer(serializers.ModelSerializer):
sources = serializers.ManyPrimaryKeyRelatedField()
@@ -51,17 +30,17 @@ class ForeignKeySourceSerializer(serializers.ModelSerializer):
model = ForeignKeySource
-# Nullable ForeignKey
+class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = NullableForeignKeySource
-class NullableForeignKeySource(models.Model):
- name = models.CharField(max_length=100)
- target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True,
- related_name='nullable_sources')
+# OneToOne
+class NullableOneToOneTargetSerializer(serializers.ModelSerializer):
+ nullable_source = serializers.PrimaryKeyRelatedField()
-class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
- model = NullableForeignKeySource
+ model = OneToOneTarget
# TODO: Add test that .data cannot be accessed prior to .is_valid
@@ -218,6 +197,13 @@ class PKForeignKeyTests(TestCase):
]
self.assertEquals(serializer.data, expected)
+ def test_foreign_key_update_incorrect_type(self):
+ data = {'id': 1, 'name': u'source-1', 'target': 'foo'}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'target': [u'Incorrect type. Expected pk value, received str.']})
+
def test_reverse_foreign_key_update(self):
data = {'id': 2, 'name': 'target-2', 'sources': [1, 3]}
instance = ForeignKeyTarget.objects.get(pk=2)
@@ -230,7 +216,7 @@ class PKForeignKeyTests(TestCase):
expected = [
{'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
{'id': 2, 'name': 'target-2', 'sources': []},
- ]
+ ]
self.assertEquals(new_serializer.data, expected)
serializer.save()
@@ -414,3 +400,22 @@ class PKNullableForeignKeyTests(TestCase):
# {'id': 2, 'name': 'target-2', 'sources': []},
# ]
# self.assertEquals(serializer.data, expected)
+
+
+class PKNullableOneToOneTests(TestCase):
+ def setUp(self):
+ target = OneToOneTarget(name='target-1')
+ target.save()
+ new_target = OneToOneTarget(name='target-2')
+ new_target.save()
+ source = NullableOneToOneSource(name='source-1', target=target)
+ source.save()
+
+ def test_reverse_foreign_key_retrieve_with_null(self):
+ queryset = OneToOneTarget.objects.all()
+ serializer = NullableOneToOneTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'nullable_source': 1},
+ {'id': 2, 'name': u'target-2', 'nullable_source': None},
+ ]
+ self.assertEquals(serializer.data, expected)
diff --git a/rest_framework/tests/relations_slug.py b/rest_framework/tests/relations_slug.py
new file mode 100644
index 00000000..37ccc75e
--- /dev/null
+++ b/rest_framework/tests/relations_slug.py
@@ -0,0 +1,257 @@
+from django.test import TestCase
+from rest_framework import serializers
+from rest_framework.tests.models import NullableForeignKeySource, ForeignKeySource, ForeignKeyTarget
+
+
+class ForeignKeyTargetSerializer(serializers.ModelSerializer):
+ sources = serializers.ManySlugRelatedField(slug_field='name')
+
+ class Meta:
+ model = ForeignKeyTarget
+
+
+class ForeignKeySourceSerializer(serializers.ModelSerializer):
+ target = serializers.SlugRelatedField(slug_field='name')
+
+ class Meta:
+ model = ForeignKeySource
+
+
+class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
+ target = serializers.SlugRelatedField(slug_field='name', null=True)
+
+ class Meta:
+ model = NullableForeignKeySource
+
+
+# TODO: M2M Tests, FKTests (Non-nulable), One2One
+class PKForeignKeyTests(TestCase):
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ new_target = ForeignKeyTarget(name='target-2')
+ new_target.save()
+ for idx in range(1, 4):
+ source = ForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve(self):
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': 'target-1'},
+ {'id': 2, 'name': u'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': u'source-3', 'target': 'target-1'}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_foreign_key_retrieve(self):
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
+ {'id': 2, 'name': u'target-2', 'sources': []},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_update(self):
+ data = {'id': 1, 'name': u'source-1', 'target': 'target-2'}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEquals(serializer.data, data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': 'target-2'},
+ {'id': 2, 'name': u'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': u'source-3', 'target': 'target-1'}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_update_incorrect_type(self):
+ data = {'id': 1, 'name': u'source-1', 'target': 123}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'target': [u'Object with name=123 does not exist.']})
+
+ def test_reverse_foreign_key_update(self):
+ data = {'id': 2, 'name': u'target-2', 'sources': ['source-1', 'source-3']}
+ instance = ForeignKeyTarget.objects.get(pk=2)
+ serializer = ForeignKeyTargetSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ # We shouldn't have saved anything to the db yet since save
+ # hasn't been called.
+ queryset = ForeignKeyTarget.objects.all()
+ new_serializer = ForeignKeyTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
+ {'id': 2, 'name': u'target-2', 'sources': []},
+ ]
+ self.assertEquals(new_serializer.data, expected)
+
+ serializer.save()
+ self.assertEquals(serializer.data, data)
+
+ # Ensure target 2 is update, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'sources': ['source-2']},
+ {'id': 2, 'name': u'target-2', 'sources': ['source-1', 'source-3']},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_create(self):
+ data = {'id': 4, 'name': u'source-4', 'target': 'target-2'}
+ serializer = ForeignKeySourceSerializer(data=data)
+ serializer.is_valid()
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, data)
+ self.assertEqual(obj.name, u'source-4')
+
+ # Ensure source 4 is added, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': 'target-1'},
+ {'id': 2, 'name': u'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': u'source-3', 'target': 'target-1'},
+ {'id': 4, 'name': u'source-4', 'target': 'target-2'},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_foreign_key_create(self):
+ data = {'id': 3, 'name': u'target-3', 'sources': ['source-1', 'source-3']}
+ serializer = ForeignKeyTargetSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, data)
+ self.assertEqual(obj.name, u'target-3')
+
+ # Ensure target 3 is added, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'sources': ['source-2']},
+ {'id': 2, 'name': u'target-2', 'sources': []},
+ {'id': 3, 'name': u'target-3', 'sources': ['source-1', 'source-3']},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_update_with_invalid_null(self):
+ data = {'id': 1, 'name': u'source-1', 'target': None}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'target': [u'Value may not be null']})
+
+
+class SlugNullableForeignKeyTests(TestCase):
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ for idx in range(1, 4):
+ if idx == 3:
+ target = None
+ source = NullableForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve_with_null(self):
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': 'target-1'},
+ {'id': 2, 'name': u'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': u'source-3', 'target': None},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_create_with_valid_null(self):
+ data = {'id': 4, 'name': u'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, data)
+ self.assertEqual(obj.name, u'source-4')
+
+ # Ensure source 4 is created, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': 'target-1'},
+ {'id': 2, 'name': u'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': u'source-3', 'target': None},
+ {'id': 4, 'name': u'source-4', 'target': None}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_create_with_valid_emptystring(self):
+ """
+ The emptystring should be interpreted as null in the context
+ of relationships.
+ """
+ data = {'id': 4, 'name': u'source-4', 'target': ''}
+ expected_data = {'id': 4, 'name': u'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, expected_data)
+ self.assertEqual(obj.name, u'source-4')
+
+ # Ensure source 4 is created, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': 'target-1'},
+ {'id': 2, 'name': u'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': u'source-3', 'target': None},
+ {'id': 4, 'name': u'source-4', 'target': None}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_update_with_valid_null(self):
+ data = {'id': 1, 'name': u'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEquals(serializer.data, data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': None},
+ {'id': 2, 'name': u'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': u'source-3', 'target': None}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_update_with_valid_emptystring(self):
+ """
+ The emptystring should be interpreted as null in the context
+ of relationships.
+ """
+ data = {'id': 1, 'name': u'source-1', 'target': ''}
+ expected_data = {'id': 1, 'name': u'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEquals(serializer.data, expected_data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': None},
+ {'id': 2, 'name': u'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': u'source-3', 'target': None}
+ ]
+ self.assertEquals(serializer.data, expected)
diff --git a/rest_framework/tests/request.py b/rest_framework/tests/request.py
index 7d4575bb..92b1bfd8 100644
--- a/rest_framework/tests/request.py
+++ b/rest_framework/tests/request.py
@@ -1,12 +1,12 @@
"""
Tests for content parsing, and form-overloaded content parsing.
"""
+import json
from django.contrib.auth.models import User
from django.contrib.auth import authenticate, login, logout
from django.contrib.sessions.middleware import SessionMiddleware
from django.test import TestCase, Client
from django.test.client import RequestFactory
-from django.utils import simplejson as json
from rest_framework import status
from rest_framework.authentication import SessionAuthentication
from rest_framework.compat import patterns
diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py
index 6ce7de31..a00626b5 100644
--- a/rest_framework/tests/serializer.py
+++ b/rest_framework/tests/serializer.py
@@ -56,6 +56,19 @@ class ActionItemSerializer(serializers.ModelSerializer):
model = ActionItem
+class ActionItemSerializerCustomRestore(serializers.ModelSerializer):
+
+ class Meta:
+ model = ActionItem
+
+ def restore_object(self, data, instance=None):
+ if instance is None:
+ return ActionItem(**data)
+ for key, val in data.items():
+ setattr(instance, key, val)
+ return instance
+
+
class PersonSerializer(serializers.ModelSerializer):
info = serializers.Field(source='info')
@@ -71,6 +84,7 @@ class AlbumsSerializer(serializers.ModelSerializer):
model = Album
fields = ['title'] # lists are also valid options
+
class PositiveIntegerAsChoiceSerializer(serializers.ModelSerializer):
class Meta:
model = HasPositiveIntegerAsChoice
@@ -163,7 +177,6 @@ class BasicTests(TestCase):
"""
Attempting to update fields set as read_only should have no effect.
"""
-
serializer = PersonSerializer(self.person, data={'name': 'dwight', 'age': 99})
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
@@ -184,8 +197,7 @@ class ValidationTests(TestCase):
'content': 'x' * 1001,
'created': datetime.datetime(2012, 1, 1)
}
- self.actionitem = ActionItem(title='Some to do item',
- )
+ self.actionitem = ActionItem(title='Some to do item',)
def test_create(self):
serializer = CommentSerializer(data=self.data)
@@ -217,30 +229,24 @@ class ValidationTests(TestCase):
self.assertEquals(serializer.is_valid(), True)
self.assertEquals(serializer.errors, {})
- def test_field_validation(self):
-
- class CommentSerializerWithFieldValidator(CommentSerializer):
-
- def validate_content(self, attrs, source):
- value = attrs[source]
- if "test" not in value:
- raise serializers.ValidationError("Test not in value")
- return attrs
-
- data = {
- 'email': 'tom@example.com',
- 'content': 'A test comment',
- 'created': datetime.datetime(2012, 1, 1)
- }
-
- serializer = CommentSerializerWithFieldValidator(data=data)
- self.assertTrue(serializer.is_valid())
+ def test_bad_type_data_is_false(self):
+ """
+ Data of the wrong type is not valid.
+ """
+ data = ['i am', 'a', 'list']
+ serializer = CommentSerializer(self.comment, data=data)
+ self.assertEquals(serializer.is_valid(), False)
+ self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']})
- data['content'] = 'This should not validate'
+ data = 'and i am a string'
+ serializer = CommentSerializer(self.comment, data=data)
+ self.assertEquals(serializer.is_valid(), False)
+ self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']})
- serializer = CommentSerializerWithFieldValidator(data=data)
- self.assertFalse(serializer.is_valid())
- self.assertEquals(serializer.errors, {'content': ['Test not in value']})
+ data = 42
+ serializer = CommentSerializer(self.comment, data=data)
+ self.assertEquals(serializer.is_valid(), False)
+ self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']})
def test_cross_field_validation(self):
@@ -282,6 +288,20 @@ class ValidationTests(TestCase):
self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, {'title': ['Ensure this value has at most 200 characters (it has 201).']})
+ def test_modelserializer_max_length_exceeded_with_custom_restore(self):
+ """
+ When overriding ModelSerializer.restore_object, validation tests should still apply.
+ Regression test for #623.
+
+ https://github.com/tomchristie/django-rest-framework/pull/623
+ """
+ data = {
+ 'title': 'x' * 201,
+ }
+ serializer = ActionItemSerializerCustomRestore(data=data)
+ self.assertEquals(serializer.is_valid(), False)
+ self.assertEquals(serializer.errors, {'title': [u'Ensure this value has at most 200 characters (it has 201).']})
+
def test_default_modelfield_max_length_exceeded(self):
data = {
'title': 'Testing "info" field...',
@@ -292,12 +312,69 @@ class ValidationTests(TestCase):
self.assertEquals(serializer.errors, {'info': ['Ensure this value has at most 12 characters (it has 13).']})
+class CustomValidationTests(TestCase):
+ class CommentSerializerWithFieldValidator(CommentSerializer):
+
+ def validate_email(self, attrs, source):
+ value = attrs[source]
+
+ return attrs
+
+ def validate_content(self, attrs, source):
+ value = attrs[source]
+ if "test" not in value:
+ raise serializers.ValidationError("Test not in value")
+ return attrs
+
+ def test_field_validation(self):
+ data = {
+ 'email': 'tom@example.com',
+ 'content': 'A test comment',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+
+ serializer = self.CommentSerializerWithFieldValidator(data=data)
+ self.assertTrue(serializer.is_valid())
+
+ data['content'] = 'This should not validate'
+
+ serializer = self.CommentSerializerWithFieldValidator(data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'content': [u'Test not in value']})
+
+ def test_missing_data(self):
+ """
+ Make sure that validate_content isn't called if the field is missing
+ """
+ incomplete_data = {
+ 'email': 'tom@example.com',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+ serializer = self.CommentSerializerWithFieldValidator(data=incomplete_data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'content': [u'This field is required.']})
+
+ def test_wrong_data(self):
+ """
+ Make sure that validate_content isn't called if the field input is wrong
+ """
+ wrong_data = {
+ 'email': 'not an email',
+ 'content': 'A test comment',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+ serializer = self.CommentSerializerWithFieldValidator(data=wrong_data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'email': [u'Enter a valid e-mail address.']})
+
+
class PositiveIntegerAsChoiceTests(TestCase):
def test_positive_integer_in_json_is_correctly_parsed(self):
- data = {'some_integer':1}
+ data = {'some_integer': 1}
serializer = PositiveIntegerAsChoiceSerializer(data=data)
self.assertEquals(serializer.is_valid(), True)
+
class ModelValidationTests(TestCase):
def test_validate_unique(self):
"""
@@ -342,7 +419,6 @@ class ModelValidationTests(TestCase):
self.assertTrue(photo_serializer.save())
-
class RegexValidationTest(TestCase):
def test_create_failed(self):
serializer = BookSerializer(data={'isbn': '1234567890'})
@@ -553,6 +629,21 @@ class DefaultValueTests(TestCase):
self.assertEquals(instance.pk, 1)
self.assertEquals(instance.text, 'overridden')
+ def test_partial_update_default(self):
+ """ Regression test for issue #532 """
+ data = {'text': 'overridden'}
+ serializer = self.serializer_class(data=data, partial=True)
+ self.assertEquals(serializer.is_valid(), True)
+ instance = serializer.save()
+
+ data = {'extra': 'extra_value'}
+ serializer = self.serializer_class(instance=instance, data=data, partial=True)
+ self.assertEquals(serializer.is_valid(), True)
+ instance = serializer.save()
+
+ self.assertEquals(instance.extra, 'extra_value')
+ self.assertEquals(instance.text, 'overridden')
+
class CallableDefaultValueTests(TestCase):
def setUp(self):
diff --git a/rest_framework/tests/settings.py b/rest_framework/tests/settings.py
new file mode 100644
index 00000000..0293fdc3
--- /dev/null
+++ b/rest_framework/tests/settings.py
@@ -0,0 +1,21 @@
+"""Tests for the settings module"""
+from django.test import TestCase
+
+from rest_framework.settings import APISettings, DEFAULTS, IMPORT_STRINGS
+
+
+class TestSettings(TestCase):
+ """Tests relating to the api settings"""
+
+ def test_non_import_errors(self):
+ """Make sure other errors aren't suppressed."""
+ settings = APISettings({'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.tests.extras.bad_import.ModelSerializer'}, DEFAULTS, IMPORT_STRINGS)
+ with self.assertRaises(ValueError):
+ settings.DEFAULT_MODEL_SERIALIZER_CLASS
+
+ def test_import_error_message_maintained(self):
+ """Make sure real import errors are captured and raised sensibly."""
+ settings = APISettings({'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.tests.extras.not_here.ModelSerializer'}, DEFAULTS, IMPORT_STRINGS)
+ with self.assertRaises(ImportError) as cm:
+ settings.DEFAULT_MODEL_SERIALIZER_CLASS
+ self.assertTrue('ImportError' in str(cm.exception))
diff --git a/rest_framework/tests/urlpatterns.py b/rest_framework/tests/urlpatterns.py
new file mode 100644
index 00000000..43e8ef69
--- /dev/null
+++ b/rest_framework/tests/urlpatterns.py
@@ -0,0 +1,78 @@
+from collections import namedtuple
+
+from django.core import urlresolvers
+
+from django.test import TestCase
+from django.test.client import RequestFactory
+
+from rest_framework.compat import patterns, url, include
+from rest_framework.urlpatterns import format_suffix_patterns
+
+
+# A container class for test paths for the test case
+URLTestPath = namedtuple('URLTestPath', ['path', 'args', 'kwargs'])
+
+
+def dummy_view(request, *args, **kwargs):
+ pass
+
+
+class FormatSuffixTests(TestCase):
+ """
+ Tests `format_suffix_patterns` against different URLPatterns to ensure the URLs still resolve properly, including any captured parameters.
+ """
+ def _resolve_urlpatterns(self, urlpatterns, test_paths):
+ factory = RequestFactory()
+ try:
+ urlpatterns = format_suffix_patterns(urlpatterns)
+ except:
+ self.fail("Failed to apply `format_suffix_patterns` on the supplied urlpatterns")
+ resolver = urlresolvers.RegexURLResolver(r'^/', urlpatterns)
+ for test_path in test_paths:
+ request = factory.get(test_path.path)
+ try:
+ callback, callback_args, callback_kwargs = resolver.resolve(request.path_info)
+ except:
+ self.fail("Failed to resolve URL: %s" % request.path_info)
+ self.assertEquals(callback_args, test_path.args)
+ self.assertEquals(callback_kwargs, test_path.kwargs)
+
+ def test_format_suffix(self):
+ urlpatterns = patterns(
+ '',
+ url(r'^test$', dummy_view),
+ )
+ test_paths = [
+ URLTestPath('/test', (), {}),
+ URLTestPath('/test.api', (), {'format': 'api'}),
+ URLTestPath('/test.asdf', (), {'format': 'asdf'}),
+ ]
+ self._resolve_urlpatterns(urlpatterns, test_paths)
+
+ def test_default_args(self):
+ urlpatterns = patterns(
+ '',
+ url(r'^test$', dummy_view, {'foo': 'bar'}),
+ )
+ test_paths = [
+ URLTestPath('/test', (), {'foo': 'bar', }),
+ URLTestPath('/test.api', (), {'foo': 'bar', 'format': 'api'}),
+ URLTestPath('/test.asdf', (), {'foo': 'bar', 'format': 'asdf'}),
+ ]
+ self._resolve_urlpatterns(urlpatterns, test_paths)
+
+ def test_included_urls(self):
+ nested_patterns = patterns(
+ '',
+ url(r'^path$', dummy_view)
+ )
+ urlpatterns = patterns(
+ '',
+ url(r'^test/', include(nested_patterns), {'foo': 'bar'}),
+ )
+ test_paths = [
+ URLTestPath('/test/path', (), {'foo': 'bar', }),
+ URLTestPath('/test/path.api', (), {'foo': 'bar', 'format': 'api'}),
+ URLTestPath('/test/path.asdf', (), {'foo': 'bar', 'format': 'asdf'}),
+ ]
+ self._resolve_urlpatterns(urlpatterns, test_paths)
diff --git a/rest_framework/tests/utils.py b/rest_framework/tests/utils.py
new file mode 100644
index 00000000..3906adb9
--- /dev/null
+++ b/rest_framework/tests/utils.py
@@ -0,0 +1,27 @@
+from django.test.client import RequestFactory, FakePayload
+from django.test.client import MULTIPART_CONTENT
+from urlparse import urlparse
+
+
+class RequestFactory(RequestFactory):
+
+ def __init__(self, **defaults):
+ super(RequestFactory, self).__init__(**defaults)
+
+ def patch(self, path, data={}, content_type=MULTIPART_CONTENT,
+ **extra):
+ "Construct a PATCH request."
+
+ patch_data = self._encode_data(data, content_type)
+
+ parsed = urlparse(path)
+ r = {
+ 'CONTENT_LENGTH': len(patch_data),
+ 'CONTENT_TYPE': content_type,
+ 'PATH_INFO': self._get_path(parsed),
+ 'QUERY_STRING': parsed[4],
+ 'REQUEST_METHOD': 'PATCH',
+ 'wsgi.input': FakePayload(patch_data),
+ }
+ r.update(extra)
+ return self.request(**r)
diff --git a/rest_framework/tests/views.py b/rest_framework/tests/views.py
index e51ca9f3..f2432516 100644
--- a/rest_framework/tests/views.py
+++ b/rest_framework/tests/views.py
@@ -20,7 +20,7 @@ class BasicView(APIView):
return Response({'method': 'POST', 'data': request.DATA})
-@api_view(['GET', 'POST', 'PUT'])
+@api_view(['GET', 'POST', 'PUT', 'PATCH'])
def basic_view(request):
if request.method == 'GET':
return {'method': 'GET'}
@@ -28,6 +28,8 @@ def basic_view(request):
return {'method': 'POST', 'data': request.DATA}
elif request.method == 'PUT':
return {'method': 'PUT', 'data': request.DATA}
+ elif request.method == 'PATCH':
+ return {'method': 'PATCH', 'data': request.DATA}
def sanitise_json_error(error_dict):
diff --git a/rest_framework/urlpatterns.py b/rest_framework/urlpatterns.py
index 143928c9..47789026 100644
--- a/rest_framework/urlpatterns.py
+++ b/rest_framework/urlpatterns.py
@@ -1,5 +1,35 @@
-from rest_framework.compat import url
+from rest_framework.compat import url, include
from rest_framework.settings import api_settings
+from django.core.urlresolvers import RegexURLResolver
+
+
+def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required):
+ ret = []
+ for urlpattern in urlpatterns:
+ if isinstance(urlpattern, RegexURLResolver):
+ # Set of included URL patterns
+ regex = urlpattern.regex.pattern
+ namespace = urlpattern.namespace
+ app_name = urlpattern.app_name
+ kwargs = urlpattern.default_kwargs
+ # Add in the included patterns, after applying the suffixes
+ patterns = apply_suffix_patterns(urlpattern.url_patterns,
+ suffix_pattern,
+ suffix_required)
+ ret.append(url(regex, include(patterns, namespace, app_name), kwargs))
+
+ else:
+ # Regular URL pattern
+ regex = urlpattern.regex.pattern.rstrip('$') + suffix_pattern
+ view = urlpattern._callback or urlpattern._callback_str
+ kwargs = urlpattern.default_args
+ name = urlpattern.name
+ # Add in both the existing and the new urlpattern
+ if not suffix_required:
+ ret.append(urlpattern)
+ ret.append(url(regex, view, kwargs, name))
+
+ return ret
def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None):
@@ -28,15 +58,4 @@ def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None):
else:
suffix_pattern = r'\.(?P<%s>[a-z]+)$' % suffix_kwarg
- ret = []
- for urlpattern in urlpatterns:
- # Form our complementing '.format' urlpattern
- regex = urlpattern.regex.pattern.rstrip('$') + suffix_pattern
- view = urlpattern._callback or urlpattern._callback_str
- kwargs = urlpattern.default_args
- name = urlpattern.name
- # Add in both the existing and the new urlpattern
- if not suffix_required:
- ret.append(urlpattern)
- ret.append(url(regex, view, kwargs, name))
- return ret
+ return apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required)
diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py
index 2d1fb353..7afe100a 100644
--- a/rest_framework/utils/encoders.py
+++ b/rest_framework/utils/encoders.py
@@ -4,7 +4,7 @@ Helper classes for parsers.
import datetime
import decimal
import types
-from django.utils import simplejson as json
+import json
from django.utils.datastructures import SortedDict
from rest_framework.compat import timezone
from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata
@@ -12,7 +12,7 @@ from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata
class JSONEncoder(json.JSONEncoder):
"""
- JSONEncoder subclass that knows how to encode date/time,
+ JSONEncoder subclass that knows how to encode date/time/timedelta,
decimal types, and generators.
"""
def default(self, o):
@@ -34,6 +34,8 @@ class JSONEncoder(json.JSONEncoder):
if o.microsecond:
r = r[:12]
return r
+ elif isinstance(o, datetime.timedelta):
+ return str(o.total_seconds())
elif isinstance(o, decimal.Decimal):
return str(o)
elif hasattr(o, '__iter__'):
diff --git a/rest_framework/views.py b/rest_framework/views.py
index 10bdd5a5..ac9b3385 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -148,6 +148,8 @@ class APIView(View):
"""
If request is not permitted, determine what kind of exception to raise.
"""
+ if not self.request.successful_authenticator:
+ raise exceptions.NotAuthenticated()
raise exceptions.PermissionDenied()
def throttled(self, request, wait):
@@ -156,6 +158,15 @@ class APIView(View):
"""
raise exceptions.Throttled(wait)
+ def get_authenticate_header(self, request):
+ """
+ If a request is unauthenticated, determine the WWW-Authenticate
+ header to use for 401 responses, if any.
+ """
+ authenticators = self.get_authenticators()
+ if authenticators:
+ return authenticators[0].authenticate_header(request)
+
def get_parser_context(self, http_request):
"""
Returns a dict that is passed through to Parser.parse(),
@@ -319,6 +330,16 @@ class APIView(View):
# Throttle wait header
self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait
+ if isinstance(exc, (exceptions.NotAuthenticated,
+ exceptions.AuthenticationFailed)):
+ # WWW-Authenticate header for 401 responses, else coerce to 403
+ auth_header = self.get_authenticate_header(self.request)
+
+ if auth_header:
+ self.headers['WWW-Authenticate'] = auth_header
+ else:
+ exc.status_code = status.HTTP_403_FORBIDDEN
+
if isinstance(exc, exceptions.APIException):
return Response({'detail': exc.detail},
status=exc.status_code,