From 3fcc01273c5efef26d911e50c02a4a43f89b34eb Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Thu, 27 Jun 2013 20:29:52 +0100
Subject: Remove deprecated code
---
rest_framework/compat.py | 5 ++-
rest_framework/fields.py | 7 ----
rest_framework/permissions.py | 12 +-----
rest_framework/relations.py | 69 +++++----------------------------
rest_framework/serializers.py | 32 ++-------------
rest_framework/tests/test_serializer.py | 4 +-
6 files changed, 21 insertions(+), 108 deletions(-)
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index b748dcc5..161fffa8 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -494,11 +494,14 @@ try:
if provider_version in ('0.2.3', '0.2.4'):
# 0.2.3 and 0.2.4 are supported version that do not support
# timezone aware datetimes
- from datetime.datetime import now as provider_now
+ import datetime
+ provider_now = datetime.datetime.now
else:
# Any other supported version does use timezone aware datetimes
from django.utils.timezone import now as provider_now
except ImportError:
+ import traceback
+ traceback.print_exc()
oauth2_provider = None
oauth2_provider_models = None
oauth2_provider_forms = None
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 35848b4c..2e23715d 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -224,13 +224,6 @@ class WritableField(Field):
validators=[], error_messages=None, widget=None,
default=None, blank=None):
- # 'blank' is to be deprecated in favor of 'required'
- if blank is not None:
- warnings.warn('The `blank` keyword argument is deprecated. '
- 'Use the `required` keyword argument instead.',
- DeprecationWarning, stacklevel=2)
- required = not(blank)
-
super(WritableField, self).__init__(source=source, label=label, help_text=help_text)
self.read_only = read_only
diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py
index 1036663e..0c7b02ff 100644
--- a/rest_framework/permissions.py
+++ b/rest_framework/permissions.py
@@ -2,13 +2,10 @@
Provides a set of pluggable permission policies.
"""
from __future__ import unicode_literals
-import inspect
-import warnings
+from rest_framework.compat import oauth2_provider_scope, oauth2_constants
SAFE_METHODS = ['GET', 'HEAD', 'OPTIONS']
-from rest_framework.compat import oauth2_provider_scope, oauth2_constants
-
class BasePermission(object):
"""
@@ -25,13 +22,6 @@ class BasePermission(object):
"""
Return `True` if permission is granted, `False` otherwise.
"""
- if len(inspect.getargspec(self.has_permission).args) == 4:
- warnings.warn(
- 'The `obj` argument in `has_permission` is deprecated. '
- 'Use `has_object_permission()` instead for object permissions.',
- DeprecationWarning, stacklevel=2
- )
- return self.has_permission(request, view, obj)
return True
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index edaf76d6..ede694e3 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -40,14 +40,6 @@ class RelatedField(WritableField):
many = False
def __init__(self, *args, **kwargs):
-
- # 'null' is to be deprecated in favor of 'required'
- if 'null' in kwargs:
- warnings.warn('The `null` keyword argument is deprecated. '
- 'Use the `required` keyword argument instead.',
- DeprecationWarning, stacklevel=2)
- kwargs['required'] = not kwargs.pop('null')
-
queryset = kwargs.pop('queryset', None)
self.many = kwargs.pop('many', self.many)
if self.many:
@@ -424,14 +416,11 @@ class HyperlinkedRelatedField(RelatedField):
request = self.context.get('request', None)
format = self.format or self.context.get('format', None)
- if request is None:
- msg = (
- "Using `HyperlinkedRelatedField` without including the request "
- "in the serializer context is deprecated. "
- "Add `context={'request': request}` when instantiating "
- "the serializer."
- )
- warnings.warn(msg, DeprecationWarning, stacklevel=4)
+ assert request is not None, (
+ "`HyperlinkedRelatedField` requires the request in the serializer "
+ "context. Add `context={'request': request}` when instantiating "
+ "the serializer."
+ )
# If the object has not yet been saved then we cannot hyperlink to it.
if getattr(obj, 'pk', None) is None:
@@ -530,11 +519,11 @@ class HyperlinkedIdentityField(Field):
format = self.context.get('format', None)
view_name = self.view_name
- if request is None:
- warnings.warn("Using `HyperlinkedIdentityField` without including the "
- "request in the serializer context is deprecated. "
- "Add `context={'request': request}` when instantiating the serializer.",
- DeprecationWarning, stacklevel=4)
+ assert request is not None, (
+ "`HyperlinkedIdentityField` requires the request in the serializer"
+ " context. Add `context={'request': request}` when instantiating "
+ "the serializer."
+ )
# By default use whatever format is given for the current context
# unless the target is a different type to the source.
@@ -593,41 +582,3 @@ class HyperlinkedIdentityField(Field):
pass
raise NoReverseMatch()
-
-
-### Old-style many classes for backwards compat
-
-class ManyRelatedField(RelatedField):
- def __init__(self, *args, **kwargs):
- warnings.warn('`ManyRelatedField()` is deprecated. '
- 'Use `RelatedField(many=True)` instead.',
- DeprecationWarning, stacklevel=2)
- kwargs['many'] = True
- super(ManyRelatedField, self).__init__(*args, **kwargs)
-
-
-class ManyPrimaryKeyRelatedField(PrimaryKeyRelatedField):
- def __init__(self, *args, **kwargs):
- warnings.warn('`ManyPrimaryKeyRelatedField()` is deprecated. '
- 'Use `PrimaryKeyRelatedField(many=True)` instead.',
- DeprecationWarning, stacklevel=2)
- kwargs['many'] = True
- super(ManyPrimaryKeyRelatedField, self).__init__(*args, **kwargs)
-
-
-class ManySlugRelatedField(SlugRelatedField):
- def __init__(self, *args, **kwargs):
- warnings.warn('`ManySlugRelatedField()` is deprecated. '
- 'Use `SlugRelatedField(many=True)` instead.',
- DeprecationWarning, stacklevel=2)
- kwargs['many'] = True
- super(ManySlugRelatedField, self).__init__(*args, **kwargs)
-
-
-class ManyHyperlinkedRelatedField(HyperlinkedRelatedField):
- def __init__(self, *args, **kwargs):
- warnings.warn('`ManyHyperlinkedRelatedField()` is deprecated. '
- 'Use `HyperlinkedRelatedField(many=True)` instead.',
- DeprecationWarning, stacklevel=2)
- kwargs['many'] = True
- super(ManyHyperlinkedRelatedField, self).__init__(*args, **kwargs)
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 023f7ccf..ae39cce8 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -15,7 +15,6 @@ import copy
import datetime
import types
from decimal import Decimal
-from django.core.paginator import Page
from django.db import models
from django.forms import widgets
from django.utils.datastructures import SortedDict
@@ -141,7 +140,7 @@ class BaseSerializer(WritableField):
_dict_class = SortedDictWithMetadata
def __init__(self, instance=None, data=None, files=None,
- context=None, partial=False, many=None,
+ context=None, partial=False, many=False,
allow_add_remove=False, **kwargs):
super(BaseSerializer, self).__init__(**kwargs)
self.opts = self._options_class(self.Meta)
@@ -348,12 +347,7 @@ class BaseSerializer(WritableField):
if value is None:
return None
- if self.many is not None:
- many = self.many
- else:
- many = hasattr(value, '__iter__') and not isinstance(value, (Page, dict, six.text_type))
-
- if many:
+ if self.many:
return [self.to_native(item) for item in value]
return self.to_native(value)
@@ -424,16 +418,7 @@ class BaseSerializer(WritableField):
if self._errors is None:
data, files = self.init_data, self.init_files
- if self.many is not None:
- many = self.many
- else:
- many = hasattr(data, '__iter__') and not isinstance(data, (Page, dict, six.text_type))
- if many:
- warnings.warn('Implict list/queryset serialization is deprecated. '
- 'Use the `many=True` flag when instantiating the serializer.',
- DeprecationWarning, stacklevel=3)
-
- if many:
+ if self.many:
ret = []
errors = []
update = self.object is not None
@@ -486,16 +471,7 @@ class BaseSerializer(WritableField):
if self._data is None:
obj = self.object
- if self.many is not None:
- many = self.many
- else:
- many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict))
- if many:
- warnings.warn('Implict list/queryset serialization is deprecated. '
- 'Use the `many=True` flag when instantiating the serializer.',
- DeprecationWarning, stacklevel=2)
-
- if many:
+ if self.many:
self._data = [self.to_native(item) for item in obj]
else:
self._data = self.to_native(obj)
diff --git a/rest_framework/tests/test_serializer.py b/rest_framework/tests/test_serializer.py
index 8b87a084..151eb648 100644
--- a/rest_framework/tests/test_serializer.py
+++ b/rest_framework/tests/test_serializer.py
@@ -1268,7 +1268,7 @@ class NestedSerializerContextTests(TestCase):
model = Album
fields = ("photo_set", "callable")
- photo_set = PhotoSerializer(source="photo_set")
+ photo_set = PhotoSerializer(source="photo_set", many=True)
callable = serializers.SerializerMethodField("_callable")
def _callable(self, instance):
@@ -1280,7 +1280,7 @@ class NestedSerializerContextTests(TestCase):
albums = None
class AlbumCollectionSerializer(serializers.Serializer):
- albums = AlbumSerializer(source="albums")
+ albums = AlbumSerializer(source="albums", many=True)
album1 = Album.objects.create(title="album 1")
album2 = Album.objects.create(title="album 2")
--
cgit v1.2.3
From 379ad8a82485e61b180ee823ba49799d39446aeb Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Thu, 27 Jun 2013 20:36:14 +0100
Subject: pending deprecations -> deprecated
---
rest_framework/generics.py | 32 ++++++++++++++++----------------
rest_framework/mixins.py | 8 ++++----
rest_framework/relations.py | 36 ++++++++++++++++++------------------
rest_framework/serializers.py | 8 ++++----
4 files changed, 42 insertions(+), 42 deletions(-)
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index 99e9782e..874a142c 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -108,11 +108,11 @@ class GenericAPIView(views.APIView):
deprecated_style = False
if page_size is not None:
warnings.warn('The `page_size` parameter to `paginate_queryset()` '
- 'is due to be deprecated. '
+ 'is deprecated. '
'Note that the return style of this method is also '
'changed, and will simply return a page object '
'when called without a `page_size` argument.',
- PendingDeprecationWarning, stacklevel=2)
+ DeprecationWarning, stacklevel=2)
deprecated_style = True
else:
# Determine the required page size.
@@ -123,10 +123,10 @@ class GenericAPIView(views.APIView):
if not self.allow_empty:
warnings.warn(
- 'The `allow_empty` parameter is due to be deprecated. '
+ 'The `allow_empty` parameter is deprecated. '
'To use `allow_empty=False` style behavior, You should override '
'`get_queryset()` and explicitly raise a 404 on empty querysets.',
- PendingDeprecationWarning, stacklevel=2
+ DeprecationWarning, stacklevel=2
)
paginator = self.paginator_class(queryset, page_size,
@@ -166,10 +166,10 @@ class GenericAPIView(views.APIView):
if not filter_backends and self.filter_backend:
warnings.warn(
'The `filter_backend` attribute and `FILTER_BACKEND` setting '
- 'are due to be deprecated in favor of a `filter_backends` '
+ 'are deprecated in favor of a `filter_backends` '
'attribute and `DEFAULT_FILTER_BACKENDS` setting, that take '
'a *list* of filter backend classes.',
- PendingDeprecationWarning, stacklevel=2
+ DeprecationWarning, stacklevel=2
)
filter_backends = [self.filter_backend]
@@ -192,8 +192,8 @@ class GenericAPIView(views.APIView):
"""
if queryset is not None:
warnings.warn('The `queryset` parameter to `get_paginate_by()` '
- 'is due to be deprecated.',
- PendingDeprecationWarning, stacklevel=2)
+ 'is deprecated.',
+ DeprecationWarning, stacklevel=2)
if self.paginate_by_param:
query_params = self.request.QUERY_PARAMS
@@ -272,16 +272,16 @@ class GenericAPIView(views.APIView):
filter_kwargs = {self.lookup_field: lookup}
elif pk is not None and self.lookup_field == 'pk':
warnings.warn(
- 'The `pk_url_kwarg` attribute is due to be deprecated. '
+ 'The `pk_url_kwarg` attribute is deprecated. '
'Use the `lookup_field` attribute instead',
- PendingDeprecationWarning
+ DeprecationWarning
)
filter_kwargs = {'pk': pk}
elif slug is not None and self.lookup_field == 'pk':
warnings.warn(
- 'The `slug_url_kwarg` attribute is due to be deprecated. '
+ 'The `slug_url_kwarg` attribute is deprecated. '
'Use the `lookup_field` attribute instead',
- PendingDeprecationWarning
+ DeprecationWarning
)
filter_kwargs = {self.slug_field: slug}
else:
@@ -482,9 +482,9 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,
class MultipleObjectAPIView(GenericAPIView):
def __init__(self, *args, **kwargs):
warnings.warn(
- 'Subclassing `MultipleObjectAPIView` is due to be deprecated. '
+ 'Subclassing `MultipleObjectAPIView` is deprecated. '
'You should simply subclass `GenericAPIView` instead.',
- PendingDeprecationWarning, stacklevel=2
+ DeprecationWarning, stacklevel=2
)
super(MultipleObjectAPIView, self).__init__(*args, **kwargs)
@@ -492,8 +492,8 @@ class MultipleObjectAPIView(GenericAPIView):
class SingleObjectAPIView(GenericAPIView):
def __init__(self, *args, **kwargs):
warnings.warn(
- 'Subclassing `SingleObjectAPIView` is due to be deprecated. '
+ 'Subclassing `SingleObjectAPIView` is deprecated. '
'You should simply subclass `GenericAPIView` instead.',
- PendingDeprecationWarning, stacklevel=2
+ DeprecationWarning, stacklevel=2
)
super(SingleObjectAPIView, self).__init__(*args, **kwargs)
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index f11def6d..679dfa6c 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -24,14 +24,14 @@ def _get_validation_exclusions(obj, pk=None, slug_field=None, lookup_field=None)
include = []
if pk:
- # Pending deprecation
+ # Deprecated
pk_field = obj._meta.pk
while pk_field.rel:
pk_field = pk_field.rel.to._meta.pk
include.append(pk_field.name)
if slug_field:
- # Pending deprecation
+ # Deprecated
include.append(slug_field)
if lookup_field and lookup_field != 'pk':
@@ -77,10 +77,10 @@ class ListModelMixin(object):
# `.allow_empty = False`, to raise 404 errors on empty querysets.
if not self.allow_empty and not self.object_list:
warnings.warn(
- 'The `allow_empty` parameter is due to be deprecated. '
+ 'The `allow_empty` parameter is deprecated. '
'To use `allow_empty=False` style behavior, You should override '
'`get_queryset()` and explicitly raise a 404 on empty querysets.',
- PendingDeprecationWarning
+ DeprecationWarning
)
class_name = self.__class__.__name__
error_msg = self.empty_error % {'class_name': class_name}
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index ede694e3..f1f7dea7 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -314,7 +314,7 @@ class HyperlinkedRelatedField(RelatedField):
'incorrect_type': _('Incorrect type. Expected url string, received %s.'),
}
- # These are all pending deprecation
+ # These are all deprecated
pk_url_kwarg = 'pk'
slug_field = 'slug'
slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
@@ -328,16 +328,16 @@ class HyperlinkedRelatedField(RelatedField):
self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
self.format = kwargs.pop('format', None)
- # These are pending deprecation
+ # These are deprecated
if 'pk_url_kwarg' in kwargs:
- msg = 'pk_url_kwarg is pending deprecation. Use lookup_field instead.'
- warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+ msg = 'pk_url_kwarg is deprecated. Use lookup_field instead.'
+ warnings.warn(msg, DeprecationWarning, stacklevel=2)
if 'slug_url_kwarg' in kwargs:
- msg = 'slug_url_kwarg is pending deprecation. Use lookup_field instead.'
- warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+ msg = 'slug_url_kwarg is deprecated. Use lookup_field instead.'
+ warnings.warn(msg, DeprecationWarning, stacklevel=2)
if 'slug_field' in kwargs:
- msg = 'slug_field is pending deprecation. Use lookup_field instead.'
- warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+ msg = 'slug_field is deprecated. Use lookup_field instead.'
+ warnings.warn(msg, DeprecationWarning, stacklevel=2)
self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg)
self.slug_field = kwargs.pop('slug_field', self.slug_field)
@@ -380,9 +380,9 @@ class HyperlinkedRelatedField(RelatedField):
# If the lookup succeeds using the default slug params,
# then `slug_field` is being used implicitly, and we
# we need to warn about the pending deprecation.
- msg = 'Implicit slug field hyperlinked fields are pending deprecation.' \
+ msg = 'Implicit slug field hyperlinked fields are deprecated.' \
'You should set `lookup_field=slug` on the HyperlinkedRelatedField.'
- warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+ warnings.warn(msg, DeprecationWarning, stacklevel=2)
return ret
except NoReverseMatch:
pass
@@ -480,7 +480,7 @@ class HyperlinkedIdentityField(Field):
lookup_field = 'pk'
read_only = True
- # These are all pending deprecation
+ # These are all deprecated
pk_url_kwarg = 'pk'
slug_field = 'slug'
slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
@@ -496,16 +496,16 @@ class HyperlinkedIdentityField(Field):
lookup_field = kwargs.pop('lookup_field', None)
self.lookup_field = lookup_field or self.lookup_field
- # These are pending deprecation
+ # These are deprecated
if 'pk_url_kwarg' in kwargs:
- msg = 'pk_url_kwarg is pending deprecation. Use lookup_field instead.'
- warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+ msg = 'pk_url_kwarg is deprecated. Use lookup_field instead.'
+ warnings.warn(msg, DeprecationWarning, stacklevel=2)
if 'slug_url_kwarg' in kwargs:
- msg = 'slug_url_kwarg is pending deprecation. Use lookup_field instead.'
- warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+ msg = 'slug_url_kwarg is deprecated. Use lookup_field instead.'
+ warnings.warn(msg, DeprecationWarning, stacklevel=2)
if 'slug_field' in kwargs:
- msg = 'slug_field is pending deprecation. Use lookup_field instead.'
- warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+ msg = 'slug_field is deprecated. Use lookup_field instead.'
+ warnings.warn(msg, DeprecationWarning, stacklevel=2)
self.slug_field = kwargs.pop('slug_field', self.slug_field)
default_slug_kwarg = self.slug_url_kwarg or self.slug_field
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index ae39cce8..dd9e14ad 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -593,10 +593,10 @@ class ModelSerializer(Serializer):
if len(inspect.getargspec(self.get_nested_field).args) == 2:
warnings.warn(
'The `get_nested_field(model_field)` call signature '
- 'is due to be deprecated. '
+ 'is deprecated. '
'Use `get_nested_field(model_field, related_model, '
'to_many) instead',
- PendingDeprecationWarning
+ DeprecationWarning
)
field = self.get_nested_field(model_field)
else:
@@ -605,10 +605,10 @@ class ModelSerializer(Serializer):
if len(inspect.getargspec(self.get_nested_field).args) == 3:
warnings.warn(
'The `get_related_field(model_field, to_many)` call '
- 'signature is due to be deprecated. '
+ 'signature is deprecated. '
'Use `get_related_field(model_field, related_model, '
'to_many) instead',
- PendingDeprecationWarning
+ DeprecationWarning
)
field = self.get_related_field(model_field, to_many=to_many)
else:
--
cgit v1.2.3
From d72603bc6a16112008959c5267839f819c2bc43a Mon Sep 17 00:00:00 2001
From: Alex Burgel
Date: Wed, 5 Jun 2013 17:39:14 -0400
Subject: Add support for collection routes to SimpleRouter
---
rest_framework/decorators.py | 26 +++++++++++++++++++
rest_framework/routers.py | 33 ++++++++++++++++++++++---
rest_framework/tests/test_routers.py | 48 +++++++++++++++++++++++++++++++++++-
3 files changed, 103 insertions(+), 4 deletions(-)
diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py
index c69756a4..dacd380f 100644
--- a/rest_framework/decorators.py
+++ b/rest_framework/decorators.py
@@ -113,6 +113,7 @@ def link(**kwargs):
"""
def decorator(func):
func.bind_to_methods = ['get']
+ func.collection = False
func.kwargs = kwargs
return func
return decorator
@@ -124,6 +125,31 @@ def action(methods=['post'], **kwargs):
"""
def decorator(func):
func.bind_to_methods = methods
+ func.collection = False
+ func.kwargs = kwargs
+ return func
+ return decorator
+
+
+def collection_link(**kwargs):
+ """
+ Used to mark a method on a ViewSet that should be routed for GET requests.
+ """
+ def decorator(func):
+ func.bind_to_methods = ['get']
+ func.collection = True
+ func.kwargs = kwargs
+ return func
+ return decorator
+
+
+def collection_action(methods=['post'], **kwargs):
+ """
+ Used to mark a method on a ViewSet that should be routed for POST requests.
+ """
+ def decorator(func):
+ func.bind_to_methods = methods
+ func.collection = True
func.kwargs = kwargs
return func
return decorator
diff --git a/rest_framework/routers.py b/rest_framework/routers.py
index 930011d3..9b859a7c 100644
--- a/rest_framework/routers.py
+++ b/rest_framework/routers.py
@@ -88,6 +88,17 @@ class SimpleRouter(BaseRouter):
name='{basename}-list',
initkwargs={'suffix': 'List'}
),
+ # Dynamically generated collection routes.
+ # Generated using @collection_action or @collection_link decorators
+ # on methods of the viewset.
+ Route(
+ url=r'^{prefix}/{methodname}{trailing_slash}$',
+ mapping={
+ '{httpmethod}': '{methodname}',
+ },
+ name='{basename}-collection-{methodnamehyphen}',
+ initkwargs={}
+ ),
# Detail route.
Route(
url=r'^{prefix}/{lookup}{trailing_slash}$',
@@ -107,7 +118,7 @@ class SimpleRouter(BaseRouter):
mapping={
'{httpmethod}': '{methodname}',
},
- name='{basename}-{methodnamehyphen}',
+ name='{basename}-dynamic-{methodnamehyphen}',
initkwargs={}
),
]
@@ -142,20 +153,25 @@ class SimpleRouter(BaseRouter):
known_actions = flatten([route.mapping.values() for route in self.routes])
# Determine any `@action` or `@link` decorated methods on the viewset
+ collection_routes = []
dynamic_routes = []
for methodname in dir(viewset):
attr = getattr(viewset, methodname)
httpmethods = getattr(attr, 'bind_to_methods', None)
+ collection = getattr(attr, 'collection', False)
if httpmethods:
if methodname in known_actions:
raise ImproperlyConfigured('Cannot use @action or @link decorator on '
'method "%s" as it is an existing route' % methodname)
httpmethods = [method.lower() for method in httpmethods]
- dynamic_routes.append((httpmethods, methodname))
+ if collection:
+ collection_routes.append((httpmethods, methodname))
+ else:
+ dynamic_routes.append((httpmethods, methodname))
ret = []
for route in self.routes:
- if route.mapping == {'{httpmethod}': '{methodname}'}:
+ if route.name == '{basename}-dynamic-{methodnamehyphen}':
# Dynamic routes (@link or @action decorator)
for httpmethods, methodname in dynamic_routes:
initkwargs = route.initkwargs.copy()
@@ -166,6 +182,17 @@ class SimpleRouter(BaseRouter):
name=replace_methodname(route.name, methodname),
initkwargs=initkwargs,
))
+ elif route.name == '{basename}-collection-{methodnamehyphen}':
+ # Dynamic routes (@collection_link or @collection_action decorator)
+ for httpmethods, methodname in collection_routes:
+ initkwargs = route.initkwargs.copy()
+ initkwargs.update(getattr(viewset, methodname).kwargs)
+ ret.append(Route(
+ url=replace_methodname(route.url, methodname),
+ mapping=dict((httpmethod, methodname) for httpmethod in httpmethods),
+ name=replace_methodname(route.name, methodname),
+ initkwargs=initkwargs,
+ ))
else:
# Standard route
ret.append(route)
diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py
index 5fcccb74..60f150d2 100644
--- a/rest_framework/tests/test_routers.py
+++ b/rest_framework/tests/test_routers.py
@@ -4,7 +4,7 @@ from django.test import TestCase
from django.core.exceptions import ImproperlyConfigured
from rest_framework import serializers, viewsets, permissions
from rest_framework.compat import include, patterns, url
-from rest_framework.decorators import link, action
+from rest_framework.decorators import link, action, collection_link, collection_action
from rest_framework.response import Response
from rest_framework.routers import SimpleRouter, DefaultRouter
from rest_framework.test import APIRequestFactory
@@ -214,3 +214,49 @@ class TestActionAppliedToExistingRoute(TestCase):
with self.assertRaises(ImproperlyConfigured):
self.router.urls
+
+
+class StaticAndDynamicViewSet(viewsets.ViewSet):
+ def list(self, request, *args, **kwargs):
+ return Response({'method': 'list'})
+
+ @collection_action()
+ def collection_action(self, request, *args, **kwargs):
+ return Response({'method': 'action1'})
+
+ @action()
+ def dynamic_action(self, request, *args, **kwargs):
+ return Response({'method': 'action2'})
+
+ @collection_link()
+ def collection_link(self, request, *args, **kwargs):
+ return Response({'method': 'link1'})
+
+ @link()
+ def dynamic_link(self, request, *args, **kwargs):
+ return Response({'method': 'link2'})
+
+
+class TestStaticAndDynamicRouter(TestCase):
+ def setUp(self):
+ self.router = SimpleRouter()
+
+ def test_link_and_action_decorator(self):
+ routes = self.router.get_routes(StaticAndDynamicViewSet)
+ decorator_routes = [r for r in routes if not (r.name.endswith('-list') or r.name.endswith('-detail'))]
+ # Make sure all these endpoints exist and none have been clobbered
+ for i, endpoint in enumerate(['collection_action', 'collection_link', 'dynamic_action', 'dynamic_link']):
+ route = decorator_routes[i]
+ # check url listing
+ if endpoint.startswith('collection_'):
+ self.assertEqual(route.url,
+ '^{{prefix}}/{0}{{trailing_slash}}$'.format(endpoint))
+ else:
+ self.assertEqual(route.url,
+ '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(endpoint))
+ # check method to function mapping
+ if endpoint.endswith('action'):
+ method_map = 'post'
+ else:
+ method_map = 'get'
+ self.assertEqual(route.mapping[method_map], endpoint)
--
cgit v1.2.3
From 5b11e23f6fb35834057fba35832a597ce443cc77 Mon Sep 17 00:00:00 2001
From: Alex Burgel
Date: Wed, 5 Jun 2013 17:41:29 -0400
Subject: Add docs for collection routes
---
docs/api-guide/viewsets.md | 15 ++++++++++++---
1 file changed, 12 insertions(+), 3 deletions(-)
diff --git a/docs/api-guide/viewsets.md b/docs/api-guide/viewsets.md
index 47e59e2b..9fa6615b 100644
--- a/docs/api-guide/viewsets.md
+++ b/docs/api-guide/viewsets.md
@@ -92,7 +92,9 @@ The default routers included with REST framework will provide routes for a stand
def destroy(self, request, pk=None):
pass
-If you have ad-hoc methods that you need to be routed to, you can mark them as requiring routing using the `@link` or `@action` decorators. The `@link` decorator will route `GET` requests, and the `@action` decorator will route `POST` requests.
+If you have ad-hoc methods that you need to be routed to, you can mark them as requiring routing using the `@collection_link`, `@collection_action`, `@link`, or `@action` decorators. The `@collection_link` and `@link` decorator will route `GET` requests, and the `@collection_action` and `@action` decorator will route `POST` requests.
+
+The `@link` and `@action` decorators contain `pk` in their URL pattern and are intended for methods which require a single instance. The `@collection_link` and `@collection_action` decorators are intended for methods which operate on a collection of objects.
For example:
@@ -121,13 +123,20 @@ For example:
return Response(serializer.errors,
status=status.HTTP_400_BAD_REQUEST)
-The `@action` and `@link` decorators can additionally take extra arguments that will be set for the routed view only. For example...
+ @collection_link()
+ def recent_users(self, request):
+ recent_users = User.objects.all().order('-last_login')
+ page = self.paginate_queryset(recent_users)
+ serializer = self.get_pagination_serializer(page)
+ return Response(serializer.data)
+
+The decorators can additionally take extra arguments that will be set for the routed view only. For example...
@action(permission_classes=[IsAdminOrIsSelf])
def set_password(self, request, pk=None):
...
-The `@action` decorator will route `POST` requests by default, but may also accept other HTTP methods, by using the `method` argument. For example:
+The `@collection_action` and `@action` decorators will route `POST` requests by default, but may also accept other HTTP methods, by using the `method` argument. For example:
@action(methods=['POST', 'DELETE'])
def unset_password(self, request, pk=None):
--
cgit v1.2.3
From 57cf8b5fa4f62f9b58912f10536a7ae5076ce54c Mon Sep 17 00:00:00 2001
From: Alex Burgel
Date: Thu, 6 Jun 2013 11:51:52 -0400
Subject: Rework extra routes doc for better readability
---
docs/api-guide/viewsets.md | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/docs/api-guide/viewsets.md b/docs/api-guide/viewsets.md
index 9fa6615b..e83487fb 100644
--- a/docs/api-guide/viewsets.md
+++ b/docs/api-guide/viewsets.md
@@ -92,7 +92,7 @@ The default routers included with REST framework will provide routes for a stand
def destroy(self, request, pk=None):
pass
-If you have ad-hoc methods that you need to be routed to, you can mark them as requiring routing using the `@collection_link`, `@collection_action`, `@link`, or `@action` decorators. The `@collection_link` and `@link` decorator will route `GET` requests, and the `@collection_action` and `@action` decorator will route `POST` requests.
+If you have ad-hoc methods that you need to be routed to, you can mark them as requiring routing using the `@link`, `@action`, `@collection_link`, or `@collection_action` decorators. The `@link` and `@collection_link` decorators will route `GET` requests, and the `@action` and `@collection_action` decorators will route `POST` requests.
The `@link` and `@action` decorators contain `pk` in their URL pattern and are intended for methods which require a single instance. The `@collection_link` and `@collection_action` decorators are intended for methods which operate on a collection of objects.
@@ -136,7 +136,7 @@ The decorators can additionally take extra arguments that will be set for the ro
def set_password(self, request, pk=None):
...
-The `@collection_action` and `@action` decorators will route `POST` requests by default, but may also accept other HTTP methods, by using the `method` argument. For example:
+The `@action` and `@collection_action` decorators will route `POST` requests by default, but may also accept other HTTP methods, by using the `methods` argument. For example:
@action(methods=['POST', 'DELETE'])
def unset_password(self, request, pk=None):
--
cgit v1.2.3
From 8d521c068a254cef604df1f15690275dca986778 Mon Sep 17 00:00:00 2001
From: Alex Burgel
Date: Sun, 16 Jun 2013 12:43:59 -0400
Subject: Revert route name change and add key to Route object to identify
different route types
---
rest_framework/routers.py | 16 +++++++++++-----
1 file changed, 11 insertions(+), 5 deletions(-)
diff --git a/rest_framework/routers.py b/rest_framework/routers.py
index 9b859a7c..541df4a9 100644
--- a/rest_framework/routers.py
+++ b/rest_framework/routers.py
@@ -25,7 +25,7 @@ from rest_framework.reverse import reverse
from rest_framework.urlpatterns import format_suffix_patterns
-Route = namedtuple('Route', ['url', 'mapping', 'name', 'initkwargs'])
+Route = namedtuple('Route', ['key', 'url', 'mapping', 'name', 'initkwargs'])
def replace_methodname(format_string, methodname):
@@ -80,6 +80,7 @@ class SimpleRouter(BaseRouter):
routes = [
# List route.
Route(
+ key='list',
url=r'^{prefix}{trailing_slash}$',
mapping={
'get': 'list',
@@ -92,15 +93,17 @@ class SimpleRouter(BaseRouter):
# Generated using @collection_action or @collection_link decorators
# on methods of the viewset.
Route(
+ key='collection',
url=r'^{prefix}/{methodname}{trailing_slash}$',
mapping={
'{httpmethod}': '{methodname}',
},
- name='{basename}-collection-{methodnamehyphen}',
+ name='{basename}-{methodnamehyphen}',
initkwargs={}
),
# Detail route.
Route(
+ key='detail',
url=r'^{prefix}/{lookup}{trailing_slash}$',
mapping={
'get': 'retrieve',
@@ -114,11 +117,12 @@ class SimpleRouter(BaseRouter):
# Dynamically generated routes.
# Generated using @action or @link decorators on methods of the viewset.
Route(
+ key='dynamic',
url=r'^{prefix}/{lookup}/{methodname}{trailing_slash}$',
mapping={
'{httpmethod}': '{methodname}',
},
- name='{basename}-dynamic-{methodnamehyphen}',
+ name='{basename}-{methodnamehyphen}',
initkwargs={}
),
]
@@ -171,23 +175,25 @@ class SimpleRouter(BaseRouter):
ret = []
for route in self.routes:
- if route.name == '{basename}-dynamic-{methodnamehyphen}':
+ if route.key == 'dynamic':
# Dynamic routes (@link or @action decorator)
for httpmethods, methodname in dynamic_routes:
initkwargs = route.initkwargs.copy()
initkwargs.update(getattr(viewset, methodname).kwargs)
ret.append(Route(
+ key=route.key,
url=replace_methodname(route.url, methodname),
mapping=dict((httpmethod, methodname) for httpmethod in httpmethods),
name=replace_methodname(route.name, methodname),
initkwargs=initkwargs,
))
- elif route.name == '{basename}-collection-{methodnamehyphen}':
+ elif route.key == 'collection':
# Dynamic routes (@collection_link or @collection_action decorator)
for httpmethods, methodname in collection_routes:
initkwargs = route.initkwargs.copy()
initkwargs.update(getattr(viewset, methodname).kwargs)
ret.append(Route(
+ key=route.key,
url=replace_methodname(route.url, methodname),
mapping=dict((httpmethod, methodname) for httpmethod in httpmethods),
name=replace_methodname(route.name, methodname),
--
cgit v1.2.3
From f02274307826ebf98998e502fecca171bb0de696 Mon Sep 17 00:00:00 2001
From: Alex Burgel
Date: Sun, 16 Jun 2013 12:51:33 -0400
Subject: Rename router collection test case
---
rest_framework/tests/test_routers.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py
index 60f150d2..e0a7e292 100644
--- a/rest_framework/tests/test_routers.py
+++ b/rest_framework/tests/test_routers.py
@@ -216,7 +216,7 @@ class TestActionAppliedToExistingRoute(TestCase):
self.router.urls
-class StaticAndDynamicViewSet(viewsets.ViewSet):
+class CollectionAndDynamicViewSet(viewsets.ViewSet):
def list(self, request, *args, **kwargs):
return Response({'method': 'list'})
@@ -237,12 +237,12 @@ class StaticAndDynamicViewSet(viewsets.ViewSet):
return Response({'method': 'link2'})
-class TestStaticAndDynamicRouter(TestCase):
+class TestCollectionAndDynamicRouter(TestCase):
def setUp(self):
self.router = SimpleRouter()
def test_link_and_action_decorator(self):
- routes = self.router.get_routes(StaticAndDynamicViewSet)
+ routes = self.router.get_routes(CollectionAndDynamicViewSet)
decorator_routes = [r for r in routes if not (r.name.endswith('-list') or r.name.endswith('-detail'))]
# Make sure all these endpoints exist and none have been clobbered
for i, endpoint in enumerate(['collection_action', 'collection_link', 'dynamic_action', 'dynamic_link']):
--
cgit v1.2.3
From e14cbaf6961ad9c94deaf0417d8e8ce5ec96d0ac Mon Sep 17 00:00:00 2001
From: Alex Burgel
Date: Sat, 13 Jul 2013 11:11:53 -0400
Subject: Changed collection_* decorators to list_*
---
docs/api-guide/viewsets.md | 10 +++++-----
rest_framework/decorators.py | 16 ++++++++--------
rest_framework/routers.py | 31 ++++++++++++++++---------------
rest_framework/tests/test_routers.py | 24 ++++++++++++------------
4 files changed, 41 insertions(+), 40 deletions(-)
diff --git a/docs/api-guide/viewsets.md b/docs/api-guide/viewsets.md
index e83487fb..6d6bb133 100644
--- a/docs/api-guide/viewsets.md
+++ b/docs/api-guide/viewsets.md
@@ -92,15 +92,15 @@ The default routers included with REST framework will provide routes for a stand
def destroy(self, request, pk=None):
pass
-If you have ad-hoc methods that you need to be routed to, you can mark them as requiring routing using the `@link`, `@action`, `@collection_link`, or `@collection_action` decorators. The `@link` and `@collection_link` decorators will route `GET` requests, and the `@action` and `@collection_action` decorators will route `POST` requests.
+If you have ad-hoc methods that you need to be routed to, you can mark them as requiring routing using the `@link`, `@action`, `@list_link`, or `@list_action` decorators. The `@link` and `@list_link` decorators will route `GET` requests, and the `@action` and `@list_action` decorators will route `POST` requests.
-The `@link` and `@action` decorators contain `pk` in their URL pattern and are intended for methods which require a single instance. The `@collection_link` and `@collection_action` decorators are intended for methods which operate on a collection of objects.
+The `@link` and `@action` decorators contain `pk` in their URL pattern and are intended for methods which require a single instance. The `@list_link` and `@list_action` decorators are intended for methods which operate on a list of objects.
For example:
from django.contrib.auth.models import User
from rest_framework import viewsets
- from rest_framework.decorators import action
+ from rest_framework.decorators import action, list_link
from rest_framework.response import Response
from myapp.serializers import UserSerializer, PasswordSerializer
@@ -123,7 +123,7 @@ For example:
return Response(serializer.errors,
status=status.HTTP_400_BAD_REQUEST)
- @collection_link()
+ @list_link()
def recent_users(self, request):
recent_users = User.objects.all().order('-last_login')
page = self.paginate_queryset(recent_users)
@@ -136,7 +136,7 @@ The decorators can additionally take extra arguments that will be set for the ro
def set_password(self, request, pk=None):
...
-The `@action` and `@collection_action` decorators will route `POST` requests by default, but may also accept other HTTP methods, by using the `methods` argument. For example:
+The `@action` and `@list_action` decorators will route `POST` requests by default, but may also accept other HTTP methods, by using the `methods` argument. For example:
@action(methods=['POST', 'DELETE'])
def unset_password(self, request, pk=None):
diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py
index dacd380f..92f551db 100644
--- a/rest_framework/decorators.py
+++ b/rest_framework/decorators.py
@@ -109,11 +109,11 @@ def permission_classes(permission_classes):
def link(**kwargs):
"""
- Used to mark a method on a ViewSet that should be routed for GET requests.
+ Used to mark a method on a ViewSet that should be routed for detail GET requests.
"""
def decorator(func):
func.bind_to_methods = ['get']
- func.collection = False
+ func.detail = True
func.kwargs = kwargs
return func
return decorator
@@ -121,35 +121,35 @@ def link(**kwargs):
def action(methods=['post'], **kwargs):
"""
- Used to mark a method on a ViewSet that should be routed for POST requests.
+ Used to mark a method on a ViewSet that should be routed for detail POST requests.
"""
def decorator(func):
func.bind_to_methods = methods
- func.collection = False
+ func.detail = True
func.kwargs = kwargs
return func
return decorator
-def collection_link(**kwargs):
+def list_link(**kwargs):
"""
Used to mark a method on a ViewSet that should be routed for GET requests.
"""
def decorator(func):
func.bind_to_methods = ['get']
- func.collection = True
+ func.detail = False
func.kwargs = kwargs
return func
return decorator
-def collection_action(methods=['post'], **kwargs):
+def list_action(methods=['post'], **kwargs):
"""
Used to mark a method on a ViewSet that should be routed for POST requests.
"""
def decorator(func):
func.bind_to_methods = methods
- func.collection = True
+ func.detail = False
func.kwargs = kwargs
return func
return decorator
diff --git a/rest_framework/routers.py b/rest_framework/routers.py
index 541df4a9..c8f711e9 100644
--- a/rest_framework/routers.py
+++ b/rest_framework/routers.py
@@ -89,8 +89,8 @@ class SimpleRouter(BaseRouter):
name='{basename}-list',
initkwargs={'suffix': 'List'}
),
- # Dynamically generated collection routes.
- # Generated using @collection_action or @collection_link decorators
+ # Dynamically generated list routes.
+ # Generated using @list_action or @list_link decorators
# on methods of the viewset.
Route(
key='collection',
@@ -114,7 +114,7 @@ class SimpleRouter(BaseRouter):
name='{basename}-detail',
initkwargs={'suffix': 'Instance'}
),
- # Dynamically generated routes.
+ # Dynamically generated detail routes.
# Generated using @action or @link decorators on methods of the viewset.
Route(
key='dynamic',
@@ -157,27 +157,28 @@ class SimpleRouter(BaseRouter):
known_actions = flatten([route.mapping.values() for route in self.routes])
# Determine any `@action` or `@link` decorated methods on the viewset
- collection_routes = []
- dynamic_routes = []
+ detail_routes = []
+ list_routes = []
for methodname in dir(viewset):
attr = getattr(viewset, methodname)
httpmethods = getattr(attr, 'bind_to_methods', None)
- collection = getattr(attr, 'collection', False)
+ detail = getattr(attr, 'detail', True)
if httpmethods:
if methodname in known_actions:
- raise ImproperlyConfigured('Cannot use @action or @link decorator on '
- 'method "%s" as it is an existing route' % methodname)
+ raise ImproperlyConfigured('Cannot use @action, @link, @list_action '
+ 'or @list_link decorator on method "%s" '
+ 'as it is an existing route' % methodname)
httpmethods = [method.lower() for method in httpmethods]
- if collection:
- collection_routes.append((httpmethods, methodname))
+ if detail:
+ detail_routes.append((httpmethods, methodname))
else:
- dynamic_routes.append((httpmethods, methodname))
+ list_routes.append((httpmethods, methodname))
ret = []
for route in self.routes:
if route.key == 'dynamic':
- # Dynamic routes (@link or @action decorator)
- for httpmethods, methodname in dynamic_routes:
+ # Dynamic detail routes (@link or @action decorator)
+ for httpmethods, methodname in detail_routes:
initkwargs = route.initkwargs.copy()
initkwargs.update(getattr(viewset, methodname).kwargs)
ret.append(Route(
@@ -188,8 +189,8 @@ class SimpleRouter(BaseRouter):
initkwargs=initkwargs,
))
elif route.key == 'collection':
- # Dynamic routes (@collection_link or @collection_action decorator)
- for httpmethods, methodname in collection_routes:
+ # Dynamic list routes (@list_link or @list_action decorator)
+ for httpmethods, methodname in list_routes:
initkwargs = route.initkwargs.copy()
initkwargs.update(getattr(viewset, methodname).kwargs)
ret.append(Route(
diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py
index e0a7e292..39310176 100644
--- a/rest_framework/tests/test_routers.py
+++ b/rest_framework/tests/test_routers.py
@@ -4,7 +4,7 @@ from django.test import TestCase
from django.core.exceptions import ImproperlyConfigured
from rest_framework import serializers, viewsets, permissions
from rest_framework.compat import include, patterns, url
-from rest_framework.decorators import link, action, collection_link, collection_action
+from rest_framework.decorators import link, action, list_link, list_action
from rest_framework.response import Response
from rest_framework.routers import SimpleRouter, DefaultRouter
from rest_framework.test import APIRequestFactory
@@ -216,39 +216,39 @@ class TestActionAppliedToExistingRoute(TestCase):
self.router.urls
-class CollectionAndDynamicViewSet(viewsets.ViewSet):
+class DynamicListAndDetailViewSet(viewsets.ViewSet):
def list(self, request, *args, **kwargs):
return Response({'method': 'list'})
- @collection_action()
- def collection_action(self, request, *args, **kwargs):
+ @list_action()
+ def list_action(self, request, *args, **kwargs):
return Response({'method': 'action1'})
@action()
- def dynamic_action(self, request, *args, **kwargs):
+ def detail_action(self, request, *args, **kwargs):
return Response({'method': 'action2'})
- @collection_link()
- def collection_link(self, request, *args, **kwargs):
+ @list_link()
+ def list_link(self, request, *args, **kwargs):
return Response({'method': 'link1'})
@link()
- def dynamic_link(self, request, *args, **kwargs):
+ def detail_link(self, request, *args, **kwargs):
return Response({'method': 'link2'})
-class TestCollectionAndDynamicRouter(TestCase):
+class TestDynamicListAndDetailRouter(TestCase):
def setUp(self):
self.router = SimpleRouter()
def test_link_and_action_decorator(self):
- routes = self.router.get_routes(CollectionAndDynamicViewSet)
+ routes = self.router.get_routes(DynamicListAndDetailViewSet)
decorator_routes = [r for r in routes if not (r.name.endswith('-list') or r.name.endswith('-detail'))]
# Make sure all these endpoints exist and none have been clobbered
- for i, endpoint in enumerate(['collection_action', 'collection_link', 'dynamic_action', 'dynamic_link']):
+ for i, endpoint in enumerate(['list_action', 'list_link', 'detail_action', 'detail_link']):
route = decorator_routes[i]
# check url listing
- if endpoint.startswith('collection_'):
+ if endpoint.startswith('list_'):
self.assertEqual(route.url,
'^{{prefix}}/{0}{{trailing_slash}}$'.format(endpoint))
else:
--
cgit v1.2.3
From ca7ba07b4e42bd1c7c6bb8088c0c5a2c434b56ee Mon Sep 17 00:00:00 2001
From: Alex Burgel
Date: Sat, 13 Jul 2013 11:12:59 -0400
Subject: Introduce DynamicDetailRoute and DynamicListRoute to distinguish
between different route types
---
rest_framework/routers.py | 26 ++++++++------------------
1 file changed, 8 insertions(+), 18 deletions(-)
diff --git a/rest_framework/routers.py b/rest_framework/routers.py
index c8f711e9..b8f19b66 100644
--- a/rest_framework/routers.py
+++ b/rest_framework/routers.py
@@ -25,7 +25,9 @@ from rest_framework.reverse import reverse
from rest_framework.urlpatterns import format_suffix_patterns
-Route = namedtuple('Route', ['key', 'url', 'mapping', 'name', 'initkwargs'])
+Route = namedtuple('Route', ['url', 'mapping', 'name', 'initkwargs'])
+DynamicDetailRoute = namedtuple('DynamicDetailRoute', ['url', 'name', 'initkwargs'])
+DynamicListRoute = namedtuple('DynamicListRoute', ['url', 'name', 'initkwargs'])
def replace_methodname(format_string, methodname):
@@ -80,7 +82,6 @@ class SimpleRouter(BaseRouter):
routes = [
# List route.
Route(
- key='list',
url=r'^{prefix}{trailing_slash}$',
mapping={
'get': 'list',
@@ -92,18 +93,13 @@ class SimpleRouter(BaseRouter):
# Dynamically generated list routes.
# Generated using @list_action or @list_link decorators
# on methods of the viewset.
- Route(
- key='collection',
+ DynamicListRoute(
url=r'^{prefix}/{methodname}{trailing_slash}$',
- mapping={
- '{httpmethod}': '{methodname}',
- },
name='{basename}-{methodnamehyphen}',
initkwargs={}
),
# Detail route.
Route(
- key='detail',
url=r'^{prefix}/{lookup}{trailing_slash}$',
mapping={
'get': 'retrieve',
@@ -116,12 +112,8 @@ class SimpleRouter(BaseRouter):
),
# Dynamically generated detail routes.
# Generated using @action or @link decorators on methods of the viewset.
- Route(
- key='dynamic',
+ DynamicDetailRoute(
url=r'^{prefix}/{lookup}/{methodname}{trailing_slash}$',
- mapping={
- '{httpmethod}': '{methodname}',
- },
name='{basename}-{methodnamehyphen}',
initkwargs={}
),
@@ -154,7 +146,7 @@ class SimpleRouter(BaseRouter):
Returns a list of the Route namedtuple.
"""
- known_actions = flatten([route.mapping.values() for route in self.routes])
+ known_actions = flatten([route.mapping.values() for route in self.routes if isinstance(route, Route)])
# Determine any `@action` or `@link` decorated methods on the viewset
detail_routes = []
@@ -176,25 +168,23 @@ class SimpleRouter(BaseRouter):
ret = []
for route in self.routes:
- if route.key == 'dynamic':
+ if isinstance(route, DynamicDetailRoute):
# Dynamic detail routes (@link or @action decorator)
for httpmethods, methodname in detail_routes:
initkwargs = route.initkwargs.copy()
initkwargs.update(getattr(viewset, methodname).kwargs)
ret.append(Route(
- key=route.key,
url=replace_methodname(route.url, methodname),
mapping=dict((httpmethod, methodname) for httpmethod in httpmethods),
name=replace_methodname(route.name, methodname),
initkwargs=initkwargs,
))
- elif route.key == 'collection':
+ elif isinstance(route, DynamicListRoute):
# Dynamic list routes (@list_link or @list_action decorator)
for httpmethods, methodname in list_routes:
initkwargs = route.initkwargs.copy()
initkwargs.update(getattr(viewset, methodname).kwargs)
ret.append(Route(
- key=route.key,
url=replace_methodname(route.url, methodname),
mapping=dict((httpmethod, methodname) for httpmethod in httpmethods),
name=replace_methodname(route.name, methodname),
--
cgit v1.2.3
From eaae8fb2d973769a827214e0606a7e41028d5d34 Mon Sep 17 00:00:00 2001
From: Alex Burgel
Date: Mon, 15 Jul 2013 18:35:13 -0400
Subject: Combined link_* and action_* decorators into detail_route and
list_route, marked the originals as deprecated.
---
docs/api-guide/routers.md | 16 ++++++-------
docs/api-guide/viewsets.md | 16 ++++++-------
docs/tutorial/6-viewsets-and-routers.md | 8 +++----
rest_framework/decorators.py | 19 ++++++++++------
rest_framework/routers.py | 14 ++++++------
rest_framework/tests/test_routers.py | 40 ++++++++++++++++-----------------
6 files changed, 59 insertions(+), 54 deletions(-)
diff --git a/docs/api-guide/routers.md b/docs/api-guide/routers.md
index 86582905..f196dc3c 100644
--- a/docs/api-guide/routers.md
+++ b/docs/api-guide/routers.md
@@ -35,12 +35,12 @@ The example above would generate the following URL patterns:
* URL pattern: `^accounts/$` Name: `'account-list'`
* URL pattern: `^accounts/{pk}/$` Name: `'account-detail'`
-### Extra link and actions
+### Registering additional routes
-Any methods on the viewset decorated with `@link` or `@action` will also be routed.
+Any methods on the viewset decorated with `@detail_route` or `@list_route` will also be routed.
For example, a given method like this on the `UserViewSet` class:
- @action(permission_classes=[IsAdminOrIsSelf])
+ @detail_route(methods=['post'], permission_classes=[IsAdminOrIsSelf])
def set_password(self, request, pk=None):
...
@@ -52,7 +52,7 @@ The following URL pattern would additionally be generated:
## SimpleRouter
-This router includes routes for the standard set of `list`, `create`, `retrieve`, `update`, `partial_update` and `destroy` actions. The viewset can also mark additional methods to be routed, using the `@link` or `@action` decorators.
+This router includes routes for the standard set of `list`, `create`, `retrieve`, `update`, `partial_update` and `destroy` actions. The viewset can also mark additional methods to be routed, using the `@detail_route` or `@list_route` decorators.
URL Style
HTTP Method
Action
URL Name
@@ -62,8 +62,8 @@ This router includes routes for the standard set of `list`, `create`, `retrieve`
PUT
update
PATCH
partial_update
DELETE
destroy
-
{prefix}/{lookup}/{methodname}/
GET
@link decorated method
{basename}-{methodname}
-
POST
@action decorated method
+
{prefix}/{lookup}/{methodname}/
GET
@detail_route decorated method
{basename}-{methodname}
+
POST
@detail_route decorated method
By default the URLs created by `SimpleRouter` are appending with a trailing slash.
@@ -86,8 +86,8 @@ This router is similar to `SimpleRouter` as above, but additionally includes a d
PUT
update
PATCH
partial_update
DELETE
destroy
-
{prefix}/{lookup}/{methodname}/[.format]
GET
@link decorated method
{basename}-{methodname}
-
POST
@action decorated method
+
{prefix}/{lookup}/{methodname}/[.format]
GET
@detail_route decorated method
{basename}-{methodname}
+
POST
@detail_route decorated method
As with `SimpleRouter` the trailing slashs on the URL routes can be removed by setting the `trailing_slash` argument to `False` when instantiating the router.
diff --git a/docs/api-guide/viewsets.md b/docs/api-guide/viewsets.md
index 6d6bb133..7a8d5979 100644
--- a/docs/api-guide/viewsets.md
+++ b/docs/api-guide/viewsets.md
@@ -92,15 +92,15 @@ The default routers included with REST framework will provide routes for a stand
def destroy(self, request, pk=None):
pass
-If you have ad-hoc methods that you need to be routed to, you can mark them as requiring routing using the `@link`, `@action`, `@list_link`, or `@list_action` decorators. The `@link` and `@list_link` decorators will route `GET` requests, and the `@action` and `@list_action` decorators will route `POST` requests.
+If you have ad-hoc methods that you need to be routed to, you can mark them as requiring routing using the `@detail_route` or `@list_route` decorators.
-The `@link` and `@action` decorators contain `pk` in their URL pattern and are intended for methods which require a single instance. The `@list_link` and `@list_action` decorators are intended for methods which operate on a list of objects.
+The `@detail_route` decorator contains `pk` in its URL pattern and is intended for methods which require a single instance. The `@list_route` decorator is intended for methods which operate on a list of objects.
For example:
from django.contrib.auth.models import User
from rest_framework import viewsets
- from rest_framework.decorators import action, list_link
+ from rest_framework.decorators import detail_route, list_route
from rest_framework.response import Response
from myapp.serializers import UserSerializer, PasswordSerializer
@@ -111,7 +111,7 @@ For example:
queryset = User.objects.all()
serializer_class = UserSerializer
- @action()
+ @detail_route(methods=['post'])
def set_password(self, request, pk=None):
user = self.get_object()
serializer = PasswordSerializer(data=request.DATA)
@@ -123,7 +123,7 @@ For example:
return Response(serializer.errors,
status=status.HTTP_400_BAD_REQUEST)
- @list_link()
+ @list_route()
def recent_users(self, request):
recent_users = User.objects.all().order('-last_login')
page = self.paginate_queryset(recent_users)
@@ -132,13 +132,13 @@ For example:
The decorators can additionally take extra arguments that will be set for the routed view only. For example...
- @action(permission_classes=[IsAdminOrIsSelf])
+ @detail_route(methods=['post'], permission_classes=[IsAdminOrIsSelf])
def set_password(self, request, pk=None):
...
-The `@action` and `@list_action` decorators will route `POST` requests by default, but may also accept other HTTP methods, by using the `methods` argument. For example:
+By default, the decorators will route `GET` requests, but may also accept other HTTP methods, by using the `methods` argument. For example:
- @action(methods=['POST', 'DELETE'])
+ @detail_route(methods=['post', 'delete'])
def unset_password(self, request, pk=None):
...
---
diff --git a/docs/tutorial/6-viewsets-and-routers.md b/docs/tutorial/6-viewsets-and-routers.md
index f16add39..f126ba04 100644
--- a/docs/tutorial/6-viewsets-and-routers.md
+++ b/docs/tutorial/6-viewsets-and-routers.md
@@ -25,7 +25,7 @@ Here we've used `ReadOnlyModelViewSet` class to automatically provide the defaul
Next we're going to replace the `SnippetList`, `SnippetDetail` and `SnippetHighlight` view classes. We can remove the three views, and again replace them with a single class.
- from rest_framework.decorators import link
+ from rest_framework.decorators import detail_route
class SnippetViewSet(viewsets.ModelViewSet):
"""
@@ -39,7 +39,7 @@ Next we're going to replace the `SnippetList`, `SnippetDetail` and `SnippetHighl
permission_classes = (permissions.IsAuthenticatedOrReadOnly,
IsOwnerOrReadOnly,)
- @link(renderer_classes=[renderers.StaticHTMLRenderer])
+ @detail_route(renderer_classes=[renderers.StaticHTMLRenderer])
def highlight(self, request, *args, **kwargs):
snippet = self.get_object()
return Response(snippet.highlighted)
@@ -49,9 +49,9 @@ Next we're going to replace the `SnippetList`, `SnippetDetail` and `SnippetHighl
This time we've used the `ModelViewSet` class in order to get the complete set of default read and write operations.
-Notice that we've also used the `@link` decorator to create a custom action, named `highlight`. This decorator can be used to add any custom endpoints that don't fit into the standard `create`/`update`/`delete` style.
+Notice that we've also used the `@detail_route` decorator to create a custom action, named `highlight`. This decorator can be used to add any custom endpoints that don't fit into the standard `create`/`update`/`delete` style.
-Custom actions which use the `@link` decorator will respond to `GET` requests. We could have instead used the `@action` decorator if we wanted an action that responded to `POST` requests.
+Custom actions which use the `@detail_route` decorator will respond to `GET` requests. We can use the `methods` argument if we wanted an action that responded to `POST` requests.
## Binding ViewSets to URLs explicitly
diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py
index 92f551db..1ca176f2 100644
--- a/rest_framework/decorators.py
+++ b/rest_framework/decorators.py
@@ -3,13 +3,14 @@ The most important decorator in this module is `@api_view`, which is used
for writing function-based views with REST framework.
There are also various decorators for setting the API policies on function
-based views, as well as the `@action` and `@link` decorators, which are
+based views, as well as the `@detail_route` and `@list_route` decorators, which are
used to annotate methods on viewsets that should be included by routers.
"""
from __future__ import unicode_literals
from rest_framework.compat import six
from rest_framework.views import APIView
import types
+import warnings
def api_view(http_method_names):
@@ -111,6 +112,8 @@ def link(**kwargs):
"""
Used to mark a method on a ViewSet that should be routed for detail GET requests.
"""
+ msg = 'link is pending deprecation. Use detail_route instead.'
+ warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
def decorator(func):
func.bind_to_methods = ['get']
func.detail = True
@@ -123,6 +126,8 @@ def action(methods=['post'], **kwargs):
"""
Used to mark a method on a ViewSet that should be routed for detail POST requests.
"""
+ msg = 'action is pending deprecation. Use detail_route instead.'
+ warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
def decorator(func):
func.bind_to_methods = methods
func.detail = True
@@ -131,21 +136,21 @@ def action(methods=['post'], **kwargs):
return decorator
-def list_link(**kwargs):
+def detail_route(methods=['get'], **kwargs):
"""
- Used to mark a method on a ViewSet that should be routed for GET requests.
+ Used to mark a method on a ViewSet that should be routed for detail requests.
"""
def decorator(func):
- func.bind_to_methods = ['get']
- func.detail = False
+ func.bind_to_methods = methods
+ func.detail = True
func.kwargs = kwargs
return func
return decorator
-def list_action(methods=['post'], **kwargs):
+def list_route(methods=['get'], **kwargs):
"""
- Used to mark a method on a ViewSet that should be routed for POST requests.
+ Used to mark a method on a ViewSet that should be routed for list requests.
"""
def decorator(func):
func.bind_to_methods = methods
diff --git a/rest_framework/routers.py b/rest_framework/routers.py
index b8f19b66..b761ba9a 100644
--- a/rest_framework/routers.py
+++ b/rest_framework/routers.py
@@ -91,7 +91,7 @@ class SimpleRouter(BaseRouter):
initkwargs={'suffix': 'List'}
),
# Dynamically generated list routes.
- # Generated using @list_action or @list_link decorators
+ # Generated using @list_route decorator
# on methods of the viewset.
DynamicListRoute(
url=r'^{prefix}/{methodname}{trailing_slash}$',
@@ -111,7 +111,7 @@ class SimpleRouter(BaseRouter):
initkwargs={'suffix': 'Instance'}
),
# Dynamically generated detail routes.
- # Generated using @action or @link decorators on methods of the viewset.
+ # Generated using @detail_route decorator on methods of the viewset.
DynamicDetailRoute(
url=r'^{prefix}/{lookup}/{methodname}{trailing_slash}$',
name='{basename}-{methodnamehyphen}',
@@ -148,7 +148,7 @@ class SimpleRouter(BaseRouter):
known_actions = flatten([route.mapping.values() for route in self.routes if isinstance(route, Route)])
- # Determine any `@action` or `@link` decorated methods on the viewset
+ # Determine any `@detail_route` or `@list_route` decorated methods on the viewset
detail_routes = []
list_routes = []
for methodname in dir(viewset):
@@ -157,8 +157,8 @@ class SimpleRouter(BaseRouter):
detail = getattr(attr, 'detail', True)
if httpmethods:
if methodname in known_actions:
- raise ImproperlyConfigured('Cannot use @action, @link, @list_action '
- 'or @list_link decorator on method "%s" '
+ raise ImproperlyConfigured('Cannot use @detail_route or @list_route '
+ 'decorators on method "%s" '
'as it is an existing route' % methodname)
httpmethods = [method.lower() for method in httpmethods]
if detail:
@@ -169,7 +169,7 @@ class SimpleRouter(BaseRouter):
ret = []
for route in self.routes:
if isinstance(route, DynamicDetailRoute):
- # Dynamic detail routes (@link or @action decorator)
+ # Dynamic detail routes (@detail_route decorator)
for httpmethods, methodname in detail_routes:
initkwargs = route.initkwargs.copy()
initkwargs.update(getattr(viewset, methodname).kwargs)
@@ -180,7 +180,7 @@ class SimpleRouter(BaseRouter):
initkwargs=initkwargs,
))
elif isinstance(route, DynamicListRoute):
- # Dynamic list routes (@list_link or @list_action decorator)
+ # Dynamic list routes (@list_route decorator)
for httpmethods, methodname in list_routes:
initkwargs = route.initkwargs.copy()
initkwargs.update(getattr(viewset, methodname).kwargs)
diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py
index 39310176..c3597e38 100644
--- a/rest_framework/tests/test_routers.py
+++ b/rest_framework/tests/test_routers.py
@@ -4,7 +4,7 @@ from django.test import TestCase
from django.core.exceptions import ImproperlyConfigured
from rest_framework import serializers, viewsets, permissions
from rest_framework.compat import include, patterns, url
-from rest_framework.decorators import link, action, list_link, list_action
+from rest_framework.decorators import detail_route, list_route
from rest_framework.response import Response
from rest_framework.routers import SimpleRouter, DefaultRouter
from rest_framework.test import APIRequestFactory
@@ -18,23 +18,23 @@ class BasicViewSet(viewsets.ViewSet):
def list(self, request, *args, **kwargs):
return Response({'method': 'list'})
- @action()
+ @detail_route(methods=['post'])
def action1(self, request, *args, **kwargs):
return Response({'method': 'action1'})
- @action()
+ @detail_route(methods=['post'])
def action2(self, request, *args, **kwargs):
return Response({'method': 'action2'})
- @action(methods=['post', 'delete'])
+ @detail_route(methods=['post', 'delete'])
def action3(self, request, *args, **kwargs):
return Response({'method': 'action2'})
- @link()
+ @detail_route()
def link1(self, request, *args, **kwargs):
return Response({'method': 'link1'})
- @link()
+ @detail_route()
def link2(self, request, *args, **kwargs):
return Response({'method': 'link2'})
@@ -175,7 +175,7 @@ class TestActionKeywordArgs(TestCase):
class TestViewSet(viewsets.ModelViewSet):
permission_classes = []
- @action(permission_classes=[permissions.AllowAny])
+ @detail_route(methods=['post'], permission_classes=[permissions.AllowAny])
def custom(self, request, *args, **kwargs):
return Response({
'permission_classes': self.permission_classes
@@ -196,14 +196,14 @@ class TestActionKeywordArgs(TestCase):
class TestActionAppliedToExistingRoute(TestCase):
"""
- Ensure `@action` decorator raises an except when applied
+ Ensure `@detail_route` decorator raises an except when applied
to an existing route
"""
def test_exception_raised_when_action_applied_to_existing_route(self):
class TestViewSet(viewsets.ModelViewSet):
- @action()
+ @detail_route(methods=['post'])
def retrieve(self, request, *args, **kwargs):
return Response({
'hello': 'world'
@@ -220,20 +220,20 @@ class DynamicListAndDetailViewSet(viewsets.ViewSet):
def list(self, request, *args, **kwargs):
return Response({'method': 'list'})
- @list_action()
- def list_action(self, request, *args, **kwargs):
+ @list_route(methods=['post'])
+ def list_route_post(self, request, *args, **kwargs):
return Response({'method': 'action1'})
- @action()
- def detail_action(self, request, *args, **kwargs):
+ @detail_route(methods=['post'])
+ def detail_route_post(self, request, *args, **kwargs):
return Response({'method': 'action2'})
- @list_link()
- def list_link(self, request, *args, **kwargs):
+ @list_route()
+ def list_route_get(self, request, *args, **kwargs):
return Response({'method': 'link1'})
- @link()
- def detail_link(self, request, *args, **kwargs):
+ @detail_route()
+ def detail_route_get(self, request, *args, **kwargs):
return Response({'method': 'link2'})
@@ -241,11 +241,11 @@ class TestDynamicListAndDetailRouter(TestCase):
def setUp(self):
self.router = SimpleRouter()
- def test_link_and_action_decorator(self):
+ def test_list_and_detail_route_decorators(self):
routes = self.router.get_routes(DynamicListAndDetailViewSet)
decorator_routes = [r for r in routes if not (r.name.endswith('-list') or r.name.endswith('-detail'))]
# Make sure all these endpoints exist and none have been clobbered
- for i, endpoint in enumerate(['list_action', 'list_link', 'detail_action', 'detail_link']):
+ for i, endpoint in enumerate(['list_route_get', 'list_route_post', 'detail_route_get', 'detail_route_post']):
route = decorator_routes[i]
# check url listing
if endpoint.startswith('list_'):
@@ -255,7 +255,7 @@ class TestDynamicListAndDetailRouter(TestCase):
self.assertEqual(route.url,
'^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(endpoint))
# check method to function mapping
- if endpoint.endswith('action'):
+ if endpoint.endswith('_post'):
method_map = 'post'
else:
method_map = 'get'
--
cgit v1.2.3
From 4292cc18fa3e4b3f5e67c02c3780cdcbf901a0a1 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Mon, 19 Aug 2013 20:53:30 +0100
Subject: Docs tweaking
---
docs/api-guide/routers.md | 11 +++++++----
docs/api-guide/viewsets.md | 2 +-
2 files changed, 8 insertions(+), 5 deletions(-)
diff --git a/docs/api-guide/routers.md b/docs/api-guide/routers.md
index 7884c2e9..c8465418 100644
--- a/docs/api-guide/routers.md
+++ b/docs/api-guide/routers.md
@@ -48,6 +48,8 @@ The following URL pattern would additionally be generated:
* URL pattern: `^users/{pk}/set_password/$` Name: `'user-set-password'`
+For more information see the viewset documentation on [marking extra actions for routing][route-decorators].
+
# API Guide
## SimpleRouter
@@ -58,12 +60,12 @@ This router includes routes for the standard set of `list`, `create`, `retrieve`
URL Style
HTTP Method
Action
URL Name
{prefix}/
GET
list
{basename}-list
POST
create
+
{prefix}/{methodname}/
GET, or as specified by `methods` argument
`@list_route` decorated method
{basename}-{methodname}
{prefix}/{lookup}/
GET
retrieve
{basename}-detail
PUT
update
PATCH
partial_update
DELETE
destroy
-
{prefix}/{lookup}/{methodname}/
GET
@detail_route decorated method
{basename}-{methodname}
-
POST
@detail_route decorated method
+
{prefix}/{lookup}/{methodname}/
GET, or as specified by `methods` argument
`@detail_route` decorated method
{basename}-{methodname}
By default the URLs created by `SimpleRouter` are appended with a trailing slash.
@@ -82,12 +84,12 @@ This router is similar to `SimpleRouter` as above, but additionally includes a d
[.format]
GET
automatically generated root view
api-root
{prefix}/[.format]
GET
list
{basename}-list
POST
create
+
{prefix}/{methodname}/[.format]
GET, or as specified by `methods` argument
`@list_route` decorated method
{basename}-{methodname}
{prefix}/{lookup}/[.format]
GET
retrieve
{basename}-detail
PUT
update
PATCH
partial_update
DELETE
destroy
-
{prefix}/{lookup}/{methodname}/[.format]
GET
@detail_route decorated method
{basename}-{methodname}
-
POST
@detail_route decorated method
+
{prefix}/{lookup}/{methodname}/[.format]
GET, or as specified by `methods` argument
`@detail_route` decorated method
{basename}-{methodname}
As with `SimpleRouter` the trailing slashes on the URL routes can be removed by setting the `trailing_slash` argument to `False` when instantiating the router.
@@ -144,3 +146,4 @@ If you want to provide totally custom behavior, you can override `BaseRouter` an
You may also want to override the `get_default_base_name(self, viewset)` method, or else always explicitly set the `base_name` argument when registering your viewsets with the router.
[cite]: http://guides.rubyonrails.org/routing.html
+[route-decorators]: viewsets.html#marking-extra-actions-for-routing
\ No newline at end of file
diff --git a/docs/api-guide/viewsets.md b/docs/api-guide/viewsets.md
index 95efc229..9005e7cb 100644
--- a/docs/api-guide/viewsets.md
+++ b/docs/api-guide/viewsets.md
@@ -61,7 +61,7 @@ There are two main advantages of using a `ViewSet` class over using a `View` cla
Both of these come with a trade-off. Using regular views and URL confs is more explicit and gives you more control. ViewSets are helpful if you want to get up and running quickly, or when you have a large API and you want to enforce a consistent URL configuration throughout.
-## Marking extra methods for routing
+## Marking extra actions for routing
The default routers included with REST framework will provide routes for a standard set of create/retrieve/update/destroy style operations, as shown below:
--
cgit v1.2.3
From 8acee2e626746f3096c49b3ebb13aaf7dc882917 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Mon, 19 Aug 2013 21:02:22 +0100
Subject: Commenting link/action decorators as pending deprecation
---
rest_framework/decorators.py | 35 ++++++++++++++++++-----------------
1 file changed, 18 insertions(+), 17 deletions(-)
diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py
index 1ca176f2..18e41a18 100644
--- a/rest_framework/decorators.py
+++ b/rest_framework/decorators.py
@@ -108,53 +108,54 @@ def permission_classes(permission_classes):
return decorator
-def link(**kwargs):
+def detail_route(methods=['get'], **kwargs):
"""
- Used to mark a method on a ViewSet that should be routed for detail GET requests.
+ Used to mark a method on a ViewSet that should be routed for detail requests.
"""
- msg = 'link is pending deprecation. Use detail_route instead.'
- warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
def decorator(func):
- func.bind_to_methods = ['get']
+ func.bind_to_methods = methods
func.detail = True
func.kwargs = kwargs
return func
return decorator
-def action(methods=['post'], **kwargs):
+def list_route(methods=['get'], **kwargs):
"""
- Used to mark a method on a ViewSet that should be routed for detail POST requests.
+ Used to mark a method on a ViewSet that should be routed for list requests.
"""
- msg = 'action is pending deprecation. Use detail_route instead.'
- warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
def decorator(func):
func.bind_to_methods = methods
- func.detail = True
+ func.detail = False
func.kwargs = kwargs
return func
return decorator
+# These are now pending deprecation, in favor of `detail_route` and `list_route`.
-def detail_route(methods=['get'], **kwargs):
+def link(**kwargs):
"""
- Used to mark a method on a ViewSet that should be routed for detail requests.
+ Used to mark a method on a ViewSet that should be routed for detail GET requests.
"""
+ msg = 'link is pending deprecation. Use detail_route instead.'
+ warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
def decorator(func):
- func.bind_to_methods = methods
+ func.bind_to_methods = ['get']
func.detail = True
func.kwargs = kwargs
return func
return decorator
-def list_route(methods=['get'], **kwargs):
+def action(methods=['post'], **kwargs):
"""
- Used to mark a method on a ViewSet that should be routed for list requests.
+ Used to mark a method on a ViewSet that should be routed for detail POST requests.
"""
+ msg = 'action is pending deprecation. Use detail_route instead.'
+ warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
def decorator(func):
func.bind_to_methods = methods
- func.detail = False
+ func.detail = True
func.kwargs = kwargs
return func
- return decorator
+ return decorator
\ No newline at end of file
--
cgit v1.2.3
From 815ef50735f50c7aff5255e60f1b484e75178e87 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Wed, 21 Aug 2013 21:18:46 +0100
Subject: If page size query param <= 0, just use default page size.
Closes #1028
---
rest_framework/generics.py | 11 ++++++++++-
1 file changed, 10 insertions(+), 1 deletion(-)
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index 874a142c..bcd62bf9 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -14,6 +14,15 @@ from rest_framework.settings import api_settings
import warnings
+def strict_positive_int(integer_string):
+ """
+ Cast a string to a strictly positive integer.
+ """
+ ret = int(integer_string)
+ if ret <= 0:
+ raise ValueError()
+ return ret
+
def get_object_or_404(queryset, **filter_kwargs):
"""
Same as Django's standard shortcut, but make sure to raise 404
@@ -198,7 +207,7 @@ class GenericAPIView(views.APIView):
if self.paginate_by_param:
query_params = self.request.QUERY_PARAMS
try:
- return int(query_params[self.paginate_by_param])
+ return strict_positive_int(query_params[self.paginate_by_param])
except (KeyError, ValueError):
pass
--
cgit v1.2.3
From 44ceef841543877a700c3fb4a0f84dfecbad0cbb Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Wed, 21 Aug 2013 21:30:25 +0100
Subject: Updating 2.4.0 release notes
---
.travis.yml | 1 +
docs/topics/release-notes.md | 5 +-
rest_framework/compat.py | 2 +-
rest_framework/six.py | 389 -------------------------------------------
tox.ini | 2 +
5 files changed, 8 insertions(+), 391 deletions(-)
delete mode 100644 rest_framework/six.py
diff --git a/.travis.yml b/.travis.yml
index 6a453241..f8640db2 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -19,6 +19,7 @@ install:
- "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install django-oauth-plus==2.0 --use-mirrors; fi"
- "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install django-oauth2-provider==0.2.4 --use-mirrors; fi"
- "if [[ ${DJANGO::11} == 'django==1.3' ]]; then pip install django-filter==0.5.4 --use-mirrors; fi"
+ - "if [[ ${DJANGO::11} == 'django==1.3' ]]; then pip install six --use-mirrors; fi"
- "if [[ ${DJANGO::11} != 'django==1.3' ]]; then pip install django-filter==0.6 --use-mirrors; fi"
- export PYTHONPATH=.
diff --git a/docs/topics/release-notes.md b/docs/topics/release-notes.md
index 52abfc08..f3bb19c6 100644
--- a/docs/topics/release-notes.md
+++ b/docs/topics/release-notes.md
@@ -40,9 +40,12 @@ You can determine your currently installed version using `pip freeze`:
## 2.3.x series
-### Master
+### 2.4.0
+* `@detail_route` and `@list_route` decorators replace `@action` and `@link`.
+* `six` no longer bundled. For Django <= 1.4.1, install `six` package.
* Support customizable view name and description functions, using the `VIEW_NAME_FUNCTION` and `VIEW_DESCRIPTION_FUNCTION` settings.
+* Bugfix: `?page_size=0` query parameter now falls back to default page size for view, instead of always turning pagination off.
### 2.3.7
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index baee3a9c..178a697f 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -14,7 +14,7 @@ from django.conf import settings
try:
from django.utils import six
except ImportError:
- from rest_framework import six
+ import six
# location of patterns, url, include changes in 1.4 onwards
try:
diff --git a/rest_framework/six.py b/rest_framework/six.py
deleted file mode 100644
index 9e382312..00000000
--- a/rest_framework/six.py
+++ /dev/null
@@ -1,389 +0,0 @@
-"""Utilities for writing code that runs on Python 2 and 3"""
-
-import operator
-import sys
-import types
-
-__author__ = "Benjamin Peterson "
-__version__ = "1.2.0"
-
-
-# True if we are running on Python 3.
-PY3 = sys.version_info[0] == 3
-
-if PY3:
- string_types = str,
- integer_types = int,
- class_types = type,
- text_type = str
- binary_type = bytes
-
- MAXSIZE = sys.maxsize
-else:
- string_types = basestring,
- integer_types = (int, long)
- class_types = (type, types.ClassType)
- text_type = unicode
- binary_type = str
-
- if sys.platform == "java":
- # Jython always uses 32 bits.
- MAXSIZE = int((1 << 31) - 1)
- else:
- # It's possible to have sizeof(long) != sizeof(Py_ssize_t).
- class X(object):
- def __len__(self):
- return 1 << 31
- try:
- len(X())
- except OverflowError:
- # 32-bit
- MAXSIZE = int((1 << 31) - 1)
- else:
- # 64-bit
- MAXSIZE = int((1 << 63) - 1)
- del X
-
-
-def _add_doc(func, doc):
- """Add documentation to a function."""
- func.__doc__ = doc
-
-
-def _import_module(name):
- """Import module, returning the module after the last dot."""
- __import__(name)
- return sys.modules[name]
-
-
-class _LazyDescr(object):
-
- def __init__(self, name):
- self.name = name
-
- def __get__(self, obj, tp):
- result = self._resolve()
- setattr(obj, self.name, result)
- # This is a bit ugly, but it avoids running this again.
- delattr(tp, self.name)
- return result
-
-
-class MovedModule(_LazyDescr):
-
- def __init__(self, name, old, new=None):
- super(MovedModule, self).__init__(name)
- if PY3:
- if new is None:
- new = name
- self.mod = new
- else:
- self.mod = old
-
- def _resolve(self):
- return _import_module(self.mod)
-
-
-class MovedAttribute(_LazyDescr):
-
- def __init__(self, name, old_mod, new_mod, old_attr=None, new_attr=None):
- super(MovedAttribute, self).__init__(name)
- if PY3:
- if new_mod is None:
- new_mod = name
- self.mod = new_mod
- if new_attr is None:
- if old_attr is None:
- new_attr = name
- else:
- new_attr = old_attr
- self.attr = new_attr
- else:
- self.mod = old_mod
- if old_attr is None:
- old_attr = name
- self.attr = old_attr
-
- def _resolve(self):
- module = _import_module(self.mod)
- return getattr(module, self.attr)
-
-
-
-class _MovedItems(types.ModuleType):
- """Lazy loading of moved objects"""
-
-
-_moved_attributes = [
- MovedAttribute("cStringIO", "cStringIO", "io", "StringIO"),
- MovedAttribute("filter", "itertools", "builtins", "ifilter", "filter"),
- MovedAttribute("input", "__builtin__", "builtins", "raw_input", "input"),
- MovedAttribute("map", "itertools", "builtins", "imap", "map"),
- MovedAttribute("reload_module", "__builtin__", "imp", "reload"),
- MovedAttribute("reduce", "__builtin__", "functools"),
- MovedAttribute("StringIO", "StringIO", "io"),
- MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"),
- MovedAttribute("zip", "itertools", "builtins", "izip", "zip"),
-
- MovedModule("builtins", "__builtin__"),
- MovedModule("configparser", "ConfigParser"),
- MovedModule("copyreg", "copy_reg"),
- MovedModule("http_cookiejar", "cookielib", "http.cookiejar"),
- MovedModule("http_cookies", "Cookie", "http.cookies"),
- MovedModule("html_entities", "htmlentitydefs", "html.entities"),
- MovedModule("html_parser", "HTMLParser", "html.parser"),
- MovedModule("http_client", "httplib", "http.client"),
- MovedModule("BaseHTTPServer", "BaseHTTPServer", "http.server"),
- MovedModule("CGIHTTPServer", "CGIHTTPServer", "http.server"),
- MovedModule("SimpleHTTPServer", "SimpleHTTPServer", "http.server"),
- MovedModule("cPickle", "cPickle", "pickle"),
- MovedModule("queue", "Queue"),
- MovedModule("reprlib", "repr"),
- MovedModule("socketserver", "SocketServer"),
- MovedModule("tkinter", "Tkinter"),
- MovedModule("tkinter_dialog", "Dialog", "tkinter.dialog"),
- MovedModule("tkinter_filedialog", "FileDialog", "tkinter.filedialog"),
- MovedModule("tkinter_scrolledtext", "ScrolledText", "tkinter.scrolledtext"),
- MovedModule("tkinter_simpledialog", "SimpleDialog", "tkinter.simpledialog"),
- MovedModule("tkinter_tix", "Tix", "tkinter.tix"),
- MovedModule("tkinter_constants", "Tkconstants", "tkinter.constants"),
- MovedModule("tkinter_dnd", "Tkdnd", "tkinter.dnd"),
- MovedModule("tkinter_colorchooser", "tkColorChooser",
- "tkinter.colorchooser"),
- MovedModule("tkinter_commondialog", "tkCommonDialog",
- "tkinter.commondialog"),
- MovedModule("tkinter_tkfiledialog", "tkFileDialog", "tkinter.filedialog"),
- MovedModule("tkinter_font", "tkFont", "tkinter.font"),
- MovedModule("tkinter_messagebox", "tkMessageBox", "tkinter.messagebox"),
- MovedModule("tkinter_tksimpledialog", "tkSimpleDialog",
- "tkinter.simpledialog"),
- MovedModule("urllib_robotparser", "robotparser", "urllib.robotparser"),
- MovedModule("winreg", "_winreg"),
-]
-for attr in _moved_attributes:
- setattr(_MovedItems, attr.name, attr)
-del attr
-
-moves = sys.modules["django.utils.six.moves"] = _MovedItems("moves")
-
-
-def add_move(move):
- """Add an item to six.moves."""
- setattr(_MovedItems, move.name, move)
-
-
-def remove_move(name):
- """Remove item from six.moves."""
- try:
- delattr(_MovedItems, name)
- except AttributeError:
- try:
- del moves.__dict__[name]
- except KeyError:
- raise AttributeError("no such move, %r" % (name,))
-
-
-if PY3:
- _meth_func = "__func__"
- _meth_self = "__self__"
-
- _func_code = "__code__"
- _func_defaults = "__defaults__"
-
- _iterkeys = "keys"
- _itervalues = "values"
- _iteritems = "items"
-else:
- _meth_func = "im_func"
- _meth_self = "im_self"
-
- _func_code = "func_code"
- _func_defaults = "func_defaults"
-
- _iterkeys = "iterkeys"
- _itervalues = "itervalues"
- _iteritems = "iteritems"
-
-
-try:
- advance_iterator = next
-except NameError:
- def advance_iterator(it):
- return it.next()
-next = advance_iterator
-
-
-if PY3:
- def get_unbound_function(unbound):
- return unbound
-
- Iterator = object
-
- def callable(obj):
- return any("__call__" in klass.__dict__ for klass in type(obj).__mro__)
-else:
- def get_unbound_function(unbound):
- return unbound.im_func
-
- class Iterator(object):
-
- def next(self):
- return type(self).__next__(self)
-
- callable = callable
-_add_doc(get_unbound_function,
- """Get the function out of a possibly unbound function""")
-
-
-get_method_function = operator.attrgetter(_meth_func)
-get_method_self = operator.attrgetter(_meth_self)
-get_function_code = operator.attrgetter(_func_code)
-get_function_defaults = operator.attrgetter(_func_defaults)
-
-
-def iterkeys(d):
- """Return an iterator over the keys of a dictionary."""
- return iter(getattr(d, _iterkeys)())
-
-def itervalues(d):
- """Return an iterator over the values of a dictionary."""
- return iter(getattr(d, _itervalues)())
-
-def iteritems(d):
- """Return an iterator over the (key, value) pairs of a dictionary."""
- return iter(getattr(d, _iteritems)())
-
-
-if PY3:
- def b(s):
- return s.encode("latin-1")
- def u(s):
- return s
- if sys.version_info[1] <= 1:
- def int2byte(i):
- return bytes((i,))
- else:
- # This is about 2x faster than the implementation above on 3.2+
- int2byte = operator.methodcaller("to_bytes", 1, "big")
- import io
- StringIO = io.StringIO
- BytesIO = io.BytesIO
-else:
- def b(s):
- return s
- def u(s):
- return unicode(s, "unicode_escape")
- int2byte = chr
- import StringIO
- StringIO = BytesIO = StringIO.StringIO
-_add_doc(b, """Byte literal""")
-_add_doc(u, """Text literal""")
-
-
-if PY3:
- import builtins
- exec_ = getattr(builtins, "exec")
-
-
- def reraise(tp, value, tb=None):
- if value.__traceback__ is not tb:
- raise value.with_traceback(tb)
- raise value
-
-
- print_ = getattr(builtins, "print")
- del builtins
-
-else:
- def exec_(code, globs=None, locs=None):
- """Execute code in a namespace."""
- if globs is None:
- frame = sys._getframe(1)
- globs = frame.f_globals
- if locs is None:
- locs = frame.f_locals
- del frame
- elif locs is None:
- locs = globs
- exec("""exec code in globs, locs""")
-
-
- exec_("""def reraise(tp, value, tb=None):
- raise tp, value, tb
-""")
-
-
- def print_(*args, **kwargs):
- """The new-style print function."""
- fp = kwargs.pop("file", sys.stdout)
- if fp is None:
- return
- def write(data):
- if not isinstance(data, basestring):
- data = str(data)
- fp.write(data)
- want_unicode = False
- sep = kwargs.pop("sep", None)
- if sep is not None:
- if isinstance(sep, unicode):
- want_unicode = True
- elif not isinstance(sep, str):
- raise TypeError("sep must be None or a string")
- end = kwargs.pop("end", None)
- if end is not None:
- if isinstance(end, unicode):
- want_unicode = True
- elif not isinstance(end, str):
- raise TypeError("end must be None or a string")
- if kwargs:
- raise TypeError("invalid keyword arguments to print()")
- if not want_unicode:
- for arg in args:
- if isinstance(arg, unicode):
- want_unicode = True
- break
- if want_unicode:
- newline = unicode("\n")
- space = unicode(" ")
- else:
- newline = "\n"
- space = " "
- if sep is None:
- sep = space
- if end is None:
- end = newline
- for i, arg in enumerate(args):
- if i:
- write(sep)
- write(arg)
- write(end)
-
-_add_doc(reraise, """Reraise an exception.""")
-
-
-def with_metaclass(meta, base=object):
- """Create a base class with a metaclass."""
- return meta("NewBase", (base,), {})
-
-
-### Additional customizations for Django ###
-
-if PY3:
- _iterlists = "lists"
- _assertRaisesRegex = "assertRaisesRegex"
-else:
- _iterlists = "iterlists"
- _assertRaisesRegex = "assertRaisesRegexp"
-
-
-def iterlists(d):
- """Return an iterator over the values of a MultiValueDict."""
- return getattr(d, _iterlists)()
-
-
-def assertRaisesRegex(self, *args, **kwargs):
- return getattr(self, _assertRaisesRegex)(*args, **kwargs)
-
-
-add_move(MovedModule("_dummy_thread", "dummy_thread"))
-add_move(MovedModule("_thread", "thread"))
diff --git a/tox.ini b/tox.ini
index aa97fd35..6ec400dd 100644
--- a/tox.ini
+++ b/tox.ini
@@ -91,6 +91,7 @@ deps = django==1.3.5
django-oauth-plus==2.0
oauth2==1.5.211
django-oauth2-provider==0.2.3
+ six
[testenv:py2.6-django1.3]
basepython = python2.6
@@ -100,3 +101,4 @@ deps = django==1.3.5
django-oauth-plus==2.0
oauth2==1.5.211
django-oauth2-provider==0.2.3
+ six
--
cgit v1.2.3
From f631f55f8ebdf3d4e478aa5ca435ad36e86bee0f Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Wed, 21 Aug 2013 21:35:17 +0100
Subject: Tweak comment
---
rest_framework/compat.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index 178a697f..66be96a6 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -10,7 +10,7 @@ import django
from django.core.exceptions import ImproperlyConfigured
from django.conf import settings
-# Try to import six from Django, fallback to included `six`.
+# Try to import six from Django, fallback to external `six` package.
try:
from django.utils import six
except ImportError:
--
cgit v1.2.3
From bf07b8e616bd92e4ae3c2c09b198181d7075e6bd Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Thu, 29 Aug 2013 08:53:19 +0100
Subject: Better docs for customizing dynamic routes. Refs #908
---
docs/api-guide/routers.md | 81 ++++++++++++++++++++++++++++++++++++++++-------
1 file changed, 70 insertions(+), 11 deletions(-)
diff --git a/docs/api-guide/routers.md b/docs/api-guide/routers.md
index f083b3d4..730fa876 100644
--- a/docs/api-guide/routers.md
+++ b/docs/api-guide/routers.md
@@ -123,28 +123,87 @@ The arguments to the `Route` named tuple are:
**initkwargs**: A dictionary of any additional arguments that should be passed when instantiating the view. Note that the `suffix` argument is reserved for identifying the viewset type, used when generating the view name and breadcrumb links.
+## Customizing dynamic routes
+
+You can also customize how the `@list_route` and `@detail_route` decorators are routed.
+To route either or both of these decorators, include a `DynamicListRoute` and/or `DynamicDetailRoute` named tuple in the `.routes` list.
+
+The arguments to `DynamicListRoute` and `DynamicDetailRoute` are:
+
+**url**: A string representing the URL to be routed. May include the same format strings as `Route`, and additionally accepts the `{methodname}` and `{methodnamehyphen}` format strings.
+
+**name**: The name of the URL as used in `reverse` calls. May include the following format strings: `{basename}`, `{methodname}` and `{methodnamehyphen}`.
+
+**initkwargs**: A dictionary of any additional arguments that should be passed when instantiating the view.
+
## Example
The following example will only route to the `list` and `retrieve` actions, and does not use the trailing slash convention.
- from rest_framework.routers import Route, SimpleRouter
+ from rest_framework.routers import Route, DynamicDetailRoute, SimpleRouter
- class ReadOnlyRouter(SimpleRouter):
+ class CustomReadOnlyRouter(SimpleRouter):
"""
A router for read-only APIs, which doesn't use trailing slashes.
"""
routes = [
- Route(url=r'^{prefix}$',
- mapping={'get': 'list'},
- name='{basename}-list',
- initkwargs={'suffix': 'List'}),
- Route(url=r'^{prefix}/{lookup}$',
- mapping={'get': 'retrieve'},
- name='{basename}-detail',
- initkwargs={'suffix': 'Detail'})
+ Route(
+ url=r'^{prefix}$',
+ mapping={'get': 'list'},
+ name='{basename}-list',
+ initkwargs={'suffix': 'List'}
+ ),
+ Route(
+ url=r'^{prefix}/{lookup}$',
+ mapping={'get': 'retrieve'},
+ name='{basename}-detail',
+ initkwargs={'suffix': 'Detail'}
+ ),
+ DynamicDetailRoute(
+ url=r'^{prefix}/{lookup}/{methodnamehyphen}$',
+ name='{basename}-{methodnamehyphen}',
+ initkwargs={}
+ )
]
-The `SimpleRouter` class provides another example of setting the `.routes` attribute.
+Let's take a look at the routes our `CustomReadOnlyRouter` would generate for a simple viewset.
+
+`views.py`:
+
+ class UserViewSet(viewsets.ReadOnlyModelViewSet):
+ """
+ A viewset that provides the standard actions
+ """
+ queryset = User.objects.all()
+ serializer_class = UserSerializer
+ lookup_field = 'username'
+
+ @detail_route()
+ def group_names(self, request):
+ """
+ Returns a list of all the group names that the given
+ user belongs to.
+ """
+ user = self.get_object()
+ groups = user.groups.all()
+ return Response([group.name for group in groups])
+
+`urls.py`:
+
+ router = CustomReadOnlyRouter()
+ router.register('users', UserViewSet)
+ urlpatterns = router.urls
+
+The following mappings would be generated...
+
+
+
URL
HTTP Method
Action
URL Name
+
/users
GET
list
user-list
+
/users/{username}
GET
retrieve
user-detail
+
/users/{username}/group-names
GET
group_names
user-group-names
+
+
+For another example of setting the `.routes` attribute, see the source code for the `SimpleRouter` class.
## Advanced custom routers
--
cgit v1.2.3
From e441f85109e64345a12e65062fc0e51c5787e67f Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Wed, 25 Sep 2013 10:30:04 +0100
Subject: Drop 1.3 support
---
.travis.yml | 12 +-
rest_framework/authentication.py | 2 +-
rest_framework/compat.py | 383 ++-------------------
rest_framework/fields.py | 4 +-
rest_framework/routers.py | 2 +-
rest_framework/runtests/settings.py | 5 +-
rest_framework/runtests/urls.py | 2 +-
rest_framework/serializers.py | 6 +-
rest_framework/templates/rest_framework/base.html | 1 +
.../templates/rest_framework/login_base.html | 1 +
rest_framework/templatetags/rest_framework.py | 87 +----
rest_framework/tests/test_authentication.py | 2 +-
rest_framework/tests/test_breadcrumbs.py | 2 +-
rest_framework/tests/test_filters.py | 3 +-
rest_framework/tests/test_htmlrenderer.py | 2 +-
.../tests/test_hyperlinkedserializers.py | 4 +-
rest_framework/tests/test_relations_hyperlink.py | 2 +-
rest_framework/tests/test_renderers.py | 3 +-
rest_framework/tests/test_request.py | 2 +-
rest_framework/tests/test_response.py | 2 +-
rest_framework/tests/test_reverse.py | 2 +-
rest_framework/tests/test_routers.py | 2 +-
rest_framework/tests/test_testing.py | 2 +-
rest_framework/tests/test_urlpatterns.py | 2 +-
rest_framework/urlpatterns.py | 2 +-
rest_framework/urls.py | 2 +-
rest_framework/utils/encoders.py | 3 +-
tox.ini | 22 +-
28 files changed, 57 insertions(+), 507 deletions(-)
diff --git a/.travis.yml b/.travis.yml
index 7ebe715a..456f8e9c 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -10,18 +10,15 @@ env:
- DJANGO="https://www.djangoproject.com/download/1.6a1/tarball/"
- DJANGO="django==1.5.1 --use-mirrors"
- DJANGO="django==1.4.5 --use-mirrors"
- - DJANGO="django==1.3.7 --use-mirrors"
install:
- pip install $DJANGO
- - pip install defusedxml==0.3
+ - pip install defusedxml==0.3 --use-mirrors
+ - pip install django-filter==0.6 --use-mirrors
- "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install oauth2==1.5.211 --use-mirrors; fi"
- "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install django-oauth-plus==2.0 --use-mirrors; fi"
- "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install django-oauth2-provider==0.2.4 --use-mirrors; fi"
- "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install django-guardian==1.1.1 --use-mirrors; fi"
- - "if [[ ${DJANGO::11} == 'django==1.3' ]]; then pip install django-filter==0.5.4 --use-mirrors; fi"
- - "if [[ ${DJANGO::11} == 'django==1.3' ]]; then pip install six --use-mirrors; fi"
- - "if [[ ${DJANGO::11} != 'django==1.3' ]]; then pip install django-filter==0.6 --use-mirrors; fi"
- export PYTHONPATH=.
script:
@@ -31,10 +28,5 @@ matrix:
exclude:
- python: "3.2"
env: DJANGO="django==1.4.5 --use-mirrors"
- - python: "3.2"
- env: DJANGO="django==1.3.7 --use-mirrors"
- python: "3.3"
env: DJANGO="django==1.4.5 --use-mirrors"
- - python: "3.3"
- env: DJANGO="django==1.3.7 --use-mirrors"
-
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py
index cf001a24..db5cce40 100644
--- a/rest_framework/authentication.py
+++ b/rest_framework/authentication.py
@@ -6,8 +6,8 @@ import base64
from django.contrib.auth import authenticate
from django.core.exceptions import ImproperlyConfigured
+from django.middleware.csrf import CsrfViewMiddleware
from rest_framework import exceptions, HTTP_HEADER_ENCODING
-from rest_framework.compat import CsrfViewMiddleware
from rest_framework.compat import oauth, oauth_provider, oauth_provider_store
from rest_framework.compat import oauth2_provider, provider_now
from rest_framework.authtoken.models import Token
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index 1238f043..f048b10a 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -5,25 +5,19 @@ versions of django/python, and compatibility wrappers around optional packages.
# flake8: noqa
from __future__ import unicode_literals
-
import django
from django.core.exceptions import ImproperlyConfigured
from django.conf import settings
+
# Try to import six from Django, fallback to external `six` package.
try:
from django.utils import six
except ImportError:
import six
-# location of patterns, url, include changes in 1.4 onwards
-try:
- from django.conf.urls import patterns, url, include
-except ImportError:
- from django.conf.urls.defaults import patterns, url, include
-
-# Handle django.utils.encoding rename:
-# smart_unicode -> smart_text
+# Handle django.utils.encoding rename in 1.5 onwards.
+# smart_unicode -> smart_text
# force_unicode -> force_text
try:
from django.utils.encoding import smart_text
@@ -41,13 +35,15 @@ try:
except ImportError:
from django.http import HttpResponse as HttpResponseBase
+
# django-filter is optional
try:
import django_filters
except ImportError:
django_filters = None
-# guardian is optional
+
+# django-guardian is optional
try:
import guardian
except ImportError:
@@ -80,14 +76,6 @@ except ImportError:
Image = None
-def get_concrete_model(model_cls):
- try:
- return model_cls._meta.concrete_model
- except AttributeError:
- # 1.3 does not include concrete model
- return model_cls
-
-
# Django 1.5 add support for custom auth user model
if django.VERSION >= (1, 5):
AUTH_USER_MODEL = settings.AUTH_USER_MODEL
@@ -95,46 +83,13 @@ else:
AUTH_USER_MODEL = 'auth.User'
+# View._allowed_methods only present from 1.5 onwards
if django.VERSION >= (1, 5):
from django.views.generic import View
else:
- from django.views.generic import View as _View
- from django.utils.decorators import classonlymethod
- from django.utils.functional import update_wrapper
-
- class View(_View):
- # 1.3 does not include head method in base View class
- # See: https://code.djangoproject.com/ticket/15668
- @classonlymethod
- def as_view(cls, **initkwargs):
- """
- Main entry point for a request-response process.
- """
- # sanitize keyword arguments
- for key in initkwargs:
- if key in cls.http_method_names:
- raise TypeError("You tried to pass in the %s method name as a "
- "keyword argument to %s(). Don't do that."
- % (key, cls.__name__))
- if not hasattr(cls, key):
- raise TypeError("%s() received an invalid keyword %r" % (
- cls.__name__, key))
-
- def view(request, *args, **kwargs):
- self = cls(**initkwargs)
- if hasattr(self, 'get') and not hasattr(self, 'head'):
- self.head = self.get
- return self.dispatch(request, *args, **kwargs)
-
- # take name and docstring from class
- update_wrapper(view, cls, updated=())
-
- # and possible attributes set by decorators
- # like csrf_exempt from dispatch
- update_wrapper(view, cls.dispatch, assigned=())
- return view
-
- # _allowed_methods only present from 1.5 onwards
+ from django.views.generic import View as DjangoView
+
+ class View(DjangoView):
def _allowed_methods(self):
return [m.upper() for m in self.http_method_names if hasattr(self, m)]
@@ -144,316 +99,16 @@ if 'patch' not in View.http_method_names:
View.http_method_names = View.http_method_names + ['patch']
-# PUT, DELETE do not require CSRF until 1.4. They should. Make it better.
-if django.VERSION >= (1, 4):
- from django.middleware.csrf import CsrfViewMiddleware
-else:
- import hashlib
- import re
- import random
- import logging
-
- from django.conf import settings
- from django.core.urlresolvers import get_callable
-
- try:
- from logging import NullHandler
- except ImportError:
- class NullHandler(logging.Handler):
- def emit(self, record):
- pass
-
- logger = logging.getLogger('django.request')
-
- if not logger.handlers:
- logger.addHandler(NullHandler())
-
- def same_origin(url1, url2):
- """
- Checks if two URLs are 'same-origin'
- """
- p1, p2 = urlparse.urlparse(url1), urlparse.urlparse(url2)
- return p1[0:2] == p2[0:2]
-
- def constant_time_compare(val1, val2):
- """
- Returns True if the two strings are equal, False otherwise.
-
- The time taken is independent of the number of characters that match.
- """
- if len(val1) != len(val2):
- return False
- result = 0
- for x, y in zip(val1, val2):
- result |= ord(x) ^ ord(y)
- return result == 0
-
- # Use the system (hardware-based) random number generator if it exists.
- if hasattr(random, 'SystemRandom'):
- randrange = random.SystemRandom().randrange
- else:
- randrange = random.randrange
-
- _MAX_CSRF_KEY = 18446744073709551616 # 2 << 63
-
- REASON_NO_REFERER = "Referer checking failed - no Referer."
- REASON_BAD_REFERER = "Referer checking failed - %s does not match %s."
- REASON_NO_CSRF_COOKIE = "CSRF cookie not set."
- REASON_BAD_TOKEN = "CSRF token missing or incorrect."
-
- def _get_failure_view():
- """
- Returns the view to be used for CSRF rejections
- """
- return get_callable(settings.CSRF_FAILURE_VIEW)
-
- def _get_new_csrf_key():
- return hashlib.md5("%s%s" % (randrange(0, _MAX_CSRF_KEY), settings.SECRET_KEY)).hexdigest()
-
- def get_token(request):
- """
- Returns the the CSRF token required for a POST form. The token is an
- alphanumeric value.
-
- A side effect of calling this function is to make the the csrf_protect
- decorator and the CsrfViewMiddleware add a CSRF cookie and a 'Vary: Cookie'
- header to the outgoing response. For this reason, you may need to use this
- function lazily, as is done by the csrf context processor.
- """
- request.META["CSRF_COOKIE_USED"] = True
- return request.META.get("CSRF_COOKIE", None)
-
- def _sanitize_token(token):
- # Allow only alphanum, and ensure we return a 'str' for the sake of the post
- # processing middleware.
- token = re.sub('[^a-zA-Z0-9]', '', str(token.decode('ascii', 'ignore')))
- if token == "":
- # In case the cookie has been truncated to nothing at some point.
- return _get_new_csrf_key()
- else:
- return token
-
- class CsrfViewMiddleware(object):
- """
- Middleware that requires a present and correct csrfmiddlewaretoken
- for POST requests that have a CSRF cookie, and sets an outgoing
- CSRF cookie.
-
- This middleware should be used in conjunction with the csrf_token template
- tag.
- """
- # The _accept and _reject methods currently only exist for the sake of the
- # requires_csrf_token decorator.
- def _accept(self, request):
- # Avoid checking the request twice by adding a custom attribute to
- # request. This will be relevant when both decorator and middleware
- # are used.
- request.csrf_processing_done = True
- return None
-
- def _reject(self, request, reason):
- return _get_failure_view()(request, reason=reason)
-
- def process_view(self, request, callback, callback_args, callback_kwargs):
-
- if getattr(request, 'csrf_processing_done', False):
- return None
-
- try:
- csrf_token = _sanitize_token(request.COOKIES[settings.CSRF_COOKIE_NAME])
- # Use same token next time
- request.META['CSRF_COOKIE'] = csrf_token
- except KeyError:
- csrf_token = None
- # Generate token and store it in the request, so it's available to the view.
- request.META["CSRF_COOKIE"] = _get_new_csrf_key()
-
- # Wait until request.META["CSRF_COOKIE"] has been manipulated before
- # bailing out, so that get_token still works
- if getattr(callback, 'csrf_exempt', False):
- return None
-
- # Assume that anything not defined as 'safe' by RC2616 needs protection.
- if request.method not in ('GET', 'HEAD', 'OPTIONS', 'TRACE'):
- if getattr(request, '_dont_enforce_csrf_checks', False):
- # Mechanism to turn off CSRF checks for test suite. It comes after
- # the creation of CSRF cookies, so that everything else continues to
- # work exactly the same (e.g. cookies are sent etc), but before the
- # any branches that call reject()
- return self._accept(request)
-
- if request.is_secure():
- # Suppose user visits http://example.com/
- # An active network attacker,(man-in-the-middle, MITM) sends a
- # POST form which targets https://example.com/detonate-bomb/ and
- # submits it via javascript.
- #
- # The attacker will need to provide a CSRF cookie and token, but
- # that is no problem for a MITM and the session independent
- # nonce we are using. So the MITM can circumvent the CSRF
- # protection. This is true for any HTTP connection, but anyone
- # using HTTPS expects better! For this reason, for
- # https://example.com/ we need additional protection that treats
- # http://example.com/ as completely untrusted. Under HTTPS,
- # Barth et al. found that the Referer header is missing for
- # same-domain requests in only about 0.2% of cases or less, so
- # we can use strict Referer checking.
- referer = request.META.get('HTTP_REFERER')
- if referer is None:
- logger.warning('Forbidden (%s): %s' % (REASON_NO_REFERER, request.path),
- extra={
- 'status_code': 403,
- 'request': request,
- }
- )
- return self._reject(request, REASON_NO_REFERER)
-
- # Note that request.get_host() includes the port
- good_referer = 'https://%s/' % request.get_host()
- if not same_origin(referer, good_referer):
- reason = REASON_BAD_REFERER % (referer, good_referer)
- logger.warning('Forbidden (%s): %s' % (reason, request.path),
- extra={
- 'status_code': 403,
- 'request': request,
- }
- )
- return self._reject(request, reason)
-
- if csrf_token is None:
- # No CSRF cookie. For POST requests, we insist on a CSRF cookie,
- # and in this way we can avoid all CSRF attacks, including login
- # CSRF.
- logger.warning('Forbidden (%s): %s' % (REASON_NO_CSRF_COOKIE, request.path),
- extra={
- 'status_code': 403,
- 'request': request,
- }
- )
- return self._reject(request, REASON_NO_CSRF_COOKIE)
-
- # check non-cookie token for match
- request_csrf_token = ""
- if request.method == "POST":
- request_csrf_token = request.POST.get('csrfmiddlewaretoken', '')
-
- if request_csrf_token == "":
- # Fall back to X-CSRFToken, to make things easier for AJAX,
- # and possible for PUT/DELETE
- request_csrf_token = request.META.get('HTTP_X_CSRFTOKEN', '')
-
- if not constant_time_compare(request_csrf_token, csrf_token):
- logger.warning('Forbidden (%s): %s' % (REASON_BAD_TOKEN, request.path),
- extra={
- 'status_code': 403,
- 'request': request,
- }
- )
- return self._reject(request, REASON_BAD_TOKEN)
-
- return self._accept(request)
-
-# timezone support is new in Django 1.4
-try:
- from django.utils import timezone
-except ImportError:
- timezone = None
-
-# dateparse is ALSO new in Django 1.4
-try:
- from django.utils.dateparse import parse_date, parse_datetime, parse_time
-except ImportError:
- import datetime
- import re
-
- date_re = re.compile(
- r'(?P\d{4})-(?P\d{1,2})-(?P\d{1,2})$'
- )
-
- datetime_re = re.compile(
- r'(?P\d{4})-(?P\d{1,2})-(?P\d{1,2})'
- r'[T ](?P\d{1,2}):(?P\d{1,2})'
- r'(?::(?P\d{1,2})(?:\.(?P\d{1,6})\d{0,6})?)?'
- r'(?PZ|[+-]\d{1,2}:\d{1,2})?$'
- )
-
- time_re = re.compile(
- r'(?P\d{1,2}):(?P\d{1,2})'
- r'(?::(?P\d{1,2})(?:\.(?P\d{1,6})\d{0,6})?)?'
- )
-
- def parse_date(value):
- match = date_re.match(value)
- if match:
- kw = dict((k, int(v)) for k, v in match.groupdict().iteritems())
- return datetime.date(**kw)
-
- def parse_time(value):
- match = time_re.match(value)
- if match:
- kw = match.groupdict()
- if kw['microsecond']:
- kw['microsecond'] = kw['microsecond'].ljust(6, '0')
- kw = dict((k, int(v)) for k, v in kw.iteritems() if v is not None)
- return datetime.time(**kw)
-
- def parse_datetime(value):
- """Parse datetime, but w/o the timezone awareness in 1.4"""
- match = datetime_re.match(value)
- if match:
- kw = match.groupdict()
- if kw['microsecond']:
- kw['microsecond'] = kw['microsecond'].ljust(6, '0')
- kw = dict((k, int(v)) for k, v in kw.iteritems() if v is not None)
- return datetime.datetime(**kw)
-
-
-# smart_urlquote is new on Django 1.4
-try:
- from django.utils.html import smart_urlquote
-except ImportError:
- import re
- from django.utils.encoding import smart_str
- try:
- from urllib.parse import quote, urlsplit, urlunsplit
- except ImportError: # Python 2
- from urllib import quote
- from urlparse import urlsplit, urlunsplit
-
- unquoted_percents_re = re.compile(r'%(?![0-9A-Fa-f]{2})')
-
- def smart_urlquote(url):
- "Quotes a URL if it isn't already quoted."
- # Handle IDN before quoting.
- scheme, netloc, path, query, fragment = urlsplit(url)
- try:
- netloc = netloc.encode('idna').decode('ascii') # IDN -> ACE
- except UnicodeError: # invalid domain part
- pass
- else:
- url = urlunsplit((scheme, netloc, path, query, fragment))
-
- # An URL is considered unquoted if it contains no % characters or
- # contains a % not followed by two hexadecimal digits. See #9655.
- if '%' not in url or unquoted_percents_re.search(url):
- # See http://bugs.python.org/issue2637
- url = quote(smart_str(url), safe=b'!*\'();:@&=+$,/?#[]~')
-
- return force_text(url)
-
-
-# RequestFactory only provide `generic` from 1.5 onwards
-
+# RequestFactory only provides `generic` from 1.5 onwards
from django.test.client import RequestFactory as DjangoRequestFactory
from django.test.client import FakePayload
try:
# In 1.5 the test client uses force_bytes
from django.utils.encoding import force_bytes_or_smart_bytes
except ImportError:
- # In 1.3 and 1.4 the test client just uses smart_str
+ # In 1.4 the test client just uses smart_str
from django.utils.encoding import smart_str as force_bytes_or_smart_bytes
-
class RequestFactory(DjangoRequestFactory):
def generic(self, method, path,
data='', content_type='application/octet-stream', **extra):
@@ -478,6 +133,7 @@ class RequestFactory(DjangoRequestFactory):
r.update(extra)
return self.request(**r)
+
# Markdown is optional
try:
import markdown
@@ -492,7 +148,6 @@ try:
safe_mode = False
md = markdown.Markdown(extensions=extensions, safe_mode=safe_mode)
return md.convert(text)
-
except ImportError:
apply_markdown = None
@@ -510,14 +165,16 @@ try:
except ImportError:
etree = None
-# OAuth is optional
+
+# OAuth2 is optional
try:
# Note: The `oauth2` package actually provides oauth1.0a support. Urg.
import oauth2 as oauth
except ImportError:
oauth = None
-# OAuth is optional
+
+# OAuthProvider is optional
try:
import oauth_provider
from oauth_provider.store import store as oauth_provider_store
@@ -525,6 +182,7 @@ except (ImportError, ImproperlyConfigured):
oauth_provider = None
oauth_provider_store = None
+
# OAuth 2 support is optional
try:
import provider.oauth2 as oauth2_provider
@@ -542,8 +200,6 @@ try:
# Any other supported version does use timezone aware datetimes
from django.utils.timezone import now as provider_now
except ImportError:
- import traceback
- traceback.print_exc()
oauth2_provider = None
oauth2_provider_models = None
oauth2_provider_forms = None
@@ -551,7 +207,8 @@ except ImportError:
oauth2_constants = None
provider_now = None
-# Handle lazy strings
+
+# Handle lazy strings across Py2/Py3
from django.utils.functional import Promise
if six.PY3:
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index b3a9b0df..f340510d 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -18,12 +18,14 @@ from django.conf import settings
from django.db.models.fields import BLANK_CHOICE_DASH
from django.http import QueryDict
from django.forms import widgets
+from django.utils import timezone
from django.utils.encoding import is_protected_type
from django.utils.translation import ugettext_lazy as _
from django.utils.datastructures import SortedDict
+from django.utils.dateparse import parse_date, parse_datetime, parse_time
from rest_framework import ISO_8601
from rest_framework.compat import (
- timezone, parse_date, parse_datetime, parse_time, BytesIO, six, smart_text,
+ BytesIO, six, smart_text,
force_text, is_non_str_iterable
)
from rest_framework.settings import api_settings
diff --git a/rest_framework/routers.py b/rest_framework/routers.py
index 1c7a8158..790299cc 100644
--- a/rest_framework/routers.py
+++ b/rest_framework/routers.py
@@ -17,9 +17,9 @@ from __future__ import unicode_literals
import itertools
from collections import namedtuple
+from django.conf.urls import patterns, url
from django.core.exceptions import ImproperlyConfigured
from rest_framework import views
-from rest_framework.compat import patterns, url
from rest_framework.response import Response
from rest_framework.reverse import reverse
from rest_framework.urlpatterns import format_suffix_patterns
diff --git a/rest_framework/runtests/settings.py b/rest_framework/runtests/settings.py
index be721658..12aa73e7 100644
--- a/rest_framework/runtests/settings.py
+++ b/rest_framework/runtests/settings.py
@@ -93,10 +93,7 @@ INSTALLED_APPS = (
'django.contrib.sessions',
'django.contrib.sites',
'django.contrib.messages',
- # Uncomment the next line to enable the admin:
- # 'django.contrib.admin',
- # Uncomment the next line to enable admin documentation:
- # 'django.contrib.admindocs',
+ 'django.contrib.staticfiles',
'rest_framework',
'rest_framework.authtoken',
'rest_framework.tests',
diff --git a/rest_framework/runtests/urls.py b/rest_framework/runtests/urls.py
index ed5baeae..dff71011 100644
--- a/rest_framework/runtests/urls.py
+++ b/rest_framework/runtests/urls.py
@@ -1,7 +1,7 @@
"""
Blank URLConf just to keep runtests.py happy.
"""
-from rest_framework.compat import patterns
+from django.conf.urls import patterns
urlpatterns = patterns('',
)
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index f1775762..9e3881a2 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -18,7 +18,7 @@ from decimal import Decimal
from django.db import models
from django.forms import widgets
from django.utils.datastructures import SortedDict
-from rest_framework.compat import get_concrete_model, six
+from rest_framework.compat import six
# Note: We do the following so that users of the framework can use this style:
#
@@ -575,7 +575,7 @@ class ModelSerializer(Serializer):
cls = self.opts.model
assert cls is not None, \
"Serializer class '%s' is missing 'model' Meta option" % self.__class__.__name__
- opts = get_concrete_model(cls)._meta
+ opts = cls._meta.concrete_model._meta
ret = SortedDict()
nested = bool(self.opts.depth)
@@ -784,7 +784,7 @@ class ModelSerializer(Serializer):
Return a list of field names to exclude from model validation.
"""
cls = self.opts.model
- opts = get_concrete_model(cls)._meta
+ opts = cls._meta.concrete_model._meta
exclusions = [field.name for field in opts.fields + opts.many_to_many]
for field_name, field in self.fields.items():
diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html
index 2776d550..47377d51 100644
--- a/rest_framework/templates/rest_framework/base.html
+++ b/rest_framework/templates/rest_framework/base.html
@@ -1,4 +1,5 @@
{% load url from future %}
+{% load staticfiles %}
{% load rest_framework %}
diff --git a/rest_framework/templates/rest_framework/login_base.html b/rest_framework/templates/rest_framework/login_base.html
index be9a0072..be83c2f5 100644
--- a/rest_framework/templates/rest_framework/login_base.html
+++ b/rest_framework/templates/rest_framework/login_base.html
@@ -1,4 +1,5 @@
{% load url from future %}
+{% load staticfiles %}
{% load rest_framework %}
diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py
index e9c1cdd5..55f36149 100644
--- a/rest_framework/templatetags/rest_framework.py
+++ b/rest_framework/templatetags/rest_framework.py
@@ -2,97 +2,14 @@ from __future__ import unicode_literals, absolute_import
from django import template
from django.core.urlresolvers import reverse, NoReverseMatch
from django.http import QueryDict
-from django.utils.html import escape
+from django.utils.html import escape, smart_urlquote
from django.utils.safestring import SafeData, mark_safe
-from rest_framework.compat import urlparse, force_text, six, smart_urlquote
+from rest_framework.compat import urlparse, force_text, six
import re, string
register = template.Library()
-# Note we don't use 'load staticfiles', because we need a 1.3 compatible
-# version, so instead we include the `static` template tag ourselves.
-
-# When 1.3 becomes unsupported by REST framework, we can instead start to
-# use the {% load staticfiles %} tag, remove the following code,
-# and add a dependency that `django.contrib.staticfiles` must be installed.
-
-# Note: We can't put this into the `compat` module because the compat import
-# from rest_framework.compat import ...
-# conflicts with this rest_framework template tag module.
-
-try: # Django 1.5+
- from django.contrib.staticfiles.templatetags.staticfiles import StaticFilesNode
-
- @register.tag('static')
- def do_static(parser, token):
- return StaticFilesNode.handle_token(parser, token)
-
-except ImportError:
- try: # Django 1.4
- from django.contrib.staticfiles.storage import staticfiles_storage
-
- @register.simple_tag
- def static(path):
- """
- A template tag that returns the URL to a file
- using staticfiles' storage backend
- """
- return staticfiles_storage.url(path)
-
- except ImportError: # Django 1.3
- from urlparse import urljoin
- from django import template
- from django.templatetags.static import PrefixNode
-
- class StaticNode(template.Node):
- def __init__(self, varname=None, path=None):
- if path is None:
- raise template.TemplateSyntaxError(
- "Static template nodes must be given a path to return.")
- self.path = path
- self.varname = varname
-
- def url(self, context):
- path = self.path.resolve(context)
- return self.handle_simple(path)
-
- def render(self, context):
- url = self.url(context)
- if self.varname is None:
- return url
- context[self.varname] = url
- return ''
-
- @classmethod
- def handle_simple(cls, path):
- return urljoin(PrefixNode.handle_simple("STATIC_URL"), path)
-
- @classmethod
- def handle_token(cls, parser, token):
- """
- Class method to parse prefix node and return a Node.
- """
- bits = token.split_contents()
-
- if len(bits) < 2:
- raise template.TemplateSyntaxError(
- "'%s' takes at least one argument (path to file)" % bits[0])
-
- path = parser.compile_filter(bits[1])
-
- if len(bits) >= 2 and bits[-2] == 'as':
- varname = bits[3]
- else:
- varname = None
-
- return cls(varname, path)
-
- @register.tag('static')
- def do_static_13(parser, token):
- return StaticNode.handle_token(parser, token)
-
-
def replace_query_param(url, key, val):
"""
Given a URL and a key/val pair, set or replace an item in the query
diff --git a/rest_framework/tests/test_authentication.py b/rest_framework/tests/test_authentication.py
index a44813b6..e9a817c0 100644
--- a/rest_framework/tests/test_authentication.py
+++ b/rest_framework/tests/test_authentication.py
@@ -1,4 +1,5 @@
from __future__ import unicode_literals
+from django.conf.urls import patterns, url, include
from django.contrib.auth.models import User
from django.http import HttpResponse
from django.test import TestCase
@@ -18,7 +19,6 @@ from rest_framework.authentication import (
OAuth2Authentication
)
from rest_framework.authtoken.models import Token
-from rest_framework.compat import patterns, url, include
from rest_framework.compat import oauth2_provider, oauth2_provider_models, oauth2_provider_scope
from rest_framework.compat import oauth, oauth_provider
from rest_framework.test import APIRequestFactory, APIClient
diff --git a/rest_framework/tests/test_breadcrumbs.py b/rest_framework/tests/test_breadcrumbs.py
index 41ddf2ce..33740cbb 100644
--- a/rest_framework/tests/test_breadcrumbs.py
+++ b/rest_framework/tests/test_breadcrumbs.py
@@ -1,6 +1,6 @@
from __future__ import unicode_literals
+from django.conf.urls import patterns, url
from django.test import TestCase
-from rest_framework.compat import patterns, url
from rest_framework.utils.breadcrumbs import get_breadcrumbs
from rest_framework.views import APIView
diff --git a/rest_framework/tests/test_filters.py b/rest_framework/tests/test_filters.py
index 379db29d..9697c5ee 100644
--- a/rest_framework/tests/test_filters.py
+++ b/rest_framework/tests/test_filters.py
@@ -1,12 +1,13 @@
from __future__ import unicode_literals
import datetime
from decimal import Decimal
+from django.conf.urls import patterns, url
from django.db import models
from django.core.urlresolvers import reverse
from django.test import TestCase
from django.utils import unittest
from rest_framework import generics, serializers, status, filters
-from rest_framework.compat import django_filters, patterns, url
+from rest_framework.compat import django_filters
from rest_framework.test import APIRequestFactory
from rest_framework.tests.models import BasicModel
diff --git a/rest_framework/tests/test_htmlrenderer.py b/rest_framework/tests/test_htmlrenderer.py
index 8957a43c..6c570dfd 100644
--- a/rest_framework/tests/test_htmlrenderer.py
+++ b/rest_framework/tests/test_htmlrenderer.py
@@ -1,11 +1,11 @@
from __future__ import unicode_literals
from django.core.exceptions import PermissionDenied
+from django.conf.urls import patterns, url
from django.http import Http404
from django.test import TestCase
from django.template import TemplateDoesNotExist, Template
import django.template.loader
from rest_framework import status
-from rest_framework.compat import patterns, url
from rest_framework.decorators import api_view, renderer_classes
from rest_framework.renderers import TemplateHTMLRenderer
from rest_framework.response import Response
diff --git a/rest_framework/tests/test_hyperlinkedserializers.py b/rest_framework/tests/test_hyperlinkedserializers.py
index 61e613d7..ea7f70f2 100644
--- a/rest_framework/tests/test_hyperlinkedserializers.py
+++ b/rest_framework/tests/test_hyperlinkedserializers.py
@@ -1,8 +1,8 @@
from __future__ import unicode_literals
import json
+from django.conf.urls import patterns, url
from django.test import TestCase
from rest_framework import generics, status, serializers
-from rest_framework.compat import patterns, url
from rest_framework.test import APIRequestFactory
from rest_framework.tests.models import (
Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment,
@@ -24,7 +24,7 @@ class BlogPostCommentSerializer(serializers.ModelSerializer):
class PhotoSerializer(serializers.Serializer):
description = serializers.CharField()
- album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', queryset=Album.objects.all(), lookup_field='title', slug_url_kwarg='title')
+ album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', queryset=Album.objects.all(), lookup_field='title')
def restore_object(self, attrs, instance=None):
return Photo(**attrs)
diff --git a/rest_framework/tests/test_relations_hyperlink.py b/rest_framework/tests/test_relations_hyperlink.py
index 3c4d39af..fa6b01ac 100644
--- a/rest_framework/tests/test_relations_hyperlink.py
+++ b/rest_framework/tests/test_relations_hyperlink.py
@@ -1,7 +1,7 @@
from __future__ import unicode_literals
+from django.conf.urls import patterns, url
from django.test import TestCase
from rest_framework import serializers
-from rest_framework.compat import patterns, url
from rest_framework.test import APIRequestFactory
from rest_framework.tests.models import (
BlogPost,
diff --git a/rest_framework/tests/test_renderers.py b/rest_framework/tests/test_renderers.py
index df6f4aa6..9d1dd77e 100644
--- a/rest_framework/tests/test_renderers.py
+++ b/rest_framework/tests/test_renderers.py
@@ -2,12 +2,13 @@
from __future__ import unicode_literals
from decimal import Decimal
+from django.conf.urls import patterns, url, include
from django.core.cache import cache
from django.test import TestCase
from django.utils import unittest
from django.utils.translation import ugettext_lazy as _
from rest_framework import status, permissions
-from rest_framework.compat import yaml, etree, patterns, url, include, six, StringIO
+from rest_framework.compat import yaml, etree, six, StringIO
from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \
diff --git a/rest_framework/tests/test_request.py b/rest_framework/tests/test_request.py
index 969d8024..d6363425 100644
--- a/rest_framework/tests/test_request.py
+++ b/rest_framework/tests/test_request.py
@@ -2,13 +2,13 @@
Tests for content parsing, and form-overloaded content parsing.
"""
from __future__ import unicode_literals
+from django.conf.urls import patterns
from django.contrib.auth.models import User
from django.contrib.auth import authenticate, login, logout
from django.contrib.sessions.middleware import SessionMiddleware
from django.test import TestCase
from rest_framework import status
from rest_framework.authentication import SessionAuthentication
-from rest_framework.compat import patterns
from rest_framework.parsers import (
BaseParser,
FormParser,
diff --git a/rest_framework/tests/test_response.py b/rest_framework/tests/test_response.py
index eea3c641..1c4c551c 100644
--- a/rest_framework/tests/test_response.py
+++ b/rest_framework/tests/test_response.py
@@ -1,7 +1,7 @@
from __future__ import unicode_literals
+from django.conf.urls import patterns, url, include
from django.test import TestCase
from rest_framework.tests.models import BasicModel, BasicModelSerializer
-from rest_framework.compat import patterns, url, include
from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework import generics
diff --git a/rest_framework/tests/test_reverse.py b/rest_framework/tests/test_reverse.py
index 690a30b1..320b125d 100644
--- a/rest_framework/tests/test_reverse.py
+++ b/rest_framework/tests/test_reverse.py
@@ -1,6 +1,6 @@
from __future__ import unicode_literals
+from django.conf.urls import patterns, url
from django.test import TestCase
-from rest_framework.compat import patterns, url
from rest_framework.reverse import reverse
from rest_framework.test import APIRequestFactory
diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py
index 3f456fef..1c34648f 100644
--- a/rest_framework/tests/test_routers.py
+++ b/rest_framework/tests/test_routers.py
@@ -1,9 +1,9 @@
from __future__ import unicode_literals
+from django.conf.urls import patterns, url, include
from django.db import models
from django.test import TestCase
from django.core.exceptions import ImproperlyConfigured
from rest_framework import serializers, viewsets, permissions
-from rest_framework.compat import include, patterns, url
from rest_framework.decorators import detail_route, list_route
from rest_framework.response import Response
from rest_framework.routers import SimpleRouter, DefaultRouter
diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py
index 48b8956b..c08dd493 100644
--- a/rest_framework/tests/test_testing.py
+++ b/rest_framework/tests/test_testing.py
@@ -1,9 +1,9 @@
# -- coding: utf-8 --
from __future__ import unicode_literals
+from django.conf.urls import patterns, url
from django.contrib.auth.models import User
from django.test import TestCase
-from rest_framework.compat import patterns, url
from rest_framework.decorators import api_view
from rest_framework.response import Response
from rest_framework.test import APIClient, APIRequestFactory, force_authenticate
diff --git a/rest_framework/tests/test_urlpatterns.py b/rest_framework/tests/test_urlpatterns.py
index 8132ec4c..e0060e69 100644
--- a/rest_framework/tests/test_urlpatterns.py
+++ b/rest_framework/tests/test_urlpatterns.py
@@ -1,9 +1,9 @@
from __future__ import unicode_literals
from collections import namedtuple
+from django.conf.urls import patterns, url, include
from django.core import urlresolvers
from django.test import TestCase
from rest_framework.test import APIRequestFactory
-from rest_framework.compat import patterns, url, include
from rest_framework.urlpatterns import format_suffix_patterns
diff --git a/rest_framework/urlpatterns.py b/rest_framework/urlpatterns.py
index d9143bb4..a62530c7 100644
--- a/rest_framework/urlpatterns.py
+++ b/rest_framework/urlpatterns.py
@@ -1,6 +1,6 @@
from __future__ import unicode_literals
+from django.conf.urls import url, include
from django.core.urlresolvers import RegexURLResolver
-from rest_framework.compat import url, include
from rest_framework.settings import api_settings
diff --git a/rest_framework/urls.py b/rest_framework/urls.py
index 9c4719f1..87ec0f0a 100644
--- a/rest_framework/urls.py
+++ b/rest_framework/urls.py
@@ -13,7 +13,7 @@ your authentication settings include `SessionAuthentication`.
)
"""
from __future__ import unicode_literals
-from rest_framework.compat import patterns, url
+from django.conf.urls import patterns, url
template_name = {'template_name': 'rest_framework/login.html'}
diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py
index 7efd5417..13a85550 100644
--- a/rest_framework/utils/encoders.py
+++ b/rest_framework/utils/encoders.py
@@ -2,9 +2,10 @@
Helper classes for parsers.
"""
from __future__ import unicode_literals
+from django.utils import timezone
from django.utils.datastructures import SortedDict
from django.utils.functional import Promise
-from rest_framework.compat import timezone, force_text
+from rest_framework.compat import force_text
from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata
import datetime
import decimal
diff --git a/tox.ini b/tox.ini
index 6e3b8e0a..7bd140e1 100644
--- a/tox.ini
+++ b/tox.ini
@@ -1,6 +1,6 @@
[tox]
downloadcache = {toxworkdir}/cache/
-envlist = py3.3-django1.6,py3.2-django1.6,py2.7-django1.6,py2.6-django1.6,py3.3-django1.5,py3.2-django1.5,py2.7-django1.5,py2.6-django1.5,py2.7-django1.4,py2.6-django1.4,py2.7-django1.3,py2.6-django1.3
+envlist = py3.3-django1.6,py3.2-django1.6,py2.7-django1.6,py2.6-django1.6,py3.3-django1.5,py3.2-django1.5,py2.7-django1.5,py2.6-django1.5,py2.7-django1.4,py2.6-django1.4
[testenv]
commands = {envpython} rest_framework/runtests/runtests.py
@@ -88,23 +88,3 @@ deps = django==1.4.3
oauth2==1.5.211
django-oauth2-provider==0.2.3
django-guardian==1.1.1
-
-[testenv:py2.7-django1.3]
-basepython = python2.7
-deps = django==1.3.5
- django-filter==0.5.4
- defusedxml==0.3
- django-oauth-plus==2.0
- oauth2==1.5.211
- django-oauth2-provider==0.2.3
- django-guardian==1.1.1
-
-[testenv:py2.6-django1.3]
-basepython = python2.6
-deps = django==1.3.5
- django-filter==0.5.4
- defusedxml==0.3
- django-oauth-plus==2.0
- oauth2==1.5.211
- django-oauth2-provider==0.2.3
- django-guardian==1.1.1
--
cgit v1.2.3
From 1bd8fe415296739521fd2e75c0b604cbf3dd3a83 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Wed, 25 Sep 2013 10:36:08 +0100
Subject: Whitespace fix
---
rest_framework/compat.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index f048b10a..efd2581f 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -17,7 +17,7 @@ except ImportError:
import six
# Handle django.utils.encoding rename in 1.5 onwards.
-# smart_unicode -> smart_text
+# smart_unicode -> smart_text
# force_unicode -> force_text
try:
from django.utils.encoding import smart_text
--
cgit v1.2.3
From ab4be47379ba49092063f843fd446919534db776 Mon Sep 17 00:00:00 2001
From: Omer Katz
Date: Thu, 3 Oct 2013 17:34:34 +0200
Subject: Fixed code example.
---
docs/api-guide/routers.md | 11 +++++++----
1 file changed, 7 insertions(+), 4 deletions(-)
diff --git a/docs/api-guide/routers.md b/docs/api-guide/routers.md
index 730fa876..f20a695b 100644
--- a/docs/api-guide/routers.md
+++ b/docs/api-guide/routers.md
@@ -42,12 +42,15 @@ The example above would generate the following URL patterns:
Any methods on the viewset decorated with `@detail_route` or `@list_route` will also be routed.
For example, given a method like this on the `UserViewSet` class:
- from myapp.permissions import IsAdminOrIsSelf
+ from myapp.permissions import IsAdminOrIsSelf
from rest_framework.decorators import detail_route
-
- @detail_route(methods=['post'], permission_classes=[IsAdminOrIsSelf])
- def set_password(self, request, pk=None):
+
+ class UserViewSet(ModelViewSet):
...
+
+ @detail_route(methods=['post'], permission_classes=[IsAdminOrIsSelf])
+ def set_password(self, request, pk=None):
+ ...
The following URL pattern would additionally be generated:
--
cgit v1.2.3
From 9ab0759e38492d9950d66299ed5c58155d39e696 Mon Sep 17 00:00:00 2001
From: kahnjw
Date: Fri, 6 Dec 2013 14:21:33 -0800
Subject: Add tests to pass for get_ident method in BaseThrottle class.
---
rest_framework/tests/test_throttling.py | 65 +++++++++++++++++++++++++++++++++
1 file changed, 65 insertions(+)
diff --git a/rest_framework/tests/test_throttling.py b/rest_framework/tests/test_throttling.py
index 41bff692..03127696 100644
--- a/rest_framework/tests/test_throttling.py
+++ b/rest_framework/tests/test_throttling.py
@@ -5,6 +5,7 @@ from __future__ import unicode_literals
from django.test import TestCase
from django.contrib.auth.models import User
from django.core.cache import cache
+from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory
from rest_framework.views import APIView
from rest_framework.throttling import BaseThrottle, UserRateThrottle, ScopedRateThrottle
@@ -275,3 +276,67 @@ class ScopedRateThrottleTests(TestCase):
self.increment_timer()
response = self.unscoped_view(request)
self.assertEqual(200, response.status_code)
+
+class XffTestingBase(TestCase):
+ def setUp(self):
+
+ class Throttle(ScopedRateThrottle):
+ THROTTLE_RATES = {'test_limit': '1/day'}
+ TIMER_SECONDS = 0
+ timer = lambda self: self.TIMER_SECONDS
+
+ class View(APIView):
+ throttle_classes = (Throttle,)
+ throttle_scope = 'test_limit'
+
+ def get(self, request):
+ return Response('test_limit')
+
+ cache.clear()
+ self.throttle = Throttle()
+ self.view = View.as_view()
+ self.request = APIRequestFactory().get('/some_uri')
+ self.request.META['REMOTE_ADDR'] = '3.3.3.3'
+ self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 1.1.1.1, 2.2.2.2'
+
+ def config_proxy(self, num_proxies):
+ setattr(api_settings, 'NUM_PROXIES', num_proxies)
+
+
+class IdWithXffBasicTests(XffTestingBase):
+ def test_accepts_request_under_limit(self):
+ self.config_proxy(0)
+ self.assertEqual(200, self.view(self.request).status_code)
+
+ def test_denies_request_over_limit(self):
+ self.config_proxy(0)
+ self.view(self.request)
+ self.assertEqual(429, self.view(self.request).status_code)
+
+
+class XffSpoofingTests(XffTestingBase):
+ def test_xff_spoofing_doesnt_change_machine_id_with_one_app_proxy(self):
+ self.config_proxy(1)
+ self.view(self.request)
+ self.request.META['HTTP_X_FORWARDED_FOR'] = '4.4.4.4, 5.5.5.5, 2.2.2.2'
+ self.assertEqual(429, self.view(self.request).status_code)
+
+ def test_xff_spoofing_doesnt_change_machine_id_with_two_app_proxies(self):
+ self.config_proxy(2)
+ self.view(self.request)
+ self.request.META['HTTP_X_FORWARDED_FOR'] = '4.4.4.4, 1.1.1.1, 2.2.2.2'
+ self.assertEqual(429, self.view(self.request).status_code)
+
+
+class XffUniqueMachinesTest(XffTestingBase):
+ def test_unique_clients_are_counted_independently_with_one_proxy(self):
+ self.config_proxy(1)
+ self.view(self.request)
+ self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 1.1.1.1, 7.7.7.7'
+ self.assertEqual(200, self.view(self.request).status_code)
+
+ def test_unique_clients_are_counted_independently_with_two_proxies(self):
+ self.config_proxy(2)
+ self.view(self.request)
+ self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 7.7.7.7, 2.2.2.2'
+ self.assertEqual(200, self.view(self.request).status_code)
--
cgit v1.2.3
From 89f26c5e040febd27bc9142b0096ca119bb3fa32 Mon Sep 17 00:00:00 2001
From: kahnjw
Date: Fri, 6 Dec 2013 14:21:52 -0800
Subject: Add get_ident method to pass new tests.
---
rest_framework/settings.py | 1 +
rest_framework/throttling.py | 25 ++++++++++++++++++-------
2 files changed, 19 insertions(+), 7 deletions(-)
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index 8abaf140..383de72e 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -63,6 +63,7 @@ DEFAULTS = {
'user': None,
'anon': None,
},
+ 'NUM_PROXIES': None,
# Pagination
'PAGINATE_BY': None,
diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py
index a946d837..60e46d47 100644
--- a/rest_framework/throttling.py
+++ b/rest_framework/throttling.py
@@ -18,6 +18,21 @@ class BaseThrottle(object):
"""
raise NotImplementedError('.allow_request() must be overridden')
+ def get_ident(self, request):
+ """
+ Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR
+ if present and number of proxies is > 0. If not use all of
+ HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR.
+ """
+ xff = request.META.get('HTTP_X_FORWARDED_FOR')
+ remote_addr = request.META.get('REMOTE_ADDR')
+ num_proxies = api_settings.NUM_PROXIES
+
+ if xff and num_proxies:
+ return xff.split(',')[-min(num_proxies, len(xff))].strip()
+
+ return xff if xff else remote_addr
+
def wait(self):
"""
Optionally, return a recommended number of seconds to wait before
@@ -152,13 +167,9 @@ class AnonRateThrottle(SimpleRateThrottle):
if request.user.is_authenticated():
return None # Only throttle unauthenticated requests.
- ident = request.META.get('HTTP_X_FORWARDED_FOR')
- if ident is None:
- ident = request.META.get('REMOTE_ADDR')
-
return self.cache_format % {
'scope': self.scope,
- 'ident': ident
+ 'ident': self.get_ident(request)
}
@@ -176,7 +187,7 @@ class UserRateThrottle(SimpleRateThrottle):
if request.user.is_authenticated():
ident = request.user.id
else:
- ident = request.META.get('REMOTE_ADDR', None)
+ ident = self.get_ident(request)
return self.cache_format % {
'scope': self.scope,
@@ -224,7 +235,7 @@ class ScopedRateThrottle(SimpleRateThrottle):
if request.user.is_authenticated():
ident = request.user.id
else:
- ident = request.META.get('REMOTE_ADDR', None)
+ ident = self.get_ident(request)
return self.cache_format % {
'scope': self.scope,
--
cgit v1.2.3
From 100a933279e3119e2627d744cd7eb472b542f6fe Mon Sep 17 00:00:00 2001
From: kahnjw
Date: Fri, 6 Dec 2013 14:22:08 -0800
Subject: Add documentation to explain what effect these changes have.
---
docs/api-guide/throttling.md | 7 ++++++-
1 file changed, 6 insertions(+), 1 deletion(-)
diff --git a/docs/api-guide/throttling.md b/docs/api-guide/throttling.md
index cc469217..ee57383c 100644
--- a/docs/api-guide/throttling.md
+++ b/docs/api-guide/throttling.md
@@ -35,11 +35,16 @@ The default throttling policy may be set globally, using the `DEFAULT_THROTTLE_C
'DEFAULT_THROTTLE_RATES': {
'anon': '100/day',
'user': '1000/day'
- }
+ },
+ 'NUM_PROXIES': 2,
}
The rate descriptions used in `DEFAULT_THROTTLE_RATES` may include `second`, `minute`, `hour` or `day` as the throttle period.
+By default Django REST Framework will try to use the `HTTP_X_FORWARDED_FOR` header to uniquely identify client machines for throttling. If HTTP_X_FORWARDED_FOR is not present `REMOTE_ADDR` header value will be used.
+
+To help Django REST Framework identify unique clients the number of application proxies can be set using `NUM_PROXIES`. This setting will allow the throttle to correctly identify unique requests whenthere are multiple application side proxies in front of the server. `NUM_PROXIES` should be set to an integer. It is important to understand that if you configure `NUM_PROXIES > 0` all clients behind a unique [NAT'd](http://en.wikipedia.org/wiki/Network_address_translation) gateway will be treated as a single client.
+
You can also set the throttling policy on a per-view or per-viewset basis,
using the `APIView` class based views.
--
cgit v1.2.3
From 196c5952e4f610054e832aef36cb2383b8c129c0 Mon Sep 17 00:00:00 2001
From: kahnjw
Date: Fri, 6 Dec 2013 14:24:16 -0800
Subject: Fix typo
---
docs/api-guide/throttling.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/api-guide/throttling.md b/docs/api-guide/throttling.md
index ee57383c..69b15a82 100644
--- a/docs/api-guide/throttling.md
+++ b/docs/api-guide/throttling.md
@@ -43,7 +43,7 @@ The rate descriptions used in `DEFAULT_THROTTLE_RATES` may include `second`, `mi
By default Django REST Framework will try to use the `HTTP_X_FORWARDED_FOR` header to uniquely identify client machines for throttling. If HTTP_X_FORWARDED_FOR is not present `REMOTE_ADDR` header value will be used.
-To help Django REST Framework identify unique clients the number of application proxies can be set using `NUM_PROXIES`. This setting will allow the throttle to correctly identify unique requests whenthere are multiple application side proxies in front of the server. `NUM_PROXIES` should be set to an integer. It is important to understand that if you configure `NUM_PROXIES > 0` all clients behind a unique [NAT'd](http://en.wikipedia.org/wiki/Network_address_translation) gateway will be treated as a single client.
+To help Django REST Framework identify unique clients the number of application proxies can be set using `NUM_PROXIES`. This setting will allow the throttle to correctly identify unique requests when there are multiple application side proxies in front of the server. `NUM_PROXIES` should be set to an integer. It is important to understand that if you configure `NUM_PROXIES > 0` all clients behind a unique [NAT'd](http://en.wikipedia.org/wiki/Network_address_translation) gateway will be treated as a single client.
You can also set the throttling policy on a per-view or per-viewset basis,
using the `APIView` class based views.
--
cgit v1.2.3
From 887da7f6c5a9e7b5007f5e4af32a6b93b18c70ea Mon Sep 17 00:00:00 2001
From: kahnjw
Date: Fri, 6 Dec 2013 14:30:33 -0800
Subject: Add missing tick marks
---
docs/api-guide/throttling.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/api-guide/throttling.md b/docs/api-guide/throttling.md
index 69b15a82..34418e84 100644
--- a/docs/api-guide/throttling.md
+++ b/docs/api-guide/throttling.md
@@ -41,7 +41,7 @@ The default throttling policy may be set globally, using the `DEFAULT_THROTTLE_C
The rate descriptions used in `DEFAULT_THROTTLE_RATES` may include `second`, `minute`, `hour` or `day` as the throttle period.
-By default Django REST Framework will try to use the `HTTP_X_FORWARDED_FOR` header to uniquely identify client machines for throttling. If HTTP_X_FORWARDED_FOR is not present `REMOTE_ADDR` header value will be used.
+By default Django REST Framework will try to use the `HTTP_X_FORWARDED_FOR` header to uniquely identify client machines for throttling. If `HTTP_X_FORWARDED_FOR` is not present `REMOTE_ADDR` header value will be used.
To help Django REST Framework identify unique clients the number of application proxies can be set using `NUM_PROXIES`. This setting will allow the throttle to correctly identify unique requests when there are multiple application side proxies in front of the server. `NUM_PROXIES` should be set to an integer. It is important to understand that if you configure `NUM_PROXIES > 0` all clients behind a unique [NAT'd](http://en.wikipedia.org/wiki/Network_address_translation) gateway will be treated as a single client.
--
cgit v1.2.3
From 23db6c98495d7b3c18a3069c6cb770d5cbc18ee1 Mon Sep 17 00:00:00 2001
From: kahnjw
Date: Fri, 6 Dec 2013 14:52:39 -0800
Subject: PEP8 Compliance
---
rest_framework/tests/test_throttling.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/rest_framework/tests/test_throttling.py b/rest_framework/tests/test_throttling.py
index 03127696..8c5eefe9 100644
--- a/rest_framework/tests/test_throttling.py
+++ b/rest_framework/tests/test_throttling.py
@@ -277,6 +277,7 @@ class ScopedRateThrottleTests(TestCase):
response = self.unscoped_view(request)
self.assertEqual(200, response.status_code)
+
class XffTestingBase(TestCase):
def setUp(self):
--
cgit v1.2.3
From 83da4949c099fcf7e7636c98b9052b502e1bf74b Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Fri, 13 Dec 2013 00:02:18 +0000
Subject: Allow NUM_PROXIES=0 and include more docs
---
docs/api-guide/settings.md | 6 ++++++
docs/api-guide/throttling.md | 18 ++++++++++++------
rest_framework/throttling.py | 8 ++++++--
3 files changed, 24 insertions(+), 8 deletions(-)
diff --git a/docs/api-guide/settings.md b/docs/api-guide/settings.md
index 13f96f9a..d8c878ff 100644
--- a/docs/api-guide/settings.md
+++ b/docs/api-guide/settings.md
@@ -359,5 +359,11 @@ The name of a parameter in the URL conf that may be used to provide a format suf
Default: `'format'`
+#### NUM_PROXIES
+
+An integer of 0 or more, that may be used to specify the number of application proxies that the API runs behind. This allows throttling to more accurately identify client IP addresses. If set to `None` then less strict IP matching will be used by the throttle classes.
+
+Default: `None`
+
[cite]: http://www.python.org/dev/peps/pep-0020/
[strftime]: http://docs.python.org/2/library/time.html#time.strftime
diff --git a/docs/api-guide/throttling.md b/docs/api-guide/throttling.md
index 34418e84..b2a5bb19 100644
--- a/docs/api-guide/throttling.md
+++ b/docs/api-guide/throttling.md
@@ -35,16 +35,11 @@ The default throttling policy may be set globally, using the `DEFAULT_THROTTLE_C
'DEFAULT_THROTTLE_RATES': {
'anon': '100/day',
'user': '1000/day'
- },
- 'NUM_PROXIES': 2,
+ }
}
The rate descriptions used in `DEFAULT_THROTTLE_RATES` may include `second`, `minute`, `hour` or `day` as the throttle period.
-By default Django REST Framework will try to use the `HTTP_X_FORWARDED_FOR` header to uniquely identify client machines for throttling. If `HTTP_X_FORWARDED_FOR` is not present `REMOTE_ADDR` header value will be used.
-
-To help Django REST Framework identify unique clients the number of application proxies can be set using `NUM_PROXIES`. This setting will allow the throttle to correctly identify unique requests when there are multiple application side proxies in front of the server. `NUM_PROXIES` should be set to an integer. It is important to understand that if you configure `NUM_PROXIES > 0` all clients behind a unique [NAT'd](http://en.wikipedia.org/wiki/Network_address_translation) gateway will be treated as a single client.
-
You can also set the throttling policy on a per-view or per-viewset basis,
using the `APIView` class based views.
@@ -71,6 +66,16 @@ Or, if you're using the `@api_view` decorator with function based views.
}
return Response(content)
+## How clients are identified
+
+By default the `X-Forwarded-For` HTTP header is used to uniquely identify client machines for throttling. If the `X-Forwarded-For` header is not present, then the value of the `Remote-Addr` header will be used.
+
+If you need to more strictly identify unique clients, you'll need to configure the number of application proxies that the API runs behind by setting the `NUM_PROXIES` setting. This setting should be an integer of 0 or more, and will allow the throttle to identify the client IP as being the last IP address in the `X-Forwarded-For` header, once any application proxy IP addresses have first been excluded.
+
+It is important to understand that if you configure the `NUM_PROXIES` setting, then all clients behind a unique [NAT'd](http://en.wikipedia.org/wiki/Network_address_translation) gateway will be treated as a single client.
+
+Further context on how the `X-Forwarded-For` header works, and identifier a remote client IP can be [found here][identifing-clients].
+
## Setting up the cache
The throttle classes provided by REST framework use Django's cache backend. You should make sure that you've set appropriate [cache settings][cache-setting]. The default value of `LocMemCache` backend should be okay for simple setups. See Django's [cache documentation][cache-docs] for more details.
@@ -183,5 +188,6 @@ The following is an example of a rate throttle, that will randomly throttle 1 in
[cite]: https://dev.twitter.com/docs/error-codes-responses
[permissions]: permissions.md
+[identifing-clients]: http://oxpedia.org/wiki/index.php?title=AppSuite:Grizzly#Multiple_Proxies_in_front_of_the_cluster
[cache-setting]: https://docs.djangoproject.com/en/dev/ref/settings/#caches
[cache-docs]: https://docs.djangoproject.com/en/dev/topics/cache/#setting-up-the-cache
diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py
index 60e46d47..c40f3065 100644
--- a/rest_framework/throttling.py
+++ b/rest_framework/throttling.py
@@ -28,8 +28,12 @@ class BaseThrottle(object):
remote_addr = request.META.get('REMOTE_ADDR')
num_proxies = api_settings.NUM_PROXIES
- if xff and num_proxies:
- return xff.split(',')[-min(num_proxies, len(xff))].strip()
+ if num_proxies is not None:
+ if num_proxies == 0 or xff is None:
+ return remote_addr
+ addrs = xff.split(',')
+ client_addr = addrs[-min(num_proxies, len(xff))]
+ return client_addr.strip()
return xff if xff else remote_addr
--
cgit v1.2.3
From ed931b90ae9e72f963673e6e188b1802a5a65360 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Fri, 13 Dec 2013 00:11:59 +0000
Subject: Further docs tweaks
---
docs/api-guide/throttling.md | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/docs/api-guide/throttling.md b/docs/api-guide/throttling.md
index b2a5bb19..536f0ab7 100644
--- a/docs/api-guide/throttling.md
+++ b/docs/api-guide/throttling.md
@@ -68,13 +68,13 @@ Or, if you're using the `@api_view` decorator with function based views.
## How clients are identified
-By default the `X-Forwarded-For` HTTP header is used to uniquely identify client machines for throttling. If the `X-Forwarded-For` header is not present, then the value of the `Remote-Addr` header will be used.
+The `X-Forwarded-For` and `Remote-Addr` HTTP headers are used to uniquely identify client IP addresses for throttling. If the `X-Forwarded-For` header is present then it will be used, otherwise the value of the `Remote-Addr` header will be used.
-If you need to more strictly identify unique clients, you'll need to configure the number of application proxies that the API runs behind by setting the `NUM_PROXIES` setting. This setting should be an integer of 0 or more, and will allow the throttle to identify the client IP as being the last IP address in the `X-Forwarded-For` header, once any application proxy IP addresses have first been excluded.
+If you need to strictly identify unique client IP addresses, you'll need to first configure the number of application proxies that the API runs behind by setting the `NUM_PROXIES` setting. This setting should be an integer of zero or more. If set to non-zero then the client IP will be identified as being the last IP address in the `X-Forwarded-For` header, once any application proxy IP addresses have first been excluded. If set to zero, then the `Remote-Addr` header will always be used as the identifying IP address.
It is important to understand that if you configure the `NUM_PROXIES` setting, then all clients behind a unique [NAT'd](http://en.wikipedia.org/wiki/Network_address_translation) gateway will be treated as a single client.
-Further context on how the `X-Forwarded-For` header works, and identifier a remote client IP can be [found here][identifing-clients].
+Further context on how the `X-Forwarded-For` header works, and identifing a remote client IP can be [found here][identifing-clients].
## Setting up the cache
--
cgit v1.2.3
From a1d7aa8f712b659f9d8302a2d2a098d2538e6c89 Mon Sep 17 00:00:00 2001
From: Paul Melnikow
Date: Thu, 2 Jan 2014 17:44:47 -0500
Subject: Allow viewset to specify lookup value regex for routing
This patch allows a viewset to define a pattern for its lookup field, which the router will honor. Without this patch, any characters are allowed in the lookup field, and overriding this behavior requires subclassing router and copying and pasting the implementation of get_lookup_regex.
It's possible it would be better to remove this functionality from the routers and simply expose a parameter to get_lookup_regex which allows overriding the lookup_regex. That way the viewset config logic could be in the a subclass, which could invoke the super method directly.
I'm using this now for PostgreSQL UUID fields using https://github.com/dcramer/django-uuidfield . Without this patch, that field passes the lookup string to the database driver, which raises a DataError to complain about the invalid UUID. It's possible the field ought to signal this error in a different way, which could obviate the need to specify a pattern.
---
docs/api-guide/routers.md | 6 ++++++
rest_framework/routers.py | 20 ++++++++++++++------
rest_framework/tests/test_routers.py | 21 +++++++++++++++++++++
3 files changed, 41 insertions(+), 6 deletions(-)
diff --git a/docs/api-guide/routers.md b/docs/api-guide/routers.md
index 846ac9f9..f3beabdd 100644
--- a/docs/api-guide/routers.md
+++ b/docs/api-guide/routers.md
@@ -83,6 +83,12 @@ This behavior can be modified by setting the `trailing_slash` argument to `False
Trailing slashes are conventional in Django, but are not used by default in some other frameworks such as Rails. Which style you choose to use is largely a matter of preference, although some javascript frameworks may expect a particular routing style.
+With `trailing_slash` set to True, the router will match lookup values containing any characters except slashes and dots. When set to False, dots are allowed. To restrict the lookup pattern, set the `lookup_field_regex` attribute on the viewset. For example, you can limit the lookup to valid UUIDs:
+
+ class MyModelViewSet(mixins.RetrieveModelMixin, viewsets.GenericViewSet):
+ lookup_field = 'my_model_id'
+ lookup_value_regex = '[0-9a-f]{32}'
+
## DefaultRouter
This router is similar to `SimpleRouter` as above, but additionally includes a default API root view, that returns a response containing hyperlinks to all the list views. It also generates routes for optional `.json` style format suffixes.
diff --git a/rest_framework/routers.py b/rest_framework/routers.py
index 740d58f0..8766ecb2 100644
--- a/rest_framework/routers.py
+++ b/rest_framework/routers.py
@@ -219,13 +219,21 @@ class SimpleRouter(BaseRouter):
https://github.com/alanjds/drf-nested-routers
"""
- if self.trailing_slash:
- base_regex = '(?P<{lookup_prefix}{lookup_field}>[^/]+)'
- else:
- # Don't consume `.json` style suffixes
- base_regex = '(?P<{lookup_prefix}{lookup_field}>[^/.]+)'
+ base_regex = '(?P<{lookup_prefix}{lookup_field}>{lookup_value})'
lookup_field = getattr(viewset, 'lookup_field', 'pk')
- return base_regex.format(lookup_field=lookup_field, lookup_prefix=lookup_prefix)
+ try:
+ lookup_value = viewset.lookup_value_regex
+ except AttributeError:
+ if self.trailing_slash:
+ lookup_value = '[^/]+'
+ else:
+ # Don't consume `.json` style suffixes
+ lookup_value = '[^/.]+'
+ return base_regex.format(
+ lookup_prefix=lookup_prefix,
+ lookup_field=lookup_field,
+ lookup_value=lookup_value
+ )
def get_urls(self):
"""
diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py
index 1c34648f..0f6d62c7 100644
--- a/rest_framework/tests/test_routers.py
+++ b/rest_framework/tests/test_routers.py
@@ -121,6 +121,27 @@ class TestCustomLookupFields(TestCase):
)
+class TestLookupValueRegex(TestCase):
+ """
+ Ensure the router honors lookup_value_regex when applied
+ to the viewset.
+ """
+ def setUp(self):
+ class NoteViewSet(viewsets.ModelViewSet):
+ queryset = RouterTestModel.objects.all()
+ lookup_field = 'uuid'
+ lookup_value_regex = '[0-9a-f]{32}'
+
+ self.router = SimpleRouter()
+ self.router.register(r'notes', NoteViewSet)
+ self.urls = self.router.urls
+
+ def test_urls_limited_by_lookup_value_regex(self):
+ expected = ['^notes/$', '^notes/(?P[0-9a-f]{32})/$']
+ for idx in range(len(expected)):
+ self.assertEqual(expected[idx], self.urls[idx].regex.pattern)
+
+
class TestTrailingSlashIncluded(TestCase):
def setUp(self):
class NoteViewSet(viewsets.ModelViewSet):
--
cgit v1.2.3
From 3cd15fb1713dfc49e1bf1fd48045ca3ae5654e18 Mon Sep 17 00:00:00 2001
From: Paul Melnikow
Date: Sat, 4 Jan 2014 16:57:50 -0500
Subject: Router: Do not automatically adjust lookup_regex when trailing_slash
is True
BREAKING CHANGE
When trailing_slash is set to True, the router no longer will adjust the lookup regex to allow it to include periods. To simulate the old behavior, the programmer should specify `lookup_regex = '[^/]+'` on the viewset.
https://github.com/tomchristie/django-rest-framework/pull/1328#issuecomment-31517099
---
docs/api-guide/routers.md | 2 +-
rest_framework/routers.py | 7 ++-----
rest_framework/tests/test_routers.py | 2 +-
3 files changed, 4 insertions(+), 7 deletions(-)
diff --git a/docs/api-guide/routers.md b/docs/api-guide/routers.md
index f3beabdd..6b4ae6db 100644
--- a/docs/api-guide/routers.md
+++ b/docs/api-guide/routers.md
@@ -83,7 +83,7 @@ This behavior can be modified by setting the `trailing_slash` argument to `False
Trailing slashes are conventional in Django, but are not used by default in some other frameworks such as Rails. Which style you choose to use is largely a matter of preference, although some javascript frameworks may expect a particular routing style.
-With `trailing_slash` set to True, the router will match lookup values containing any characters except slashes and dots. When set to False, dots are allowed. To restrict the lookup pattern, set the `lookup_field_regex` attribute on the viewset. For example, you can limit the lookup to valid UUIDs:
+The router will match lookup values containing any characters except slashes and period characters. For a more restrictive (or lenient) lookup pattern, set the `lookup_field_regex` attribute on the viewset. For example, you can limit the lookup to valid UUIDs:
class MyModelViewSet(mixins.RetrieveModelMixin, viewsets.GenericViewSet):
lookup_field = 'my_model_id'
diff --git a/rest_framework/routers.py b/rest_framework/routers.py
index 8766ecb2..df1233fd 100644
--- a/rest_framework/routers.py
+++ b/rest_framework/routers.py
@@ -224,11 +224,8 @@ class SimpleRouter(BaseRouter):
try:
lookup_value = viewset.lookup_value_regex
except AttributeError:
- if self.trailing_slash:
- lookup_value = '[^/]+'
- else:
- # Don't consume `.json` style suffixes
- lookup_value = '[^/.]+'
+ # Don't consume `.json` style suffixes
+ lookup_value = '[^/.]+'
return base_regex.format(
lookup_prefix=lookup_prefix,
lookup_field=lookup_field,
diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py
index 0f6d62c7..e41da57f 100644
--- a/rest_framework/tests/test_routers.py
+++ b/rest_framework/tests/test_routers.py
@@ -152,7 +152,7 @@ class TestTrailingSlashIncluded(TestCase):
self.urls = self.router.urls
def test_urls_have_trailing_slash_by_default(self):
- expected = ['^notes/$', '^notes/(?P[^/]+)/$']
+ expected = ['^notes/$', '^notes/(?P[^/.]+)/$']
for idx in range(len(expected)):
self.assertEqual(expected[idx], self.urls[idx].regex.pattern)
--
cgit v1.2.3
From 899381575a6038f550a064261ed5c6ba0655211b Mon Sep 17 00:00:00 2001
From: Paul Melnikow
Date: Sat, 4 Jan 2014 17:03:01 -0500
Subject: Fix a typo
---
docs/api-guide/routers.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/api-guide/routers.md b/docs/api-guide/routers.md
index 6b4ae6db..249e99a4 100644
--- a/docs/api-guide/routers.md
+++ b/docs/api-guide/routers.md
@@ -83,7 +83,7 @@ This behavior can be modified by setting the `trailing_slash` argument to `False
Trailing slashes are conventional in Django, but are not used by default in some other frameworks such as Rails. Which style you choose to use is largely a matter of preference, although some javascript frameworks may expect a particular routing style.
-The router will match lookup values containing any characters except slashes and period characters. For a more restrictive (or lenient) lookup pattern, set the `lookup_field_regex` attribute on the viewset. For example, you can limit the lookup to valid UUIDs:
+The router will match lookup values containing any characters except slashes and period characters. For a more restrictive (or lenient) lookup pattern, set the `lookup_value_regex` attribute on the viewset. For example, you can limit the lookup to valid UUIDs:
class MyModelViewSet(mixins.RetrieveModelMixin, viewsets.GenericViewSet):
lookup_field = 'my_model_id'
--
cgit v1.2.3
From 46f5c62530744017f744cdcfec91774a0566c179 Mon Sep 17 00:00:00 2001
From: Yuri Prezument
Date: Sun, 5 Jan 2014 15:16:55 +0200
Subject: Regression test for #1330 (Coerce None to '')
---
rest_framework/tests/test_serializer.py | 14 ++++++++++++++
1 file changed, 14 insertions(+)
diff --git a/rest_framework/tests/test_serializer.py b/rest_framework/tests/test_serializer.py
index 86f365de..8c2c09cf 100644
--- a/rest_framework/tests/test_serializer.py
+++ b/rest_framework/tests/test_serializer.py
@@ -1124,6 +1124,20 @@ class BlankFieldTests(TestCase):
serializer = self.model_serializer_class(data={})
self.assertEqual(serializer.is_valid(), True)
+ def test_create_model_null_field_save(self):
+ """
+ Regression test for #1330.
+
+ https://github.com/tomchristie/django-rest-framework/pull/1330
+ """
+ serializer = self.model_serializer_class(data={'title': None})
+ self.assertEqual(serializer.is_valid(), True)
+
+ try:
+ serializer.save()
+ except Exception:
+ self.fail('Exception raised on save() after validation passes')
+
#test for issue #460
class SerializerPickleTests(TestCase):
--
cgit v1.2.3
From e88e3c6ae163029f0fe564dd214235ab350dbfc9 Mon Sep 17 00:00:00 2001
From: Yuri Prezument
Date: Sun, 5 Jan 2014 15:25:16 +0200
Subject: Possible fix for #1330
Coerce None to '' in CharField.to_native()
---
rest_framework/fields.py | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 5ee75235..22f0120b 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -452,7 +452,9 @@ class CharField(WritableField):
self.validators.append(validators.MaxLengthValidator(max_length))
def from_native(self, value):
- if isinstance(value, six.string_types) or value is None:
+ if value is None:
+ return ''
+ if isinstance(value, six.string_types):
return value
return smart_text(value)
--
cgit v1.2.3
From 6e622d644c9b55b905e24497f0fb818d557fd970 Mon Sep 17 00:00:00 2001
From: Yuri Prezument
Date: Sun, 5 Jan 2014 15:58:46 +0200
Subject: CharField - add allow_null argument
---
docs/api-guide/fields.md | 7 ++++---
rest_framework/fields.py | 5 +++--
2 files changed, 7 insertions(+), 5 deletions(-)
diff --git a/docs/api-guide/fields.md b/docs/api-guide/fields.md
index e05c0306..83825350 100644
--- a/docs/api-guide/fields.md
+++ b/docs/api-guide/fields.md
@@ -157,23 +157,24 @@ Corresponds to `django.db.models.fields.BooleanField`.
## CharField
A text representation, optionally validates the text to be shorter than `max_length` and longer than `min_length`.
+If `allow_none` is `False` (default), `None` values will be converted to an empty string.
Corresponds to `django.db.models.fields.CharField`
or `django.db.models.fields.TextField`.
-**Signature:** `CharField(max_length=None, min_length=None)`
+**Signature:** `CharField(max_length=None, min_length=None, allow_none=False)`
## URLField
Corresponds to `django.db.models.fields.URLField`. Uses Django's `django.core.validators.URLValidator` for validation.
-**Signature:** `CharField(max_length=200, min_length=None)`
+**Signature:** `CharField(max_length=200, min_length=None, allow_none=False)`
## SlugField
Corresponds to `django.db.models.fields.SlugField`.
-**Signature:** `CharField(max_length=50, min_length=None)`
+**Signature:** `CharField(max_length=50, min_length=None, allow_none=False)`
## ChoiceField
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 22f0120b..16485b41 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -443,8 +443,9 @@ class CharField(WritableField):
type_label = 'string'
form_field_class = forms.CharField
- def __init__(self, max_length=None, min_length=None, *args, **kwargs):
+ def __init__(self, max_length=None, min_length=None, allow_none=False, *args, **kwargs):
self.max_length, self.min_length = max_length, min_length
+ self.allow_none = allow_none
super(CharField, self).__init__(*args, **kwargs)
if min_length is not None:
self.validators.append(validators.MinLengthValidator(min_length))
@@ -452,7 +453,7 @@ class CharField(WritableField):
self.validators.append(validators.MaxLengthValidator(max_length))
def from_native(self, value):
- if value is None:
+ if value is None and not self.allow_none:
return ''
if isinstance(value, six.string_types):
return value
--
cgit v1.2.3
From e1bbe9d514c95aba596cff64292eb0f0bc7d99fa Mon Sep 17 00:00:00 2001
From: Yuri Prezument
Date: Mon, 6 Jan 2014 13:56:57 +0200
Subject: Set `allow_none = True` for CharFields with null=True
---
rest_framework/serializers.py | 7 ++++++-
1 file changed, 6 insertions(+), 1 deletion(-)
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index fa935306..0164965c 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -821,10 +821,15 @@ class ModelSerializer(Serializer):
kwargs.update({attribute: getattr(model_field, attribute)})
try:
- return self.field_mapping[model_field.__class__](**kwargs)
+ field_class = self.field_mapping[model_field.__class__]
except KeyError:
return ModelField(model_field=model_field, **kwargs)
+ if issubclass(field_class, CharField) and model_field.null:
+ kwargs['allow_none'] = True
+
+ return field_class(**kwargs)
+
def get_validation_exclusions(self):
"""
Return a list of field names to exclude from model validation.
--
cgit v1.2.3
From 0fd0454a5c1ddcf8676e23b30dfaee40fa7cb0c8 Mon Sep 17 00:00:00 2001
From: Yuri Prezument
Date: Mon, 6 Jan 2014 14:02:00 +0200
Subject: Test for setting allow_none=True for nullable CharFields
---
rest_framework/tests/test_serializer.py | 8 ++++++++
1 file changed, 8 insertions(+)
diff --git a/rest_framework/tests/test_serializer.py b/rest_framework/tests/test_serializer.py
index 8c2c09cf..6d9b85ee 100644
--- a/rest_framework/tests/test_serializer.py
+++ b/rest_framework/tests/test_serializer.py
@@ -1504,6 +1504,7 @@ class AttributeMappingOnAutogeneratedFieldsTests(TestCase):
image_field = models.ImageField(max_length=1024, blank=True)
slug_field = models.SlugField(max_length=1024, blank=True)
url_field = models.URLField(max_length=1024, blank=True)
+ nullable_char_field = models.CharField(max_length=1024, blank=True, null=True)
class AMOAFSerializer(serializers.ModelSerializer):
class Meta:
@@ -1536,6 +1537,10 @@ class AttributeMappingOnAutogeneratedFieldsTests(TestCase):
'url_field': [
('max_length', 1024),
],
+ 'nullable_char_field': [
+ ('max_length', 1024),
+ ('allow_none', True),
+ ],
}
def field_test(self, field):
@@ -1572,6 +1577,9 @@ class AttributeMappingOnAutogeneratedFieldsTests(TestCase):
def test_url_field(self):
self.field_test('url_field')
+ def test_nullable_char_field(self):
+ self.field_test('nullable_char_field')
+
class DefaultValuesOnAutogeneratedFieldsTests(TestCase):
--
cgit v1.2.3
From cd9a4194ea4f4dc0e43a34485cd8a27eba44a39a Mon Sep 17 00:00:00 2001
From: Yuri Prezument
Date: Sun, 12 Jan 2014 16:30:26 +0200
Subject: Check the modelfield's class instead
---
rest_framework/serializers.py | 11 +++++------
1 file changed, 5 insertions(+), 6 deletions(-)
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 0164965c..cbf73fc3 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -804,6 +804,10 @@ class ModelSerializer(Serializer):
issubclass(model_field.__class__, models.PositiveSmallIntegerField):
kwargs['min_value'] = 0
+ if model_field.null and \
+ issubclass(model_field.__class__, (models.CharField, models.TextField)):
+ kwargs['allow_none'] = True
+
attribute_dict = {
models.CharField: ['max_length'],
models.CommaSeparatedIntegerField: ['max_length'],
@@ -821,15 +825,10 @@ class ModelSerializer(Serializer):
kwargs.update({attribute: getattr(model_field, attribute)})
try:
- field_class = self.field_mapping[model_field.__class__]
+ return self.field_mapping[model_field.__class__](**kwargs)
except KeyError:
return ModelField(model_field=model_field, **kwargs)
- if issubclass(field_class, CharField) and model_field.null:
- kwargs['allow_none'] = True
-
- return field_class(**kwargs)
-
def get_validation_exclusions(self):
"""
Return a list of field names to exclude from model validation.
--
cgit v1.2.3
From a90796c0f0d9db1a7d9bfaca8fbdfed22435c628 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Mon, 13 Jan 2014 09:56:57 +0000
Subject: Track changes that need noting in 2.4 announcement
---
docs/topics/2.4-accouncement.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/topics/2.4-accouncement.md b/docs/topics/2.4-accouncement.md
index a5425d54..0cf50ce9 100644
--- a/docs/topics/2.4-accouncement.md
+++ b/docs/topics/2.4-accouncement.md
@@ -1,4 +1,4 @@
* Writable nested serializers.
* List/detail routes.
* 1.3 Support dropped, install six for <=1.4.?.
-* Note title ordering changed
\ No newline at end of file
+* `allow_none` for char fields
--
cgit v1.2.3
From 2911cd64ad67ba193e3d37322ee71692cb482623 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Mon, 13 Jan 2014 15:37:52 +0000
Subject: Minor tweaks to 'lookup_value_regex' work
---
docs/topics/2.4-accouncement.md | 1 +
rest_framework/routers.py | 8 +++-----
2 files changed, 4 insertions(+), 5 deletions(-)
diff --git a/docs/topics/2.4-accouncement.md b/docs/topics/2.4-accouncement.md
index 0cf50ce9..91472b9c 100644
--- a/docs/topics/2.4-accouncement.md
+++ b/docs/topics/2.4-accouncement.md
@@ -2,3 +2,4 @@
* List/detail routes.
* 1.3 Support dropped, install six for <=1.4.?.
* `allow_none` for char fields
+* `trailing_slash = True` --> `[^/]`, `trailing_slash = False` --> `[^/.]`, becomes simply `[^/]` and `lookup_value_regex` is added.
diff --git a/rest_framework/routers.py b/rest_framework/routers.py
index df1233fd..406ebcf7 100644
--- a/rest_framework/routers.py
+++ b/rest_framework/routers.py
@@ -220,12 +220,10 @@ class SimpleRouter(BaseRouter):
https://github.com/alanjds/drf-nested-routers
"""
base_regex = '(?P<{lookup_prefix}{lookup_field}>{lookup_value})'
+ # Use `pk` as default field, unset set. Default regex should not
+ # consume `.json` style suffixes and should break at '/' boundaries.
lookup_field = getattr(viewset, 'lookup_field', 'pk')
- try:
- lookup_value = viewset.lookup_value_regex
- except AttributeError:
- # Don't consume `.json` style suffixes
- lookup_value = '[^/.]+'
+ lookup_value = getattr(viewset, 'lookup_value_regex', '[^/.]+')
return base_regex.format(
lookup_prefix=lookup_prefix,
lookup_field=lookup_field,
--
cgit v1.2.3
From 971578ca345c3d3bae7fd93b87c41d43483b6f05 Mon Sep 17 00:00:00 2001
From: Andreas Pelme
Date: Sun, 2 Mar 2014 12:40:30 +0100
Subject: Support for running the test suite with py.test
* Get rid of runtests.py
* Moved test code from rest_framework/tests and rest_framework/runtests to tests
* Invoke py.test from setup.py
* Invoke py.test from Travis
* Invoke py.test from tox
* Changed setUpClass to be just plain setUp in test_permissions.py
* Updated contribution guideline to show how to invoke py.test
---
.travis.yml | 3 +-
CONTRIBUTING.md | 2 +-
conftest.py | 85 +
docs/index.md | 16 +-
docs/topics/contributing.md | 2 +-
pytest.ini | 2 +
requirements.txt | 2 +
rest_framework/runtests/__init__.py | 0
rest_framework/runtests/runcoverage.py | 78 -
rest_framework/runtests/runtests.py | 48 -
rest_framework/runtests/settings.py | 169 --
rest_framework/runtests/urls.py | 7 -
rest_framework/tests/__init__.py | 0
rest_framework/tests/accounts/__init__.py | 0
rest_framework/tests/accounts/models.py | 8 -
rest_framework/tests/accounts/serializers.py | 11 -
rest_framework/tests/description.py | 26 -
rest_framework/tests/extras/__init__.py | 0
rest_framework/tests/extras/bad_import.py | 1 -
rest_framework/tests/models.py | 170 --
rest_framework/tests/records/__init__.py | 0
rest_framework/tests/records/models.py | 6 -
rest_framework/tests/serializers.py | 8 -
rest_framework/tests/test_authentication.py | 637 -------
rest_framework/tests/test_breadcrumbs.py | 73 -
rest_framework/tests/test_decorators.py | 157 --
rest_framework/tests/test_description.py | 108 --
rest_framework/tests/test_fields.py | 984 -----------
rest_framework/tests/test_files.py | 95 -
rest_framework/tests/test_filters.py | 615 -------
rest_framework/tests/test_genericrelations.py | 129 --
rest_framework/tests/test_generics.py | 609 -------
rest_framework/tests/test_htmlrenderer.py | 118 --
.../tests/test_hyperlinkedserializers.py | 379 ----
.../tests/test_multitable_inheritance.py | 67 -
rest_framework/tests/test_negotiation.py | 45 -
rest_framework/tests/test_nullable_fields.py | 30 -
rest_framework/tests/test_pagination.py | 517 ------
rest_framework/tests/test_parsers.py | 115 --
rest_framework/tests/test_permissions.py | 300 ----
rest_framework/tests/test_relations.py | 120 --
rest_framework/tests/test_relations_hyperlink.py | 524 ------
rest_framework/tests/test_relations_nested.py | 328 ----
rest_framework/tests/test_relations_pk.py | 551 ------
rest_framework/tests/test_relations_slug.py | 257 ---
rest_framework/tests/test_renderers.py | 651 -------
rest_framework/tests/test_request.py | 347 ----
rest_framework/tests/test_response.py | 278 ---
rest_framework/tests/test_reverse.py | 27 -
rest_framework/tests/test_routers.py | 216 ---
rest_framework/tests/test_serializer.py | 1857 --------------------
.../tests/test_serializer_bulk_update.py | 278 ---
rest_framework/tests/test_serializer_empty.py | 15 -
rest_framework/tests/test_serializer_import.py | 19 -
rest_framework/tests/test_serializer_nested.py | 347 ----
rest_framework/tests/test_serializers.py | 28 -
rest_framework/tests/test_settings.py | 22 -
rest_framework/tests/test_status.py | 33 -
rest_framework/tests/test_templatetags.py | 51 -
rest_framework/tests/test_testing.py | 154 --
rest_framework/tests/test_throttling.py | 277 ---
rest_framework/tests/test_urlpatterns.py | 76 -
rest_framework/tests/test_validation.py | 104 --
rest_framework/tests/test_views.py | 142 --
rest_framework/tests/test_write_only_fields.py | 42 -
rest_framework/tests/tests.py | 16 -
rest_framework/tests/users/__init__.py | 0
rest_framework/tests/users/models.py | 6 -
rest_framework/tests/users/serializers.py | 8 -
rest_framework/tests/views.py | 8 -
setup.py | 17 +-
tests/__init__.py | 0
tests/accounts/__init__.py | 0
tests/accounts/models.py | 8 +
tests/accounts/serializers.py | 11 +
tests/description.py | 26 +
tests/extras/__init__.py | 0
tests/extras/bad_import.py | 1 +
tests/models.py | 170 ++
tests/records/__init__.py | 0
tests/records/models.py | 6 +
tests/serializers.py | 8 +
tests/settings.py | 169 ++
tests/test_authentication.py | 637 +++++++
tests/test_breadcrumbs.py | 73 +
tests/test_decorators.py | 157 ++
tests/test_description.py | 108 ++
tests/test_fields.py | 984 +++++++++++
tests/test_files.py | 95 +
tests/test_filters.py | 615 +++++++
tests/test_genericrelations.py | 129 ++
tests/test_generics.py | 609 +++++++
tests/test_htmlrenderer.py | 118 ++
tests/test_hyperlinkedserializers.py | 379 ++++
tests/test_multitable_inheritance.py | 67 +
tests/test_negotiation.py | 45 +
tests/test_nullable_fields.py | 30 +
tests/test_pagination.py | 517 ++++++
tests/test_parsers.py | 115 ++
tests/test_permissions.py | 291 +++
tests/test_relations.py | 120 ++
tests/test_relations_hyperlink.py | 524 ++++++
tests/test_relations_nested.py | 328 ++++
tests/test_relations_pk.py | 551 ++++++
tests/test_relations_slug.py | 257 +++
tests/test_renderers.py | 651 +++++++
tests/test_request.py | 347 ++++
tests/test_response.py | 278 +++
tests/test_reverse.py | 27 +
tests/test_routers.py | 216 +++
tests/test_serializer.py | 1857 ++++++++++++++++++++
tests/test_serializer_bulk_update.py | 278 +++
tests/test_serializer_empty.py | 15 +
tests/test_serializer_import.py | 19 +
tests/test_serializer_nested.py | 347 ++++
tests/test_serializers.py | 28 +
tests/test_settings.py | 22 +
tests/test_status.py | 33 +
tests/test_templatetags.py | 51 +
tests/test_testing.py | 154 ++
tests/test_throttling.py | 277 +++
tests/test_urlpatterns.py | 76 +
tests/test_validation.py | 104 ++
tests/test_views.py | 142 ++
tests/test_write_only_fields.py | 42 +
tests/urls.py | 6 +
tests/users/__init__.py | 0
tests/users/models.py | 6 +
tests/users/serializers.py | 8 +
tests/views.py | 8 +
tox.ini | 14 +-
131 files changed, 12265 insertions(+), 12310 deletions(-)
create mode 100644 conftest.py
create mode 100644 pytest.ini
delete mode 100644 rest_framework/runtests/__init__.py
delete mode 100755 rest_framework/runtests/runcoverage.py
delete mode 100755 rest_framework/runtests/runtests.py
delete mode 100644 rest_framework/runtests/settings.py
delete mode 100644 rest_framework/runtests/urls.py
delete mode 100644 rest_framework/tests/__init__.py
delete mode 100644 rest_framework/tests/accounts/__init__.py
delete mode 100644 rest_framework/tests/accounts/models.py
delete mode 100644 rest_framework/tests/accounts/serializers.py
delete mode 100644 rest_framework/tests/description.py
delete mode 100644 rest_framework/tests/extras/__init__.py
delete mode 100644 rest_framework/tests/extras/bad_import.py
delete mode 100644 rest_framework/tests/models.py
delete mode 100644 rest_framework/tests/records/__init__.py
delete mode 100644 rest_framework/tests/records/models.py
delete mode 100644 rest_framework/tests/serializers.py
delete mode 100644 rest_framework/tests/test_authentication.py
delete mode 100644 rest_framework/tests/test_breadcrumbs.py
delete mode 100644 rest_framework/tests/test_decorators.py
delete mode 100644 rest_framework/tests/test_description.py
delete mode 100644 rest_framework/tests/test_fields.py
delete mode 100644 rest_framework/tests/test_files.py
delete mode 100644 rest_framework/tests/test_filters.py
delete mode 100644 rest_framework/tests/test_genericrelations.py
delete mode 100644 rest_framework/tests/test_generics.py
delete mode 100644 rest_framework/tests/test_htmlrenderer.py
delete mode 100644 rest_framework/tests/test_hyperlinkedserializers.py
delete mode 100644 rest_framework/tests/test_multitable_inheritance.py
delete mode 100644 rest_framework/tests/test_negotiation.py
delete mode 100644 rest_framework/tests/test_nullable_fields.py
delete mode 100644 rest_framework/tests/test_pagination.py
delete mode 100644 rest_framework/tests/test_parsers.py
delete mode 100644 rest_framework/tests/test_permissions.py
delete mode 100644 rest_framework/tests/test_relations.py
delete mode 100644 rest_framework/tests/test_relations_hyperlink.py
delete mode 100644 rest_framework/tests/test_relations_nested.py
delete mode 100644 rest_framework/tests/test_relations_pk.py
delete mode 100644 rest_framework/tests/test_relations_slug.py
delete mode 100644 rest_framework/tests/test_renderers.py
delete mode 100644 rest_framework/tests/test_request.py
delete mode 100644 rest_framework/tests/test_response.py
delete mode 100644 rest_framework/tests/test_reverse.py
delete mode 100644 rest_framework/tests/test_routers.py
delete mode 100644 rest_framework/tests/test_serializer.py
delete mode 100644 rest_framework/tests/test_serializer_bulk_update.py
delete mode 100644 rest_framework/tests/test_serializer_empty.py
delete mode 100644 rest_framework/tests/test_serializer_import.py
delete mode 100644 rest_framework/tests/test_serializer_nested.py
delete mode 100644 rest_framework/tests/test_serializers.py
delete mode 100644 rest_framework/tests/test_settings.py
delete mode 100644 rest_framework/tests/test_status.py
delete mode 100644 rest_framework/tests/test_templatetags.py
delete mode 100644 rest_framework/tests/test_testing.py
delete mode 100644 rest_framework/tests/test_throttling.py
delete mode 100644 rest_framework/tests/test_urlpatterns.py
delete mode 100644 rest_framework/tests/test_validation.py
delete mode 100644 rest_framework/tests/test_views.py
delete mode 100644 rest_framework/tests/test_write_only_fields.py
delete mode 100644 rest_framework/tests/tests.py
delete mode 100644 rest_framework/tests/users/__init__.py
delete mode 100644 rest_framework/tests/users/models.py
delete mode 100644 rest_framework/tests/users/serializers.py
delete mode 100644 rest_framework/tests/views.py
create mode 100644 tests/__init__.py
create mode 100644 tests/accounts/__init__.py
create mode 100644 tests/accounts/models.py
create mode 100644 tests/accounts/serializers.py
create mode 100644 tests/description.py
create mode 100644 tests/extras/__init__.py
create mode 100644 tests/extras/bad_import.py
create mode 100644 tests/models.py
create mode 100644 tests/records/__init__.py
create mode 100644 tests/records/models.py
create mode 100644 tests/serializers.py
create mode 100644 tests/settings.py
create mode 100644 tests/test_authentication.py
create mode 100644 tests/test_breadcrumbs.py
create mode 100644 tests/test_decorators.py
create mode 100644 tests/test_description.py
create mode 100644 tests/test_fields.py
create mode 100644 tests/test_files.py
create mode 100644 tests/test_filters.py
create mode 100644 tests/test_genericrelations.py
create mode 100644 tests/test_generics.py
create mode 100644 tests/test_htmlrenderer.py
create mode 100644 tests/test_hyperlinkedserializers.py
create mode 100644 tests/test_multitable_inheritance.py
create mode 100644 tests/test_negotiation.py
create mode 100644 tests/test_nullable_fields.py
create mode 100644 tests/test_pagination.py
create mode 100644 tests/test_parsers.py
create mode 100644 tests/test_permissions.py
create mode 100644 tests/test_relations.py
create mode 100644 tests/test_relations_hyperlink.py
create mode 100644 tests/test_relations_nested.py
create mode 100644 tests/test_relations_pk.py
create mode 100644 tests/test_relations_slug.py
create mode 100644 tests/test_renderers.py
create mode 100644 tests/test_request.py
create mode 100644 tests/test_response.py
create mode 100644 tests/test_reverse.py
create mode 100644 tests/test_routers.py
create mode 100644 tests/test_serializer.py
create mode 100644 tests/test_serializer_bulk_update.py
create mode 100644 tests/test_serializer_empty.py
create mode 100644 tests/test_serializer_import.py
create mode 100644 tests/test_serializer_nested.py
create mode 100644 tests/test_serializers.py
create mode 100644 tests/test_settings.py
create mode 100644 tests/test_status.py
create mode 100644 tests/test_templatetags.py
create mode 100644 tests/test_testing.py
create mode 100644 tests/test_throttling.py
create mode 100644 tests/test_urlpatterns.py
create mode 100644 tests/test_validation.py
create mode 100644 tests/test_views.py
create mode 100644 tests/test_write_only_fields.py
create mode 100644 tests/urls.py
create mode 100644 tests/users/__init__.py
create mode 100644 tests/users/models.py
create mode 100644 tests/users/serializers.py
create mode 100644 tests/views.py
diff --git a/.travis.yml b/.travis.yml
index 2e6ed46a..061d4c73 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -15,6 +15,7 @@ env:
install:
- pip install $DJANGO
- pip install defusedxml==0.3
+ - pip install pytest-django==2.6
- "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install oauth2==1.5.211; fi"
- "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install django-oauth-plus==2.2.1; fi"
- "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install django-oauth2-provider==0.2.4; fi"
@@ -24,7 +25,7 @@ install:
- export PYTHONPATH=.
script:
- - python rest_framework/runtests/runtests.py
+ - py.test
matrix:
exclude:
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index a7aa6fc4..ff6018b8 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -65,7 +65,7 @@ To run the tests, clone the repository, and then:
pip install -r optionals.txt
# Run the tests
- rest_framework/runtests/runtests.py
+ py.test
You can also use the excellent [`tox`][tox] testing tool to run the tests against all supported versions of Python and Django. Install `tox` globally, and then simply run:
diff --git a/conftest.py b/conftest.py
new file mode 100644
index 00000000..7cfc77f2
--- /dev/null
+++ b/conftest.py
@@ -0,0 +1,85 @@
+def pytest_configure():
+ from django.conf import settings
+
+ settings.configure(
+ DEBUG_PROPAGATE_EXCEPTIONS=True,
+ DATABASES={'default': {'ENGINE': 'django.db.backends.sqlite3',
+ 'NAME': ':memory:'}},
+ SECRET_KEY='not very secret in tests',
+ USE_I18N=True,
+ USE_L10N=True,
+ STATIC_URL='/static/',
+ ROOT_URLCONF='tests.urls',
+ TEMPLATE_LOADERS=(
+ 'django.template.loaders.filesystem.Loader',
+ 'django.template.loaders.app_directories.Loader',
+ ),
+ MIDDLEWARE_CLASSES=(
+ 'django.middleware.common.CommonMiddleware',
+ 'django.contrib.sessions.middleware.SessionMiddleware',
+ 'django.middleware.csrf.CsrfViewMiddleware',
+ 'django.contrib.auth.middleware.AuthenticationMiddleware',
+ 'django.contrib.messages.middleware.MessageMiddleware',
+ ),
+ INSTALLED_APPS=(
+ 'django.contrib.auth',
+ 'django.contrib.contenttypes',
+ 'django.contrib.sessions',
+ 'django.contrib.sites',
+ 'django.contrib.messages',
+
+ 'rest_framework',
+ 'rest_framework.authtoken',
+ 'tests',
+ 'tests.accounts',
+ 'tests.records',
+ 'tests.users',
+ ),
+ PASSWORD_HASHERS=(
+ 'django.contrib.auth.hashers.SHA1PasswordHasher',
+ 'django.contrib.auth.hashers.PBKDF2PasswordHasher',
+ 'django.contrib.auth.hashers.PBKDF2SHA1PasswordHasher',
+ 'django.contrib.auth.hashers.BCryptPasswordHasher',
+ 'django.contrib.auth.hashers.MD5PasswordHasher',
+ 'django.contrib.auth.hashers.CryptPasswordHasher',
+ ),
+ )
+
+ try:
+ import oauth_provider
+ import oauth2
+ except ImportError:
+ pass
+ else:
+ settings.INSTALLED_APPS += (
+ 'oauth_provider',
+ )
+
+ try:
+ import provider
+ except ImportError:
+ pass
+ else:
+ settings.INSTALLED_APPS += (
+ 'provider',
+ 'provider.oauth2',
+ )
+
+ # guardian is optional
+ try:
+ import guardian
+ except ImportError:
+ pass
+ else:
+ settings.ANONYMOUS_USER_ID = -1
+ settings.AUTHENTICATION_BACKENDS = (
+ 'django.contrib.auth.backends.ModelBackend', # default
+ 'guardian.backends.ObjectPermissionBackend',
+ )
+ settings.INSTALLED_APPS += (
+ 'guardian',
+ )
+
+ # Force Django to load all models
+ from django.db.models import get_models
+ get_models()
diff --git a/docs/index.md b/docs/index.md
index 2a4ad885..9ad647ac 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -206,19 +206,9 @@ General guides to using REST framework.
## Development
-If you want to work on REST framework itself, clone the repository, then...
-
-Build the docs:
-
- ./mkdocs.py
-
-Run the tests:
-
- ./rest_framework/runtests/runtests.py
-
-To run the tests against all supported configurations, first install [the tox testing tool][tox] globally, using `pip install tox`, then simply run `tox`:
-
- tox
+See the [Contribution guidelines][contributing] for information on how to clone
+the repository, run the test suite and contribute changes back to REST
+Framework.
## Support
diff --git a/docs/topics/contributing.md b/docs/topics/contributing.md
index 5a5d1a80..09cc00b3 100644
--- a/docs/topics/contributing.md
+++ b/docs/topics/contributing.md
@@ -65,7 +65,7 @@ To run the tests, clone the repository, and then:
pip install -r optionals.txt
# Run the tests
- rest_framework/runtests/runtests.py
+ py.test
You can also use the excellent `[tox][tox]` testing tool to run the tests against all supported versions of Python and Django. Install `tox` globally, and then simply run:
diff --git a/pytest.ini b/pytest.ini
new file mode 100644
index 00000000..bbd083ac
--- /dev/null
+++ b/pytest.ini
@@ -0,0 +1,2 @@
+[pytest]
+addopts = --tb=short
diff --git a/requirements.txt b/requirements.txt
index 730c1d07..360acb14 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1 +1,3 @@
+-e .
Django>=1.3
+pytest-django==2.6
diff --git a/rest_framework/runtests/__init__.py b/rest_framework/runtests/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/rest_framework/runtests/runcoverage.py b/rest_framework/runtests/runcoverage.py
deleted file mode 100755
index ce11b213..00000000
--- a/rest_framework/runtests/runcoverage.py
+++ /dev/null
@@ -1,78 +0,0 @@
-#!/usr/bin/env python
-"""
-Useful tool to run the test suite for rest_framework and generate a coverage report.
-"""
-
-# http://ericholscher.com/blog/2009/jun/29/enable-setuppy-test-your-django-apps/
-# http://www.travisswicegood.com/2010/01/17/django-virtualenv-pip-and-fabric/
-# http://code.djangoproject.com/svn/django/trunk/tests/runtests.py
-import os
-import sys
-
-# fix sys path so we don't need to setup PYTHONPATH
-sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
-os.environ['DJANGO_SETTINGS_MODULE'] = 'rest_framework.runtests.settings'
-
-from coverage import coverage
-
-
-def main():
- """Run the tests for rest_framework and generate a coverage report."""
-
- cov = coverage()
- cov.erase()
- cov.start()
-
- from django.conf import settings
- from django.test.utils import get_runner
- TestRunner = get_runner(settings)
-
- if hasattr(TestRunner, 'func_name'):
- # Pre 1.2 test runners were just functions,
- # and did not support the 'failfast' option.
- import warnings
- warnings.warn(
- 'Function-based test runners are deprecated. Test runners should be classes with a run_tests() method.',
- DeprecationWarning
- )
- failures = TestRunner(['tests'])
- else:
- test_runner = TestRunner()
- failures = test_runner.run_tests(['tests'])
- cov.stop()
-
- # Discover the list of all modules that we should test coverage for
- import rest_framework
-
- project_dir = os.path.dirname(rest_framework.__file__)
- cov_files = []
-
- for (path, dirs, files) in os.walk(project_dir):
- # Drop tests and runtests directories from the test coverage report
- if os.path.basename(path) in ['tests', 'runtests', 'migrations']:
- continue
-
- # Drop the compat and six modules from coverage, since we're not interested in the coverage
- # of modules which are specifically for resolving environment dependant imports.
- # (Because we'll end up getting different coverage reports for it for each environment)
- if 'compat.py' in files:
- files.remove('compat.py')
-
- if 'six.py' in files:
- files.remove('six.py')
-
- # Same applies to template tags module.
- # This module has to include branching on Django versions,
- # so it's never possible for it to have full coverage.
- if 'rest_framework.py' in files:
- files.remove('rest_framework.py')
-
- cov_files.extend([os.path.join(path, file) for file in files if file.endswith('.py')])
-
- cov.report(cov_files)
- if '--html' in sys.argv:
- cov.html_report(cov_files, directory='coverage')
- sys.exit(failures)
-
-if __name__ == '__main__':
- main()
diff --git a/rest_framework/runtests/runtests.py b/rest_framework/runtests/runtests.py
deleted file mode 100755
index da36d23f..00000000
--- a/rest_framework/runtests/runtests.py
+++ /dev/null
@@ -1,48 +0,0 @@
-#!/usr/bin/env python
-
-# http://ericholscher.com/blog/2009/jun/29/enable-setuppy-test-your-django-apps/
-# http://www.travisswicegood.com/2010/01/17/django-virtualenv-pip-and-fabric/
-# http://code.djangoproject.com/svn/django/trunk/tests/runtests.py
-import os
-import sys
-
-# fix sys path so we don't need to setup PYTHONPATH
-sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
-os.environ['DJANGO_SETTINGS_MODULE'] = 'rest_framework.runtests.settings'
-
-import django
-from django.conf import settings
-from django.test.utils import get_runner
-
-
-def usage():
- return """
- Usage: python runtests.py [UnitTestClass].[method]
-
- You can pass the Class name of the `UnitTestClass` you want to test.
-
- Append a method name if you only want to test a specific method of that class.
- """
-
-
-def main():
- TestRunner = get_runner(settings)
-
- test_runner = TestRunner()
- if len(sys.argv) == 2:
- test_case = '.' + sys.argv[1]
- elif len(sys.argv) == 1:
- test_case = ''
- else:
- print(usage())
- sys.exit(1)
- test_module_name = 'rest_framework.tests'
- if django.VERSION[0] == 1 and django.VERSION[1] < 6:
- test_module_name = 'tests'
-
- failures = test_runner.run_tests([test_module_name + test_case])
-
- sys.exit(failures)
-
-if __name__ == '__main__':
- main()
diff --git a/rest_framework/runtests/settings.py b/rest_framework/runtests/settings.py
deleted file mode 100644
index 3fc0eb2f..00000000
--- a/rest_framework/runtests/settings.py
+++ /dev/null
@@ -1,169 +0,0 @@
-# Django settings for testproject project.
-
-DEBUG = True
-TEMPLATE_DEBUG = DEBUG
-DEBUG_PROPAGATE_EXCEPTIONS = True
-
-ALLOWED_HOSTS = ['*']
-
-ADMINS = (
- # ('Your Name', 'your_email@domain.com'),
-)
-
-MANAGERS = ADMINS
-
-DATABASES = {
- 'default': {
- 'ENGINE': 'django.db.backends.sqlite3', # Add 'postgresql_psycopg2', 'postgresql', 'mysql', 'sqlite3' or 'oracle'.
- 'NAME': 'sqlite.db', # Or path to database file if using sqlite3.
- 'USER': '', # Not used with sqlite3.
- 'PASSWORD': '', # Not used with sqlite3.
- 'HOST': '', # Set to empty string for localhost. Not used with sqlite3.
- 'PORT': '', # Set to empty string for default. Not used with sqlite3.
- }
-}
-
-CACHES = {
- 'default': {
- 'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
- }
-}
-
-# Local time zone for this installation. Choices can be found here:
-# http://en.wikipedia.org/wiki/List_of_tz_zones_by_name
-# although not all choices may be available on all operating systems.
-# On Unix systems, a value of None will cause Django to use the same
-# timezone as the operating system.
-# If running in a Windows environment this must be set to the same as your
-# system time zone.
-TIME_ZONE = 'Europe/London'
-
-# Language code for this installation. All choices can be found here:
-# http://www.i18nguy.com/unicode/language-identifiers.html
-LANGUAGE_CODE = 'en-uk'
-
-SITE_ID = 1
-
-# If you set this to False, Django will make some optimizations so as not
-# to load the internationalization machinery.
-USE_I18N = True
-
-# If you set this to False, Django will not format dates, numbers and
-# calendars according to the current locale
-USE_L10N = True
-
-# Absolute filesystem path to the directory that will hold user-uploaded files.
-# Example: "/home/media/media.lawrence.com/"
-MEDIA_ROOT = ''
-
-# URL that handles the media served from MEDIA_ROOT. Make sure to use a
-# trailing slash if there is a path component (optional in other cases).
-# Examples: "http://media.lawrence.com", "http://example.com/media/"
-MEDIA_URL = ''
-
-# Make this unique, and don't share it with anybody.
-SECRET_KEY = 'u@x-aj9(hoh#rb-^ymf#g2jx_hp0vj7u5#b@ag1n^seu9e!%cy'
-
-# List of callables that know how to import templates from various sources.
-TEMPLATE_LOADERS = (
- 'django.template.loaders.filesystem.Loader',
- 'django.template.loaders.app_directories.Loader',
-# 'django.template.loaders.eggs.Loader',
-)
-
-MIDDLEWARE_CLASSES = (
- 'django.middleware.common.CommonMiddleware',
- 'django.contrib.sessions.middleware.SessionMiddleware',
- 'django.middleware.csrf.CsrfViewMiddleware',
- 'django.contrib.auth.middleware.AuthenticationMiddleware',
- 'django.contrib.messages.middleware.MessageMiddleware',
-)
-
-ROOT_URLCONF = 'urls'
-
-TEMPLATE_DIRS = (
- # Put strings here, like "/home/html/django_templates" or "C:/www/django/templates".
- # Always use forward slashes, even on Windows.
- # Don't forget to use absolute paths, not relative paths.
-)
-
-INSTALLED_APPS = (
- 'django.contrib.auth',
- 'django.contrib.contenttypes',
- 'django.contrib.sessions',
- 'django.contrib.sites',
- 'django.contrib.messages',
- # Uncomment the next line to enable the admin:
- # 'django.contrib.admin',
- # Uncomment the next line to enable admin documentation:
- # 'django.contrib.admindocs',
- 'rest_framework',
- 'rest_framework.authtoken',
- 'rest_framework.tests',
- 'rest_framework.tests.accounts',
- 'rest_framework.tests.records',
- 'rest_framework.tests.users',
-)
-
-# OAuth is optional and won't work if there is no oauth_provider & oauth2
-try:
- import oauth_provider
- import oauth2
-except ImportError:
- pass
-else:
- INSTALLED_APPS += (
- 'oauth_provider',
- )
-
-try:
- import provider
-except ImportError:
- pass
-else:
- INSTALLED_APPS += (
- 'provider',
- 'provider.oauth2',
- )
-
-# guardian is optional
-try:
- import guardian
-except ImportError:
- pass
-else:
- ANONYMOUS_USER_ID = -1
- AUTHENTICATION_BACKENDS = (
- 'django.contrib.auth.backends.ModelBackend', # default
- 'guardian.backends.ObjectPermissionBackend',
- )
- INSTALLED_APPS += (
- 'guardian',
- )
-
-STATIC_URL = '/static/'
-
-PASSWORD_HASHERS = (
- 'django.contrib.auth.hashers.SHA1PasswordHasher',
- 'django.contrib.auth.hashers.PBKDF2PasswordHasher',
- 'django.contrib.auth.hashers.PBKDF2SHA1PasswordHasher',
- 'django.contrib.auth.hashers.BCryptPasswordHasher',
- 'django.contrib.auth.hashers.MD5PasswordHasher',
- 'django.contrib.auth.hashers.CryptPasswordHasher',
-)
-
-AUTH_USER_MODEL = 'auth.User'
-
-import django
-
-if django.VERSION < (1, 3):
- INSTALLED_APPS += ('staticfiles',)
-
-
-# If we're running on the Jenkins server we want to archive the coverage reports as XML.
-import os
-if os.environ.get('HUDSON_URL', None):
- TEST_RUNNER = 'xmlrunner.extra.djangotestrunner.XMLTestRunner'
- TEST_OUTPUT_VERBOSE = True
- TEST_OUTPUT_DESCRIPTIONS = True
- TEST_OUTPUT_DIR = 'xmlrunner'
diff --git a/rest_framework/runtests/urls.py b/rest_framework/runtests/urls.py
deleted file mode 100644
index ed5baeae..00000000
--- a/rest_framework/runtests/urls.py
+++ /dev/null
@@ -1,7 +0,0 @@
-"""
-Blank URLConf just to keep runtests.py happy.
-"""
-from rest_framework.compat import patterns
-
-urlpatterns = patterns('',
-)
diff --git a/rest_framework/tests/__init__.py b/rest_framework/tests/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/rest_framework/tests/accounts/__init__.py b/rest_framework/tests/accounts/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/rest_framework/tests/accounts/models.py b/rest_framework/tests/accounts/models.py
deleted file mode 100644
index 525e601b..00000000
--- a/rest_framework/tests/accounts/models.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from django.db import models
-
-from rest_framework.tests.users.models import User
-
-
-class Account(models.Model):
- owner = models.ForeignKey(User, related_name='accounts_owned')
- admins = models.ManyToManyField(User, blank=True, null=True, related_name='accounts_administered')
diff --git a/rest_framework/tests/accounts/serializers.py b/rest_framework/tests/accounts/serializers.py
deleted file mode 100644
index a27b9ca6..00000000
--- a/rest_framework/tests/accounts/serializers.py
+++ /dev/null
@@ -1,11 +0,0 @@
-from rest_framework import serializers
-
-from rest_framework.tests.accounts.models import Account
-from rest_framework.tests.users.serializers import UserSerializer
-
-
-class AccountSerializer(serializers.ModelSerializer):
- admins = UserSerializer(many=True)
-
- class Meta:
- model = Account
diff --git a/rest_framework/tests/description.py b/rest_framework/tests/description.py
deleted file mode 100644
index b46d7f54..00000000
--- a/rest_framework/tests/description.py
+++ /dev/null
@@ -1,26 +0,0 @@
-# -- coding: utf-8 --
-
-# Apparently there is a python 2.6 issue where docstrings of imported view classes
-# do not retain their encoding information even if a module has a proper
-# encoding declaration at the top of its source file. Therefore for tests
-# to catch unicode related errors, a mock view has to be declared in a separate
-# module.
-
-from rest_framework.views import APIView
-
-
-# test strings snatched from http://www.columbia.edu/~fdc/utf8/,
-# http://winrus.com/utf8-jap.htm and memory
-UTF8_TEST_DOCSTRING = (
- 'zażółć gęślą jaźń'
- 'Sîne klâwen durh die wolken sint geslagen'
- 'Τη γλώσσα μου έδωσαν ελληνική'
- 'யாமறிந்த மொழிகளிலே தமிழ்மொழி'
- 'На берегу пустынных волн'
- 'てすと'
- 'アイウエオカキクケコサシスセソタチツテ'
-)
-
-
-class ViewWithNonASCIICharactersInDocstring(APIView):
- __doc__ = UTF8_TEST_DOCSTRING
diff --git a/rest_framework/tests/extras/__init__.py b/rest_framework/tests/extras/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/rest_framework/tests/extras/bad_import.py b/rest_framework/tests/extras/bad_import.py
deleted file mode 100644
index 68263d94..00000000
--- a/rest_framework/tests/extras/bad_import.py
+++ /dev/null
@@ -1 +0,0 @@
-raise ValueError
diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py
deleted file mode 100644
index 32a726c0..00000000
--- a/rest_framework/tests/models.py
+++ /dev/null
@@ -1,170 +0,0 @@
-from __future__ import unicode_literals
-from django.db import models
-from django.utils.translation import ugettext_lazy as _
-from rest_framework import serializers
-
-
-def foobar():
- return 'foobar'
-
-
-class CustomField(models.CharField):
-
- def __init__(self, *args, **kwargs):
- kwargs['max_length'] = 12
- super(CustomField, self).__init__(*args, **kwargs)
-
-
-class RESTFrameworkModel(models.Model):
- """
- Base for test models that sets app_label, so they play nicely.
- """
- class Meta:
- app_label = 'tests'
- abstract = True
-
-
-class HasPositiveIntegerAsChoice(RESTFrameworkModel):
- some_choices = ((1, 'A'), (2, 'B'), (3, 'C'))
- some_integer = models.PositiveIntegerField(choices=some_choices)
-
-
-class Anchor(RESTFrameworkModel):
- text = models.CharField(max_length=100, default='anchor')
-
-
-class BasicModel(RESTFrameworkModel):
- text = models.CharField(max_length=100, verbose_name=_("Text comes here"), help_text=_("Text description."))
-
-
-class SlugBasedModel(RESTFrameworkModel):
- text = models.CharField(max_length=100)
- slug = models.SlugField(max_length=32)
-
-
-class DefaultValueModel(RESTFrameworkModel):
- text = models.CharField(default='foobar', max_length=100)
- extra = models.CharField(blank=True, null=True, max_length=100)
-
-
-class CallableDefaultValueModel(RESTFrameworkModel):
- text = models.CharField(default=foobar, max_length=100)
-
-
-class ManyToManyModel(RESTFrameworkModel):
- rel = models.ManyToManyField(Anchor, help_text='Some help text.')
-
-
-class ReadOnlyManyToManyModel(RESTFrameworkModel):
- text = models.CharField(max_length=100, default='anchor')
- rel = models.ManyToManyField(Anchor)
-
-
-# Model for regression test for #285
-
-class Comment(RESTFrameworkModel):
- email = models.EmailField()
- content = models.CharField(max_length=200)
- created = models.DateTimeField(auto_now_add=True)
-
-
-class ActionItem(RESTFrameworkModel):
- title = models.CharField(max_length=200)
- started = models.NullBooleanField(default=False)
- done = models.BooleanField(default=False)
- info = CustomField(default='---', max_length=12)
-
-
-# Models for reverse relations
-class Person(RESTFrameworkModel):
- name = models.CharField(max_length=10)
- age = models.IntegerField(null=True, blank=True)
-
- @property
- def info(self):
- return {
- 'name': self.name,
- 'age': self.age,
- }
-
-
-class BlogPost(RESTFrameworkModel):
- title = models.CharField(max_length=100)
- writer = models.ForeignKey(Person, null=True, blank=True)
-
- def get_first_comment(self):
- return self.blogpostcomment_set.all()[0]
-
-
-class BlogPostComment(RESTFrameworkModel):
- text = models.TextField()
- blog_post = models.ForeignKey(BlogPost)
-
-
-class Album(RESTFrameworkModel):
- title = models.CharField(max_length=100, unique=True)
-
-
-class Photo(RESTFrameworkModel):
- description = models.TextField()
- album = models.ForeignKey(Album)
-
-
-# Model for issue #324
-class BlankFieldModel(RESTFrameworkModel):
- title = models.CharField(max_length=100, blank=True, null=False)
-
-
-# Model for issue #380
-class OptionalRelationModel(RESTFrameworkModel):
- other = models.ForeignKey('OptionalRelationModel', blank=True, null=True)
-
-
-# Model for RegexField
-class Book(RESTFrameworkModel):
- isbn = models.CharField(max_length=13)
-
-
-# Models for relations tests
-# ManyToMany
-class ManyToManyTarget(RESTFrameworkModel):
- name = models.CharField(max_length=100)
-
-
-class ManyToManySource(RESTFrameworkModel):
- name = models.CharField(max_length=100)
- targets = models.ManyToManyField(ManyToManyTarget, related_name='sources')
-
-
-# ForeignKey
-class ForeignKeyTarget(RESTFrameworkModel):
- name = models.CharField(max_length=100)
-
-
-class ForeignKeySource(RESTFrameworkModel):
- name = models.CharField(max_length=100)
- target = models.ForeignKey(ForeignKeyTarget, related_name='sources')
-
-
-# Nullable ForeignKey
-class NullableForeignKeySource(RESTFrameworkModel):
- name = models.CharField(max_length=100)
- target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True,
- related_name='nullable_sources')
-
-
-# OneToOne
-class OneToOneTarget(RESTFrameworkModel):
- name = models.CharField(max_length=100)
-
-
-class NullableOneToOneSource(RESTFrameworkModel):
- name = models.CharField(max_length=100)
- target = models.OneToOneField(OneToOneTarget, null=True, blank=True,
- related_name='nullable_source')
-
-
-# Serializer used to test BasicModel
-class BasicModelSerializer(serializers.ModelSerializer):
- class Meta:
- model = BasicModel
diff --git a/rest_framework/tests/records/__init__.py b/rest_framework/tests/records/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/rest_framework/tests/records/models.py b/rest_framework/tests/records/models.py
deleted file mode 100644
index 76954807..00000000
--- a/rest_framework/tests/records/models.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from django.db import models
-
-
-class Record(models.Model):
- account = models.ForeignKey('accounts.Account', blank=True, null=True)
- owner = models.ForeignKey('users.User', blank=True, null=True)
diff --git a/rest_framework/tests/serializers.py b/rest_framework/tests/serializers.py
deleted file mode 100644
index cc943c7d..00000000
--- a/rest_framework/tests/serializers.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from rest_framework import serializers
-
-from rest_framework.tests.models import NullableForeignKeySource
-
-
-class NullableFKSourceSerializer(serializers.ModelSerializer):
- class Meta:
- model = NullableForeignKeySource
diff --git a/rest_framework/tests/test_authentication.py b/rest_framework/tests/test_authentication.py
deleted file mode 100644
index f072b81b..00000000
--- a/rest_framework/tests/test_authentication.py
+++ /dev/null
@@ -1,637 +0,0 @@
-from __future__ import unicode_literals
-from django.contrib.auth.models import User
-from django.http import HttpResponse
-from django.test import TestCase
-from django.utils import unittest
-from rest_framework import HTTP_HEADER_ENCODING
-from rest_framework import exceptions
-from rest_framework import permissions
-from rest_framework import renderers
-from rest_framework.response import Response
-from rest_framework import status
-from rest_framework.authentication import (
- BaseAuthentication,
- TokenAuthentication,
- BasicAuthentication,
- SessionAuthentication,
- OAuthAuthentication,
- OAuth2Authentication
-)
-from rest_framework.authtoken.models import Token
-from rest_framework.compat import patterns, url, include
-from rest_framework.compat import oauth2_provider, oauth2_provider_models, oauth2_provider_scope
-from rest_framework.compat import oauth, oauth_provider
-from rest_framework.test import APIRequestFactory, APIClient
-from rest_framework.views import APIView
-import base64
-import time
-import datetime
-
-factory = APIRequestFactory()
-
-
-class MockView(APIView):
- permission_classes = (permissions.IsAuthenticated,)
-
- def get(self, request):
- return HttpResponse({'a': 1, 'b': 2, 'c': 3})
-
- def post(self, request):
- return HttpResponse({'a': 1, 'b': 2, 'c': 3})
-
- def put(self, request):
- return HttpResponse({'a': 1, 'b': 2, 'c': 3})
-
-
-urlpatterns = patterns('',
- (r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])),
- (r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])),
- (r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),
- (r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),
- (r'^oauth/$', MockView.as_view(authentication_classes=[OAuthAuthentication])),
- (r'^oauth-with-scope/$', MockView.as_view(authentication_classes=[OAuthAuthentication],
- permission_classes=[permissions.TokenHasReadWriteScope]))
-)
-
-if oauth2_provider is not None:
- urlpatterns += patterns('',
- url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')),
- url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])),
- url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication],
- permission_classes=[permissions.TokenHasReadWriteScope])),
- )
-
-
-class BasicAuthTests(TestCase):
- """Basic authentication"""
- urls = 'rest_framework.tests.test_authentication'
-
- def setUp(self):
- self.csrf_client = APIClient(enforce_csrf_checks=True)
- self.username = 'john'
- self.email = 'lennon@thebeatles.com'
- self.password = 'password'
- self.user = User.objects.create_user(self.username, self.email, self.password)
-
- def test_post_form_passing_basic_auth(self):
- """Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF"""
- credentials = ('%s:%s' % (self.username, self.password))
- base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING)
- auth = 'Basic %s' % base64_credentials
- response = self.csrf_client.post('/basic/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
- def test_post_json_passing_basic_auth(self):
- """Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF"""
- credentials = ('%s:%s' % (self.username, self.password))
- base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING)
- auth = 'Basic %s' % base64_credentials
- response = self.csrf_client.post('/basic/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
- def test_post_form_failing_basic_auth(self):
- """Ensure POSTing form over basic auth without correct credentials fails"""
- response = self.csrf_client.post('/basic/', {'example': 'example'})
- self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
-
- def test_post_json_failing_basic_auth(self):
- """Ensure POSTing json over basic auth without correct credentials fails"""
- response = self.csrf_client.post('/basic/', {'example': 'example'}, format='json')
- self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
- self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"')
-
-
-class SessionAuthTests(TestCase):
- """User session authentication"""
- urls = 'rest_framework.tests.test_authentication'
-
- def setUp(self):
- self.csrf_client = APIClient(enforce_csrf_checks=True)
- self.non_csrf_client = APIClient(enforce_csrf_checks=False)
- self.username = 'john'
- self.email = 'lennon@thebeatles.com'
- self.password = 'password'
- self.user = User.objects.create_user(self.username, self.email, self.password)
-
- def tearDown(self):
- self.csrf_client.logout()
-
- def test_post_form_session_auth_failing_csrf(self):
- """
- Ensure POSTing form over session authentication without CSRF token fails.
- """
- self.csrf_client.login(username=self.username, password=self.password)
- response = self.csrf_client.post('/session/', {'example': 'example'})
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
-
- def test_post_form_session_auth_passing(self):
- """
- Ensure POSTing form over session authentication with logged in user and CSRF token passes.
- """
- self.non_csrf_client.login(username=self.username, password=self.password)
- response = self.non_csrf_client.post('/session/', {'example': 'example'})
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
- def test_put_form_session_auth_passing(self):
- """
- Ensure PUTting form over session authentication with logged in user and CSRF token passes.
- """
- self.non_csrf_client.login(username=self.username, password=self.password)
- response = self.non_csrf_client.put('/session/', {'example': 'example'})
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
- def test_post_form_session_auth_failing(self):
- """
- Ensure POSTing form over session authentication without logged in user fails.
- """
- response = self.csrf_client.post('/session/', {'example': 'example'})
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
-
-
-class TokenAuthTests(TestCase):
- """Token authentication"""
- urls = 'rest_framework.tests.test_authentication'
-
- def setUp(self):
- self.csrf_client = APIClient(enforce_csrf_checks=True)
- self.username = 'john'
- self.email = 'lennon@thebeatles.com'
- self.password = 'password'
- self.user = User.objects.create_user(self.username, self.email, self.password)
-
- self.key = 'abcd1234'
- self.token = Token.objects.create(key=self.key, user=self.user)
-
- def test_post_form_passing_token_auth(self):
- """Ensure POSTing json over token auth with correct credentials passes and does not require CSRF"""
- auth = 'Token ' + self.key
- response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
- def test_post_json_passing_token_auth(self):
- """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF"""
- auth = "Token " + self.key
- response = self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
- def test_post_form_failing_token_auth(self):
- """Ensure POSTing form over token auth without correct credentials fails"""
- response = self.csrf_client.post('/token/', {'example': 'example'})
- self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
-
- def test_post_json_failing_token_auth(self):
- """Ensure POSTing json over token auth without correct credentials fails"""
- response = self.csrf_client.post('/token/', {'example': 'example'}, format='json')
- self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
-
- def test_token_has_auto_assigned_key_if_none_provided(self):
- """Ensure creating a token with no key will auto-assign a key"""
- self.token.delete()
- token = Token.objects.create(user=self.user)
- self.assertTrue(bool(token.key))
-
- def test_token_login_json(self):
- """Ensure token login view using JSON POST works."""
- client = APIClient(enforce_csrf_checks=True)
- response = client.post('/auth-token/',
- {'username': self.username, 'password': self.password}, format='json')
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['token'], self.key)
-
- def test_token_login_json_bad_creds(self):
- """Ensure token login view using JSON POST fails if bad credentials are used."""
- client = APIClient(enforce_csrf_checks=True)
- response = client.post('/auth-token/',
- {'username': self.username, 'password': "badpass"}, format='json')
- self.assertEqual(response.status_code, 400)
-
- def test_token_login_json_missing_fields(self):
- """Ensure token login view using JSON POST fails if missing fields."""
- client = APIClient(enforce_csrf_checks=True)
- response = client.post('/auth-token/',
- {'username': self.username}, format='json')
- self.assertEqual(response.status_code, 400)
-
- def test_token_login_form(self):
- """Ensure token login view using form POST works."""
- client = APIClient(enforce_csrf_checks=True)
- response = client.post('/auth-token/',
- {'username': self.username, 'password': self.password})
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['token'], self.key)
-
-
-class IncorrectCredentialsTests(TestCase):
- def test_incorrect_credentials(self):
- """
- If a request contains bad authentication credentials, then
- authentication should run and error, even if no permissions
- are set on the view.
- """
- class IncorrectCredentialsAuth(BaseAuthentication):
- def authenticate(self, request):
- raise exceptions.AuthenticationFailed('Bad credentials')
-
- request = factory.get('/')
- view = MockView.as_view(
- authentication_classes=(IncorrectCredentialsAuth,),
- permission_classes=()
- )
- response = view(request)
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
- self.assertEqual(response.data, {'detail': 'Bad credentials'})
-
-
-class OAuthTests(TestCase):
- """OAuth 1.0a authentication"""
- urls = 'rest_framework.tests.test_authentication'
-
- def setUp(self):
- # these imports are here because oauth is optional and hiding them in try..except block or compat
- # could obscure problems if something breaks
- from oauth_provider.models import Consumer, Scope
- from oauth_provider.models import Token as OAuthToken
- from oauth_provider import consts
-
- self.consts = consts
-
- self.csrf_client = APIClient(enforce_csrf_checks=True)
- self.username = 'john'
- self.email = 'lennon@thebeatles.com'
- self.password = 'password'
- self.user = User.objects.create_user(self.username, self.email, self.password)
-
- self.CONSUMER_KEY = 'consumer_key'
- self.CONSUMER_SECRET = 'consumer_secret'
- self.TOKEN_KEY = "token_key"
- self.TOKEN_SECRET = "token_secret"
-
- self.consumer = Consumer.objects.create(key=self.CONSUMER_KEY, secret=self.CONSUMER_SECRET,
- name='example', user=self.user, status=self.consts.ACCEPTED)
-
- self.scope = Scope.objects.create(name="resource name", url="api/")
- self.token = OAuthToken.objects.create(user=self.user, consumer=self.consumer, scope=self.scope,
- token_type=OAuthToken.ACCESS, key=self.TOKEN_KEY, secret=self.TOKEN_SECRET, is_approved=True
- )
-
- def _create_authorization_header(self):
- params = {
- 'oauth_version': "1.0",
- 'oauth_nonce': oauth.generate_nonce(),
- 'oauth_timestamp': int(time.time()),
- 'oauth_token': self.token.key,
- 'oauth_consumer_key': self.consumer.key
- }
-
- req = oauth.Request(method="GET", url="http://example.com", parameters=params)
-
- signature_method = oauth.SignatureMethod_PLAINTEXT()
- req.sign_request(signature_method, self.consumer, self.token)
-
- return req.to_header()["Authorization"]
-
- def _create_authorization_url_parameters(self):
- params = {
- 'oauth_version': "1.0",
- 'oauth_nonce': oauth.generate_nonce(),
- 'oauth_timestamp': int(time.time()),
- 'oauth_token': self.token.key,
- 'oauth_consumer_key': self.consumer.key
- }
-
- req = oauth.Request(method="GET", url="http://example.com", parameters=params)
-
- signature_method = oauth.SignatureMethod_PLAINTEXT()
- req.sign_request(signature_method, self.consumer, self.token)
- return dict(req)
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_post_form_passing_oauth(self):
- """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
- auth = self._create_authorization_header()
- response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_post_form_repeated_nonce_failing_oauth(self):
- """Ensure POSTing form over OAuth with repeated auth (same nonces and timestamp) credentials fails"""
- auth = self._create_authorization_header()
- response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
-
- # simulate reply attack auth header containes already used (nonce, timestamp) pair
- response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
- self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_post_form_token_removed_failing_oauth(self):
- """Ensure POSTing when there is no OAuth access token in db fails"""
- self.token.delete()
- auth = self._create_authorization_header()
- response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
- self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_post_form_consumer_status_not_accepted_failing_oauth(self):
- """Ensure POSTing when consumer status is anything other than ACCEPTED fails"""
- for consumer_status in (self.consts.CANCELED, self.consts.PENDING, self.consts.REJECTED):
- self.consumer.status = consumer_status
- self.consumer.save()
-
- auth = self._create_authorization_header()
- response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
- self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_post_form_with_request_token_failing_oauth(self):
- """Ensure POSTing with unauthorized request token instead of access token fails"""
- self.token.token_type = self.token.REQUEST
- self.token.save()
-
- auth = self._create_authorization_header()
- response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
- self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_post_form_with_urlencoded_parameters(self):
- """Ensure POSTing with x-www-form-urlencoded auth parameters passes"""
- params = self._create_authorization_url_parameters()
- auth = self._create_authorization_header()
- response = self.csrf_client.post('/oauth/', params, HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_get_form_with_url_parameters(self):
- """Ensure GETing with auth in url parameters passes"""
- params = self._create_authorization_url_parameters()
- response = self.csrf_client.get('/oauth/', params)
- self.assertEqual(response.status_code, 200)
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_post_hmac_sha1_signature_passes(self):
- """Ensure POSTing using HMAC_SHA1 signature method passes"""
- params = {
- 'oauth_version': "1.0",
- 'oauth_nonce': oauth.generate_nonce(),
- 'oauth_timestamp': int(time.time()),
- 'oauth_token': self.token.key,
- 'oauth_consumer_key': self.consumer.key
- }
-
- req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params)
-
- signature_method = oauth.SignatureMethod_HMAC_SHA1()
- req.sign_request(signature_method, self.consumer, self.token)
- auth = req.to_header()["Authorization"]
-
- response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_get_form_with_readonly_resource_passing_auth(self):
- """Ensure POSTing with a readonly scope instead of a write scope fails"""
- read_only_access_token = self.token
- read_only_access_token.scope.is_readonly = True
- read_only_access_token.scope.save()
- params = self._create_authorization_url_parameters()
- response = self.csrf_client.get('/oauth-with-scope/', params)
- self.assertEqual(response.status_code, 200)
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_post_form_with_readonly_resource_failing_auth(self):
- """Ensure POSTing with a readonly resource instead of a write scope fails"""
- read_only_access_token = self.token
- read_only_access_token.scope.is_readonly = True
- read_only_access_token.scope.save()
- params = self._create_authorization_url_parameters()
- response = self.csrf_client.post('/oauth-with-scope/', params)
- self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_post_form_with_write_resource_passing_auth(self):
- """Ensure POSTing with a write resource succeed"""
- read_write_access_token = self.token
- read_write_access_token.scope.is_readonly = False
- read_write_access_token.scope.save()
- params = self._create_authorization_url_parameters()
- auth = self._create_authorization_header()
- response = self.csrf_client.post('/oauth-with-scope/', params, HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_bad_consumer_key(self):
- """Ensure POSTing using HMAC_SHA1 signature method passes"""
- params = {
- 'oauth_version': "1.0",
- 'oauth_nonce': oauth.generate_nonce(),
- 'oauth_timestamp': int(time.time()),
- 'oauth_token': self.token.key,
- 'oauth_consumer_key': 'badconsumerkey'
- }
-
- req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params)
-
- signature_method = oauth.SignatureMethod_HMAC_SHA1()
- req.sign_request(signature_method, self.consumer, self.token)
- auth = req.to_header()["Authorization"]
-
- response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 401)
-
- @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
- @unittest.skipUnless(oauth, 'oauth2 not installed')
- def test_bad_token_key(self):
- """Ensure POSTing using HMAC_SHA1 signature method passes"""
- params = {
- 'oauth_version': "1.0",
- 'oauth_nonce': oauth.generate_nonce(),
- 'oauth_timestamp': int(time.time()),
- 'oauth_token': 'badtokenkey',
- 'oauth_consumer_key': self.consumer.key
- }
-
- req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params)
-
- signature_method = oauth.SignatureMethod_HMAC_SHA1()
- req.sign_request(signature_method, self.consumer, self.token)
- auth = req.to_header()["Authorization"]
-
- response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 401)
-
-
-class OAuth2Tests(TestCase):
- """OAuth 2.0 authentication"""
- urls = 'rest_framework.tests.test_authentication'
-
- def setUp(self):
- self.csrf_client = APIClient(enforce_csrf_checks=True)
- self.username = 'john'
- self.email = 'lennon@thebeatles.com'
- self.password = 'password'
- self.user = User.objects.create_user(self.username, self.email, self.password)
-
- self.CLIENT_ID = 'client_key'
- self.CLIENT_SECRET = 'client_secret'
- self.ACCESS_TOKEN = "access_token"
- self.REFRESH_TOKEN = "refresh_token"
-
- self.oauth2_client = oauth2_provider_models.Client.objects.create(
- client_id=self.CLIENT_ID,
- client_secret=self.CLIENT_SECRET,
- redirect_uri='',
- client_type=0,
- name='example',
- user=None,
- )
-
- self.access_token = oauth2_provider_models.AccessToken.objects.create(
- token=self.ACCESS_TOKEN,
- client=self.oauth2_client,
- user=self.user,
- )
- self.refresh_token = oauth2_provider_models.RefreshToken.objects.create(
- user=self.user,
- access_token=self.access_token,
- client=self.oauth2_client
- )
-
- def _create_authorization_header(self, token=None):
- return "Bearer {0}".format(token or self.access_token.token)
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_get_form_with_wrong_authorization_header_token_type_failing(self):
- """Ensure that a wrong token type lead to the correct HTTP error status code"""
- auth = "Wrong token-type-obsviously"
- response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 401)
- response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 401)
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_get_form_with_wrong_authorization_header_token_format_failing(self):
- """Ensure that a wrong token format lead to the correct HTTP error status code"""
- auth = "Bearer wrong token format"
- response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 401)
- response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 401)
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_get_form_with_wrong_authorization_header_token_failing(self):
- """Ensure that a wrong token lead to the correct HTTP error status code"""
- auth = "Bearer wrong-token"
- response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 401)
- response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 401)
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_get_form_passing_auth(self):
- """Ensure GETing form over OAuth with correct client credentials succeed"""
- auth = self._create_authorization_header()
- response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_post_form_passing_auth(self):
- """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
- auth = self._create_authorization_header()
- response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_post_form_token_removed_failing_auth(self):
- """Ensure POSTing when there is no OAuth access token in db fails"""
- self.access_token.delete()
- auth = self._create_authorization_header()
- response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
- self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_post_form_with_refresh_token_failing_auth(self):
- """Ensure POSTing with refresh token instead of access token fails"""
- auth = self._create_authorization_header(token=self.refresh_token.token)
- response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
- self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_post_form_with_expired_access_token_failing_auth(self):
- """Ensure POSTing with expired access token fails with an 'Invalid token' error"""
- self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10) # 10 seconds late
- self.access_token.save()
- auth = self._create_authorization_header()
- response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
- self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
- self.assertIn('Invalid token', response.content)
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_post_form_with_invalid_scope_failing_auth(self):
- """Ensure POSTing with a readonly scope instead of a write scope fails"""
- read_only_access_token = self.access_token
- read_only_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['read']
- read_only_access_token.save()
- auth = self._create_authorization_header(token=read_only_access_token.token)
- response = self.csrf_client.get('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
- response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_post_form_with_valid_scope_passing_auth(self):
- """Ensure POSTing with a write scope succeed"""
- read_write_access_token = self.access_token
- read_write_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['write']
- read_write_access_token.save()
- auth = self._create_authorization_header(token=read_write_access_token.token)
- response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 200)
-
-
-class FailingAuthAccessedInRenderer(TestCase):
- def setUp(self):
- class AuthAccessingRenderer(renderers.BaseRenderer):
- media_type = 'text/plain'
- format = 'txt'
-
- def render(self, data, media_type=None, renderer_context=None):
- request = renderer_context['request']
- if request.user.is_authenticated():
- return b'authenticated'
- return b'not authenticated'
-
- class FailingAuth(BaseAuthentication):
- def authenticate(self, request):
- raise exceptions.AuthenticationFailed('authentication failed')
-
- class ExampleView(APIView):
- authentication_classes = (FailingAuth,)
- renderer_classes = (AuthAccessingRenderer,)
-
- def get(self, request):
- return Response({'foo': 'bar'})
-
- self.view = ExampleView.as_view()
-
- def test_failing_auth_accessed_in_renderer(self):
- """
- When authentication fails the renderer should still be able to access
- `request.user` without raising an exception. Particularly relevant
- to HTML responses that might reasonably access `request.user`.
- """
- request = factory.get('/')
- response = self.view(request)
- content = response.render().content
- self.assertEqual(content, b'not authenticated')
diff --git a/rest_framework/tests/test_breadcrumbs.py b/rest_framework/tests/test_breadcrumbs.py
deleted file mode 100644
index 41ddf2ce..00000000
--- a/rest_framework/tests/test_breadcrumbs.py
+++ /dev/null
@@ -1,73 +0,0 @@
-from __future__ import unicode_literals
-from django.test import TestCase
-from rest_framework.compat import patterns, url
-from rest_framework.utils.breadcrumbs import get_breadcrumbs
-from rest_framework.views import APIView
-
-
-class Root(APIView):
- pass
-
-
-class ResourceRoot(APIView):
- pass
-
-
-class ResourceInstance(APIView):
- pass
-
-
-class NestedResourceRoot(APIView):
- pass
-
-
-class NestedResourceInstance(APIView):
- pass
-
-urlpatterns = patterns('',
- url(r'^$', Root.as_view()),
- url(r'^resource/$', ResourceRoot.as_view()),
- url(r'^resource/(?P[0-9]+)$', ResourceInstance.as_view()),
- url(r'^resource/(?P[0-9]+)/$', NestedResourceRoot.as_view()),
- url(r'^resource/(?P[0-9]+)/(?P[A-Za-z]+)$', NestedResourceInstance.as_view()),
-)
-
-
-class BreadcrumbTests(TestCase):
- """Tests the breadcrumb functionality used by the HTML renderer."""
-
- urls = 'rest_framework.tests.test_breadcrumbs'
-
- def test_root_breadcrumbs(self):
- url = '/'
- self.assertEqual(get_breadcrumbs(url), [('Root', '/')])
-
- def test_resource_root_breadcrumbs(self):
- url = '/resource/'
- self.assertEqual(get_breadcrumbs(url), [('Root', '/'),
- ('Resource Root', '/resource/')])
-
- def test_resource_instance_breadcrumbs(self):
- url = '/resource/123'
- self.assertEqual(get_breadcrumbs(url), [('Root', '/'),
- ('Resource Root', '/resource/'),
- ('Resource Instance', '/resource/123')])
-
- def test_nested_resource_breadcrumbs(self):
- url = '/resource/123/'
- self.assertEqual(get_breadcrumbs(url), [('Root', '/'),
- ('Resource Root', '/resource/'),
- ('Resource Instance', '/resource/123'),
- ('Nested Resource Root', '/resource/123/')])
-
- def test_nested_resource_instance_breadcrumbs(self):
- url = '/resource/123/abc'
- self.assertEqual(get_breadcrumbs(url), [('Root', '/'),
- ('Resource Root', '/resource/'),
- ('Resource Instance', '/resource/123'),
- ('Nested Resource Root', '/resource/123/'),
- ('Nested Resource Instance', '/resource/123/abc')])
-
- def test_broken_url_breadcrumbs_handled_gracefully(self):
- url = '/foobar'
- self.assertEqual(get_breadcrumbs(url), [('Root', '/')])
diff --git a/rest_framework/tests/test_decorators.py b/rest_framework/tests/test_decorators.py
deleted file mode 100644
index 195f0ba3..00000000
--- a/rest_framework/tests/test_decorators.py
+++ /dev/null
@@ -1,157 +0,0 @@
-from __future__ import unicode_literals
-from django.test import TestCase
-from rest_framework import status
-from rest_framework.authentication import BasicAuthentication
-from rest_framework.parsers import JSONParser
-from rest_framework.permissions import IsAuthenticated
-from rest_framework.response import Response
-from rest_framework.renderers import JSONRenderer
-from rest_framework.test import APIRequestFactory
-from rest_framework.throttling import UserRateThrottle
-from rest_framework.views import APIView
-from rest_framework.decorators import (
- api_view,
- renderer_classes,
- parser_classes,
- authentication_classes,
- throttle_classes,
- permission_classes,
-)
-
-
-class DecoratorTestCase(TestCase):
-
- def setUp(self):
- self.factory = APIRequestFactory()
-
- def _finalize_response(self, request, response, *args, **kwargs):
- response.request = request
- return APIView.finalize_response(self, request, response, *args, **kwargs)
-
- def test_api_view_incorrect(self):
- """
- If @api_view is not applied correct, we should raise an assertion.
- """
-
- @api_view
- def view(request):
- return Response()
-
- request = self.factory.get('/')
- self.assertRaises(AssertionError, view, request)
-
- def test_api_view_incorrect_arguments(self):
- """
- If @api_view is missing arguments, we should raise an assertion.
- """
-
- with self.assertRaises(AssertionError):
- @api_view('GET')
- def view(request):
- return Response()
-
- def test_calling_method(self):
-
- @api_view(['GET'])
- def view(request):
- return Response({})
-
- request = self.factory.get('/')
- response = view(request)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
- request = self.factory.post('/')
- response = view(request)
- self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
-
- def test_calling_put_method(self):
-
- @api_view(['GET', 'PUT'])
- def view(request):
- return Response({})
-
- request = self.factory.put('/')
- response = view(request)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
- request = self.factory.post('/')
- response = view(request)
- self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
-
- def test_calling_patch_method(self):
-
- @api_view(['GET', 'PATCH'])
- def view(request):
- return Response({})
-
- request = self.factory.patch('/')
- response = view(request)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
- request = self.factory.post('/')
- response = view(request)
- self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
-
- def test_renderer_classes(self):
-
- @api_view(['GET'])
- @renderer_classes([JSONRenderer])
- def view(request):
- return Response({})
-
- request = self.factory.get('/')
- response = view(request)
- self.assertTrue(isinstance(response.accepted_renderer, JSONRenderer))
-
- def test_parser_classes(self):
-
- @api_view(['GET'])
- @parser_classes([JSONParser])
- def view(request):
- self.assertEqual(len(request.parsers), 1)
- self.assertTrue(isinstance(request.parsers[0],
- JSONParser))
- return Response({})
-
- request = self.factory.get('/')
- view(request)
-
- def test_authentication_classes(self):
-
- @api_view(['GET'])
- @authentication_classes([BasicAuthentication])
- def view(request):
- self.assertEqual(len(request.authenticators), 1)
- self.assertTrue(isinstance(request.authenticators[0],
- BasicAuthentication))
- return Response({})
-
- request = self.factory.get('/')
- view(request)
-
- def test_permission_classes(self):
-
- @api_view(['GET'])
- @permission_classes([IsAuthenticated])
- def view(request):
- return Response({})
-
- request = self.factory.get('/')
- response = view(request)
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
-
- def test_throttle_classes(self):
- class OncePerDayUserThrottle(UserRateThrottle):
- rate = '1/day'
-
- @api_view(['GET'])
- @throttle_classes([OncePerDayUserThrottle])
- def view(request):
- return Response({})
-
- request = self.factory.get('/')
- response = view(request)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
- response = view(request)
- self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS)
diff --git a/rest_framework/tests/test_description.py b/rest_framework/tests/test_description.py
deleted file mode 100644
index 4c03c1de..00000000
--- a/rest_framework/tests/test_description.py
+++ /dev/null
@@ -1,108 +0,0 @@
-# -- coding: utf-8 --
-
-from __future__ import unicode_literals
-from django.test import TestCase
-from rest_framework.compat import apply_markdown, smart_text
-from rest_framework.views import APIView
-from rest_framework.tests.description import ViewWithNonASCIICharactersInDocstring
-from rest_framework.tests.description import UTF8_TEST_DOCSTRING
-
-# We check that docstrings get nicely un-indented.
-DESCRIPTION = """an example docstring
-====================
-
-* list
-* list
-
-another header
---------------
-
- code block
-
-indented
-
-# hash style header #"""
-
-# If markdown is installed we also test it's working
-# (and that our wrapped forces '=' to h2 and '-' to h3)
-
-# We support markdown < 2.1 and markdown >= 2.1
-MARKED_DOWN_lt_21 = """
an example docstring
-
-
list
-
list
-
-
another header
-
code block
-
-
indented
-
hash style header
"""
-
-MARKED_DOWN_gte_21 = """
an example docstring
-
-
list
-
list
-
-
another header
-
code block
-
-
indented
-
hash style header
"""
-
-
-class TestViewNamesAndDescriptions(TestCase):
- def test_view_name_uses_class_name(self):
- """
- Ensure view names are based on the class name.
- """
- class MockView(APIView):
- pass
- self.assertEqual(MockView().get_view_name(), 'Mock')
-
- def test_view_description_uses_docstring(self):
- """Ensure view descriptions are based on the docstring."""
- class MockView(APIView):
- """an example docstring
- ====================
-
- * list
- * list
-
- another header
- --------------
-
- code block
-
- indented
-
- # hash style header #"""
-
- self.assertEqual(MockView().get_view_description(), DESCRIPTION)
-
- def test_view_description_supports_unicode(self):
- """
- Unicode in docstrings should be respected.
- """
-
- self.assertEqual(
- ViewWithNonASCIICharactersInDocstring().get_view_description(),
- smart_text(UTF8_TEST_DOCSTRING)
- )
-
- def test_view_description_can_be_empty(self):
- """
- Ensure that if a view has no docstring,
- then it's description is the empty string.
- """
- class MockView(APIView):
- pass
- self.assertEqual(MockView().get_view_description(), '')
-
- def test_markdown(self):
- """
- Ensure markdown to HTML works as expected.
- """
- if apply_markdown:
- gte_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_gte_21
- lt_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_lt_21
- self.assertTrue(gte_21_match or lt_21_match)
diff --git a/rest_framework/tests/test_fields.py b/rest_framework/tests/test_fields.py
deleted file mode 100644
index e127feef..00000000
--- a/rest_framework/tests/test_fields.py
+++ /dev/null
@@ -1,984 +0,0 @@
-"""
-General serializer field tests.
-"""
-from __future__ import unicode_literals
-
-import datetime
-from decimal import Decimal
-from uuid import uuid4
-from django.core import validators
-from django.db import models
-from django.test import TestCase
-from django.utils.datastructures import SortedDict
-from rest_framework import serializers
-from rest_framework.tests.models import RESTFrameworkModel
-
-
-class TimestampedModel(models.Model):
- added = models.DateTimeField(auto_now_add=True)
- updated = models.DateTimeField(auto_now=True)
-
-
-class CharPrimaryKeyModel(models.Model):
- id = models.CharField(max_length=20, primary_key=True)
-
-
-class TimestampedModelSerializer(serializers.ModelSerializer):
- class Meta:
- model = TimestampedModel
-
-
-class CharPrimaryKeyModelSerializer(serializers.ModelSerializer):
- class Meta:
- model = CharPrimaryKeyModel
-
-
-class TimeFieldModel(models.Model):
- clock = models.TimeField()
-
-
-class TimeFieldModelSerializer(serializers.ModelSerializer):
- class Meta:
- model = TimeFieldModel
-
-
-SAMPLE_CHOICES = [
- ('red', 'Red'),
- ('green', 'Green'),
- ('blue', 'Blue'),
-]
-
-
-class ChoiceFieldModel(models.Model):
- choice = models.CharField(choices=SAMPLE_CHOICES, blank=True, max_length=255)
-
-
-class ChoiceFieldModelSerializer(serializers.ModelSerializer):
- class Meta:
- model = ChoiceFieldModel
-
-
-class ChoiceFieldModelWithNull(models.Model):
- choice = models.CharField(choices=SAMPLE_CHOICES, blank=True, null=True, max_length=255)
-
-
-class ChoiceFieldModelWithNullSerializer(serializers.ModelSerializer):
- class Meta:
- model = ChoiceFieldModelWithNull
-
-
-class BasicFieldTests(TestCase):
- def test_auto_now_fields_read_only(self):
- """
- auto_now and auto_now_add fields should be read_only by default.
- """
- serializer = TimestampedModelSerializer()
- self.assertEqual(serializer.fields['added'].read_only, True)
-
- def test_auto_pk_fields_read_only(self):
- """
- AutoField fields should be read_only by default.
- """
- serializer = TimestampedModelSerializer()
- self.assertEqual(serializer.fields['id'].read_only, True)
-
- def test_non_auto_pk_fields_not_read_only(self):
- """
- PK fields other than AutoField fields should not be read_only by default.
- """
- serializer = CharPrimaryKeyModelSerializer()
- self.assertEqual(serializer.fields['id'].read_only, False)
-
- def test_dict_field_ordering(self):
- """
- Field should preserve dictionary ordering, if it exists.
- See: https://github.com/tomchristie/django-rest-framework/issues/832
- """
- ret = SortedDict()
- ret['c'] = 1
- ret['b'] = 1
- ret['a'] = 1
- ret['z'] = 1
- field = serializers.Field()
- keys = list(field.to_native(ret).keys())
- self.assertEqual(keys, ['c', 'b', 'a', 'z'])
-
-
-class DateFieldTest(TestCase):
- """
- Tests for the DateFieldTest from_native() and to_native() behavior
- """
-
- def test_from_native_string(self):
- """
- Make sure from_native() accepts default iso input formats.
- """
- f = serializers.DateField()
- result_1 = f.from_native('1984-07-31')
-
- self.assertEqual(datetime.date(1984, 7, 31), result_1)
-
- def test_from_native_datetime_date(self):
- """
- Make sure from_native() accepts a datetime.date instance.
- """
- f = serializers.DateField()
- result_1 = f.from_native(datetime.date(1984, 7, 31))
-
- self.assertEqual(result_1, datetime.date(1984, 7, 31))
-
- def test_from_native_custom_format(self):
- """
- Make sure from_native() accepts custom input formats.
- """
- f = serializers.DateField(input_formats=['%Y -- %d'])
- result = f.from_native('1984 -- 31')
-
- self.assertEqual(datetime.date(1984, 1, 31), result)
-
- def test_from_native_invalid_default_on_custom_format(self):
- """
- Make sure from_native() don't accept default formats if custom format is preset
- """
- f = serializers.DateField(input_formats=['%Y -- %d'])
-
- try:
- f.from_native('1984-07-31')
- except validators.ValidationError as e:
- self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY -- DD"])
- else:
- self.fail("ValidationError was not properly raised")
-
- def test_from_native_empty(self):
- """
- Make sure from_native() returns None on empty param.
- """
- f = serializers.DateField()
- result = f.from_native('')
-
- self.assertEqual(result, None)
-
- def test_from_native_none(self):
- """
- Make sure from_native() returns None on None param.
- """
- f = serializers.DateField()
- result = f.from_native(None)
-
- self.assertEqual(result, None)
-
- def test_from_native_invalid_date(self):
- """
- Make sure from_native() raises a ValidationError on passing an invalid date.
- """
- f = serializers.DateField()
-
- try:
- f.from_native('1984-13-31')
- except validators.ValidationError as e:
- self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]"])
- else:
- self.fail("ValidationError was not properly raised")
-
- def test_from_native_invalid_format(self):
- """
- Make sure from_native() raises a ValidationError on passing an invalid format.
- """
- f = serializers.DateField()
-
- try:
- f.from_native('1984 -- 31')
- except validators.ValidationError as e:
- self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]"])
- else:
- self.fail("ValidationError was not properly raised")
-
- def test_to_native(self):
- """
- Make sure to_native() returns datetime as default.
- """
- f = serializers.DateField()
-
- result_1 = f.to_native(datetime.date(1984, 7, 31))
-
- self.assertEqual(datetime.date(1984, 7, 31), result_1)
-
- def test_to_native_iso(self):
- """
- Make sure to_native() with 'iso-8601' returns iso formated date.
- """
- f = serializers.DateField(format='iso-8601')
-
- result_1 = f.to_native(datetime.date(1984, 7, 31))
-
- self.assertEqual('1984-07-31', result_1)
-
- def test_to_native_custom_format(self):
- """
- Make sure to_native() returns correct custom format.
- """
- f = serializers.DateField(format="%Y - %m.%d")
-
- result_1 = f.to_native(datetime.date(1984, 7, 31))
-
- self.assertEqual('1984 - 07.31', result_1)
-
- def test_to_native_none(self):
- """
- Make sure from_native() returns None on None param.
- """
- f = serializers.DateField(required=False)
- self.assertEqual(None, f.to_native(None))
-
-
-class DateTimeFieldTest(TestCase):
- """
- Tests for the DateTimeField from_native() and to_native() behavior
- """
-
- def test_from_native_string(self):
- """
- Make sure from_native() accepts default iso input formats.
- """
- f = serializers.DateTimeField()
- result_1 = f.from_native('1984-07-31 04:31')
- result_2 = f.from_native('1984-07-31 04:31:59')
- result_3 = f.from_native('1984-07-31 04:31:59.000200')
-
- self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_1)
- self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_2)
- self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_3)
-
- def test_from_native_datetime_datetime(self):
- """
- Make sure from_native() accepts a datetime.datetime instance.
- """
- f = serializers.DateTimeField()
- result_1 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31))
- result_2 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
- result_3 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
-
- self.assertEqual(result_1, datetime.datetime(1984, 7, 31, 4, 31))
- self.assertEqual(result_2, datetime.datetime(1984, 7, 31, 4, 31, 59))
- self.assertEqual(result_3, datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
-
- def test_from_native_custom_format(self):
- """
- Make sure from_native() accepts custom input formats.
- """
- f = serializers.DateTimeField(input_formats=['%Y -- %H:%M'])
- result = f.from_native('1984 -- 04:59')
-
- self.assertEqual(datetime.datetime(1984, 1, 1, 4, 59), result)
-
- def test_from_native_invalid_default_on_custom_format(self):
- """
- Make sure from_native() don't accept default formats if custom format is preset
- """
- f = serializers.DateTimeField(input_formats=['%Y -- %H:%M'])
-
- try:
- f.from_native('1984-07-31 04:31:59')
- except validators.ValidationError as e:
- self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: YYYY -- hh:mm"])
- else:
- self.fail("ValidationError was not properly raised")
-
- def test_from_native_empty(self):
- """
- Make sure from_native() returns None on empty param.
- """
- f = serializers.DateTimeField()
- result = f.from_native('')
-
- self.assertEqual(result, None)
-
- def test_from_native_none(self):
- """
- Make sure from_native() returns None on None param.
- """
- f = serializers.DateTimeField()
- result = f.from_native(None)
-
- self.assertEqual(result, None)
-
- def test_from_native_invalid_datetime(self):
- """
- Make sure from_native() raises a ValidationError on passing an invalid datetime.
- """
- f = serializers.DateTimeField()
-
- try:
- f.from_native('04:61:59')
- except validators.ValidationError as e:
- self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: "
- "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]"])
- else:
- self.fail("ValidationError was not properly raised")
-
- def test_from_native_invalid_format(self):
- """
- Make sure from_native() raises a ValidationError on passing an invalid format.
- """
- f = serializers.DateTimeField()
-
- try:
- f.from_native('04 -- 31')
- except validators.ValidationError as e:
- self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: "
- "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]"])
- else:
- self.fail("ValidationError was not properly raised")
-
- def test_to_native(self):
- """
- Make sure to_native() returns isoformat as default.
- """
- f = serializers.DateTimeField()
-
- result_1 = f.to_native(datetime.datetime(1984, 7, 31))
- result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31))
- result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
- result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
-
- self.assertEqual(datetime.datetime(1984, 7, 31), result_1)
- self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_2)
- self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_3)
- self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_4)
-
- def test_to_native_iso(self):
- """
- Make sure to_native() with format=iso-8601 returns iso formatted datetime.
- """
- f = serializers.DateTimeField(format='iso-8601')
-
- result_1 = f.to_native(datetime.datetime(1984, 7, 31))
- result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31))
- result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
- result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
-
- self.assertEqual('1984-07-31T00:00:00', result_1)
- self.assertEqual('1984-07-31T04:31:00', result_2)
- self.assertEqual('1984-07-31T04:31:59', result_3)
- self.assertEqual('1984-07-31T04:31:59.000200', result_4)
-
- def test_to_native_custom_format(self):
- """
- Make sure to_native() returns correct custom format.
- """
- f = serializers.DateTimeField(format="%Y - %H:%M")
-
- result_1 = f.to_native(datetime.datetime(1984, 7, 31))
- result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31))
- result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
- result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
-
- self.assertEqual('1984 - 00:00', result_1)
- self.assertEqual('1984 - 04:31', result_2)
- self.assertEqual('1984 - 04:31', result_3)
- self.assertEqual('1984 - 04:31', result_4)
-
- def test_to_native_none(self):
- """
- Make sure from_native() returns None on None param.
- """
- f = serializers.DateTimeField(required=False)
- self.assertEqual(None, f.to_native(None))
-
-
-class TimeFieldTest(TestCase):
- """
- Tests for the TimeField from_native() and to_native() behavior
- """
-
- def test_from_native_string(self):
- """
- Make sure from_native() accepts default iso input formats.
- """
- f = serializers.TimeField()
- result_1 = f.from_native('04:31')
- result_2 = f.from_native('04:31:59')
- result_3 = f.from_native('04:31:59.000200')
-
- self.assertEqual(datetime.time(4, 31), result_1)
- self.assertEqual(datetime.time(4, 31, 59), result_2)
- self.assertEqual(datetime.time(4, 31, 59, 200), result_3)
-
- def test_from_native_datetime_time(self):
- """
- Make sure from_native() accepts a datetime.time instance.
- """
- f = serializers.TimeField()
- result_1 = f.from_native(datetime.time(4, 31))
- result_2 = f.from_native(datetime.time(4, 31, 59))
- result_3 = f.from_native(datetime.time(4, 31, 59, 200))
-
- self.assertEqual(result_1, datetime.time(4, 31))
- self.assertEqual(result_2, datetime.time(4, 31, 59))
- self.assertEqual(result_3, datetime.time(4, 31, 59, 200))
-
- def test_from_native_custom_format(self):
- """
- Make sure from_native() accepts custom input formats.
- """
- f = serializers.TimeField(input_formats=['%H -- %M'])
- result = f.from_native('04 -- 31')
-
- self.assertEqual(datetime.time(4, 31), result)
-
- def test_from_native_invalid_default_on_custom_format(self):
- """
- Make sure from_native() don't accept default formats if custom format is preset
- """
- f = serializers.TimeField(input_formats=['%H -- %M'])
-
- try:
- f.from_native('04:31:59')
- except validators.ValidationError as e:
- self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: hh -- mm"])
- else:
- self.fail("ValidationError was not properly raised")
-
- def test_from_native_empty(self):
- """
- Make sure from_native() returns None on empty param.
- """
- f = serializers.TimeField()
- result = f.from_native('')
-
- self.assertEqual(result, None)
-
- def test_from_native_none(self):
- """
- Make sure from_native() returns None on None param.
- """
- f = serializers.TimeField()
- result = f.from_native(None)
-
- self.assertEqual(result, None)
-
- def test_from_native_invalid_time(self):
- """
- Make sure from_native() raises a ValidationError on passing an invalid time.
- """
- f = serializers.TimeField()
-
- try:
- f.from_native('04:61:59')
- except validators.ValidationError as e:
- self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: "
- "hh:mm[:ss[.uuuuuu]]"])
- else:
- self.fail("ValidationError was not properly raised")
-
- def test_from_native_invalid_format(self):
- """
- Make sure from_native() raises a ValidationError on passing an invalid format.
- """
- f = serializers.TimeField()
-
- try:
- f.from_native('04 -- 31')
- except validators.ValidationError as e:
- self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: "
- "hh:mm[:ss[.uuuuuu]]"])
- else:
- self.fail("ValidationError was not properly raised")
-
- def test_to_native(self):
- """
- Make sure to_native() returns time object as default.
- """
- f = serializers.TimeField()
- result_1 = f.to_native(datetime.time(4, 31))
- result_2 = f.to_native(datetime.time(4, 31, 59))
- result_3 = f.to_native(datetime.time(4, 31, 59, 200))
-
- self.assertEqual(datetime.time(4, 31), result_1)
- self.assertEqual(datetime.time(4, 31, 59), result_2)
- self.assertEqual(datetime.time(4, 31, 59, 200), result_3)
-
- def test_to_native_iso(self):
- """
- Make sure to_native() with format='iso-8601' returns iso formatted time.
- """
- f = serializers.TimeField(format='iso-8601')
- result_1 = f.to_native(datetime.time(4, 31))
- result_2 = f.to_native(datetime.time(4, 31, 59))
- result_3 = f.to_native(datetime.time(4, 31, 59, 200))
-
- self.assertEqual('04:31:00', result_1)
- self.assertEqual('04:31:59', result_2)
- self.assertEqual('04:31:59.000200', result_3)
-
- def test_to_native_custom_format(self):
- """
- Make sure to_native() returns correct custom format.
- """
- f = serializers.TimeField(format="%H - %S [%f]")
- result_1 = f.to_native(datetime.time(4, 31))
- result_2 = f.to_native(datetime.time(4, 31, 59))
- result_3 = f.to_native(datetime.time(4, 31, 59, 200))
-
- self.assertEqual('04 - 00 [000000]', result_1)
- self.assertEqual('04 - 59 [000000]', result_2)
- self.assertEqual('04 - 59 [000200]', result_3)
-
-
-class DecimalFieldTest(TestCase):
- """
- Tests for the DecimalField from_native() and to_native() behavior
- """
-
- def test_from_native_string(self):
- """
- Make sure from_native() accepts string values
- """
- f = serializers.DecimalField()
- result_1 = f.from_native('9000')
- result_2 = f.from_native('1.00000001')
-
- self.assertEqual(Decimal('9000'), result_1)
- self.assertEqual(Decimal('1.00000001'), result_2)
-
- def test_from_native_invalid_string(self):
- """
- Make sure from_native() raises ValidationError on passing invalid string
- """
- f = serializers.DecimalField()
-
- try:
- f.from_native('123.45.6')
- except validators.ValidationError as e:
- self.assertEqual(e.messages, ["Enter a number."])
- else:
- self.fail("ValidationError was not properly raised")
-
- def test_from_native_integer(self):
- """
- Make sure from_native() accepts integer values
- """
- f = serializers.DecimalField()
- result = f.from_native(9000)
-
- self.assertEqual(Decimal('9000'), result)
-
- def test_from_native_float(self):
- """
- Make sure from_native() accepts float values
- """
- f = serializers.DecimalField()
- result = f.from_native(1.00000001)
-
- self.assertEqual(Decimal('1.00000001'), result)
-
- def test_from_native_empty(self):
- """
- Make sure from_native() returns None on empty param.
- """
- f = serializers.DecimalField()
- result = f.from_native('')
-
- self.assertEqual(result, None)
-
- def test_from_native_none(self):
- """
- Make sure from_native() returns None on None param.
- """
- f = serializers.DecimalField()
- result = f.from_native(None)
-
- self.assertEqual(result, None)
-
- def test_to_native(self):
- """
- Make sure to_native() returns Decimal as string.
- """
- f = serializers.DecimalField()
-
- result_1 = f.to_native(Decimal('9000'))
- result_2 = f.to_native(Decimal('1.00000001'))
-
- self.assertEqual(Decimal('9000'), result_1)
- self.assertEqual(Decimal('1.00000001'), result_2)
-
- def test_to_native_none(self):
- """
- Make sure from_native() returns None on None param.
- """
- f = serializers.DecimalField(required=False)
- self.assertEqual(None, f.to_native(None))
-
- def test_valid_serialization(self):
- """
- Make sure the serializer works correctly
- """
- class DecimalSerializer(serializers.Serializer):
- decimal_field = serializers.DecimalField(max_value=9010,
- min_value=9000,
- max_digits=6,
- decimal_places=2)
-
- self.assertTrue(DecimalSerializer(data={'decimal_field': '9001'}).is_valid())
- self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.2'}).is_valid())
- self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.23'}).is_valid())
-
- self.assertFalse(DecimalSerializer(data={'decimal_field': '8000'}).is_valid())
- self.assertFalse(DecimalSerializer(data={'decimal_field': '9900'}).is_valid())
- self.assertFalse(DecimalSerializer(data={'decimal_field': '9001.234'}).is_valid())
-
- def test_raise_max_value(self):
- """
- Make sure max_value violations raises ValidationError
- """
- class DecimalSerializer(serializers.Serializer):
- decimal_field = serializers.DecimalField(max_value=100)
-
- s = DecimalSerializer(data={'decimal_field': '123'})
-
- self.assertFalse(s.is_valid())
- self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is less than or equal to 100.']})
-
- def test_raise_min_value(self):
- """
- Make sure min_value violations raises ValidationError
- """
- class DecimalSerializer(serializers.Serializer):
- decimal_field = serializers.DecimalField(min_value=100)
-
- s = DecimalSerializer(data={'decimal_field': '99'})
-
- self.assertFalse(s.is_valid())
- self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is greater than or equal to 100.']})
-
- def test_raise_max_digits(self):
- """
- Make sure max_digits violations raises ValidationError
- """
- class DecimalSerializer(serializers.Serializer):
- decimal_field = serializers.DecimalField(max_digits=5)
-
- s = DecimalSerializer(data={'decimal_field': '123.456'})
-
- self.assertFalse(s.is_valid())
- self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 5 digits in total.']})
-
- def test_raise_max_decimal_places(self):
- """
- Make sure max_decimal_places violations raises ValidationError
- """
- class DecimalSerializer(serializers.Serializer):
- decimal_field = serializers.DecimalField(decimal_places=3)
-
- s = DecimalSerializer(data={'decimal_field': '123.4567'})
-
- self.assertFalse(s.is_valid())
- self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 3 decimal places.']})
-
- def test_raise_max_whole_digits(self):
- """
- Make sure max_whole_digits violations raises ValidationError
- """
- class DecimalSerializer(serializers.Serializer):
- decimal_field = serializers.DecimalField(max_digits=4, decimal_places=3)
-
- s = DecimalSerializer(data={'decimal_field': '12345.6'})
-
- self.assertFalse(s.is_valid())
- self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']})
-
-
-class ChoiceFieldTests(TestCase):
- """
- Tests for the ChoiceField options generator
- """
- def test_choices_required(self):
- """
- Make sure proper choices are rendered if field is required
- """
- f = serializers.ChoiceField(required=True, choices=SAMPLE_CHOICES)
- self.assertEqual(f.choices, SAMPLE_CHOICES)
-
- def test_choices_not_required(self):
- """
- Make sure proper choices (plus blank) are rendered if the field isn't required
- """
- f = serializers.ChoiceField(required=False, choices=SAMPLE_CHOICES)
- self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + SAMPLE_CHOICES)
-
- def test_invalid_choice_model(self):
- s = ChoiceFieldModelSerializer(data={'choice': 'wrong_value'})
- self.assertFalse(s.is_valid())
- self.assertEqual(s.errors, {'choice': ['Select a valid choice. wrong_value is not one of the available choices.']})
- self.assertEqual(s.data['choice'], '')
-
- def test_empty_choice_model(self):
- """
- Test that the 'empty' value is correctly passed and used depending on
- the 'null' property on the model field.
- """
- s = ChoiceFieldModelSerializer(data={'choice': ''})
- self.assertTrue(s.is_valid())
- self.assertEqual(s.data['choice'], '')
-
- s = ChoiceFieldModelWithNullSerializer(data={'choice': ''})
- self.assertTrue(s.is_valid())
- self.assertEqual(s.data['choice'], None)
-
- def test_from_native_empty(self):
- """
- Make sure from_native() returns an empty string on empty param by default.
- """
- f = serializers.ChoiceField(choices=SAMPLE_CHOICES)
- self.assertEqual(f.from_native(''), '')
- self.assertEqual(f.from_native(None), '')
-
- def test_from_native_empty_override(self):
- """
- Make sure you can override from_native() behavior regarding empty values.
- """
- f = serializers.ChoiceField(choices=SAMPLE_CHOICES, empty=None)
- self.assertEqual(f.from_native(''), None)
- self.assertEqual(f.from_native(None), None)
-
- def test_metadata_choices(self):
- """
- Make sure proper choices are included in the field's metadata.
- """
- choices = [{'value': v, 'display_name': n} for v, n in SAMPLE_CHOICES]
- f = serializers.ChoiceField(choices=SAMPLE_CHOICES)
- self.assertEqual(f.metadata()['choices'], choices)
-
- def test_metadata_choices_not_required(self):
- """
- Make sure proper choices are included in the field's metadata.
- """
- choices = [{'value': v, 'display_name': n}
- for v, n in models.fields.BLANK_CHOICE_DASH + SAMPLE_CHOICES]
- f = serializers.ChoiceField(required=False, choices=SAMPLE_CHOICES)
- self.assertEqual(f.metadata()['choices'], choices)
-
-
-class EmailFieldTests(TestCase):
- """
- Tests for EmailField attribute values
- """
-
- class EmailFieldModel(RESTFrameworkModel):
- email_field = models.EmailField(blank=True)
-
- class EmailFieldWithGivenMaxLengthModel(RESTFrameworkModel):
- email_field = models.EmailField(max_length=150, blank=True)
-
- def test_default_model_value(self):
- class EmailFieldSerializer(serializers.ModelSerializer):
- class Meta:
- model = self.EmailFieldModel
-
- serializer = EmailFieldSerializer(data={})
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 75)
-
- def test_given_model_value(self):
- class EmailFieldSerializer(serializers.ModelSerializer):
- class Meta:
- model = self.EmailFieldWithGivenMaxLengthModel
-
- serializer = EmailFieldSerializer(data={})
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 150)
-
- def test_given_serializer_value(self):
- class EmailFieldSerializer(serializers.ModelSerializer):
- email_field = serializers.EmailField(source='email_field', max_length=20, required=False)
-
- class Meta:
- model = self.EmailFieldModel
-
- serializer = EmailFieldSerializer(data={})
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 20)
-
-
-class SlugFieldTests(TestCase):
- """
- Tests for SlugField attribute values
- """
-
- class SlugFieldModel(RESTFrameworkModel):
- slug_field = models.SlugField(blank=True)
-
- class SlugFieldWithGivenMaxLengthModel(RESTFrameworkModel):
- slug_field = models.SlugField(max_length=84, blank=True)
-
- def test_default_model_value(self):
- class SlugFieldSerializer(serializers.ModelSerializer):
- class Meta:
- model = self.SlugFieldModel
-
- serializer = SlugFieldSerializer(data={})
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(getattr(serializer.fields['slug_field'], 'max_length'), 50)
-
- def test_given_model_value(self):
- class SlugFieldSerializer(serializers.ModelSerializer):
- class Meta:
- model = self.SlugFieldWithGivenMaxLengthModel
-
- serializer = SlugFieldSerializer(data={})
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(getattr(serializer.fields['slug_field'], 'max_length'), 84)
-
- def test_given_serializer_value(self):
- class SlugFieldSerializer(serializers.ModelSerializer):
- slug_field = serializers.SlugField(source='slug_field',
- max_length=20, required=False)
-
- class Meta:
- model = self.SlugFieldModel
-
- serializer = SlugFieldSerializer(data={})
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(getattr(serializer.fields['slug_field'],
- 'max_length'), 20)
-
- def test_invalid_slug(self):
- """
- Make sure an invalid slug raises ValidationError
- """
- class SlugFieldSerializer(serializers.ModelSerializer):
- slug_field = serializers.SlugField(source='slug_field', max_length=20, required=True)
-
- class Meta:
- model = self.SlugFieldModel
-
- s = SlugFieldSerializer(data={'slug_field': 'a b'})
-
- self.assertEqual(s.is_valid(), False)
- self.assertEqual(s.errors, {'slug_field': ["Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens."]})
-
-
-class URLFieldTests(TestCase):
- """
- Tests for URLField attribute values.
-
- (Includes test for #1210, checking that validators can be overridden.)
- """
-
- class URLFieldModel(RESTFrameworkModel):
- url_field = models.URLField(blank=True)
-
- class URLFieldWithGivenMaxLengthModel(RESTFrameworkModel):
- url_field = models.URLField(max_length=128, blank=True)
-
- def test_default_model_value(self):
- class URLFieldSerializer(serializers.ModelSerializer):
- class Meta:
- model = self.URLFieldModel
-
- serializer = URLFieldSerializer(data={})
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(getattr(serializer.fields['url_field'],
- 'max_length'), 200)
-
- def test_given_model_value(self):
- class URLFieldSerializer(serializers.ModelSerializer):
- class Meta:
- model = self.URLFieldWithGivenMaxLengthModel
-
- serializer = URLFieldSerializer(data={})
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(getattr(serializer.fields['url_field'],
- 'max_length'), 128)
-
- def test_given_serializer_value(self):
- class URLFieldSerializer(serializers.ModelSerializer):
- url_field = serializers.URLField(source='url_field',
- max_length=20, required=False)
-
- class Meta:
- model = self.URLFieldWithGivenMaxLengthModel
-
- serializer = URLFieldSerializer(data={})
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(getattr(serializer.fields['url_field'],
- 'max_length'), 20)
-
- def test_validators_can_be_overridden(self):
- url_field = serializers.URLField(validators=[])
- validators = url_field.validators
- self.assertEqual([], validators, 'Passing `validators` kwarg should have overridden default validators')
-
-
-class FieldMetadata(TestCase):
- def setUp(self):
- self.required_field = serializers.Field()
- self.required_field.label = uuid4().hex
- self.required_field.required = True
-
- self.optional_field = serializers.Field()
- self.optional_field.label = uuid4().hex
- self.optional_field.required = False
-
- def test_required(self):
- self.assertEqual(self.required_field.metadata()['required'], True)
-
- def test_optional(self):
- self.assertEqual(self.optional_field.metadata()['required'], False)
-
- def test_label(self):
- for field in (self.required_field, self.optional_field):
- self.assertEqual(field.metadata()['label'], field.label)
-
-
-class FieldCallableDefault(TestCase):
- def setUp(self):
- self.simple_callable = lambda: 'foo bar'
-
- def test_default_can_be_simple_callable(self):
- """
- Ensure that the 'default' argument can also be a simple callable.
- """
- field = serializers.WritableField(default=self.simple_callable)
- into = {}
- field.field_from_native({}, {}, 'field', into)
- self.assertEqual(into, {'field': 'foo bar'})
-
-
-class CustomIntegerField(TestCase):
- """
- Test that custom fields apply min_value and max_value constraints
- """
- def test_custom_fields_can_be_validated_for_value(self):
-
- class MoneyField(models.PositiveIntegerField):
- pass
-
- class EntryModel(models.Model):
- bank = MoneyField(validators=[validators.MaxValueValidator(100)])
-
- class EntrySerializer(serializers.ModelSerializer):
- class Meta:
- model = EntryModel
-
- entry = EntryModel(bank=1)
-
- serializer = EntrySerializer(entry, data={"bank": 11})
- self.assertTrue(serializer.is_valid())
-
- serializer = EntrySerializer(entry, data={"bank": -1})
- self.assertFalse(serializer.is_valid())
-
- serializer = EntrySerializer(entry, data={"bank": 101})
- self.assertFalse(serializer.is_valid())
-
-
-class BooleanField(TestCase):
- """
- Tests for BooleanField
- """
- def test_boolean_required(self):
- class BooleanRequiredSerializer(serializers.Serializer):
- bool_field = serializers.BooleanField(required=True)
-
- self.assertFalse(BooleanRequiredSerializer(data={}).is_valid())
diff --git a/rest_framework/tests/test_files.py b/rest_framework/tests/test_files.py
deleted file mode 100644
index 78f4cf42..00000000
--- a/rest_framework/tests/test_files.py
+++ /dev/null
@@ -1,95 +0,0 @@
-from __future__ import unicode_literals
-from django.test import TestCase
-from rest_framework import serializers
-from rest_framework.compat import BytesIO
-from rest_framework.compat import six
-import datetime
-
-
-class UploadedFile(object):
- def __init__(self, file=None, created=None):
- self.file = file
- self.created = created or datetime.datetime.now()
-
-
-class UploadedFileSerializer(serializers.Serializer):
- file = serializers.FileField(required=False)
- created = serializers.DateTimeField()
-
- def restore_object(self, attrs, instance=None):
- if instance:
- instance.file = attrs['file']
- instance.created = attrs['created']
- return instance
- return UploadedFile(**attrs)
-
-
-class FileSerializerTests(TestCase):
- def test_create(self):
- now = datetime.datetime.now()
- file = BytesIO(six.b('stuff'))
- file.name = 'stuff.txt'
- file.size = len(file.getvalue())
- serializer = UploadedFileSerializer(data={'created': now}, files={'file': file})
- uploaded_file = UploadedFile(file=file, created=now)
- self.assertTrue(serializer.is_valid())
- self.assertEqual(serializer.object.created, uploaded_file.created)
- self.assertEqual(serializer.object.file, uploaded_file.file)
- self.assertFalse(serializer.object is uploaded_file)
-
- def test_creation_failure(self):
- """
- Passing files=None should result in an ValidationError
-
- Regression test for:
- https://github.com/tomchristie/django-rest-framework/issues/542
- """
- now = datetime.datetime.now()
-
- serializer = UploadedFileSerializer(data={'created': now})
- self.assertTrue(serializer.is_valid())
- self.assertEqual(serializer.object.created, now)
- self.assertIsNone(serializer.object.file)
-
- def test_remove_with_empty_string(self):
- """
- Passing empty string as data should cause file to be removed
-
- Test for:
- https://github.com/tomchristie/django-rest-framework/issues/937
- """
- now = datetime.datetime.now()
- file = BytesIO(six.b('stuff'))
- file.name = 'stuff.txt'
- file.size = len(file.getvalue())
-
- uploaded_file = UploadedFile(file=file, created=now)
-
- serializer = UploadedFileSerializer(instance=uploaded_file, data={'created': now, 'file': ''})
- self.assertTrue(serializer.is_valid())
- self.assertEqual(serializer.object.created, uploaded_file.created)
- self.assertIsNone(serializer.object.file)
-
- def test_validation_error_with_non_file(self):
- """
- Passing non-files should raise a validation error.
- """
- now = datetime.datetime.now()
- errmsg = 'No file was submitted. Check the encoding type on the form.'
-
- serializer = UploadedFileSerializer(data={'created': now, 'file': 'abc'})
- self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'file': [errmsg]})
-
- def test_validation_with_no_data(self):
- """
- Validation should still function when no data dictionary is provided.
- """
- now = datetime.datetime.now()
- file = BytesIO(six.b('stuff'))
- file.name = 'stuff.txt'
- file.size = len(file.getvalue())
- uploaded_file = UploadedFile(file=file, created=now)
-
- serializer = UploadedFileSerializer(files={'file': file})
- self.assertFalse(serializer.is_valid())
diff --git a/rest_framework/tests/test_filters.py b/rest_framework/tests/test_filters.py
deleted file mode 100644
index 18188186..00000000
--- a/rest_framework/tests/test_filters.py
+++ /dev/null
@@ -1,615 +0,0 @@
-from __future__ import unicode_literals
-import datetime
-from decimal import Decimal
-from django.db import models
-from django.core.urlresolvers import reverse
-from django.test import TestCase
-from django.utils import unittest
-from rest_framework import generics, serializers, status, filters
-from rest_framework.compat import django_filters, patterns, url
-from rest_framework.test import APIRequestFactory
-from rest_framework.tests.models import BasicModel
-
-factory = APIRequestFactory()
-
-
-class FilterableItem(models.Model):
- text = models.CharField(max_length=100)
- decimal = models.DecimalField(max_digits=4, decimal_places=2)
- date = models.DateField()
-
-
-if django_filters:
- # Basic filter on a list view.
- class FilterFieldsRootView(generics.ListCreateAPIView):
- model = FilterableItem
- filter_fields = ['decimal', 'date']
- filter_backends = (filters.DjangoFilterBackend,)
-
- # These class are used to test a filter class.
- class SeveralFieldsFilter(django_filters.FilterSet):
- text = django_filters.CharFilter(lookup_type='icontains')
- decimal = django_filters.NumberFilter(lookup_type='lt')
- date = django_filters.DateFilter(lookup_type='gt')
-
- class Meta:
- model = FilterableItem
- fields = ['text', 'decimal', 'date']
-
- class FilterClassRootView(generics.ListCreateAPIView):
- model = FilterableItem
- filter_class = SeveralFieldsFilter
- filter_backends = (filters.DjangoFilterBackend,)
-
- # These classes are used to test a misconfigured filter class.
- class MisconfiguredFilter(django_filters.FilterSet):
- text = django_filters.CharFilter(lookup_type='icontains')
-
- class Meta:
- model = BasicModel
- fields = ['text']
-
- class IncorrectlyConfiguredRootView(generics.ListCreateAPIView):
- model = FilterableItem
- filter_class = MisconfiguredFilter
- filter_backends = (filters.DjangoFilterBackend,)
-
- class FilterClassDetailView(generics.RetrieveAPIView):
- model = FilterableItem
- filter_class = SeveralFieldsFilter
- filter_backends = (filters.DjangoFilterBackend,)
-
- # Regression test for #814
- class FilterableItemSerializer(serializers.ModelSerializer):
- class Meta:
- model = FilterableItem
-
- class FilterFieldsQuerysetView(generics.ListCreateAPIView):
- queryset = FilterableItem.objects.all()
- serializer_class = FilterableItemSerializer
- filter_fields = ['decimal', 'date']
- filter_backends = (filters.DjangoFilterBackend,)
-
- class GetQuerysetView(generics.ListCreateAPIView):
- serializer_class = FilterableItemSerializer
- filter_class = SeveralFieldsFilter
- filter_backends = (filters.DjangoFilterBackend,)
-
- def get_queryset(self):
- return FilterableItem.objects.all()
-
- urlpatterns = patterns('',
- url(r'^(?P\d+)/$', FilterClassDetailView.as_view(), name='detail-view'),
- url(r'^$', FilterClassRootView.as_view(), name='root-view'),
- url(r'^get-queryset/$', GetQuerysetView.as_view(),
- name='get-queryset-view'),
- )
-
-
-class CommonFilteringTestCase(TestCase):
- def _serialize_object(self, obj):
- return {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
-
- def setUp(self):
- """
- Create 10 FilterableItem instances.
- """
- base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
- for i in range(10):
- text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
- decimal = base_data[1] + i
- date = base_data[2] - datetime.timedelta(days=i * 2)
- FilterableItem(text=text, decimal=decimal, date=date).save()
-
- self.objects = FilterableItem.objects
- self.data = [
- self._serialize_object(obj)
- for obj in self.objects.all()
- ]
-
-
-class IntegrationTestFiltering(CommonFilteringTestCase):
- """
- Integration tests for filtered list views.
- """
-
- @unittest.skipUnless(django_filters, 'django-filter not installed')
- def test_get_filtered_fields_root_view(self):
- """
- GET requests to paginated ListCreateAPIView should return paginated results.
- """
- view = FilterFieldsRootView.as_view()
-
- # Basic test with no filter.
- request = factory.get('/')
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, self.data)
-
- # Tests that the decimal filter works.
- search_decimal = Decimal('2.25')
- request = factory.get('/?decimal=%s' % search_decimal)
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['decimal'] == search_decimal]
- self.assertEqual(response.data, expected_data)
-
- # Tests that the date filter works.
- search_date = datetime.date(2012, 9, 22)
- request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22'
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['date'] == search_date]
- self.assertEqual(response.data, expected_data)
-
- @unittest.skipUnless(django_filters, 'django-filter not installed')
- def test_filter_with_queryset(self):
- """
- Regression test for #814.
- """
- view = FilterFieldsQuerysetView.as_view()
-
- # Tests that the decimal filter works.
- search_decimal = Decimal('2.25')
- request = factory.get('/?decimal=%s' % search_decimal)
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['decimal'] == search_decimal]
- self.assertEqual(response.data, expected_data)
-
- @unittest.skipUnless(django_filters, 'django-filter not installed')
- def test_filter_with_get_queryset_only(self):
- """
- Regression test for #834.
- """
- view = GetQuerysetView.as_view()
- request = factory.get('/get-queryset/')
- view(request).render()
- # Used to raise "issubclass() arg 2 must be a class or tuple of classes"
- # here when neither `model' nor `queryset' was specified.
-
- @unittest.skipUnless(django_filters, 'django-filter not installed')
- def test_get_filtered_class_root_view(self):
- """
- GET requests to filtered ListCreateAPIView that have a filter_class set
- should return filtered results.
- """
- view = FilterClassRootView.as_view()
-
- # Basic test with no filter.
- request = factory.get('/')
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, self.data)
-
- # Tests that the decimal filter set with 'lt' in the filter class works.
- search_decimal = Decimal('4.25')
- request = factory.get('/?decimal=%s' % search_decimal)
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['decimal'] < search_decimal]
- self.assertEqual(response.data, expected_data)
-
- # Tests that the date filter set with 'gt' in the filter class works.
- search_date = datetime.date(2012, 10, 2)
- request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02'
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['date'] > search_date]
- self.assertEqual(response.data, expected_data)
-
- # Tests that the text filter set with 'icontains' in the filter class works.
- search_text = 'ff'
- request = factory.get('/?text=%s' % search_text)
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if search_text in f['text'].lower()]
- self.assertEqual(response.data, expected_data)
-
- # Tests that multiple filters works.
- search_decimal = Decimal('5.25')
- search_date = datetime.date(2012, 10, 2)
- request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date))
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if f['date'] > search_date and
- f['decimal'] < search_decimal]
- self.assertEqual(response.data, expected_data)
-
- @unittest.skipUnless(django_filters, 'django-filter not installed')
- def test_incorrectly_configured_filter(self):
- """
- An error should be displayed when the filter class is misconfigured.
- """
- view = IncorrectlyConfiguredRootView.as_view()
-
- request = factory.get('/')
- self.assertRaises(AssertionError, view, request)
-
- @unittest.skipUnless(django_filters, 'django-filter not installed')
- def test_unknown_filter(self):
- """
- GET requests with filters that aren't configured should return 200.
- """
- view = FilterFieldsRootView.as_view()
-
- search_integer = 10
- request = factory.get('/?integer=%s' % search_integer)
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
-
-class IntegrationTestDetailFiltering(CommonFilteringTestCase):
- """
- Integration tests for filtered detail views.
- """
- urls = 'rest_framework.tests.test_filters'
-
- def _get_url(self, item):
- return reverse('detail-view', kwargs=dict(pk=item.pk))
-
- @unittest.skipUnless(django_filters, 'django-filter not installed')
- def test_get_filtered_detail_view(self):
- """
- GET requests to filtered RetrieveAPIView that have a filter_class set
- should return filtered results.
- """
- item = self.objects.all()[0]
- data = self._serialize_object(item)
-
- # Basic test with no filter.
- response = self.client.get(self._get_url(item))
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, data)
-
- # Tests that the decimal filter set that should fail.
- search_decimal = Decimal('4.25')
- high_item = self.objects.filter(decimal__gt=search_decimal)[0]
- response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(high_item), param=search_decimal))
- self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
-
- # Tests that the decimal filter set that should succeed.
- search_decimal = Decimal('4.25')
- low_item = self.objects.filter(decimal__lt=search_decimal)[0]
- low_item_data = self._serialize_object(low_item)
- response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal))
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, low_item_data)
-
- # Tests that multiple filters works.
- search_decimal = Decimal('5.25')
- search_date = datetime.date(2012, 10, 2)
- valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0]
- valid_item_data = self._serialize_object(valid_item)
- response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date))
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, valid_item_data)
-
-
-class SearchFilterModel(models.Model):
- title = models.CharField(max_length=20)
- text = models.CharField(max_length=100)
-
-
-class SearchFilterTests(TestCase):
- def setUp(self):
- # Sequence of title/text is:
- #
- # z abc
- # zz bcd
- # zzz cde
- # ...
- for idx in range(10):
- title = 'z' * (idx + 1)
- text = (
- chr(idx + ord('a')) +
- chr(idx + ord('b')) +
- chr(idx + ord('c'))
- )
- SearchFilterModel(title=title, text=text).save()
-
- def test_search(self):
- class SearchListView(generics.ListAPIView):
- model = SearchFilterModel
- filter_backends = (filters.SearchFilter,)
- search_fields = ('title', 'text')
-
- view = SearchListView.as_view()
- request = factory.get('?search=b')
- response = view(request)
- self.assertEqual(
- response.data,
- [
- {'id': 1, 'title': 'z', 'text': 'abc'},
- {'id': 2, 'title': 'zz', 'text': 'bcd'}
- ]
- )
-
- def test_exact_search(self):
- class SearchListView(generics.ListAPIView):
- model = SearchFilterModel
- filter_backends = (filters.SearchFilter,)
- search_fields = ('=title', 'text')
-
- view = SearchListView.as_view()
- request = factory.get('?search=zzz')
- response = view(request)
- self.assertEqual(
- response.data,
- [
- {'id': 3, 'title': 'zzz', 'text': 'cde'}
- ]
- )
-
- def test_startswith_search(self):
- class SearchListView(generics.ListAPIView):
- model = SearchFilterModel
- filter_backends = (filters.SearchFilter,)
- search_fields = ('title', '^text')
-
- view = SearchListView.as_view()
- request = factory.get('?search=b')
- response = view(request)
- self.assertEqual(
- response.data,
- [
- {'id': 2, 'title': 'zz', 'text': 'bcd'}
- ]
- )
-
-
-class OrdringFilterModel(models.Model):
- title = models.CharField(max_length=20)
- text = models.CharField(max_length=100)
-
-
-class OrderingFilterRelatedModel(models.Model):
- related_object = models.ForeignKey(OrdringFilterModel,
- related_name="relateds")
-
-
-class OrderingFilterTests(TestCase):
- def setUp(self):
- # Sequence of title/text is:
- #
- # zyx abc
- # yxw bcd
- # xwv cde
- for idx in range(3):
- title = (
- chr(ord('z') - idx) +
- chr(ord('y') - idx) +
- chr(ord('x') - idx)
- )
- text = (
- chr(idx + ord('a')) +
- chr(idx + ord('b')) +
- chr(idx + ord('c'))
- )
- OrdringFilterModel(title=title, text=text).save()
-
- def test_ordering(self):
- class OrderingListView(generics.ListAPIView):
- model = OrdringFilterModel
- filter_backends = (filters.OrderingFilter,)
- ordering = ('title',)
- ordering_fields = ('text',)
-
- view = OrderingListView.as_view()
- request = factory.get('?ordering=text')
- response = view(request)
- self.assertEqual(
- response.data,
- [
- {'id': 1, 'title': 'zyx', 'text': 'abc'},
- {'id': 2, 'title': 'yxw', 'text': 'bcd'},
- {'id': 3, 'title': 'xwv', 'text': 'cde'},
- ]
- )
-
- def test_reverse_ordering(self):
- class OrderingListView(generics.ListAPIView):
- model = OrdringFilterModel
- filter_backends = (filters.OrderingFilter,)
- ordering = ('title',)
- ordering_fields = ('text',)
-
- view = OrderingListView.as_view()
- request = factory.get('?ordering=-text')
- response = view(request)
- self.assertEqual(
- response.data,
- [
- {'id': 3, 'title': 'xwv', 'text': 'cde'},
- {'id': 2, 'title': 'yxw', 'text': 'bcd'},
- {'id': 1, 'title': 'zyx', 'text': 'abc'},
- ]
- )
-
- def test_incorrectfield_ordering(self):
- class OrderingListView(generics.ListAPIView):
- model = OrdringFilterModel
- filter_backends = (filters.OrderingFilter,)
- ordering = ('title',)
- ordering_fields = ('text',)
-
- view = OrderingListView.as_view()
- request = factory.get('?ordering=foobar')
- response = view(request)
- self.assertEqual(
- response.data,
- [
- {'id': 3, 'title': 'xwv', 'text': 'cde'},
- {'id': 2, 'title': 'yxw', 'text': 'bcd'},
- {'id': 1, 'title': 'zyx', 'text': 'abc'},
- ]
- )
-
- def test_default_ordering(self):
- class OrderingListView(generics.ListAPIView):
- model = OrdringFilterModel
- filter_backends = (filters.OrderingFilter,)
- ordering = ('title',)
- oredering_fields = ('text',)
-
- view = OrderingListView.as_view()
- request = factory.get('')
- response = view(request)
- self.assertEqual(
- response.data,
- [
- {'id': 3, 'title': 'xwv', 'text': 'cde'},
- {'id': 2, 'title': 'yxw', 'text': 'bcd'},
- {'id': 1, 'title': 'zyx', 'text': 'abc'},
- ]
- )
-
- def test_default_ordering_using_string(self):
- class OrderingListView(generics.ListAPIView):
- model = OrdringFilterModel
- filter_backends = (filters.OrderingFilter,)
- ordering = 'title'
- ordering_fields = ('text',)
-
- view = OrderingListView.as_view()
- request = factory.get('')
- response = view(request)
- self.assertEqual(
- response.data,
- [
- {'id': 3, 'title': 'xwv', 'text': 'cde'},
- {'id': 2, 'title': 'yxw', 'text': 'bcd'},
- {'id': 1, 'title': 'zyx', 'text': 'abc'},
- ]
- )
-
- def test_ordering_by_aggregate_field(self):
- # create some related models to aggregate order by
- num_objs = [2, 5, 3]
- for obj, num_relateds in zip(OrdringFilterModel.objects.all(),
- num_objs):
- for _ in range(num_relateds):
- new_related = OrderingFilterRelatedModel(
- related_object=obj
- )
- new_related.save()
-
- class OrderingListView(generics.ListAPIView):
- model = OrdringFilterModel
- filter_backends = (filters.OrderingFilter,)
- ordering = 'title'
- ordering_fields = '__all__'
- queryset = OrdringFilterModel.objects.all().annotate(
- models.Count("relateds"))
-
- view = OrderingListView.as_view()
- request = factory.get('?ordering=relateds__count')
- response = view(request)
- self.assertEqual(
- response.data,
- [
- {'id': 1, 'title': 'zyx', 'text': 'abc'},
- {'id': 3, 'title': 'xwv', 'text': 'cde'},
- {'id': 2, 'title': 'yxw', 'text': 'bcd'},
- ]
- )
-
-
-class SensitiveOrderingFilterModel(models.Model):
- username = models.CharField(max_length=20)
- password = models.CharField(max_length=100)
-
-
-# Three different styles of serializer.
-# All should allow ordering by username, but not by password.
-class SensitiveDataSerializer1(serializers.ModelSerializer):
- username = serializers.CharField()
-
- class Meta:
- model = SensitiveOrderingFilterModel
- fields = ('id', 'username')
-
-
-class SensitiveDataSerializer2(serializers.ModelSerializer):
- username = serializers.CharField()
- password = serializers.CharField(write_only=True)
-
- class Meta:
- model = SensitiveOrderingFilterModel
- fields = ('id', 'username', 'password')
-
-
-class SensitiveDataSerializer3(serializers.ModelSerializer):
- user = serializers.CharField(source='username')
-
- class Meta:
- model = SensitiveOrderingFilterModel
- fields = ('id', 'user')
-
-
-class SensitiveOrderingFilterTests(TestCase):
- def setUp(self):
- for idx in range(3):
- username = {0: 'userA', 1: 'userB', 2: 'userC'}[idx]
- password = {0: 'passA', 1: 'passC', 2: 'passB'}[idx]
- SensitiveOrderingFilterModel(username=username, password=password).save()
-
- def test_order_by_serializer_fields(self):
- for serializer_cls in [
- SensitiveDataSerializer1,
- SensitiveDataSerializer2,
- SensitiveDataSerializer3
- ]:
- class OrderingListView(generics.ListAPIView):
- queryset = SensitiveOrderingFilterModel.objects.all().order_by('username')
- filter_backends = (filters.OrderingFilter,)
- serializer_class = serializer_cls
-
- view = OrderingListView.as_view()
- request = factory.get('?ordering=-username')
- response = view(request)
-
- if serializer_cls == SensitiveDataSerializer3:
- username_field = 'user'
- else:
- username_field = 'username'
-
- # Note: Inverse username ordering correctly applied.
- self.assertEqual(
- response.data,
- [
- {'id': 3, username_field: 'userC'},
- {'id': 2, username_field: 'userB'},
- {'id': 1, username_field: 'userA'},
- ]
- )
-
- def test_cannot_order_by_non_serializer_fields(self):
- for serializer_cls in [
- SensitiveDataSerializer1,
- SensitiveDataSerializer2,
- SensitiveDataSerializer3
- ]:
- class OrderingListView(generics.ListAPIView):
- queryset = SensitiveOrderingFilterModel.objects.all().order_by('username')
- filter_backends = (filters.OrderingFilter,)
- serializer_class = serializer_cls
-
- view = OrderingListView.as_view()
- request = factory.get('?ordering=password')
- response = view(request)
-
- if serializer_cls == SensitiveDataSerializer3:
- username_field = 'user'
- else:
- username_field = 'username'
-
- # Note: The passwords are not in order. Default ordering is used.
- self.assertEqual(
- response.data,
- [
- {'id': 1, username_field: 'userA'}, # PassB
- {'id': 2, username_field: 'userB'}, # PassC
- {'id': 3, username_field: 'userC'}, # PassA
- ]
- )
\ No newline at end of file
diff --git a/rest_framework/tests/test_genericrelations.py b/rest_framework/tests/test_genericrelations.py
deleted file mode 100644
index 2d341344..00000000
--- a/rest_framework/tests/test_genericrelations.py
+++ /dev/null
@@ -1,129 +0,0 @@
-from __future__ import unicode_literals
-from django.contrib.contenttypes.models import ContentType
-from django.contrib.contenttypes.generic import GenericRelation, GenericForeignKey
-from django.db import models
-from django.test import TestCase
-from rest_framework import serializers
-
-
-class Tag(models.Model):
- """
- Tags have a descriptive slug, and are attached to an arbitrary object.
- """
- tag = models.SlugField()
- content_type = models.ForeignKey(ContentType)
- object_id = models.PositiveIntegerField()
- tagged_item = GenericForeignKey('content_type', 'object_id')
-
- def __unicode__(self):
- return self.tag
-
-
-class Bookmark(models.Model):
- """
- A URL bookmark that may have multiple tags attached.
- """
- url = models.URLField()
- tags = GenericRelation(Tag)
-
- def __unicode__(self):
- return 'Bookmark: %s' % self.url
-
-
-class Note(models.Model):
- """
- A textual note that may have multiple tags attached.
- """
- text = models.TextField()
- tags = GenericRelation(Tag)
-
- def __unicode__(self):
- return 'Note: %s' % self.text
-
-
-class TestGenericRelations(TestCase):
- def setUp(self):
- self.bookmark = Bookmark.objects.create(url='https://www.djangoproject.com/')
- Tag.objects.create(tagged_item=self.bookmark, tag='django')
- Tag.objects.create(tagged_item=self.bookmark, tag='python')
- self.note = Note.objects.create(text='Remember the milk')
- Tag.objects.create(tagged_item=self.note, tag='reminder')
-
- def test_generic_relation(self):
- """
- Test a relationship that spans a GenericRelation field.
- IE. A reverse generic relationship.
- """
-
- class BookmarkSerializer(serializers.ModelSerializer):
- tags = serializers.RelatedField(many=True)
-
- class Meta:
- model = Bookmark
- exclude = ('id',)
-
- serializer = BookmarkSerializer(self.bookmark)
- expected = {
- 'tags': ['django', 'python'],
- 'url': 'https://www.djangoproject.com/'
- }
- self.assertEqual(serializer.data, expected)
-
- def test_generic_nested_relation(self):
- """
- Test saving a GenericRelation field via a nested serializer.
- """
-
- class TagSerializer(serializers.ModelSerializer):
- class Meta:
- model = Tag
- exclude = ('content_type', 'object_id')
-
- class BookmarkSerializer(serializers.ModelSerializer):
- tags = TagSerializer()
-
- class Meta:
- model = Bookmark
- exclude = ('id',)
-
- data = {
- 'url': 'https://docs.djangoproject.com/',
- 'tags': [
- {'tag': 'contenttypes'},
- {'tag': 'genericrelations'},
- ]
- }
- serializer = BookmarkSerializer(data=data)
- self.assertTrue(serializer.is_valid())
- serializer.save()
- self.assertEqual(serializer.object.tags.count(), 2)
-
- def test_generic_fk(self):
- """
- Test a relationship that spans a GenericForeignKey field.
- IE. A forward generic relationship.
- """
-
- class TagSerializer(serializers.ModelSerializer):
- tagged_item = serializers.RelatedField()
-
- class Meta:
- model = Tag
- exclude = ('id', 'content_type', 'object_id')
-
- serializer = TagSerializer(Tag.objects.all(), many=True)
- expected = [
- {
- 'tag': 'django',
- 'tagged_item': 'Bookmark: https://www.djangoproject.com/'
- },
- {
- 'tag': 'python',
- 'tagged_item': 'Bookmark: https://www.djangoproject.com/'
- },
- {
- 'tag': 'reminder',
- 'tagged_item': 'Note: Remember the milk'
- }
- ]
- self.assertEqual(serializer.data, expected)
diff --git a/rest_framework/tests/test_generics.py b/rest_framework/tests/test_generics.py
deleted file mode 100644
index 996bd5b0..00000000
--- a/rest_framework/tests/test_generics.py
+++ /dev/null
@@ -1,609 +0,0 @@
-from __future__ import unicode_literals
-from django.db import models
-from django.shortcuts import get_object_or_404
-from django.test import TestCase
-from rest_framework import generics, renderers, serializers, status
-from rest_framework.test import APIRequestFactory
-from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel
-from rest_framework.compat import six
-
-factory = APIRequestFactory()
-
-
-class RootView(generics.ListCreateAPIView):
- """
- Example description for OPTIONS.
- """
- model = BasicModel
-
-
-class InstanceView(generics.RetrieveUpdateDestroyAPIView):
- """
- Example description for OPTIONS.
- """
- model = BasicModel
-
- def get_queryset(self):
- queryset = super(InstanceView, self).get_queryset()
- return queryset.exclude(text='filtered out')
-
-
-class SlugSerializer(serializers.ModelSerializer):
- slug = serializers.Field() # read only
-
- class Meta:
- model = SlugBasedModel
- exclude = ('id',)
-
-
-class SlugBasedInstanceView(InstanceView):
- """
- A model with a slug-field.
- """
- model = SlugBasedModel
- serializer_class = SlugSerializer
- lookup_field = 'slug'
-
-
-class TestRootView(TestCase):
- def setUp(self):
- """
- Create 3 BasicModel instances.
- """
- items = ['foo', 'bar', 'baz']
- for item in items:
- BasicModel(text=item).save()
- self.objects = BasicModel.objects
- self.data = [
- {'id': obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
- self.view = RootView.as_view()
-
- def test_get_root_view(self):
- """
- GET requests to ListCreateAPIView should return list of objects.
- """
- request = factory.get('/')
- with self.assertNumQueries(1):
- response = self.view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, self.data)
-
- def test_post_root_view(self):
- """
- POST requests to ListCreateAPIView should create a new object.
- """
- data = {'text': 'foobar'}
- request = factory.post('/', data, format='json')
- with self.assertNumQueries(1):
- response = self.view(request).render()
- self.assertEqual(response.status_code, status.HTTP_201_CREATED)
- self.assertEqual(response.data, {'id': 4, 'text': 'foobar'})
- created = self.objects.get(id=4)
- self.assertEqual(created.text, 'foobar')
-
- def test_put_root_view(self):
- """
- PUT requests to ListCreateAPIView should not be allowed
- """
- data = {'text': 'foobar'}
- request = factory.put('/', data, format='json')
- with self.assertNumQueries(0):
- response = self.view(request).render()
- self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
- self.assertEqual(response.data, {"detail": "Method 'PUT' not allowed."})
-
- def test_delete_root_view(self):
- """
- DELETE requests to ListCreateAPIView should not be allowed
- """
- request = factory.delete('/')
- with self.assertNumQueries(0):
- response = self.view(request).render()
- self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
- self.assertEqual(response.data, {"detail": "Method 'DELETE' not allowed."})
-
- def test_options_root_view(self):
- """
- OPTIONS requests to ListCreateAPIView should return metadata
- """
- request = factory.options('/')
- with self.assertNumQueries(0):
- response = self.view(request).render()
- expected = {
- 'parses': [
- 'application/json',
- 'application/x-www-form-urlencoded',
- 'multipart/form-data'
- ],
- 'renders': [
- 'application/json',
- 'text/html'
- ],
- 'name': 'Root',
- 'description': 'Example description for OPTIONS.',
- 'actions': {
- 'POST': {
- 'text': {
- 'max_length': 100,
- 'read_only': False,
- 'required': True,
- 'type': 'string',
- "label": "Text comes here",
- "help_text": "Text description."
- },
- 'id': {
- 'read_only': True,
- 'required': False,
- 'type': 'integer',
- 'label': 'ID',
- },
- }
- }
- }
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, expected)
-
- def test_post_cannot_set_id(self):
- """
- POST requests to create a new object should not be able to set the id.
- """
- data = {'id': 999, 'text': 'foobar'}
- request = factory.post('/', data, format='json')
- with self.assertNumQueries(1):
- response = self.view(request).render()
- self.assertEqual(response.status_code, status.HTTP_201_CREATED)
- self.assertEqual(response.data, {'id': 4, 'text': 'foobar'})
- created = self.objects.get(id=4)
- self.assertEqual(created.text, 'foobar')
-
-
-class TestInstanceView(TestCase):
- def setUp(self):
- """
- Create 3 BasicModel intances.
- """
- items = ['foo', 'bar', 'baz', 'filtered out']
- for item in items:
- BasicModel(text=item).save()
- self.objects = BasicModel.objects.exclude(text='filtered out')
- self.data = [
- {'id': obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
- self.view = InstanceView.as_view()
- self.slug_based_view = SlugBasedInstanceView.as_view()
-
- def test_get_instance_view(self):
- """
- GET requests to RetrieveUpdateDestroyAPIView should return a single object.
- """
- request = factory.get('/1')
- with self.assertNumQueries(1):
- response = self.view(request, pk=1).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, self.data[0])
-
- def test_post_instance_view(self):
- """
- POST requests to RetrieveUpdateDestroyAPIView should not be allowed
- """
- data = {'text': 'foobar'}
- request = factory.post('/', data, format='json')
- with self.assertNumQueries(0):
- response = self.view(request).render()
- self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
- self.assertEqual(response.data, {"detail": "Method 'POST' not allowed."})
-
- def test_put_instance_view(self):
- """
- PUT requests to RetrieveUpdateDestroyAPIView should update an object.
- """
- data = {'text': 'foobar'}
- request = factory.put('/1', data, format='json')
- with self.assertNumQueries(2):
- response = self.view(request, pk='1').render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
- updated = self.objects.get(id=1)
- self.assertEqual(updated.text, 'foobar')
-
- def test_patch_instance_view(self):
- """
- PATCH requests to RetrieveUpdateDestroyAPIView should update an object.
- """
- data = {'text': 'foobar'}
- request = factory.patch('/1', data, format='json')
-
- with self.assertNumQueries(2):
- response = self.view(request, pk=1).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
- updated = self.objects.get(id=1)
- self.assertEqual(updated.text, 'foobar')
-
- def test_delete_instance_view(self):
- """
- DELETE requests to RetrieveUpdateDestroyAPIView should delete an object.
- """
- request = factory.delete('/1')
- with self.assertNumQueries(2):
- response = self.view(request, pk=1).render()
- self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
- self.assertEqual(response.content, six.b(''))
- ids = [obj.id for obj in self.objects.all()]
- self.assertEqual(ids, [2, 3])
-
- def test_options_instance_view(self):
- """
- OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata
- """
- request = factory.options('/1')
- with self.assertNumQueries(1):
- response = self.view(request, pk=1).render()
- expected = {
- 'parses': [
- 'application/json',
- 'application/x-www-form-urlencoded',
- 'multipart/form-data'
- ],
- 'renders': [
- 'application/json',
- 'text/html'
- ],
- 'name': 'Instance',
- 'description': 'Example description for OPTIONS.',
- 'actions': {
- 'PUT': {
- 'text': {
- 'max_length': 100,
- 'read_only': False,
- 'required': True,
- 'type': 'string',
- 'label': 'Text comes here',
- 'help_text': 'Text description.'
- },
- 'id': {
- 'read_only': True,
- 'required': False,
- 'type': 'integer',
- 'label': 'ID',
- },
- }
- }
- }
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, expected)
-
- def test_options_before_instance_create(self):
- """
- OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata
- before the instance has been created
- """
- request = factory.options('/999')
- with self.assertNumQueries(1):
- response = self.view(request, pk=999).render()
- expected = {
- 'parses': [
- 'application/json',
- 'application/x-www-form-urlencoded',
- 'multipart/form-data'
- ],
- 'renders': [
- 'application/json',
- 'text/html'
- ],
- 'name': 'Instance',
- 'description': 'Example description for OPTIONS.',
- 'actions': {
- 'PUT': {
- 'text': {
- 'max_length': 100,
- 'read_only': False,
- 'required': True,
- 'type': 'string',
- 'label': 'Text comes here',
- 'help_text': 'Text description.'
- },
- 'id': {
- 'read_only': True,
- 'required': False,
- 'type': 'integer',
- 'label': 'ID',
- },
- }
- }
- }
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, expected)
-
- def test_get_instance_view_incorrect_arg(self):
- """
- GET requests with an incorrect pk type, should raise 404, not 500.
- Regression test for #890.
- """
- request = factory.get('/a')
- with self.assertNumQueries(0):
- response = self.view(request, pk='a').render()
- self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
-
- def test_put_cannot_set_id(self):
- """
- PUT requests to create a new object should not be able to set the id.
- """
- data = {'id': 999, 'text': 'foobar'}
- request = factory.put('/1', data, format='json')
- with self.assertNumQueries(2):
- response = self.view(request, pk=1).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
- updated = self.objects.get(id=1)
- self.assertEqual(updated.text, 'foobar')
-
- def test_put_to_deleted_instance(self):
- """
- PUT requests to RetrieveUpdateDestroyAPIView should create an object
- if it does not currently exist.
- """
- self.objects.get(id=1).delete()
- data = {'text': 'foobar'}
- request = factory.put('/1', data, format='json')
- with self.assertNumQueries(3):
- response = self.view(request, pk=1).render()
- self.assertEqual(response.status_code, status.HTTP_201_CREATED)
- self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
- updated = self.objects.get(id=1)
- self.assertEqual(updated.text, 'foobar')
-
- def test_put_to_filtered_out_instance(self):
- """
- PUT requests to an URL of instance which is filtered out should not be
- able to create new objects.
- """
- data = {'text': 'foo'}
- filtered_out_pk = BasicModel.objects.filter(text='filtered out')[0].pk
- request = factory.put('/{0}'.format(filtered_out_pk), data, format='json')
- response = self.view(request, pk=filtered_out_pk).render()
- self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
-
- def test_put_as_create_on_id_based_url(self):
- """
- PUT requests to RetrieveUpdateDestroyAPIView should create an object
- at the requested url if it doesn't exist.
- """
- data = {'text': 'foobar'}
- # pk fields can not be created on demand, only the database can set the pk for a new object
- request = factory.put('/5', data, format='json')
- with self.assertNumQueries(3):
- response = self.view(request, pk=5).render()
- self.assertEqual(response.status_code, status.HTTP_201_CREATED)
- new_obj = self.objects.get(pk=5)
- self.assertEqual(new_obj.text, 'foobar')
-
- def test_put_as_create_on_slug_based_url(self):
- """
- PUT requests to RetrieveUpdateDestroyAPIView should create an object
- at the requested url if possible, else return HTTP_403_FORBIDDEN error-response.
- """
- data = {'text': 'foobar'}
- request = factory.put('/test_slug', data, format='json')
- with self.assertNumQueries(2):
- response = self.slug_based_view(request, slug='test_slug').render()
- self.assertEqual(response.status_code, status.HTTP_201_CREATED)
- self.assertEqual(response.data, {'slug': 'test_slug', 'text': 'foobar'})
- new_obj = SlugBasedModel.objects.get(slug='test_slug')
- self.assertEqual(new_obj.text, 'foobar')
-
- def test_patch_cannot_create_an_object(self):
- """
- PATCH requests should not be able to create objects.
- """
- data = {'text': 'foobar'}
- request = factory.patch('/999', data, format='json')
- with self.assertNumQueries(1):
- response = self.view(request, pk=999).render()
- self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
- self.assertFalse(self.objects.filter(id=999).exists())
-
-
-class TestOverriddenGetObject(TestCase):
- """
- Test cases for a RetrieveUpdateDestroyAPIView that does NOT use the
- queryset/model mechanism but instead overrides get_object()
- """
- def setUp(self):
- """
- Create 3 BasicModel intances.
- """
- items = ['foo', 'bar', 'baz']
- for item in items:
- BasicModel(text=item).save()
- self.objects = BasicModel.objects
- self.data = [
- {'id': obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
-
- class OverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView):
- """
- Example detail view for override of get_object().
- """
- model = BasicModel
-
- def get_object(self):
- pk = int(self.kwargs['pk'])
- return get_object_or_404(BasicModel.objects.all(), id=pk)
-
- self.view = OverriddenGetObjectView.as_view()
-
- def test_overridden_get_object_view(self):
- """
- GET requests to RetrieveUpdateDestroyAPIView should return a single object.
- """
- request = factory.get('/1')
- with self.assertNumQueries(1):
- response = self.view(request, pk=1).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, self.data[0])
-
-
-# Regression test for #285
-
-class CommentSerializer(serializers.ModelSerializer):
- class Meta:
- model = Comment
- exclude = ('created',)
-
-
-class CommentView(generics.ListCreateAPIView):
- serializer_class = CommentSerializer
- model = Comment
-
-
-class TestCreateModelWithAutoNowAddField(TestCase):
- def setUp(self):
- self.objects = Comment.objects
- self.view = CommentView.as_view()
-
- def test_create_model_with_auto_now_add_field(self):
- """
- Regression test for #285
-
- https://github.com/tomchristie/django-rest-framework/issues/285
- """
- data = {'email': 'foobar@example.com', 'content': 'foobar'}
- request = factory.post('/', data, format='json')
- response = self.view(request).render()
- self.assertEqual(response.status_code, status.HTTP_201_CREATED)
- created = self.objects.get(id=1)
- self.assertEqual(created.content, 'foobar')
-
-
-# Test for particularly ugly regression with m2m in browsable API
-class ClassB(models.Model):
- name = models.CharField(max_length=255)
-
-
-class ClassA(models.Model):
- name = models.CharField(max_length=255)
- childs = models.ManyToManyField(ClassB, blank=True, null=True)
-
-
-class ClassASerializer(serializers.ModelSerializer):
- childs = serializers.PrimaryKeyRelatedField(many=True, source='childs')
-
- class Meta:
- model = ClassA
-
-
-class ExampleView(generics.ListCreateAPIView):
- serializer_class = ClassASerializer
- model = ClassA
-
-
-class TestM2MBrowseableAPI(TestCase):
- def test_m2m_in_browseable_api(self):
- """
- Test for particularly ugly regression with m2m in browsable API
- """
- request = factory.get('/', HTTP_ACCEPT='text/html')
- view = ExampleView().as_view()
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
-
-class InclusiveFilterBackend(object):
- def filter_queryset(self, request, queryset, view):
- return queryset.filter(text='foo')
-
-
-class ExclusiveFilterBackend(object):
- def filter_queryset(self, request, queryset, view):
- return queryset.filter(text='other')
-
-
-class TwoFieldModel(models.Model):
- field_a = models.CharField(max_length=100)
- field_b = models.CharField(max_length=100)
-
-
-class DynamicSerializerView(generics.ListCreateAPIView):
- model = TwoFieldModel
- renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer)
-
- def get_serializer_class(self):
- if self.request.method == 'POST':
- class DynamicSerializer(serializers.ModelSerializer):
- class Meta:
- model = TwoFieldModel
- fields = ('field_b',)
- return DynamicSerializer
- return super(DynamicSerializerView, self).get_serializer_class()
-
-
-class TestFilterBackendAppliedToViews(TestCase):
-
- def setUp(self):
- """
- Create 3 BasicModel instances to filter on.
- """
- items = ['foo', 'bar', 'baz']
- for item in items:
- BasicModel(text=item).save()
- self.objects = BasicModel.objects
- self.data = [
- {'id': obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
-
- def test_get_root_view_filters_by_name_with_filter_backend(self):
- """
- GET requests to ListCreateAPIView should return filtered list.
- """
- root_view = RootView.as_view(filter_backends=(InclusiveFilterBackend,))
- request = factory.get('/')
- response = root_view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(len(response.data), 1)
- self.assertEqual(response.data, [{'id': 1, 'text': 'foo'}])
-
- def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(self):
- """
- GET requests to ListCreateAPIView should return empty list when all models are filtered out.
- """
- root_view = RootView.as_view(filter_backends=(ExclusiveFilterBackend,))
- request = factory.get('/')
- response = root_view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, [])
-
- def test_get_instance_view_filters_out_name_with_filter_backend(self):
- """
- GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out.
- """
- instance_view = InstanceView.as_view(filter_backends=(ExclusiveFilterBackend,))
- request = factory.get('/1')
- response = instance_view(request, pk=1).render()
- self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
- self.assertEqual(response.data, {'detail': 'Not found'})
-
- def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self):
- """
- GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded
- """
- instance_view = InstanceView.as_view(filter_backends=(InclusiveFilterBackend,))
- request = factory.get('/1')
- response = instance_view(request, pk=1).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, {'id': 1, 'text': 'foo'})
-
- def test_dynamic_serializer_form_in_browsable_api(self):
- """
- GET requests to ListCreateAPIView should return filtered list.
- """
- view = DynamicSerializerView.as_view()
- request = factory.get('/')
- response = view(request).render()
- self.assertContains(response, 'field_b')
- self.assertNotContains(response, 'field_a')
diff --git a/rest_framework/tests/test_htmlrenderer.py b/rest_framework/tests/test_htmlrenderer.py
deleted file mode 100644
index 8957a43c..00000000
--- a/rest_framework/tests/test_htmlrenderer.py
+++ /dev/null
@@ -1,118 +0,0 @@
-from __future__ import unicode_literals
-from django.core.exceptions import PermissionDenied
-from django.http import Http404
-from django.test import TestCase
-from django.template import TemplateDoesNotExist, Template
-import django.template.loader
-from rest_framework import status
-from rest_framework.compat import patterns, url
-from rest_framework.decorators import api_view, renderer_classes
-from rest_framework.renderers import TemplateHTMLRenderer
-from rest_framework.response import Response
-from rest_framework.compat import six
-
-
-@api_view(('GET',))
-@renderer_classes((TemplateHTMLRenderer,))
-def example(request):
- """
- A view that can returns an HTML representation.
- """
- data = {'object': 'foobar'}
- return Response(data, template_name='example.html')
-
-
-@api_view(('GET',))
-@renderer_classes((TemplateHTMLRenderer,))
-def permission_denied(request):
- raise PermissionDenied()
-
-
-@api_view(('GET',))
-@renderer_classes((TemplateHTMLRenderer,))
-def not_found(request):
- raise Http404()
-
-
-urlpatterns = patterns('',
- url(r'^$', example),
- url(r'^permission_denied$', permission_denied),
- url(r'^not_found$', not_found),
-)
-
-
-class TemplateHTMLRendererTests(TestCase):
- urls = 'rest_framework.tests.test_htmlrenderer'
-
- def setUp(self):
- """
- Monkeypatch get_template
- """
- self.get_template = django.template.loader.get_template
-
- def get_template(template_name):
- if template_name == 'example.html':
- return Template("example: {{ object }}")
- raise TemplateDoesNotExist(template_name)
-
- django.template.loader.get_template = get_template
-
- def tearDown(self):
- """
- Revert monkeypatching
- """
- django.template.loader.get_template = self.get_template
-
- def test_simple_html_view(self):
- response = self.client.get('/')
- self.assertContains(response, "example: foobar")
- self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
-
- def test_not_found_html_view(self):
- response = self.client.get('/not_found')
- self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
- self.assertEqual(response.content, six.b("404 Not Found"))
- self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
-
- def test_permission_denied_html_view(self):
- response = self.client.get('/permission_denied')
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
- self.assertEqual(response.content, six.b("403 Forbidden"))
- self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
-
-
-class TemplateHTMLRendererExceptionTests(TestCase):
- urls = 'rest_framework.tests.test_htmlrenderer'
-
- def setUp(self):
- """
- Monkeypatch get_template
- """
- self.get_template = django.template.loader.get_template
-
- def get_template(template_name):
- if template_name == '404.html':
- return Template("404: {{ detail }}")
- if template_name == '403.html':
- return Template("403: {{ detail }}")
- raise TemplateDoesNotExist(template_name)
-
- django.template.loader.get_template = get_template
-
- def tearDown(self):
- """
- Revert monkeypatching
- """
- django.template.loader.get_template = self.get_template
-
- def test_not_found_html_view_with_template(self):
- response = self.client.get('/not_found')
- self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
- self.assertEqual(response.content, six.b("404: Not found"))
- self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
-
- def test_permission_denied_html_view_with_template(self):
- response = self.client.get('/permission_denied')
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
- self.assertEqual(response.content, six.b("403: Permission denied"))
- self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
diff --git a/rest_framework/tests/test_hyperlinkedserializers.py b/rest_framework/tests/test_hyperlinkedserializers.py
deleted file mode 100644
index 83d46043..00000000
--- a/rest_framework/tests/test_hyperlinkedserializers.py
+++ /dev/null
@@ -1,379 +0,0 @@
-from __future__ import unicode_literals
-import json
-from django.test import TestCase
-from rest_framework import generics, status, serializers
-from rest_framework.compat import patterns, url
-from rest_framework.settings import api_settings
-from rest_framework.test import APIRequestFactory
-from rest_framework.tests.models import (
- Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment,
- Album, Photo, OptionalRelationModel
-)
-
-factory = APIRequestFactory()
-
-
-class BlogPostCommentSerializer(serializers.ModelSerializer):
- url = serializers.HyperlinkedIdentityField(view_name='blogpostcomment-detail')
- text = serializers.CharField()
- blog_post_url = serializers.HyperlinkedRelatedField(source='blog_post', view_name='blogpost-detail')
-
- class Meta:
- model = BlogPostComment
- fields = ('text', 'blog_post_url', 'url')
-
-
-class PhotoSerializer(serializers.Serializer):
- description = serializers.CharField()
- album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', queryset=Album.objects.all(), lookup_field='title', slug_url_kwarg='title')
-
- def restore_object(self, attrs, instance=None):
- return Photo(**attrs)
-
-
-class AlbumSerializer(serializers.ModelSerializer):
- url = serializers.HyperlinkedIdentityField(view_name='album-detail', lookup_field='title')
-
- class Meta:
- model = Album
- fields = ('title', 'url')
-
-
-class BasicList(generics.ListCreateAPIView):
- model = BasicModel
- model_serializer_class = serializers.HyperlinkedModelSerializer
-
-
-class BasicDetail(generics.RetrieveUpdateDestroyAPIView):
- model = BasicModel
- model_serializer_class = serializers.HyperlinkedModelSerializer
-
-
-class AnchorDetail(generics.RetrieveAPIView):
- model = Anchor
- model_serializer_class = serializers.HyperlinkedModelSerializer
-
-
-class ManyToManyList(generics.ListAPIView):
- model = ManyToManyModel
- model_serializer_class = serializers.HyperlinkedModelSerializer
-
-
-class ManyToManyDetail(generics.RetrieveAPIView):
- model = ManyToManyModel
- model_serializer_class = serializers.HyperlinkedModelSerializer
-
-
-class BlogPostCommentListCreate(generics.ListCreateAPIView):
- model = BlogPostComment
- serializer_class = BlogPostCommentSerializer
-
-
-class BlogPostCommentDetail(generics.RetrieveAPIView):
- model = BlogPostComment
- serializer_class = BlogPostCommentSerializer
-
-
-class BlogPostDetail(generics.RetrieveAPIView):
- model = BlogPost
-
-
-class PhotoListCreate(generics.ListCreateAPIView):
- model = Photo
- model_serializer_class = PhotoSerializer
-
-
-class AlbumDetail(generics.RetrieveAPIView):
- model = Album
- serializer_class = AlbumSerializer
- lookup_field = 'title'
-
-
-class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView):
- model = OptionalRelationModel
- model_serializer_class = serializers.HyperlinkedModelSerializer
-
-
-urlpatterns = patterns('',
- url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'),
- url(r'^basic/(?P\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'),
- url(r'^anchor/(?P\d+)/$', AnchorDetail.as_view(), name='anchor-detail'),
- url(r'^manytomany/$', ManyToManyList.as_view(), name='manytomanymodel-list'),
- url(r'^manytomany/(?P\d+)/$', ManyToManyDetail.as_view(), name='manytomanymodel-detail'),
- url(r'^posts/(?P\d+)/$', BlogPostDetail.as_view(), name='blogpost-detail'),
- url(r'^comments/$', BlogPostCommentListCreate.as_view(), name='blogpostcomment-list'),
- url(r'^comments/(?P\d+)/$', BlogPostCommentDetail.as_view(), name='blogpostcomment-detail'),
- url(r'^albums/(?P\w[\w-]*)/$', AlbumDetail.as_view(), name='album-detail'),
- url(r'^photos/$', PhotoListCreate.as_view(), name='photo-list'),
- url(r'^optionalrelation/(?P\d+)/$', OptionalRelationDetail.as_view(), name='optionalrelationmodel-detail'),
-)
-
-
-class TestBasicHyperlinkedView(TestCase):
- urls = 'rest_framework.tests.test_hyperlinkedserializers'
-
- def setUp(self):
- """
- Create 3 BasicModel instances.
- """
- items = ['foo', 'bar', 'baz']
- for item in items:
- BasicModel(text=item).save()
- self.objects = BasicModel.objects
- self.data = [
- {'url': 'http://testserver/basic/%d/' % obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
- self.list_view = BasicList.as_view()
- self.detail_view = BasicDetail.as_view()
-
- def test_get_list_view(self):
- """
- GET requests to ListCreateAPIView should return list of objects.
- """
- request = factory.get('/basic/')
- response = self.list_view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, self.data)
-
- def test_get_detail_view(self):
- """
- GET requests to ListCreateAPIView should return list of objects.
- """
- request = factory.get('/basic/1')
- response = self.detail_view(request, pk=1).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, self.data[0])
-
-
-class TestManyToManyHyperlinkedView(TestCase):
- urls = 'rest_framework.tests.test_hyperlinkedserializers'
-
- def setUp(self):
- """
- Create 3 BasicModel instances.
- """
- items = ['foo', 'bar', 'baz']
- anchors = []
- for item in items:
- anchor = Anchor(text=item)
- anchor.save()
- anchors.append(anchor)
-
- manytomany = ManyToManyModel()
- manytomany.save()
- manytomany.rel.add(*anchors)
-
- self.data = [{
- 'url': 'http://testserver/manytomany/1/',
- 'rel': [
- 'http://testserver/anchor/1/',
- 'http://testserver/anchor/2/',
- 'http://testserver/anchor/3/',
- ]
- }]
- self.list_view = ManyToManyList.as_view()
- self.detail_view = ManyToManyDetail.as_view()
-
- def test_get_list_view(self):
- """
- GET requests to ListCreateAPIView should return list of objects.
- """
- request = factory.get('/manytomany/')
- response = self.list_view(request)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, self.data)
-
- def test_get_detail_view(self):
- """
- GET requests to ListCreateAPIView should return list of objects.
- """
- request = factory.get('/manytomany/1/')
- response = self.detail_view(request, pk=1)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, self.data[0])
-
-
-class TestHyperlinkedIdentityFieldLookup(TestCase):
- urls = 'rest_framework.tests.test_hyperlinkedserializers'
-
- def setUp(self):
- """
- Create 3 Album instances.
- """
- titles = ['foo', 'bar', 'baz']
- for title in titles:
- album = Album(title=title)
- album.save()
- self.detail_view = AlbumDetail.as_view()
- self.data = {
- 'foo': {'title': 'foo', 'url': 'http://testserver/albums/foo/'},
- 'bar': {'title': 'bar', 'url': 'http://testserver/albums/bar/'},
- 'baz': {'title': 'baz', 'url': 'http://testserver/albums/baz/'}
- }
-
- def test_lookup_field(self):
- """
- GET requests to AlbumDetail view should return serialized Albums
- with a url field keyed by `title`.
- """
- for album in Album.objects.all():
- request = factory.get('/albums/{0}/'.format(album.title))
- response = self.detail_view(request, title=album.title)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, self.data[album.title])
-
-
-class TestCreateWithForeignKeys(TestCase):
- urls = 'rest_framework.tests.test_hyperlinkedserializers'
-
- def setUp(self):
- """
- Create a blog post
- """
- self.post = BlogPost.objects.create(title="Test post")
- self.create_view = BlogPostCommentListCreate.as_view()
-
- def test_create_comment(self):
-
- data = {
- 'text': 'A test comment',
- 'blog_post_url': 'http://testserver/posts/1/'
- }
-
- request = factory.post('/comments/', data=data)
- response = self.create_view(request)
- self.assertEqual(response.status_code, status.HTTP_201_CREATED)
- self.assertEqual(response['Location'], 'http://testserver/comments/1/')
- self.assertEqual(self.post.blogpostcomment_set.count(), 1)
- self.assertEqual(self.post.blogpostcomment_set.all()[0].text, 'A test comment')
-
-
-class TestCreateWithForeignKeysAndCustomSlug(TestCase):
- urls = 'rest_framework.tests.test_hyperlinkedserializers'
-
- def setUp(self):
- """
- Create an Album
- """
- self.post = Album.objects.create(title='test-album')
- self.list_create_view = PhotoListCreate.as_view()
-
- def test_create_photo(self):
-
- data = {
- 'description': 'A test photo',
- 'album_url': 'http://testserver/albums/test-album/'
- }
-
- request = factory.post('/photos/', data=data)
- response = self.list_create_view(request)
- self.assertEqual(response.status_code, status.HTTP_201_CREATED)
- self.assertNotIn('Location', response, msg='Location should only be included if there is a "url" field on the serializer')
- self.assertEqual(self.post.photo_set.count(), 1)
- self.assertEqual(self.post.photo_set.all()[0].description, 'A test photo')
-
-
-class TestOptionalRelationHyperlinkedView(TestCase):
- urls = 'rest_framework.tests.test_hyperlinkedserializers'
-
- def setUp(self):
- """
- Create 1 OptionalRelationModel instances.
- """
- OptionalRelationModel().save()
- self.objects = OptionalRelationModel.objects
- self.detail_view = OptionalRelationDetail.as_view()
- self.data = {"url": "http://testserver/optionalrelation/1/", "other": None}
-
- def test_get_detail_view(self):
- """
- GET requests to RetrieveAPIView with optional relations should return None
- for non existing relations.
- """
- request = factory.get('/optionalrelationmodel-detail/1')
- response = self.detail_view(request, pk=1)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data, self.data)
-
- def test_put_detail_view(self):
- """
- PUT requests to RetrieveUpdateDestroyAPIView with optional relations
- should accept None for non existing relations.
- """
- response = self.client.put('/optionalrelation/1/',
- data=json.dumps(self.data),
- content_type='application/json')
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
-
-class TestOverriddenURLField(TestCase):
- def setUp(self):
- class OverriddenURLSerializer(serializers.HyperlinkedModelSerializer):
- url = serializers.SerializerMethodField('get_url')
-
- class Meta:
- model = BlogPost
- fields = ('title', 'url')
-
- def get_url(self, obj):
- return 'foo bar'
-
- self.Serializer = OverriddenURLSerializer
- self.obj = BlogPost.objects.create(title='New blog post')
-
- def test_overridden_url_field(self):
- """
- The 'url' field should respect overriding.
- Regression test for #936.
- """
- serializer = self.Serializer(self.obj)
- self.assertEqual(
- serializer.data,
- {'title': 'New blog post', 'url': 'foo bar'}
- )
-
-
-class TestURLFieldNameBySettings(TestCase):
- urls = 'rest_framework.tests.test_hyperlinkedserializers'
-
- def setUp(self):
- self.saved_url_field_name = api_settings.URL_FIELD_NAME
- api_settings.URL_FIELD_NAME = 'global_url_field'
-
- class Serializer(serializers.HyperlinkedModelSerializer):
-
- class Meta:
- model = BlogPost
- fields = ('title', api_settings.URL_FIELD_NAME)
-
- self.Serializer = Serializer
- self.obj = BlogPost.objects.create(title="New blog post")
-
- def tearDown(self):
- api_settings.URL_FIELD_NAME = self.saved_url_field_name
-
- def test_overridden_url_field_name(self):
- request = factory.get('/posts/')
- serializer = self.Serializer(self.obj, context={'request': request})
- self.assertIn(api_settings.URL_FIELD_NAME, serializer.data)
-
-
-class TestURLFieldNameByOptions(TestCase):
- urls = 'rest_framework.tests.test_hyperlinkedserializers'
-
- def setUp(self):
- class Serializer(serializers.HyperlinkedModelSerializer):
-
- class Meta:
- model = BlogPost
- fields = ('title', 'serializer_url_field')
- url_field_name = 'serializer_url_field'
-
- self.Serializer = Serializer
- self.obj = BlogPost.objects.create(title="New blog post")
-
- def test_overridden_url_field_name(self):
- request = factory.get('/posts/')
- serializer = self.Serializer(self.obj, context={'request': request})
- self.assertIn(self.Serializer.Meta.url_field_name, serializer.data)
diff --git a/rest_framework/tests/test_multitable_inheritance.py b/rest_framework/tests/test_multitable_inheritance.py
deleted file mode 100644
index 00c15327..00000000
--- a/rest_framework/tests/test_multitable_inheritance.py
+++ /dev/null
@@ -1,67 +0,0 @@
-from __future__ import unicode_literals
-from django.db import models
-from django.test import TestCase
-from rest_framework import serializers
-from rest_framework.tests.models import RESTFrameworkModel
-
-
-# Models
-class ParentModel(RESTFrameworkModel):
- name1 = models.CharField(max_length=100)
-
-
-class ChildModel(ParentModel):
- name2 = models.CharField(max_length=100)
-
-
-class AssociatedModel(RESTFrameworkModel):
- ref = models.OneToOneField(ParentModel, primary_key=True)
- name = models.CharField(max_length=100)
-
-
-# Serializers
-class DerivedModelSerializer(serializers.ModelSerializer):
- class Meta:
- model = ChildModel
-
-
-class AssociatedModelSerializer(serializers.ModelSerializer):
- class Meta:
- model = AssociatedModel
-
-
-# Tests
-class IneritedModelSerializationTests(TestCase):
-
- def test_multitable_inherited_model_fields_as_expected(self):
- """
- Assert that the parent pointer field is not included in the fields
- serialized fields
- """
- child = ChildModel(name1='parent name', name2='child name')
- serializer = DerivedModelSerializer(child)
- self.assertEqual(set(serializer.data.keys()),
- set(['name1', 'name2', 'id']))
-
- def test_onetoone_primary_key_model_fields_as_expected(self):
- """
- Assert that a model with a onetoone field that is the primary key is
- not treated like a derived model
- """
- parent = ParentModel(name1='parent name')
- associate = AssociatedModel(name='hello', ref=parent)
- serializer = AssociatedModelSerializer(associate)
- self.assertEqual(set(serializer.data.keys()),
- set(['name', 'ref']))
-
- def test_data_is_valid_without_parent_ptr(self):
- """
- Assert that the pointer to the parent table is not a required field
- for input data
- """
- data = {
- 'name1': 'parent name',
- 'name2': 'child name',
- }
- serializer = DerivedModelSerializer(data=data)
- self.assertEqual(serializer.is_valid(), True)
diff --git a/rest_framework/tests/test_negotiation.py b/rest_framework/tests/test_negotiation.py
deleted file mode 100644
index 04b89eb6..00000000
--- a/rest_framework/tests/test_negotiation.py
+++ /dev/null
@@ -1,45 +0,0 @@
-from __future__ import unicode_literals
-from django.test import TestCase
-from rest_framework.negotiation import DefaultContentNegotiation
-from rest_framework.request import Request
-from rest_framework.renderers import BaseRenderer
-from rest_framework.test import APIRequestFactory
-
-
-factory = APIRequestFactory()
-
-
-class MockJSONRenderer(BaseRenderer):
- media_type = 'application/json'
-
-
-class MockHTMLRenderer(BaseRenderer):
- media_type = 'text/html'
-
-
-class NoCharsetSpecifiedRenderer(BaseRenderer):
- media_type = 'my/media'
-
-
-class TestAcceptedMediaType(TestCase):
- def setUp(self):
- self.renderers = [MockJSONRenderer(), MockHTMLRenderer()]
- self.negotiator = DefaultContentNegotiation()
-
- def select_renderer(self, request):
- return self.negotiator.select_renderer(request, self.renderers)
-
- def test_client_without_accept_use_renderer(self):
- request = Request(factory.get('/'))
- accepted_renderer, accepted_media_type = self.select_renderer(request)
- self.assertEqual(accepted_media_type, 'application/json')
-
- def test_client_underspecifies_accept_use_renderer(self):
- request = Request(factory.get('/', HTTP_ACCEPT='*/*'))
- accepted_renderer, accepted_media_type = self.select_renderer(request)
- self.assertEqual(accepted_media_type, 'application/json')
-
- def test_client_overspecifies_accept_use_client(self):
- request = Request(factory.get('/', HTTP_ACCEPT='application/json; indent=8'))
- accepted_renderer, accepted_media_type = self.select_renderer(request)
- self.assertEqual(accepted_media_type, 'application/json; indent=8')
diff --git a/rest_framework/tests/test_nullable_fields.py b/rest_framework/tests/test_nullable_fields.py
deleted file mode 100644
index 6ee55c00..00000000
--- a/rest_framework/tests/test_nullable_fields.py
+++ /dev/null
@@ -1,30 +0,0 @@
-from django.core.urlresolvers import reverse
-
-from rest_framework.compat import patterns, url
-from rest_framework.test import APITestCase
-from rest_framework.tests.models import NullableForeignKeySource
-from rest_framework.tests.serializers import NullableFKSourceSerializer
-from rest_framework.tests.views import NullableFKSourceDetail
-
-
-urlpatterns = patterns(
- '',
- url(r'^objects/(?P\d+)/$', NullableFKSourceDetail.as_view(), name='object-detail'),
-)
-
-
-class NullableForeignKeyTests(APITestCase):
- """
- DRF should be able to handle nullable foreign keys when a test
- Client POST/PUT request is made with its own serialized object.
- """
- urls = 'rest_framework.tests.test_nullable_fields'
-
- def test_updating_object_with_null_fk(self):
- obj = NullableForeignKeySource(name='example', target=None)
- obj.save()
- serialized_data = NullableFKSourceSerializer(obj).data
-
- response = self.client.put(reverse('object-detail', args=[obj.pk]), serialized_data)
-
- self.assertEqual(response.data, serialized_data)
diff --git a/rest_framework/tests/test_pagination.py b/rest_framework/tests/test_pagination.py
deleted file mode 100644
index cadb515f..00000000
--- a/rest_framework/tests/test_pagination.py
+++ /dev/null
@@ -1,517 +0,0 @@
-from __future__ import unicode_literals
-import datetime
-from decimal import Decimal
-from django.db import models
-from django.core.paginator import Paginator
-from django.test import TestCase
-from django.utils import unittest
-from rest_framework import generics, status, pagination, filters, serializers
-from rest_framework.compat import django_filters
-from rest_framework.test import APIRequestFactory
-from rest_framework.tests.models import BasicModel
-
-factory = APIRequestFactory()
-
-
-class FilterableItem(models.Model):
- text = models.CharField(max_length=100)
- decimal = models.DecimalField(max_digits=4, decimal_places=2)
- date = models.DateField()
-
-
-class RootView(generics.ListCreateAPIView):
- """
- Example description for OPTIONS.
- """
- model = BasicModel
- paginate_by = 10
-
-
-class DefaultPageSizeKwargView(generics.ListAPIView):
- """
- View for testing default paginate_by_param usage
- """
- model = BasicModel
-
-
-class PaginateByParamView(generics.ListAPIView):
- """
- View for testing custom paginate_by_param usage
- """
- model = BasicModel
- paginate_by_param = 'page_size'
-
-
-class MaxPaginateByView(generics.ListAPIView):
- """
- View for testing custom max_paginate_by usage
- """
- model = BasicModel
- paginate_by = 3
- max_paginate_by = 5
- paginate_by_param = 'page_size'
-
-
-class IntegrationTestPagination(TestCase):
- """
- Integration tests for paginated list views.
- """
-
- def setUp(self):
- """
- Create 26 BasicModel instances.
- """
- for char in 'abcdefghijklmnopqrstuvwxyz':
- BasicModel(text=char * 3).save()
- self.objects = BasicModel.objects
- self.data = [
- {'id': obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
- self.view = RootView.as_view()
-
- def test_get_paginated_root_view(self):
- """
- GET requests to paginated ListCreateAPIView should return paginated results.
- """
- request = factory.get('/')
- # Note: Database queries are a `SELECT COUNT`, and `SELECT `
- with self.assertNumQueries(2):
- response = self.view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 26)
- self.assertEqual(response.data['results'], self.data[:10])
- self.assertNotEqual(response.data['next'], None)
- self.assertEqual(response.data['previous'], None)
-
- request = factory.get(response.data['next'])
- with self.assertNumQueries(2):
- response = self.view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 26)
- self.assertEqual(response.data['results'], self.data[10:20])
- self.assertNotEqual(response.data['next'], None)
- self.assertNotEqual(response.data['previous'], None)
-
- request = factory.get(response.data['next'])
- with self.assertNumQueries(2):
- response = self.view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 26)
- self.assertEqual(response.data['results'], self.data[20:])
- self.assertEqual(response.data['next'], None)
- self.assertNotEqual(response.data['previous'], None)
-
-
-class IntegrationTestPaginationAndFiltering(TestCase):
-
- def setUp(self):
- """
- Create 50 FilterableItem instances.
- """
- base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
- for i in range(26):
- text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
- decimal = base_data[1] + i
- date = base_data[2] - datetime.timedelta(days=i * 2)
- FilterableItem(text=text, decimal=decimal, date=date).save()
-
- self.objects = FilterableItem.objects
- self.data = [
- {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
- for obj in self.objects.all()
- ]
-
- @unittest.skipUnless(django_filters, 'django-filter not installed')
- def test_get_django_filter_paginated_filtered_root_view(self):
- """
- GET requests to paginated filtered ListCreateAPIView should return
- paginated results. The next and previous links should preserve the
- filtered parameters.
- """
- class DecimalFilter(django_filters.FilterSet):
- decimal = django_filters.NumberFilter(lookup_type='lt')
-
- class Meta:
- model = FilterableItem
- fields = ['text', 'decimal', 'date']
-
- class FilterFieldsRootView(generics.ListCreateAPIView):
- model = FilterableItem
- paginate_by = 10
- filter_class = DecimalFilter
- filter_backends = (filters.DjangoFilterBackend,)
-
- view = FilterFieldsRootView.as_view()
-
- EXPECTED_NUM_QUERIES = 2
-
- request = factory.get('/?decimal=15.20')
- with self.assertNumQueries(EXPECTED_NUM_QUERIES):
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 15)
- self.assertEqual(response.data['results'], self.data[:10])
- self.assertNotEqual(response.data['next'], None)
- self.assertEqual(response.data['previous'], None)
-
- request = factory.get(response.data['next'])
- with self.assertNumQueries(EXPECTED_NUM_QUERIES):
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 15)
- self.assertEqual(response.data['results'], self.data[10:15])
- self.assertEqual(response.data['next'], None)
- self.assertNotEqual(response.data['previous'], None)
-
- request = factory.get(response.data['previous'])
- with self.assertNumQueries(EXPECTED_NUM_QUERIES):
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 15)
- self.assertEqual(response.data['results'], self.data[:10])
- self.assertNotEqual(response.data['next'], None)
- self.assertEqual(response.data['previous'], None)
-
- def test_get_basic_paginated_filtered_root_view(self):
- """
- Same as `test_get_django_filter_paginated_filtered_root_view`,
- except using a custom filter backend instead of the django-filter
- backend,
- """
-
- class DecimalFilterBackend(filters.BaseFilterBackend):
- def filter_queryset(self, request, queryset, view):
- return queryset.filter(decimal__lt=Decimal(request.GET['decimal']))
-
- class BasicFilterFieldsRootView(generics.ListCreateAPIView):
- model = FilterableItem
- paginate_by = 10
- filter_backends = (DecimalFilterBackend,)
-
- view = BasicFilterFieldsRootView.as_view()
-
- request = factory.get('/?decimal=15.20')
- with self.assertNumQueries(2):
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 15)
- self.assertEqual(response.data['results'], self.data[:10])
- self.assertNotEqual(response.data['next'], None)
- self.assertEqual(response.data['previous'], None)
-
- request = factory.get(response.data['next'])
- with self.assertNumQueries(2):
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 15)
- self.assertEqual(response.data['results'], self.data[10:15])
- self.assertEqual(response.data['next'], None)
- self.assertNotEqual(response.data['previous'], None)
-
- request = factory.get(response.data['previous'])
- with self.assertNumQueries(2):
- response = view(request).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data['count'], 15)
- self.assertEqual(response.data['results'], self.data[:10])
- self.assertNotEqual(response.data['next'], None)
- self.assertEqual(response.data['previous'], None)
-
-
-class PassOnContextPaginationSerializer(pagination.PaginationSerializer):
- class Meta:
- object_serializer_class = serializers.Serializer
-
-
-class UnitTestPagination(TestCase):
- """
- Unit tests for pagination of primitive objects.
- """
-
- def setUp(self):
- self.objects = [char * 3 for char in 'abcdefghijklmnopqrstuvwxyz']
- paginator = Paginator(self.objects, 10)
- self.first_page = paginator.page(1)
- self.last_page = paginator.page(3)
-
- def test_native_pagination(self):
- serializer = pagination.PaginationSerializer(self.first_page)
- self.assertEqual(serializer.data['count'], 26)
- self.assertEqual(serializer.data['next'], '?page=2')
- self.assertEqual(serializer.data['previous'], None)
- self.assertEqual(serializer.data['results'], self.objects[:10])
-
- serializer = pagination.PaginationSerializer(self.last_page)
- self.assertEqual(serializer.data['count'], 26)
- self.assertEqual(serializer.data['next'], None)
- self.assertEqual(serializer.data['previous'], '?page=2')
- self.assertEqual(serializer.data['results'], self.objects[20:])
-
- def test_context_available_in_result(self):
- """
- Ensure context gets passed through to the object serializer.
- """
- serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'})
- serializer.data
- results = serializer.fields[serializer.results_field]
- self.assertEqual(serializer.context, results.context)
-
-
-class TestUnpaginated(TestCase):
- """
- Tests for list views without pagination.
- """
-
- def setUp(self):
- """
- Create 13 BasicModel instances.
- """
- for i in range(13):
- BasicModel(text=i).save()
- self.objects = BasicModel.objects
- self.data = [
- {'id': obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
- self.view = DefaultPageSizeKwargView.as_view()
-
- def test_unpaginated(self):
- """
- Tests the default page size for this view.
- no page size --> no limit --> no meta data
- """
- request = factory.get('/')
- response = self.view(request)
- self.assertEqual(response.data, self.data)
-
-
-class TestCustomPaginateByParam(TestCase):
- """
- Tests for list views with default page size kwarg
- """
-
- def setUp(self):
- """
- Create 13 BasicModel instances.
- """
- for i in range(13):
- BasicModel(text=i).save()
- self.objects = BasicModel.objects
- self.data = [
- {'id': obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
- self.view = PaginateByParamView.as_view()
-
- def test_default_page_size(self):
- """
- Tests the default page size for this view.
- no page size --> no limit --> no meta data
- """
- request = factory.get('/')
- response = self.view(request).render()
- self.assertEqual(response.data, self.data)
-
- def test_paginate_by_param(self):
- """
- If paginate_by_param is set, the new kwarg should limit per view requests.
- """
- request = factory.get('/?page_size=5')
- response = self.view(request).render()
- self.assertEqual(response.data['count'], 13)
- self.assertEqual(response.data['results'], self.data[:5])
-
-
-class TestMaxPaginateByParam(TestCase):
- """
- Tests for list views with max_paginate_by kwarg
- """
-
- def setUp(self):
- """
- Create 13 BasicModel instances.
- """
- for i in range(13):
- BasicModel(text=i).save()
- self.objects = BasicModel.objects
- self.data = [
- {'id': obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
- self.view = MaxPaginateByView.as_view()
-
- def test_max_paginate_by(self):
- """
- If max_paginate_by is set, it should limit page size for the view.
- """
- request = factory.get('/?page_size=10')
- response = self.view(request).render()
- self.assertEqual(response.data['count'], 13)
- self.assertEqual(response.data['results'], self.data[:5])
-
- def test_max_paginate_by_without_page_size_param(self):
- """
- If max_paginate_by is set, but client does not specifiy page_size,
- standard `paginate_by` behavior should be used.
- """
- request = factory.get('/')
- response = self.view(request).render()
- self.assertEqual(response.data['results'], self.data[:3])
-
-
-### Tests for context in pagination serializers
-
-class CustomField(serializers.Field):
- def to_native(self, value):
- if not 'view' in self.context:
- raise RuntimeError("context isn't getting passed into custom field")
- return "value"
-
-
-class BasicModelSerializer(serializers.Serializer):
- text = CustomField()
-
- def __init__(self, *args, **kwargs):
- super(BasicModelSerializer, self).__init__(*args, **kwargs)
- if not 'view' in self.context:
- raise RuntimeError("context isn't getting passed into serializer init")
-
-
-class TestContextPassedToCustomField(TestCase):
- def setUp(self):
- BasicModel.objects.create(text='ala ma kota')
-
- def test_with_pagination(self):
- class ListView(generics.ListCreateAPIView):
- model = BasicModel
- serializer_class = BasicModelSerializer
- paginate_by = 1
-
- self.view = ListView.as_view()
- request = factory.get('/')
- response = self.view(request).render()
-
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
-
-### Tests for custom pagination serializers
-
-class LinksSerializer(serializers.Serializer):
- next = pagination.NextPageField(source='*')
- prev = pagination.PreviousPageField(source='*')
-
-
-class CustomPaginationSerializer(pagination.BasePaginationSerializer):
- links = LinksSerializer(source='*') # Takes the page object as the source
- total_results = serializers.Field(source='paginator.count')
-
- results_field = 'objects'
-
-
-class TestCustomPaginationSerializer(TestCase):
- def setUp(self):
- objects = ['john', 'paul', 'george', 'ringo']
- paginator = Paginator(objects, 2)
- self.page = paginator.page(1)
-
- def test_custom_pagination_serializer(self):
- request = APIRequestFactory().get('/foobar')
- serializer = CustomPaginationSerializer(
- instance=self.page,
- context={'request': request}
- )
- expected = {
- 'links': {
- 'next': 'http://testserver/foobar?page=2',
- 'prev': None
- },
- 'total_results': 4,
- 'objects': ['john', 'paul']
- }
- self.assertEqual(serializer.data, expected)
-
-
-class NonIntegerPage(object):
-
- def __init__(self, paginator, object_list, prev_token, token, next_token):
- self.paginator = paginator
- self.object_list = object_list
- self.prev_token = prev_token
- self.token = token
- self.next_token = next_token
-
- def has_next(self):
- return not not self.next_token
-
- def next_page_number(self):
- return self.next_token
-
- def has_previous(self):
- return not not self.prev_token
-
- def previous_page_number(self):
- return self.prev_token
-
-
-class NonIntegerPaginator(object):
-
- def __init__(self, object_list, per_page):
- self.object_list = object_list
- self.per_page = per_page
-
- def count(self):
- # pretend like we don't know how many pages we have
- return None
-
- def page(self, token=None):
- if token:
- try:
- first = self.object_list.index(token)
- except ValueError:
- first = 0
- else:
- first = 0
- n = len(self.object_list)
- last = min(first + self.per_page, n)
- prev_token = self.object_list[last - (2 * self.per_page)] if first else None
- next_token = self.object_list[last] if last < n else None
- return NonIntegerPage(self, self.object_list[first:last], prev_token, token, next_token)
-
-
-class TestNonIntegerPagination(TestCase):
-
-
- def test_custom_pagination_serializer(self):
- objects = ['john', 'paul', 'george', 'ringo']
- paginator = NonIntegerPaginator(objects, 2)
-
- request = APIRequestFactory().get('/foobar')
- serializer = CustomPaginationSerializer(
- instance=paginator.page(),
- context={'request': request}
- )
- expected = {
- 'links': {
- 'next': 'http://testserver/foobar?page={0}'.format(objects[2]),
- 'prev': None
- },
- 'total_results': None,
- 'objects': objects[:2]
- }
- self.assertEqual(serializer.data, expected)
-
- request = APIRequestFactory().get('/foobar')
- serializer = CustomPaginationSerializer(
- instance=paginator.page('george'),
- context={'request': request}
- )
- expected = {
- 'links': {
- 'next': None,
- 'prev': 'http://testserver/foobar?page={0}'.format(objects[0]),
- },
- 'total_results': None,
- 'objects': objects[2:]
- }
- self.assertEqual(serializer.data, expected)
diff --git a/rest_framework/tests/test_parsers.py b/rest_framework/tests/test_parsers.py
deleted file mode 100644
index 7699e10c..00000000
--- a/rest_framework/tests/test_parsers.py
+++ /dev/null
@@ -1,115 +0,0 @@
-from __future__ import unicode_literals
-from rest_framework.compat import StringIO
-from django import forms
-from django.core.files.uploadhandler import MemoryFileUploadHandler
-from django.test import TestCase
-from django.utils import unittest
-from rest_framework.compat import etree
-from rest_framework.parsers import FormParser, FileUploadParser
-from rest_framework.parsers import XMLParser
-import datetime
-
-
-class Form(forms.Form):
- field1 = forms.CharField(max_length=3)
- field2 = forms.CharField()
-
-
-class TestFormParser(TestCase):
- def setUp(self):
- self.string = "field1=abc&field2=defghijk"
-
- def test_parse(self):
- """ Make sure the `QueryDict` works OK """
- parser = FormParser()
-
- stream = StringIO(self.string)
- data = parser.parse(stream)
-
- self.assertEqual(Form(data).is_valid(), True)
-
-
-class TestXMLParser(TestCase):
- def setUp(self):
- self._input = StringIO(
- ''
- ''
- '121.0'
- 'dasd'
- ''
- '2011-12-25 12:45:00'
- ''
- )
- self._data = {
- 'field_a': 121,
- 'field_b': 'dasd',
- 'field_c': None,
- 'field_d': datetime.datetime(2011, 12, 25, 12, 45, 00)
- }
- self._complex_data_input = StringIO(
- ''
- ''
- '2011-12-25 12:45:00'
- ''
- '1first'
- '2second'
- ''
- 'name'
- ''
- )
- self._complex_data = {
- "creation_date": datetime.datetime(2011, 12, 25, 12, 45, 00),
- "name": "name",
- "sub_data_list": [
- {
- "sub_id": 1,
- "sub_name": "first"
- },
- {
- "sub_id": 2,
- "sub_name": "second"
- }
- ]
- }
-
- @unittest.skipUnless(etree, 'defusedxml not installed')
- def test_parse(self):
- parser = XMLParser()
- data = parser.parse(self._input)
- self.assertEqual(data, self._data)
-
- @unittest.skipUnless(etree, 'defusedxml not installed')
- def test_complex_data_parse(self):
- parser = XMLParser()
- data = parser.parse(self._complex_data_input)
- self.assertEqual(data, self._complex_data)
-
-
-class TestFileUploadParser(TestCase):
- def setUp(self):
- class MockRequest(object):
- pass
- from io import BytesIO
- self.stream = BytesIO(
- "Test text file".encode('utf-8')
- )
- request = MockRequest()
- request.upload_handlers = (MemoryFileUploadHandler(),)
- request.META = {
- 'HTTP_CONTENT_DISPOSITION': 'Content-Disposition: inline; filename=file.txt'.encode('utf-8'),
- 'HTTP_CONTENT_LENGTH': 14,
- }
- self.parser_context = {'request': request, 'kwargs': {}}
-
- def test_parse(self):
- """ Make sure the `QueryDict` works OK """
- parser = FileUploadParser()
- self.stream.seek(0)
- data_and_files = parser.parse(self.stream, None, self.parser_context)
- file_obj = data_and_files.files['file']
- self.assertEqual(file_obj._size, 14)
-
- def test_get_filename(self):
- parser = FileUploadParser()
- filename = parser.get_filename(self.stream, None, self.parser_context)
- self.assertEqual(filename, 'file.txt'.encode('utf-8'))
diff --git a/rest_framework/tests/test_permissions.py b/rest_framework/tests/test_permissions.py
deleted file mode 100644
index 6e3a6303..00000000
--- a/rest_framework/tests/test_permissions.py
+++ /dev/null
@@ -1,300 +0,0 @@
-from __future__ import unicode_literals
-from django.contrib.auth.models import User, Permission, Group
-from django.db import models
-from django.test import TestCase
-from django.utils import unittest
-from rest_framework import generics, status, permissions, authentication, HTTP_HEADER_ENCODING
-from rest_framework.compat import guardian, get_model_name
-from rest_framework.filters import DjangoObjectPermissionsFilter
-from rest_framework.test import APIRequestFactory
-from rest_framework.tests.models import BasicModel
-import base64
-
-factory = APIRequestFactory()
-
-class RootView(generics.ListCreateAPIView):
- model = BasicModel
- authentication_classes = [authentication.BasicAuthentication]
- permission_classes = [permissions.DjangoModelPermissions]
-
-
-class InstanceView(generics.RetrieveUpdateDestroyAPIView):
- model = BasicModel
- authentication_classes = [authentication.BasicAuthentication]
- permission_classes = [permissions.DjangoModelPermissions]
-
-root_view = RootView.as_view()
-instance_view = InstanceView.as_view()
-
-
-def basic_auth_header(username, password):
- credentials = ('%s:%s' % (username, password))
- base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING)
- return 'Basic %s' % base64_credentials
-
-
-class ModelPermissionsIntegrationTests(TestCase):
- def setUp(self):
- User.objects.create_user('disallowed', 'disallowed@example.com', 'password')
- user = User.objects.create_user('permitted', 'permitted@example.com', 'password')
- user.user_permissions = [
- Permission.objects.get(codename='add_basicmodel'),
- Permission.objects.get(codename='change_basicmodel'),
- Permission.objects.get(codename='delete_basicmodel')
- ]
- user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password')
- user.user_permissions = [
- Permission.objects.get(codename='change_basicmodel'),
- ]
-
- self.permitted_credentials = basic_auth_header('permitted', 'password')
- self.disallowed_credentials = basic_auth_header('disallowed', 'password')
- self.updateonly_credentials = basic_auth_header('updateonly', 'password')
-
- BasicModel(text='foo').save()
-
- def test_has_create_permissions(self):
- request = factory.post('/', {'text': 'foobar'}, format='json',
- HTTP_AUTHORIZATION=self.permitted_credentials)
- response = root_view(request, pk=1)
- self.assertEqual(response.status_code, status.HTTP_201_CREATED)
-
- def test_has_put_permissions(self):
- request = factory.put('/1', {'text': 'foobar'}, format='json',
- HTTP_AUTHORIZATION=self.permitted_credentials)
- response = instance_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
- def test_has_delete_permissions(self):
- request = factory.delete('/1', HTTP_AUTHORIZATION=self.permitted_credentials)
- response = instance_view(request, pk=1)
- self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
-
- def test_does_not_have_create_permissions(self):
- request = factory.post('/', {'text': 'foobar'}, format='json',
- HTTP_AUTHORIZATION=self.disallowed_credentials)
- response = root_view(request, pk=1)
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
-
- def test_does_not_have_put_permissions(self):
- request = factory.put('/1', {'text': 'foobar'}, format='json',
- HTTP_AUTHORIZATION=self.disallowed_credentials)
- response = instance_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
-
- def test_does_not_have_delete_permissions(self):
- request = factory.delete('/1', HTTP_AUTHORIZATION=self.disallowed_credentials)
- response = instance_view(request, pk=1)
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
-
- def test_has_put_as_create_permissions(self):
- # User only has update permissions - should be able to update an entity.
- request = factory.put('/1', {'text': 'foobar'}, format='json',
- HTTP_AUTHORIZATION=self.updateonly_credentials)
- response = instance_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
- # But if PUTing to a new entity, permission should be denied.
- request = factory.put('/2', {'text': 'foobar'}, format='json',
- HTTP_AUTHORIZATION=self.updateonly_credentials)
- response = instance_view(request, pk='2')
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
-
- def test_options_permitted(self):
- request = factory.options('/',
- HTTP_AUTHORIZATION=self.permitted_credentials)
- response = root_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertIn('actions', response.data)
- self.assertEqual(list(response.data['actions'].keys()), ['POST'])
-
- request = factory.options('/1',
- HTTP_AUTHORIZATION=self.permitted_credentials)
- response = instance_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertIn('actions', response.data)
- self.assertEqual(list(response.data['actions'].keys()), ['PUT'])
-
- def test_options_disallowed(self):
- request = factory.options('/',
- HTTP_AUTHORIZATION=self.disallowed_credentials)
- response = root_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertNotIn('actions', response.data)
-
- request = factory.options('/1',
- HTTP_AUTHORIZATION=self.disallowed_credentials)
- response = instance_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertNotIn('actions', response.data)
-
- def test_options_updateonly(self):
- request = factory.options('/',
- HTTP_AUTHORIZATION=self.updateonly_credentials)
- response = root_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertNotIn('actions', response.data)
-
- request = factory.options('/1',
- HTTP_AUTHORIZATION=self.updateonly_credentials)
- response = instance_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertIn('actions', response.data)
- self.assertEqual(list(response.data['actions'].keys()), ['PUT'])
-
-
-class BasicPermModel(models.Model):
- text = models.CharField(max_length=100)
-
- class Meta:
- app_label = 'tests'
- permissions = (
- ('view_basicpermmodel', 'Can view basic perm model'),
- # add, change, delete built in to django
- )
-
-# Custom object-level permission, that includes 'view' permissions
-class ViewObjectPermissions(permissions.DjangoObjectPermissions):
- perms_map = {
- 'GET': ['%(app_label)s.view_%(model_name)s'],
- 'OPTIONS': ['%(app_label)s.view_%(model_name)s'],
- 'HEAD': ['%(app_label)s.view_%(model_name)s'],
- 'POST': ['%(app_label)s.add_%(model_name)s'],
- 'PUT': ['%(app_label)s.change_%(model_name)s'],
- 'PATCH': ['%(app_label)s.change_%(model_name)s'],
- 'DELETE': ['%(app_label)s.delete_%(model_name)s'],
- }
-
-
-class ObjectPermissionInstanceView(generics.RetrieveUpdateDestroyAPIView):
- model = BasicPermModel
- authentication_classes = [authentication.BasicAuthentication]
- permission_classes = [ViewObjectPermissions]
-
-object_permissions_view = ObjectPermissionInstanceView.as_view()
-
-
-class ObjectPermissionListView(generics.ListAPIView):
- model = BasicPermModel
- authentication_classes = [authentication.BasicAuthentication]
- permission_classes = [ViewObjectPermissions]
-
-object_permissions_list_view = ObjectPermissionListView.as_view()
-
-
-@unittest.skipUnless(guardian, 'django-guardian not installed')
-class ObjectPermissionsIntegrationTests(TestCase):
- """
- Integration tests for the object level permissions API.
- """
- @classmethod
- def setUpClass(cls):
- from guardian.shortcuts import assign_perm
-
- # create users
- create = User.objects.create_user
- users = {
- 'fullaccess': create('fullaccess', 'fullaccess@example.com', 'password'),
- 'readonly': create('readonly', 'readonly@example.com', 'password'),
- 'writeonly': create('writeonly', 'writeonly@example.com', 'password'),
- 'deleteonly': create('deleteonly', 'deleteonly@example.com', 'password'),
- }
-
- # give everyone model level permissions, as we are not testing those
- everyone = Group.objects.create(name='everyone')
- model_name = get_model_name(BasicPermModel)
- app_label = BasicPermModel._meta.app_label
- f = '{0}_{1}'.format
- perms = {
- 'view': f('view', model_name),
- 'change': f('change', model_name),
- 'delete': f('delete', model_name)
- }
- for perm in perms.values():
- perm = '{0}.{1}'.format(app_label, perm)
- assign_perm(perm, everyone)
- everyone.user_set.add(*users.values())
-
- cls.perms = perms
- cls.users = users
-
- def setUp(self):
- from guardian.shortcuts import assign_perm
- perms = self.perms
- users = self.users
-
- # appropriate object level permissions
- readers = Group.objects.create(name='readers')
- writers = Group.objects.create(name='writers')
- deleters = Group.objects.create(name='deleters')
-
- model = BasicPermModel.objects.create(text='foo')
-
- assign_perm(perms['view'], readers, model)
- assign_perm(perms['change'], writers, model)
- assign_perm(perms['delete'], deleters, model)
-
- readers.user_set.add(users['fullaccess'], users['readonly'])
- writers.user_set.add(users['fullaccess'], users['writeonly'])
- deleters.user_set.add(users['fullaccess'], users['deleteonly'])
-
- self.credentials = {}
- for user in users.values():
- self.credentials[user.username] = basic_auth_header(user.username, 'password')
-
- # Delete
- def test_can_delete_permissions(self):
- request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['deleteonly'])
- response = object_permissions_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
-
- def test_cannot_delete_permissions(self):
- request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
- response = object_permissions_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
-
- # Update
- def test_can_update_permissions(self):
- request = factory.patch('/1', {'text': 'foobar'}, format='json',
- HTTP_AUTHORIZATION=self.credentials['writeonly'])
- response = object_permissions_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data.get('text'), 'foobar')
-
- def test_cannot_update_permissions(self):
- request = factory.patch('/1', {'text': 'foobar'}, format='json',
- HTTP_AUTHORIZATION=self.credentials['deleteonly'])
- response = object_permissions_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
-
- def test_cannot_update_permissions_non_existing(self):
- request = factory.patch('/999', {'text': 'foobar'}, format='json',
- HTTP_AUTHORIZATION=self.credentials['deleteonly'])
- response = object_permissions_view(request, pk='999')
- self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
-
- # Read
- def test_can_read_permissions(self):
- request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
- response = object_permissions_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
- def test_cannot_read_permissions(self):
- request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['writeonly'])
- response = object_permissions_view(request, pk='1')
- self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
-
- # Read list
- def test_can_read_list_permissions(self):
- request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['readonly'])
- object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,)
- response = object_permissions_list_view(request)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data[0].get('id'), 1)
-
- def test_cannot_read_list_permissions(self):
- request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['writeonly'])
- object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,)
- response = object_permissions_list_view(request)
- self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertListEqual(response.data, [])
diff --git a/rest_framework/tests/test_relations.py b/rest_framework/tests/test_relations.py
deleted file mode 100644
index f52e0e1e..00000000
--- a/rest_framework/tests/test_relations.py
+++ /dev/null
@@ -1,120 +0,0 @@
-"""
-General tests for relational fields.
-"""
-from __future__ import unicode_literals
-from django.db import models
-from django.test import TestCase
-from rest_framework import serializers
-from rest_framework.tests.models import BlogPost
-
-
-class NullModel(models.Model):
- pass
-
-
-class FieldTests(TestCase):
- def test_pk_related_field_with_empty_string(self):
- """
- Regression test for #446
-
- https://github.com/tomchristie/django-rest-framework/issues/446
- """
- field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all())
- self.assertRaises(serializers.ValidationError, field.from_native, '')
- self.assertRaises(serializers.ValidationError, field.from_native, [])
-
- def test_hyperlinked_related_field_with_empty_string(self):
- field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='')
- self.assertRaises(serializers.ValidationError, field.from_native, '')
- self.assertRaises(serializers.ValidationError, field.from_native, [])
-
- def test_slug_related_field_with_empty_string(self):
- field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk')
- self.assertRaises(serializers.ValidationError, field.from_native, '')
- self.assertRaises(serializers.ValidationError, field.from_native, [])
-
-
-class TestManyRelatedMixin(TestCase):
- def test_missing_many_to_many_related_field(self):
- '''
- Regression test for #632
-
- https://github.com/tomchristie/django-rest-framework/pull/632
- '''
- field = serializers.RelatedField(many=True, read_only=False)
-
- into = {}
- field.field_from_native({}, None, 'field_name', into)
- self.assertEqual(into['field_name'], [])
-
-
-# Regression tests for #694 (`source` attribute on related fields)
-
-class RelatedFieldSourceTests(TestCase):
- def test_related_manager_source(self):
- """
- Relational fields should be able to use manager-returning methods as their source.
- """
- BlogPost.objects.create(title='blah')
- field = serializers.RelatedField(many=True, source='get_blogposts_manager')
-
- class ClassWithManagerMethod(object):
- def get_blogposts_manager(self):
- return BlogPost.objects
-
- obj = ClassWithManagerMethod()
- value = field.field_to_native(obj, 'field_name')
- self.assertEqual(value, ['BlogPost object'])
-
- def test_related_queryset_source(self):
- """
- Relational fields should be able to use queryset-returning methods as their source.
- """
- BlogPost.objects.create(title='blah')
- field = serializers.RelatedField(many=True, source='get_blogposts_queryset')
-
- class ClassWithQuerysetMethod(object):
- def get_blogposts_queryset(self):
- return BlogPost.objects.all()
-
- obj = ClassWithQuerysetMethod()
- value = field.field_to_native(obj, 'field_name')
- self.assertEqual(value, ['BlogPost object'])
-
- def test_dotted_source(self):
- """
- Source argument should support dotted.source notation.
- """
- BlogPost.objects.create(title='blah')
- field = serializers.RelatedField(many=True, source='a.b.c')
-
- class ClassWithQuerysetMethod(object):
- a = {
- 'b': {
- 'c': BlogPost.objects.all()
- }
- }
-
- obj = ClassWithQuerysetMethod()
- value = field.field_to_native(obj, 'field_name')
- self.assertEqual(value, ['BlogPost object'])
-
- # Regression for #1129
- def test_exception_for_incorect_fk(self):
- """
- Check that the exception message are correct if the source field
- doesn't exist.
- """
- from rest_framework.tests.models import ManyToManySource
- class Meta:
- model = ManyToManySource
- attrs = {
- 'name': serializers.SlugRelatedField(
- slug_field='name', source='banzai'),
- 'Meta': Meta,
- }
-
- TestSerializer = type(str('TestSerializer'),
- (serializers.ModelSerializer,), attrs)
- with self.assertRaises(AttributeError):
- TestSerializer(data={'name': 'foo'})
diff --git a/rest_framework/tests/test_relations_hyperlink.py b/rest_framework/tests/test_relations_hyperlink.py
deleted file mode 100644
index 3c4d39af..00000000
--- a/rest_framework/tests/test_relations_hyperlink.py
+++ /dev/null
@@ -1,524 +0,0 @@
-from __future__ import unicode_literals
-from django.test import TestCase
-from rest_framework import serializers
-from rest_framework.compat import patterns, url
-from rest_framework.test import APIRequestFactory
-from rest_framework.tests.models import (
- BlogPost,
- ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource,
- NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
-)
-
-factory = APIRequestFactory()
-request = factory.get('/') # Just to ensure we have a request in the serializer context
-
-
-def dummy_view(request, pk):
- pass
-
-urlpatterns = patterns('',
- url(r'^dummyurl/(?P[0-9]+)/$', dummy_view, name='dummy-url'),
- url(r'^manytomanysource/(?P[0-9]+)/$', dummy_view, name='manytomanysource-detail'),
- url(r'^manytomanytarget/(?P[0-9]+)/$', dummy_view, name='manytomanytarget-detail'),
- url(r'^foreignkeysource/(?P[0-9]+)/$', dummy_view, name='foreignkeysource-detail'),
- url(r'^foreignkeytarget/(?P[0-9]+)/$', dummy_view, name='foreignkeytarget-detail'),
- url(r'^nullableforeignkeysource/(?P[0-9]+)/$', dummy_view, name='nullableforeignkeysource-detail'),
- url(r'^onetoonetarget/(?P[0-9]+)/$', dummy_view, name='onetoonetarget-detail'),
- url(r'^nullableonetoonesource/(?P[0-9]+)/$', dummy_view, name='nullableonetoonesource-detail'),
-)
-
-
-# ManyToMany
-class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer):
- class Meta:
- model = ManyToManyTarget
- fields = ('url', 'name', 'sources')
-
-
-class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer):
- class Meta:
- model = ManyToManySource
- fields = ('url', 'name', 'targets')
-
-
-# ForeignKey
-class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer):
- class Meta:
- model = ForeignKeyTarget
- fields = ('url', 'name', 'sources')
-
-
-class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
- class Meta:
- model = ForeignKeySource
- fields = ('url', 'name', 'target')
-
-
-# Nullable ForeignKey
-class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
- class Meta:
- model = NullableForeignKeySource
- fields = ('url', 'name', 'target')
-
-
-# Nullable OneToOne
-class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer):
- class Meta:
- model = OneToOneTarget
- fields = ('url', 'name', 'nullable_source')
-
-
-# TODO: Add test that .data cannot be accessed prior to .is_valid
-
-class HyperlinkedManyToManyTests(TestCase):
- urls = 'rest_framework.tests.test_relations_hyperlink'
-
- def setUp(self):
- for idx in range(1, 4):
- target = ManyToManyTarget(name='target-%d' % idx)
- target.save()
- source = ManyToManySource(name='source-%d' % idx)
- source.save()
- for target in ManyToManyTarget.objects.all():
- source.targets.add(target)
-
- def test_many_to_many_retrieve(self):
- queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']},
- {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
- {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_reverse_many_to_many_retrieve(self):
- queryset = ManyToManyTarget.objects.all()
- serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
- {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
- {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_many_to_many_update(self):
- data = {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
- instance = ManyToManySource.objects.get(pk=1)
- serializer = ManyToManySourceSerializer(instance, data=data, context={'request': request})
- self.assertTrue(serializer.is_valid())
- serializer.save()
- self.assertEqual(serializer.data, data)
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']},
- {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
- {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_reverse_many_to_many_update(self):
- data = {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']}
- instance = ManyToManyTarget.objects.get(pk=1)
- serializer = ManyToManyTargetSerializer(instance, data=data, context={'request': request})
- self.assertTrue(serializer.is_valid())
- serializer.save()
- self.assertEqual(serializer.data, data)
-
- # Ensure target 1 is updated, and everything else is as expected
- queryset = ManyToManyTarget.objects.all()
- serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']},
- {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
- {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}
-
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_many_to_many_create(self):
- data = {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']}
- serializer = ManyToManySourceSerializer(data=data, context={'request': request})
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'source-4')
-
- # Ensure source 4 is added, and everything else is as expected
- queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']},
- {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
- {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']},
- {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_reverse_many_to_many_create(self):
- data = {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']}
- serializer = ManyToManyTargetSerializer(data=data, context={'request': request})
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'target-4')
-
- # Ensure target 4 is added, and everything else is as expected
- queryset = ManyToManyTarget.objects.all()
- serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
- {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
- {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']},
- {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']}
- ]
- self.assertEqual(serializer.data, expected)
-
-
-class HyperlinkedForeignKeyTests(TestCase):
- urls = 'rest_framework.tests.test_relations_hyperlink'
-
- def setUp(self):
- target = ForeignKeyTarget(name='target-1')
- target.save()
- new_target = ForeignKeyTarget(name='target-2')
- new_target.save()
- for idx in range(1, 4):
- source = ForeignKeySource(name='source-%d' % idx, target=target)
- source.save()
-
- def test_foreign_key_retrieve(self):
- queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_reverse_foreign_key_retrieve(self):
- queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},
- {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_update(self):
- data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'}
- instance = ForeignKeySource.objects.get(pk=1)
- serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
- self.assertTrue(serializer.is_valid())
- self.assertEqual(serializer.data, data)
- serializer.save()
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'},
- {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_update_incorrect_type(self):
- data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 2}
- instance = ForeignKeySource.objects.get(pk=1)
- serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
- self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected url string, received int.']})
-
- def test_reverse_foreign_key_update(self):
- data = {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}
- instance = ForeignKeyTarget.objects.get(pk=2)
- serializer = ForeignKeyTargetSerializer(instance, data=data, context={'request': request})
- self.assertTrue(serializer.is_valid())
- # We shouldn't have saved anything to the db yet since save
- # hasn't been called.
- queryset = ForeignKeyTarget.objects.all()
- new_serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},
- {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
- ]
- self.assertEqual(new_serializer.data, expected)
-
- serializer.save()
- self.assertEqual(serializer.data, data)
-
- # Ensure target 2 is update, and everything else is as expected
- queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']},
- {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_create(self):
- data = {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'}
- serializer = ForeignKeySourceSerializer(data=data, context={'request': request})
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'source-4')
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_reverse_foreign_key_create(self):
- data = {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}
- serializer = ForeignKeyTargetSerializer(data=data, context={'request': request})
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'target-3')
-
- # Ensure target 4 is added, and everything else is as expected
- queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']},
- {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
- {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_update_with_invalid_null(self):
- data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': None}
- instance = ForeignKeySource.objects.get(pk=1)
- serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
- self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'target': ['This field is required.']})
-
-
-class HyperlinkedNullableForeignKeyTests(TestCase):
- urls = 'rest_framework.tests.test_relations_hyperlink'
-
- def setUp(self):
- target = ForeignKeyTarget(name='target-1')
- target.save()
- for idx in range(1, 4):
- if idx == 3:
- target = None
- source = NullableForeignKeySource(name='source-%d' % idx, target=target)
- source.save()
-
- def test_foreign_key_retrieve_with_null(self):
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_create_with_valid_null(self):
- data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
- serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request})
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'source-4')
-
- # Ensure source 4 is created, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
- {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_create_with_valid_emptystring(self):
- """
- The emptystring should be interpreted as null in the context
- of relationships.
- """
- data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': ''}
- expected_data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
- serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request})
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, expected_data)
- self.assertEqual(obj.name, 'source-4')
-
- # Ensure source 4 is created, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
- {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_update_with_valid_null(self):
- data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}
- instance = NullableForeignKeySource.objects.get(pk=1)
- serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request})
- self.assertTrue(serializer.is_valid())
- self.assertEqual(serializer.data, data)
- serializer.save()
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None},
- {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_update_with_valid_emptystring(self):
- """
- The emptystring should be interpreted as null in the context
- of relationships.
- """
- data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': ''}
- expected_data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}
- instance = NullableForeignKeySource.objects.get(pk=1)
- serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request})
- self.assertTrue(serializer.is_valid())
- self.assertEqual(serializer.data, expected_data)
- serializer.save()
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None},
- {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
- ]
- self.assertEqual(serializer.data, expected)
-
- # reverse foreign keys MUST be read_only
- # In the general case they do not provide .remove() or .clear()
- # and cannot be arbitrarily set.
-
- # def test_reverse_foreign_key_update(self):
- # data = {'id': 1, 'name': 'target-1', 'sources': [1]}
- # instance = ForeignKeyTarget.objects.get(pk=1)
- # serializer = ForeignKeyTargetSerializer(instance, data=data)
- # self.assertTrue(serializer.is_valid())
- # self.assertEqual(serializer.data, data)
- # serializer.save()
-
- # # Ensure target 1 is updated, and everything else is as expected
- # queryset = ForeignKeyTarget.objects.all()
- # serializer = ForeignKeyTargetSerializer(queryset, many=True)
- # expected = [
- # {'id': 1, 'name': 'target-1', 'sources': [1]},
- # {'id': 2, 'name': 'target-2', 'sources': []},
- # ]
- # self.assertEqual(serializer.data, expected)
-
-
-class HyperlinkedNullableOneToOneTests(TestCase):
- urls = 'rest_framework.tests.test_relations_hyperlink'
-
- def setUp(self):
- target = OneToOneTarget(name='target-1')
- target.save()
- new_target = OneToOneTarget(name='target-2')
- new_target.save()
- source = NullableOneToOneSource(name='source-1', target=target)
- source.save()
-
- def test_reverse_foreign_key_retrieve_with_null(self):
- queryset = OneToOneTarget.objects.all()
- serializer = NullableOneToOneTargetSerializer(queryset, many=True, context={'request': request})
- expected = [
- {'url': 'http://testserver/onetoonetarget/1/', 'name': 'target-1', 'nullable_source': 'http://testserver/nullableonetoonesource/1/'},
- {'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None},
- ]
- self.assertEqual(serializer.data, expected)
-
-
-# Regression tests for #694 (`source` attribute on related fields)
-
-class HyperlinkedRelatedFieldSourceTests(TestCase):
- urls = 'rest_framework.tests.test_relations_hyperlink'
-
- def test_related_manager_source(self):
- """
- Relational fields should be able to use manager-returning methods as their source.
- """
- BlogPost.objects.create(title='blah')
- field = serializers.HyperlinkedRelatedField(
- many=True,
- source='get_blogposts_manager',
- view_name='dummy-url',
- )
- field.context = {'request': request}
-
- class ClassWithManagerMethod(object):
- def get_blogposts_manager(self):
- return BlogPost.objects
-
- obj = ClassWithManagerMethod()
- value = field.field_to_native(obj, 'field_name')
- self.assertEqual(value, ['http://testserver/dummyurl/1/'])
-
- def test_related_queryset_source(self):
- """
- Relational fields should be able to use queryset-returning methods as their source.
- """
- BlogPost.objects.create(title='blah')
- field = serializers.HyperlinkedRelatedField(
- many=True,
- source='get_blogposts_queryset',
- view_name='dummy-url',
- )
- field.context = {'request': request}
-
- class ClassWithQuerysetMethod(object):
- def get_blogposts_queryset(self):
- return BlogPost.objects.all()
-
- obj = ClassWithQuerysetMethod()
- value = field.field_to_native(obj, 'field_name')
- self.assertEqual(value, ['http://testserver/dummyurl/1/'])
-
- def test_dotted_source(self):
- """
- Source argument should support dotted.source notation.
- """
- BlogPost.objects.create(title='blah')
- field = serializers.HyperlinkedRelatedField(
- many=True,
- source='a.b.c',
- view_name='dummy-url',
- )
- field.context = {'request': request}
-
- class ClassWithQuerysetMethod(object):
- a = {
- 'b': {
- 'c': BlogPost.objects.all()
- }
- }
-
- obj = ClassWithQuerysetMethod()
- value = field.field_to_native(obj, 'field_name')
- self.assertEqual(value, ['http://testserver/dummyurl/1/'])
diff --git a/rest_framework/tests/test_relations_nested.py b/rest_framework/tests/test_relations_nested.py
deleted file mode 100644
index d393b0c3..00000000
--- a/rest_framework/tests/test_relations_nested.py
+++ /dev/null
@@ -1,328 +0,0 @@
-from __future__ import unicode_literals
-from django.db import models
-from django.test import TestCase
-from rest_framework import serializers
-
-
-class OneToOneTarget(models.Model):
- name = models.CharField(max_length=100)
-
-
-class OneToOneSource(models.Model):
- name = models.CharField(max_length=100)
- target = models.OneToOneField(OneToOneTarget, related_name='source',
- null=True, blank=True)
-
-
-class OneToManyTarget(models.Model):
- name = models.CharField(max_length=100)
-
-
-class OneToManySource(models.Model):
- name = models.CharField(max_length=100)
- target = models.ForeignKey(OneToManyTarget, related_name='sources')
-
-
-class ReverseNestedOneToOneTests(TestCase):
- def setUp(self):
- class OneToOneSourceSerializer(serializers.ModelSerializer):
- class Meta:
- model = OneToOneSource
- fields = ('id', 'name')
-
- class OneToOneTargetSerializer(serializers.ModelSerializer):
- source = OneToOneSourceSerializer()
-
- class Meta:
- model = OneToOneTarget
- fields = ('id', 'name', 'source')
-
- self.Serializer = OneToOneTargetSerializer
-
- for idx in range(1, 4):
- target = OneToOneTarget(name='target-%d' % idx)
- target.save()
- source = OneToOneSource(name='source-%d' % idx, target=target)
- source.save()
-
- def test_one_to_one_retrieve(self):
- queryset = OneToOneTarget.objects.all()
- serializer = self.Serializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
- {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
- {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_one_to_one_create(self):
- data = {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}}
- serializer = self.Serializer(data=data)
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'target-4')
-
- # Ensure (target 4, target_source 4, source 4) are added, and
- # everything else is as expected.
- queryset = OneToOneTarget.objects.all()
- serializer = self.Serializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
- {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
- {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}},
- {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_one_to_one_create_with_invalid_data(self):
- data = {'id': 4, 'name': 'target-4', 'source': {'id': 4}}
- serializer = self.Serializer(data=data)
- self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'source': [{'name': ['This field is required.']}]})
-
- def test_one_to_one_update(self):
- data = {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}
- instance = OneToOneTarget.objects.get(pk=3)
- serializer = self.Serializer(instance, data=data)
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'target-3-updated')
-
- # Ensure (target 3, target_source 3, source 3) are updated,
- # and everything else is as expected.
- queryset = OneToOneTarget.objects.all()
- serializer = self.Serializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
- {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
- {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}
- ]
- self.assertEqual(serializer.data, expected)
-
-
-class ForwardNestedOneToOneTests(TestCase):
- def setUp(self):
- class OneToOneTargetSerializer(serializers.ModelSerializer):
- class Meta:
- model = OneToOneTarget
- fields = ('id', 'name')
-
- class OneToOneSourceSerializer(serializers.ModelSerializer):
- target = OneToOneTargetSerializer()
-
- class Meta:
- model = OneToOneSource
- fields = ('id', 'name', 'target')
-
- self.Serializer = OneToOneSourceSerializer
-
- for idx in range(1, 4):
- target = OneToOneTarget(name='target-%d' % idx)
- target.save()
- source = OneToOneSource(name='source-%d' % idx, target=target)
- source.save()
-
- def test_one_to_one_retrieve(self):
- queryset = OneToOneSource.objects.all()
- serializer = self.Serializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
- {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
- {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_one_to_one_create(self):
- data = {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}}
- serializer = self.Serializer(data=data)
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'source-4')
-
- # Ensure (target 4, target_source 4, source 4) are added, and
- # everything else is as expected.
- queryset = OneToOneSource.objects.all()
- serializer = self.Serializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
- {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
- {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}},
- {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_one_to_one_create_with_invalid_data(self):
- data = {'id': 4, 'name': 'source-4', 'target': {'id': 4}}
- serializer = self.Serializer(data=data)
- self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'target': [{'name': ['This field is required.']}]})
-
- def test_one_to_one_update(self):
- data = {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}}
- instance = OneToOneSource.objects.get(pk=3)
- serializer = self.Serializer(instance, data=data)
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'source-3-updated')
-
- # Ensure (target 3, target_source 3, source 3) are updated,
- # and everything else is as expected.
- queryset = OneToOneSource.objects.all()
- serializer = self.Serializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
- {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
- {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_one_to_one_update_to_null(self):
- data = {'id': 3, 'name': 'source-3-updated', 'target': None}
- instance = OneToOneSource.objects.get(pk=3)
- serializer = self.Serializer(instance, data=data)
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
-
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'source-3-updated')
- self.assertEqual(obj.target, None)
-
- queryset = OneToOneSource.objects.all()
- serializer = self.Serializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
- {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
- {'id': 3, 'name': 'source-3-updated', 'target': None}
- ]
- self.assertEqual(serializer.data, expected)
-
- # TODO: Nullable 1-1 tests
- # def test_one_to_one_delete(self):
- # data = {'id': 3, 'name': 'target-3', 'target_source': None}
- # instance = OneToOneTarget.objects.get(pk=3)
- # serializer = self.Serializer(instance, data=data)
- # self.assertTrue(serializer.is_valid())
- # serializer.save()
-
- # # Ensure (target_source 3, source 3) are deleted,
- # # and everything else is as expected.
- # queryset = OneToOneTarget.objects.all()
- # serializer = self.Serializer(queryset)
- # expected = [
- # {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
- # {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
- # {'id': 3, 'name': 'target-3', 'source': None}
- # ]
- # self.assertEqual(serializer.data, expected)
-
-
-class ReverseNestedOneToManyTests(TestCase):
- def setUp(self):
- class OneToManySourceSerializer(serializers.ModelSerializer):
- class Meta:
- model = OneToManySource
- fields = ('id', 'name')
-
- class OneToManyTargetSerializer(serializers.ModelSerializer):
- sources = OneToManySourceSerializer(many=True, allow_add_remove=True)
-
- class Meta:
- model = OneToManyTarget
- fields = ('id', 'name', 'sources')
-
- self.Serializer = OneToManyTargetSerializer
-
- target = OneToManyTarget(name='target-1')
- target.save()
- for idx in range(1, 4):
- source = OneToManySource(name='source-%d' % idx, target=target)
- source.save()
-
- def test_one_to_many_retrieve(self):
- queryset = OneToManyTarget.objects.all()
- serializer = self.Serializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
- {'id': 2, 'name': 'source-2'},
- {'id': 3, 'name': 'source-3'}]},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_one_to_many_create(self):
- data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
- {'id': 2, 'name': 'source-2'},
- {'id': 3, 'name': 'source-3'},
- {'id': 4, 'name': 'source-4'}]}
- instance = OneToManyTarget.objects.get(pk=1)
- serializer = self.Serializer(instance, data=data)
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'target-1')
-
- # Ensure source 4 is added, and everything else is as
- # expected.
- queryset = OneToManyTarget.objects.all()
- serializer = self.Serializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
- {'id': 2, 'name': 'source-2'},
- {'id': 3, 'name': 'source-3'},
- {'id': 4, 'name': 'source-4'}]}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_one_to_many_create_with_invalid_data(self):
- data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
- {'id': 2, 'name': 'source-2'},
- {'id': 3, 'name': 'source-3'},
- {'id': 4}]}
- serializer = self.Serializer(data=data)
- self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'sources': [{}, {}, {}, {'name': ['This field is required.']}]})
-
- def test_one_to_many_update(self):
- data = {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'},
- {'id': 2, 'name': 'source-2'},
- {'id': 3, 'name': 'source-3'}]}
- instance = OneToManyTarget.objects.get(pk=1)
- serializer = self.Serializer(instance, data=data)
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'target-1-updated')
-
- # Ensure (target 1, source 1) are updated,
- # and everything else is as expected.
- queryset = OneToManyTarget.objects.all()
- serializer = self.Serializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'},
- {'id': 2, 'name': 'source-2'},
- {'id': 3, 'name': 'source-3'}]}
-
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_one_to_many_delete(self):
- data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
- {'id': 3, 'name': 'source-3'}]}
- instance = OneToManyTarget.objects.get(pk=1)
- serializer = self.Serializer(instance, data=data)
- self.assertTrue(serializer.is_valid())
- serializer.save()
-
- # Ensure source 2 is deleted, and everything else is as
- # expected.
- queryset = OneToManyTarget.objects.all()
- serializer = self.Serializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
- {'id': 3, 'name': 'source-3'}]}
-
- ]
- self.assertEqual(serializer.data, expected)
diff --git a/rest_framework/tests/test_relations_pk.py b/rest_framework/tests/test_relations_pk.py
deleted file mode 100644
index 3815afdd..00000000
--- a/rest_framework/tests/test_relations_pk.py
+++ /dev/null
@@ -1,551 +0,0 @@
-from __future__ import unicode_literals
-from django.db import models
-from django.test import TestCase
-from rest_framework import serializers
-from rest_framework.tests.models import (
- BlogPost, ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource,
- NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource,
-)
-from rest_framework.compat import six
-
-
-# ManyToMany
-class ManyToManyTargetSerializer(serializers.ModelSerializer):
- class Meta:
- model = ManyToManyTarget
- fields = ('id', 'name', 'sources')
-
-
-class ManyToManySourceSerializer(serializers.ModelSerializer):
- class Meta:
- model = ManyToManySource
- fields = ('id', 'name', 'targets')
-
-
-# ForeignKey
-class ForeignKeyTargetSerializer(serializers.ModelSerializer):
- class Meta:
- model = ForeignKeyTarget
- fields = ('id', 'name', 'sources')
-
-
-class ForeignKeySourceSerializer(serializers.ModelSerializer):
- class Meta:
- model = ForeignKeySource
- fields = ('id', 'name', 'target')
-
-
-# Nullable ForeignKey
-class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
- class Meta:
- model = NullableForeignKeySource
- fields = ('id', 'name', 'target')
-
-
-# Nullable OneToOne
-class NullableOneToOneTargetSerializer(serializers.ModelSerializer):
- class Meta:
- model = OneToOneTarget
- fields = ('id', 'name', 'nullable_source')
-
-
-# TODO: Add test that .data cannot be accessed prior to .is_valid
-
-class PKManyToManyTests(TestCase):
- def setUp(self):
- for idx in range(1, 4):
- target = ManyToManyTarget(name='target-%d' % idx)
- target.save()
- source = ManyToManySource(name='source-%d' % idx)
- source.save()
- for target in ManyToManyTarget.objects.all():
- source.targets.add(target)
-
- def test_many_to_many_retrieve(self):
- queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'targets': [1]},
- {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
- {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_reverse_many_to_many_retrieve(self):
- queryset = ManyToManyTarget.objects.all()
- serializer = ManyToManyTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
- {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
- {'id': 3, 'name': 'target-3', 'sources': [3]}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_many_to_many_update(self):
- data = {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]}
- instance = ManyToManySource.objects.get(pk=1)
- serializer = ManyToManySourceSerializer(instance, data=data)
- self.assertTrue(serializer.is_valid())
- serializer.save()
- self.assertEqual(serializer.data, data)
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]},
- {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
- {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_reverse_many_to_many_update(self):
- data = {'id': 1, 'name': 'target-1', 'sources': [1]}
- instance = ManyToManyTarget.objects.get(pk=1)
- serializer = ManyToManyTargetSerializer(instance, data=data)
- self.assertTrue(serializer.is_valid())
- serializer.save()
- self.assertEqual(serializer.data, data)
-
- # Ensure target 1 is updated, and everything else is as expected
- queryset = ManyToManyTarget.objects.all()
- serializer = ManyToManyTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': [1]},
- {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
- {'id': 3, 'name': 'target-3', 'sources': [3]}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_many_to_many_create(self):
- data = {'id': 4, 'name': 'source-4', 'targets': [1, 3]}
- serializer = ManyToManySourceSerializer(data=data)
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'source-4')
-
- # Ensure source 4 is added, and everything else is as expected
- queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset, many=True)
- self.assertFalse(serializer.fields['targets'].read_only)
- expected = [
- {'id': 1, 'name': 'source-1', 'targets': [1]},
- {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
- {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]},
- {'id': 4, 'name': 'source-4', 'targets': [1, 3]},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_reverse_many_to_many_create(self):
- data = {'id': 4, 'name': 'target-4', 'sources': [1, 3]}
- serializer = ManyToManyTargetSerializer(data=data)
- self.assertFalse(serializer.fields['sources'].read_only)
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'target-4')
-
- # Ensure target 4 is added, and everything else is as expected
- queryset = ManyToManyTarget.objects.all()
- serializer = ManyToManyTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
- {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
- {'id': 3, 'name': 'target-3', 'sources': [3]},
- {'id': 4, 'name': 'target-4', 'sources': [1, 3]}
- ]
- self.assertEqual(serializer.data, expected)
-
-
-class PKForeignKeyTests(TestCase):
- def setUp(self):
- target = ForeignKeyTarget(name='target-1')
- target.save()
- new_target = ForeignKeyTarget(name='target-2')
- new_target.save()
- for idx in range(1, 4):
- source = ForeignKeySource(name='source-%d' % idx, target=target)
- source.save()
-
- def test_foreign_key_retrieve(self):
- queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 1},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': 1}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_reverse_foreign_key_retrieve(self):
- queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
- {'id': 2, 'name': 'target-2', 'sources': []},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_update(self):
- data = {'id': 1, 'name': 'source-1', 'target': 2}
- instance = ForeignKeySource.objects.get(pk=1)
- serializer = ForeignKeySourceSerializer(instance, data=data)
- self.assertTrue(serializer.is_valid())
- self.assertEqual(serializer.data, data)
- serializer.save()
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 2},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': 1}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_update_incorrect_type(self):
- data = {'id': 1, 'name': 'source-1', 'target': 'foo'}
- instance = ForeignKeySource.objects.get(pk=1)
- serializer = ForeignKeySourceSerializer(instance, data=data)
- self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected pk value, received %s.' % six.text_type.__name__]})
-
- def test_reverse_foreign_key_update(self):
- data = {'id': 2, 'name': 'target-2', 'sources': [1, 3]}
- instance = ForeignKeyTarget.objects.get(pk=2)
- serializer = ForeignKeyTargetSerializer(instance, data=data)
- self.assertTrue(serializer.is_valid())
- # We shouldn't have saved anything to the db yet since save
- # hasn't been called.
- queryset = ForeignKeyTarget.objects.all()
- new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
- {'id': 2, 'name': 'target-2', 'sources': []},
- ]
- self.assertEqual(new_serializer.data, expected)
-
- serializer.save()
- self.assertEqual(serializer.data, data)
-
- # Ensure target 2 is update, and everything else is as expected
- queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': [2]},
- {'id': 2, 'name': 'target-2', 'sources': [1, 3]},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_create(self):
- data = {'id': 4, 'name': 'source-4', 'target': 2}
- serializer = ForeignKeySourceSerializer(data=data)
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'source-4')
-
- # Ensure source 4 is added, and everything else is as expected
- queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 1},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': 1},
- {'id': 4, 'name': 'source-4', 'target': 2},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_reverse_foreign_key_create(self):
- data = {'id': 3, 'name': 'target-3', 'sources': [1, 3]}
- serializer = ForeignKeyTargetSerializer(data=data)
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'target-3')
-
- # Ensure target 3 is added, and everything else is as expected
- queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': [2]},
- {'id': 2, 'name': 'target-2', 'sources': []},
- {'id': 3, 'name': 'target-3', 'sources': [1, 3]},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_update_with_invalid_null(self):
- data = {'id': 1, 'name': 'source-1', 'target': None}
- instance = ForeignKeySource.objects.get(pk=1)
- serializer = ForeignKeySourceSerializer(instance, data=data)
- self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'target': ['This field is required.']})
-
- def test_foreign_key_with_empty(self):
- """
- Regression test for #1072
-
- https://github.com/tomchristie/django-rest-framework/issues/1072
- """
- serializer = NullableForeignKeySourceSerializer()
- self.assertEqual(serializer.data['target'], None)
-
-
-class PKNullableForeignKeyTests(TestCase):
- def setUp(self):
- target = ForeignKeyTarget(name='target-1')
- target.save()
- for idx in range(1, 4):
- if idx == 3:
- target = None
- source = NullableForeignKeySource(name='source-%d' % idx, target=target)
- source.save()
-
- def test_foreign_key_retrieve_with_null(self):
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 1},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': None},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_create_with_valid_null(self):
- data = {'id': 4, 'name': 'source-4', 'target': None}
- serializer = NullableForeignKeySourceSerializer(data=data)
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'source-4')
-
- # Ensure source 4 is created, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 1},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': None},
- {'id': 4, 'name': 'source-4', 'target': None}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_create_with_valid_emptystring(self):
- """
- The emptystring should be interpreted as null in the context
- of relationships.
- """
- data = {'id': 4, 'name': 'source-4', 'target': ''}
- expected_data = {'id': 4, 'name': 'source-4', 'target': None}
- serializer = NullableForeignKeySourceSerializer(data=data)
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, expected_data)
- self.assertEqual(obj.name, 'source-4')
-
- # Ensure source 4 is created, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 1},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': None},
- {'id': 4, 'name': 'source-4', 'target': None}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_update_with_valid_null(self):
- data = {'id': 1, 'name': 'source-1', 'target': None}
- instance = NullableForeignKeySource.objects.get(pk=1)
- serializer = NullableForeignKeySourceSerializer(instance, data=data)
- self.assertTrue(serializer.is_valid())
- self.assertEqual(serializer.data, data)
- serializer.save()
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': None},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': None}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_update_with_valid_emptystring(self):
- """
- The emptystring should be interpreted as null in the context
- of relationships.
- """
- data = {'id': 1, 'name': 'source-1', 'target': ''}
- expected_data = {'id': 1, 'name': 'source-1', 'target': None}
- instance = NullableForeignKeySource.objects.get(pk=1)
- serializer = NullableForeignKeySourceSerializer(instance, data=data)
- self.assertTrue(serializer.is_valid())
- self.assertEqual(serializer.data, expected_data)
- serializer.save()
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': None},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': None}
- ]
- self.assertEqual(serializer.data, expected)
-
- # reverse foreign keys MUST be read_only
- # In the general case they do not provide .remove() or .clear()
- # and cannot be arbitrarily set.
-
- # def test_reverse_foreign_key_update(self):
- # data = {'id': 1, 'name': 'target-1', 'sources': [1]}
- # instance = ForeignKeyTarget.objects.get(pk=1)
- # serializer = ForeignKeyTargetSerializer(instance, data=data)
- # self.assertTrue(serializer.is_valid())
- # self.assertEqual(serializer.data, data)
- # serializer.save()
-
- # # Ensure target 1 is updated, and everything else is as expected
- # queryset = ForeignKeyTarget.objects.all()
- # serializer = ForeignKeyTargetSerializer(queryset, many=True)
- # expected = [
- # {'id': 1, 'name': 'target-1', 'sources': [1]},
- # {'id': 2, 'name': 'target-2', 'sources': []},
- # ]
- # self.assertEqual(serializer.data, expected)
-
-
-class PKNullableOneToOneTests(TestCase):
- def setUp(self):
- target = OneToOneTarget(name='target-1')
- target.save()
- new_target = OneToOneTarget(name='target-2')
- new_target.save()
- source = NullableOneToOneSource(name='source-1', target=new_target)
- source.save()
-
- def test_reverse_foreign_key_retrieve_with_null(self):
- queryset = OneToOneTarget.objects.all()
- serializer = NullableOneToOneTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'nullable_source': None},
- {'id': 2, 'name': 'target-2', 'nullable_source': 1},
- ]
- self.assertEqual(serializer.data, expected)
-
-
-# The below models and tests ensure that serializer fields corresponding
-# to a ManyToManyField field with a user-specified ``through`` model are
-# set to read only
-
-
-class ManyToManyThroughTarget(models.Model):
- name = models.CharField(max_length=100)
-
-
-class ManyToManyThrough(models.Model):
- source = models.ForeignKey('ManyToManyThroughSource')
- target = models.ForeignKey(ManyToManyThroughTarget)
-
-
-class ManyToManyThroughSource(models.Model):
- name = models.CharField(max_length=100)
- targets = models.ManyToManyField(ManyToManyThroughTarget,
- related_name='sources',
- through='ManyToManyThrough')
-
-
-class ManyToManyThroughTargetSerializer(serializers.ModelSerializer):
- class Meta:
- model = ManyToManyThroughTarget
- fields = ('id', 'name', 'sources')
-
-
-class ManyToManyThroughSourceSerializer(serializers.ModelSerializer):
- class Meta:
- model = ManyToManyThroughSource
- fields = ('id', 'name', 'targets')
-
-
-class PKManyToManyThroughTests(TestCase):
- def setUp(self):
- self.source = ManyToManyThroughSource.objects.create(
- name='through-source-1')
- self.target = ManyToManyThroughTarget.objects.create(
- name='through-target-1')
-
- def test_many_to_many_create(self):
- data = {'id': 2, 'name': 'source-2', 'targets': [self.target.pk]}
- serializer = ManyToManyThroughSourceSerializer(data=data)
- self.assertTrue(serializer.fields['targets'].read_only)
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(obj.name, 'source-2')
- self.assertEqual(obj.targets.count(), 0)
-
- def test_many_to_many_reverse_create(self):
- data = {'id': 2, 'name': 'target-2', 'sources': [self.source.pk]}
- serializer = ManyToManyThroughTargetSerializer(data=data)
- self.assertTrue(serializer.fields['sources'].read_only)
- self.assertTrue(serializer.is_valid())
- serializer.save()
- obj = serializer.save()
- self.assertEqual(obj.name, 'target-2')
- self.assertEqual(obj.sources.count(), 0)
-
-
-# Regression tests for #694 (`source` attribute on related fields)
-
-
-class PrimaryKeyRelatedFieldSourceTests(TestCase):
- def test_related_manager_source(self):
- """
- Relational fields should be able to use manager-returning methods as their source.
- """
- BlogPost.objects.create(title='blah')
- field = serializers.PrimaryKeyRelatedField(many=True, source='get_blogposts_manager')
-
- class ClassWithManagerMethod(object):
- def get_blogposts_manager(self):
- return BlogPost.objects
-
- obj = ClassWithManagerMethod()
- value = field.field_to_native(obj, 'field_name')
- self.assertEqual(value, [1])
-
- def test_related_queryset_source(self):
- """
- Relational fields should be able to use queryset-returning methods as their source.
- """
- BlogPost.objects.create(title='blah')
- field = serializers.PrimaryKeyRelatedField(many=True, source='get_blogposts_queryset')
-
- class ClassWithQuerysetMethod(object):
- def get_blogposts_queryset(self):
- return BlogPost.objects.all()
-
- obj = ClassWithQuerysetMethod()
- value = field.field_to_native(obj, 'field_name')
- self.assertEqual(value, [1])
-
- def test_dotted_source(self):
- """
- Source argument should support dotted.source notation.
- """
- BlogPost.objects.create(title='blah')
- field = serializers.PrimaryKeyRelatedField(many=True, source='a.b.c')
-
- class ClassWithQuerysetMethod(object):
- a = {
- 'b': {
- 'c': BlogPost.objects.all()
- }
- }
-
- obj = ClassWithQuerysetMethod()
- value = field.field_to_native(obj, 'field_name')
- self.assertEqual(value, [1])
diff --git a/rest_framework/tests/test_relations_slug.py b/rest_framework/tests/test_relations_slug.py
deleted file mode 100644
index 435c821c..00000000
--- a/rest_framework/tests/test_relations_slug.py
+++ /dev/null
@@ -1,257 +0,0 @@
-from django.test import TestCase
-from rest_framework import serializers
-from rest_framework.tests.models import NullableForeignKeySource, ForeignKeySource, ForeignKeyTarget
-
-
-class ForeignKeyTargetSerializer(serializers.ModelSerializer):
- sources = serializers.SlugRelatedField(many=True, slug_field='name')
-
- class Meta:
- model = ForeignKeyTarget
-
-
-class ForeignKeySourceSerializer(serializers.ModelSerializer):
- target = serializers.SlugRelatedField(slug_field='name')
-
- class Meta:
- model = ForeignKeySource
-
-
-class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
- target = serializers.SlugRelatedField(slug_field='name', required=False)
-
- class Meta:
- model = NullableForeignKeySource
-
-
-# TODO: M2M Tests, FKTests (Non-nullable), One2One
-class SlugForeignKeyTests(TestCase):
- def setUp(self):
- target = ForeignKeyTarget(name='target-1')
- target.save()
- new_target = ForeignKeyTarget(name='target-2')
- new_target.save()
- for idx in range(1, 4):
- source = ForeignKeySource(name='source-%d' % idx, target=target)
- source.save()
-
- def test_foreign_key_retrieve(self):
- queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 'target-1'},
- {'id': 2, 'name': 'source-2', 'target': 'target-1'},
- {'id': 3, 'name': 'source-3', 'target': 'target-1'}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_reverse_foreign_key_retrieve(self):
- queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
- {'id': 2, 'name': 'target-2', 'sources': []},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_update(self):
- data = {'id': 1, 'name': 'source-1', 'target': 'target-2'}
- instance = ForeignKeySource.objects.get(pk=1)
- serializer = ForeignKeySourceSerializer(instance, data=data)
- self.assertTrue(serializer.is_valid())
- self.assertEqual(serializer.data, data)
- serializer.save()
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 'target-2'},
- {'id': 2, 'name': 'source-2', 'target': 'target-1'},
- {'id': 3, 'name': 'source-3', 'target': 'target-1'}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_update_incorrect_type(self):
- data = {'id': 1, 'name': 'source-1', 'target': 123}
- instance = ForeignKeySource.objects.get(pk=1)
- serializer = ForeignKeySourceSerializer(instance, data=data)
- self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'target': ['Object with name=123 does not exist.']})
-
- def test_reverse_foreign_key_update(self):
- data = {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']}
- instance = ForeignKeyTarget.objects.get(pk=2)
- serializer = ForeignKeyTargetSerializer(instance, data=data)
- self.assertTrue(serializer.is_valid())
- # We shouldn't have saved anything to the db yet since save
- # hasn't been called.
- queryset = ForeignKeyTarget.objects.all()
- new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
- {'id': 2, 'name': 'target-2', 'sources': []},
- ]
- self.assertEqual(new_serializer.data, expected)
-
- serializer.save()
- self.assertEqual(serializer.data, data)
-
- # Ensure target 2 is update, and everything else is as expected
- queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': ['source-2']},
- {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_create(self):
- data = {'id': 4, 'name': 'source-4', 'target': 'target-2'}
- serializer = ForeignKeySourceSerializer(data=data)
- serializer.is_valid()
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'source-4')
-
- # Ensure source 4 is added, and everything else is as expected
- queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 'target-1'},
- {'id': 2, 'name': 'source-2', 'target': 'target-1'},
- {'id': 3, 'name': 'source-3', 'target': 'target-1'},
- {'id': 4, 'name': 'source-4', 'target': 'target-2'},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_reverse_foreign_key_create(self):
- data = {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']}
- serializer = ForeignKeyTargetSerializer(data=data)
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'target-3')
-
- # Ensure target 3 is added, and everything else is as expected
- queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'target-1', 'sources': ['source-2']},
- {'id': 2, 'name': 'target-2', 'sources': []},
- {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_update_with_invalid_null(self):
- data = {'id': 1, 'name': 'source-1', 'target': None}
- instance = ForeignKeySource.objects.get(pk=1)
- serializer = ForeignKeySourceSerializer(instance, data=data)
- self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'target': ['This field is required.']})
-
-
-class SlugNullableForeignKeyTests(TestCase):
- def setUp(self):
- target = ForeignKeyTarget(name='target-1')
- target.save()
- for idx in range(1, 4):
- if idx == 3:
- target = None
- source = NullableForeignKeySource(name='source-%d' % idx, target=target)
- source.save()
-
- def test_foreign_key_retrieve_with_null(self):
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 'target-1'},
- {'id': 2, 'name': 'source-2', 'target': 'target-1'},
- {'id': 3, 'name': 'source-3', 'target': None},
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_create_with_valid_null(self):
- data = {'id': 4, 'name': 'source-4', 'target': None}
- serializer = NullableForeignKeySourceSerializer(data=data)
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, data)
- self.assertEqual(obj.name, 'source-4')
-
- # Ensure source 4 is created, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 'target-1'},
- {'id': 2, 'name': 'source-2', 'target': 'target-1'},
- {'id': 3, 'name': 'source-3', 'target': None},
- {'id': 4, 'name': 'source-4', 'target': None}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_create_with_valid_emptystring(self):
- """
- The emptystring should be interpreted as null in the context
- of relationships.
- """
- data = {'id': 4, 'name': 'source-4', 'target': ''}
- expected_data = {'id': 4, 'name': 'source-4', 'target': None}
- serializer = NullableForeignKeySourceSerializer(data=data)
- self.assertTrue(serializer.is_valid())
- obj = serializer.save()
- self.assertEqual(serializer.data, expected_data)
- self.assertEqual(obj.name, 'source-4')
-
- # Ensure source 4 is created, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': 'target-1'},
- {'id': 2, 'name': 'source-2', 'target': 'target-1'},
- {'id': 3, 'name': 'source-3', 'target': None},
- {'id': 4, 'name': 'source-4', 'target': None}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_update_with_valid_null(self):
- data = {'id': 1, 'name': 'source-1', 'target': None}
- instance = NullableForeignKeySource.objects.get(pk=1)
- serializer = NullableForeignKeySourceSerializer(instance, data=data)
- self.assertTrue(serializer.is_valid())
- self.assertEqual(serializer.data, data)
- serializer.save()
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': None},
- {'id': 2, 'name': 'source-2', 'target': 'target-1'},
- {'id': 3, 'name': 'source-3', 'target': None}
- ]
- self.assertEqual(serializer.data, expected)
-
- def test_foreign_key_update_with_valid_emptystring(self):
- """
- The emptystring should be interpreted as null in the context
- of relationships.
- """
- data = {'id': 1, 'name': 'source-1', 'target': ''}
- expected_data = {'id': 1, 'name': 'source-1', 'target': None}
- instance = NullableForeignKeySource.objects.get(pk=1)
- serializer = NullableForeignKeySourceSerializer(instance, data=data)
- self.assertTrue(serializer.is_valid())
- self.assertEqual(serializer.data, expected_data)
- serializer.save()
-
- # Ensure source 1 is updated, and everything else is as expected
- queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True)
- expected = [
- {'id': 1, 'name': 'source-1', 'target': None},
- {'id': 2, 'name': 'source-2', 'target': 'target-1'},
- {'id': 3, 'name': 'source-3', 'target': None}
- ]
- self.assertEqual(serializer.data, expected)
diff --git a/rest_framework/tests/test_renderers.py b/rest_framework/tests/test_renderers.py
deleted file mode 100644
index 0f3432c9..00000000
--- a/rest_framework/tests/test_renderers.py
+++ /dev/null
@@ -1,651 +0,0 @@
-# -*- coding: utf-8 -*-
-from __future__ import unicode_literals
-
-from decimal import Decimal
-from django.core.cache import cache
-from django.db import models
-from django.test import TestCase
-from django.utils import unittest
-from django.utils.translation import ugettext_lazy as _
-from rest_framework import status, permissions
-from rest_framework.compat import yaml, etree, patterns, url, include, six, StringIO
-from rest_framework.response import Response
-from rest_framework.views import APIView
-from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \
- XMLRenderer, JSONPRenderer, BrowsableAPIRenderer, UnicodeJSONRenderer
-from rest_framework.parsers import YAMLParser, XMLParser
-from rest_framework.settings import api_settings
-from rest_framework.test import APIRequestFactory
-from collections import MutableMapping
-import datetime
-import json
-import pickle
-import re
-
-
-DUMMYSTATUS = status.HTTP_200_OK
-DUMMYCONTENT = 'dummycontent'
-
-RENDERER_A_SERIALIZER = lambda x: ('Renderer A: %s' % x).encode('ascii')
-RENDERER_B_SERIALIZER = lambda x: ('Renderer B: %s' % x).encode('ascii')
-
-
-expected_results = [
- ((elem for elem in [1, 2, 3]), JSONRenderer, b'[1, 2, 3]') # Generator
-]
-
-
-class DummyTestModel(models.Model):
- name = models.CharField(max_length=42, default='')
-
-
-class BasicRendererTests(TestCase):
- def test_expected_results(self):
- for value, renderer_cls, expected in expected_results:
- output = renderer_cls().render(value)
- self.assertEqual(output, expected)
-
-
-class RendererA(BaseRenderer):
- media_type = 'mock/renderera'
- format = "formata"
-
- def render(self, data, media_type=None, renderer_context=None):
- return RENDERER_A_SERIALIZER(data)
-
-
-class RendererB(BaseRenderer):
- media_type = 'mock/rendererb'
- format = "formatb"
-
- def render(self, data, media_type=None, renderer_context=None):
- return RENDERER_B_SERIALIZER(data)
-
-
-class MockView(APIView):
- renderer_classes = (RendererA, RendererB)
-
- def get(self, request, **kwargs):
- response = Response(DUMMYCONTENT, status=DUMMYSTATUS)
- return response
-
-
-class MockGETView(APIView):
- def get(self, request, **kwargs):
- return Response({'foo': ['bar', 'baz']})
-
-
-
-class MockPOSTView(APIView):
- def post(self, request, **kwargs):
- return Response({'foo': request.DATA})
-
-
-class EmptyGETView(APIView):
- renderer_classes = (JSONRenderer,)
-
- def get(self, request, **kwargs):
- return Response(status=status.HTTP_204_NO_CONTENT)
-
-
-class HTMLView(APIView):
- renderer_classes = (BrowsableAPIRenderer, )
-
- def get(self, request, **kwargs):
- return Response('text')
-
-
-class HTMLView1(APIView):
- renderer_classes = (BrowsableAPIRenderer, JSONRenderer)
-
- def get(self, request, **kwargs):
- return Response('text')
-
-urlpatterns = patterns('',
- url(r'^.*\.(?P.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
- url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
- url(r'^cache$', MockGETView.as_view()),
- url(r'^jsonp/jsonrenderer$', MockGETView.as_view(renderer_classes=[JSONRenderer, JSONPRenderer])),
- url(r'^jsonp/nojsonrenderer$', MockGETView.as_view(renderer_classes=[JSONPRenderer])),
- url(r'^parseerror$', MockPOSTView.as_view(renderer_classes=[JSONRenderer, BrowsableAPIRenderer])),
- url(r'^html$', HTMLView.as_view()),
- url(r'^html1$', HTMLView1.as_view()),
- url(r'^empty$', EmptyGETView.as_view()),
- url(r'^api', include('rest_framework.urls', namespace='rest_framework'))
-)
-
-
-class POSTDeniedPermission(permissions.BasePermission):
- def has_permission(self, request, view):
- return request.method != 'POST'
-
-
-class POSTDeniedView(APIView):
- renderer_classes = (BrowsableAPIRenderer,)
- permission_classes = (POSTDeniedPermission,)
-
- def get(self, request):
- return Response()
-
- def post(self, request):
- return Response()
-
- def put(self, request):
- return Response()
-
- def patch(self, request):
- return Response()
-
-
-class DocumentingRendererTests(TestCase):
- def test_only_permitted_forms_are_displayed(self):
- view = POSTDeniedView.as_view()
- request = APIRequestFactory().get('/')
- response = view(request).render()
- self.assertNotContains(response, '>POST<')
- self.assertContains(response, '>PUT<')
- self.assertContains(response, '>PATCH<')
-
-
-class RendererEndToEndTests(TestCase):
- """
- End-to-end testing of renderers using an RendererMixin on a generic view.
- """
-
- urls = 'rest_framework.tests.test_renderers'
-
- def test_default_renderer_serializes_content(self):
- """If the Accept header is not set the default renderer should serialize the response."""
- resp = self.client.get('/')
- self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
-
- def test_head_method_serializes_no_content(self):
- """No response must be included in HEAD requests."""
- resp = self.client.head('/')
- self.assertEqual(resp.status_code, DUMMYSTATUS)
- self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, six.b(''))
-
- def test_default_renderer_serializes_content_on_accept_any(self):
- """If the Accept header is set to */* the default renderer should serialize the response."""
- resp = self.client.get('/', HTTP_ACCEPT='*/*')
- self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
-
- def test_specified_renderer_serializes_content_default_case(self):
- """If the Accept header is set the specified renderer should serialize the response.
- (In this case we check that works for the default renderer)"""
- resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
- self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
-
- def test_specified_renderer_serializes_content_non_default_case(self):
- """If the Accept header is set the specified renderer should serialize the response.
- (In this case we check that works for a non-default renderer)"""
- resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
-
- def test_specified_renderer_serializes_content_on_accept_query(self):
- """The '_accept' query string should behave in the same way as the Accept header."""
- param = '?%s=%s' % (
- api_settings.URL_ACCEPT_OVERRIDE,
- RendererB.media_type
- )
- resp = self.client.get('/' + param)
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
-
- def test_unsatisfiable_accept_header_on_request_returns_406_status(self):
- """If the Accept header is unsatisfiable we should return a 406 Not Acceptable response."""
- resp = self.client.get('/', HTTP_ACCEPT='foo/bar')
- self.assertEqual(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE)
-
- def test_specified_renderer_serializes_content_on_format_query(self):
- """If a 'format' query is specified, the renderer with the matching
- format attribute should serialize the response."""
- param = '?%s=%s' % (
- api_settings.URL_FORMAT_OVERRIDE,
- RendererB.format
- )
- resp = self.client.get('/' + param)
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
-
- def test_specified_renderer_serializes_content_on_format_kwargs(self):
- """If a 'format' keyword arg is specified, the renderer with the matching
- format attribute should serialize the response."""
- resp = self.client.get('/something.formatb')
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
-
- def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
- """If both a 'format' query and a matching Accept header specified,
- the renderer with the matching format attribute should serialize the response."""
- param = '?%s=%s' % (
- api_settings.URL_FORMAT_OVERRIDE,
- RendererB.format
- )
- resp = self.client.get('/' + param,
- HTTP_ACCEPT=RendererB.media_type)
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
-
- def test_parse_error_renderers_browsable_api(self):
- """Invalid data should still render the browsable API correctly."""
- resp = self.client.post('/parseerror', data='foobar', content_type='application/json', HTTP_ACCEPT='text/html')
- self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
- self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST)
-
- def test_204_no_content_responses_have_no_content_type_set(self):
- """
- Regression test for #1196
-
- https://github.com/tomchristie/django-rest-framework/issues/1196
- """
- resp = self.client.get('/empty')
- self.assertEqual(resp.get('Content-Type', None), None)
- self.assertEqual(resp.status_code, status.HTTP_204_NO_CONTENT)
-
- def test_contains_headers_of_api_response(self):
- """
- Issue #1437
-
- Test we display the headers of the API response and not those from the
- HTML response
- """
- resp = self.client.get('/html1')
- self.assertContains(resp, '>GET, HEAD, OPTIONS<')
- self.assertContains(resp, '>application/json<')
- self.assertNotContains(resp, '>text/html; charset=utf-8<')
-
-
-_flat_repr = '{"foo": ["bar", "baz"]}'
-_indented_repr = '{\n "foo": [\n "bar",\n "baz"\n ]\n}'
-
-
-def strip_trailing_whitespace(content):
- """
- Seems to be some inconsistencies re. trailing whitespace with
- different versions of the json lib.
- """
- return re.sub(' +\n', '\n', content)
-
-
-class JSONRendererTests(TestCase):
- """
- Tests specific to the JSON Renderer
- """
-
- def test_render_lazy_strings(self):
- """
- JSONRenderer should deal with lazy translated strings.
- """
- ret = JSONRenderer().render(_('test'))
- self.assertEqual(ret, b'"test"')
-
- def test_render_queryset_values(self):
- o = DummyTestModel.objects.create(name='dummy')
- qs = DummyTestModel.objects.values('id', 'name')
- ret = JSONRenderer().render(qs)
- data = json.loads(ret.decode('utf-8'))
- self.assertEquals(data, [{'id': o.id, 'name': o.name}])
-
- def test_render_queryset_values_list(self):
- o = DummyTestModel.objects.create(name='dummy')
- qs = DummyTestModel.objects.values_list('id', 'name')
- ret = JSONRenderer().render(qs)
- data = json.loads(ret.decode('utf-8'))
- self.assertEquals(data, [[o.id, o.name]])
-
- def test_render_dict_abc_obj(self):
- class Dict(MutableMapping):
- def __init__(self):
- self._dict = dict()
- def __getitem__(self, key):
- return self._dict.__getitem__(key)
- def __setitem__(self, key, value):
- return self._dict.__setitem__(key, value)
- def __delitem__(self, key):
- return self._dict.__delitem__(key)
- def __iter__(self):
- return self._dict.__iter__()
- def __len__(self):
- return self._dict.__len__()
- def keys(self):
- return self._dict.keys()
-
- x = Dict()
- x['key'] = 'string value'
- x[2] = 3
- ret = JSONRenderer().render(x)
- data = json.loads(ret.decode('utf-8'))
- self.assertEquals(data, {'key': 'string value', '2': 3})
-
- def test_render_obj_with_getitem(self):
- class DictLike(object):
- def __init__(self):
- self._dict = {}
- def set(self, value):
- self._dict = dict(value)
- def __getitem__(self, key):
- return self._dict[key]
-
- x = DictLike()
- x.set({'a': 1, 'b': 'string'})
- with self.assertRaises(TypeError):
- JSONRenderer().render(x)
-
- def test_without_content_type_args(self):
- """
- Test basic JSON rendering.
- """
- obj = {'foo': ['bar', 'baz']}
- renderer = JSONRenderer()
- content = renderer.render(obj, 'application/json')
- # Fix failing test case which depends on version of JSON library.
- self.assertEqual(content.decode('utf-8'), _flat_repr)
-
- def test_with_content_type_args(self):
- """
- Test JSON rendering with additional content type arguments supplied.
- """
- obj = {'foo': ['bar', 'baz']}
- renderer = JSONRenderer()
- content = renderer.render(obj, 'application/json; indent=2')
- self.assertEqual(strip_trailing_whitespace(content.decode('utf-8')), _indented_repr)
-
- def test_check_ascii(self):
- obj = {'countries': ['United Kingdom', 'France', 'España']}
- renderer = JSONRenderer()
- content = renderer.render(obj, 'application/json')
- self.assertEqual(content, '{"countries": ["United Kingdom", "France", "Espa\\u00f1a"]}'.encode('utf-8'))
-
-
-class UnicodeJSONRendererTests(TestCase):
- """
- Tests specific for the Unicode JSON Renderer
- """
- def test_proper_encoding(self):
- obj = {'countries': ['United Kingdom', 'France', 'España']}
- renderer = UnicodeJSONRenderer()
- content = renderer.render(obj, 'application/json')
- self.assertEqual(content, '{"countries": ["United Kingdom", "France", "España"]}'.encode('utf-8'))
-
-
-class JSONPRendererTests(TestCase):
- """
- Tests specific to the JSONP Renderer
- """
-
- urls = 'rest_framework.tests.test_renderers'
-
- def test_without_callback_with_json_renderer(self):
- """
- Test JSONP rendering with View JSON Renderer.
- """
- resp = self.client.get('/jsonp/jsonrenderer',
- HTTP_ACCEPT='application/javascript')
- self.assertEqual(resp.status_code, status.HTTP_200_OK)
- self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8')
- self.assertEqual(resp.content,
- ('callback(%s);' % _flat_repr).encode('ascii'))
-
- def test_without_callback_without_json_renderer(self):
- """
- Test JSONP rendering without View JSON Renderer.
- """
- resp = self.client.get('/jsonp/nojsonrenderer',
- HTTP_ACCEPT='application/javascript')
- self.assertEqual(resp.status_code, status.HTTP_200_OK)
- self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8')
- self.assertEqual(resp.content,
- ('callback(%s);' % _flat_repr).encode('ascii'))
-
- def test_with_callback(self):
- """
- Test JSONP rendering with callback function name.
- """
- callback_func = 'myjsonpcallback'
- resp = self.client.get('/jsonp/nojsonrenderer?callback=' + callback_func,
- HTTP_ACCEPT='application/javascript')
- self.assertEqual(resp.status_code, status.HTTP_200_OK)
- self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8')
- self.assertEqual(resp.content,
- ('%s(%s);' % (callback_func, _flat_repr)).encode('ascii'))
-
-
-if yaml:
- _yaml_repr = 'foo: [bar, baz]\n'
-
- class YAMLRendererTests(TestCase):
- """
- Tests specific to the YAML Renderer
- """
-
- def test_render(self):
- """
- Test basic YAML rendering.
- """
- obj = {'foo': ['bar', 'baz']}
- renderer = YAMLRenderer()
- content = renderer.render(obj, 'application/yaml')
- self.assertEqual(content, _yaml_repr)
-
- def test_render_and_parse(self):
- """
- Test rendering and then parsing returns the original object.
- IE obj -> render -> parse -> obj.
- """
- obj = {'foo': ['bar', 'baz']}
-
- renderer = YAMLRenderer()
- parser = YAMLParser()
-
- content = renderer.render(obj, 'application/yaml')
- data = parser.parse(StringIO(content))
- self.assertEqual(obj, data)
-
- def test_render_decimal(self):
- """
- Test YAML decimal rendering.
- """
- renderer = YAMLRenderer()
- content = renderer.render({'field': Decimal('111.2')}, 'application/yaml')
- self.assertYAMLContains(content, "field: '111.2'")
-
- def assertYAMLContains(self, content, string):
- self.assertTrue(string in content, '%r not in %r' % (string, content))
-
-
-class XMLRendererTestCase(TestCase):
- """
- Tests specific to the XML Renderer
- """
-
- _complex_data = {
- "creation_date": datetime.datetime(2011, 12, 25, 12, 45, 00),
- "name": "name",
- "sub_data_list": [
- {
- "sub_id": 1,
- "sub_name": "first"
- },
- {
- "sub_id": 2,
- "sub_name": "second"
- }
- ]
- }
-
- def test_render_string(self):
- """
- Test XML rendering.
- """
- renderer = XMLRenderer()
- content = renderer.render({'field': 'astring'}, 'application/xml')
- self.assertXMLContains(content, 'astring')
-
- def test_render_integer(self):
- """
- Test XML rendering.
- """
- renderer = XMLRenderer()
- content = renderer.render({'field': 111}, 'application/xml')
- self.assertXMLContains(content, '111')
-
- def test_render_datetime(self):
- """
- Test XML rendering.
- """
- renderer = XMLRenderer()
- content = renderer.render({
- 'field': datetime.datetime(2011, 12, 25, 12, 45, 00)
- }, 'application/xml')
- self.assertXMLContains(content, '2011-12-25 12:45:00')
-
- def test_render_float(self):
- """
- Test XML rendering.
- """
- renderer = XMLRenderer()
- content = renderer.render({'field': 123.4}, 'application/xml')
- self.assertXMLContains(content, '123.4')
-
- def test_render_decimal(self):
- """
- Test XML rendering.
- """
- renderer = XMLRenderer()
- content = renderer.render({'field': Decimal('111.2')}, 'application/xml')
- self.assertXMLContains(content, '111.2')
-
- def test_render_none(self):
- """
- Test XML rendering.
- """
- renderer = XMLRenderer()
- content = renderer.render({'field': None}, 'application/xml')
- self.assertXMLContains(content, '')
-
- def test_render_complex_data(self):
- """
- Test XML rendering.
- """
- renderer = XMLRenderer()
- content = renderer.render(self._complex_data, 'application/xml')
- self.assertXMLContains(content, 'first')
- self.assertXMLContains(content, 'second')
-
- @unittest.skipUnless(etree, 'defusedxml not installed')
- def test_render_and_parse_complex_data(self):
- """
- Test XML rendering.
- """
- renderer = XMLRenderer()
- content = StringIO(renderer.render(self._complex_data, 'application/xml'))
-
- parser = XMLParser()
- complex_data_out = parser.parse(content)
- error_msg = "complex data differs!IN:\n %s \n\n OUT:\n %s" % (repr(self._complex_data), repr(complex_data_out))
- self.assertEqual(self._complex_data, complex_data_out, error_msg)
-
- def assertXMLContains(self, xml, string):
- self.assertTrue(xml.startswith('\n'))
- self.assertTrue(xml.endswith(''))
- self.assertTrue(string in xml, '%r not in %r' % (string, xml))
-
-
-# Tests for caching issue, #346
-class CacheRenderTest(TestCase):
- """
- Tests specific to caching responses
- """
-
- urls = 'rest_framework.tests.test_renderers'
-
- cache_key = 'just_a_cache_key'
-
- @classmethod
- def _get_pickling_errors(cls, obj, seen=None):
- """ Return any errors that would be raised if `obj' is pickled
- Courtesy of koffie @ http://stackoverflow.com/a/7218986/109897
- """
- if seen == None:
- seen = []
- try:
- state = obj.__getstate__()
- except AttributeError:
- return
- if state == None:
- return
- if isinstance(state, tuple):
- if not isinstance(state[0], dict):
- state = state[1]
- else:
- state = state[0].update(state[1])
- result = {}
- for i in state:
- try:
- pickle.dumps(state[i], protocol=2)
- except pickle.PicklingError:
- if not state[i] in seen:
- seen.append(state[i])
- result[i] = cls._get_pickling_errors(state[i], seen)
- return result
-
- def http_resp(self, http_method, url):
- """
- Simple wrapper for Client http requests
- Removes the `client' and `request' attributes from as they are
- added by django.test.client.Client and not part of caching
- responses outside of tests.
- """
- method = getattr(self.client, http_method)
- resp = method(url)
- del resp.client, resp.request
- return resp
-
- def test_obj_pickling(self):
- """
- Test that responses are properly pickled
- """
- resp = self.http_resp('get', '/cache')
-
- # Make sure that no pickling errors occurred
- self.assertEqual(self._get_pickling_errors(resp), {})
-
- # Unfortunately LocMem backend doesn't raise PickleErrors but returns
- # None instead.
- cache.set(self.cache_key, resp)
- self.assertTrue(cache.get(self.cache_key) is not None)
-
- def test_head_caching(self):
- """
- Test caching of HEAD requests
- """
- resp = self.http_resp('head', '/cache')
- cache.set(self.cache_key, resp)
-
- cached_resp = cache.get(self.cache_key)
- self.assertIsInstance(cached_resp, Response)
-
- def test_get_caching(self):
- """
- Test caching of GET requests
- """
- resp = self.http_resp('get', '/cache')
- cache.set(self.cache_key, resp)
-
- cached_resp = cache.get(self.cache_key)
- self.assertIsInstance(cached_resp, Response)
- self.assertEqual(cached_resp.content, resp.content)
diff --git a/rest_framework/tests/test_request.py b/rest_framework/tests/test_request.py
deleted file mode 100644
index c0b50f33..00000000
--- a/rest_framework/tests/test_request.py
+++ /dev/null
@@ -1,347 +0,0 @@
-"""
-Tests for content parsing, and form-overloaded content parsing.
-"""
-from __future__ import unicode_literals
-from django.contrib.auth.models import User
-from django.contrib.auth import authenticate, login, logout
-from django.contrib.sessions.middleware import SessionMiddleware
-from django.core.handlers.wsgi import WSGIRequest
-from django.test import TestCase
-from rest_framework import status
-from rest_framework.authentication import SessionAuthentication
-from rest_framework.compat import patterns
-from rest_framework.parsers import (
- BaseParser,
- FormParser,
- MultiPartParser,
- JSONParser
-)
-from rest_framework.request import Request, Empty
-from rest_framework.response import Response
-from rest_framework.settings import api_settings
-from rest_framework.test import APIRequestFactory, APIClient
-from rest_framework.views import APIView
-from rest_framework.compat import six
-from io import BytesIO
-import json
-
-
-factory = APIRequestFactory()
-
-
-class PlainTextParser(BaseParser):
- media_type = 'text/plain'
-
- def parse(self, stream, media_type=None, parser_context=None):
- """
- Returns a 2-tuple of `(data, files)`.
-
- `data` will simply be a string representing the body of the request.
- `files` will always be `None`.
- """
- return stream.read()
-
-
-class TestMethodOverloading(TestCase):
- def test_method(self):
- """
- Request methods should be same as underlying request.
- """
- request = Request(factory.get('/'))
- self.assertEqual(request.method, 'GET')
- request = Request(factory.post('/'))
- self.assertEqual(request.method, 'POST')
-
- def test_overloaded_method(self):
- """
- POST requests can be overloaded to another method by setting a
- reserved form field
- """
- request = Request(factory.post('/', {api_settings.FORM_METHOD_OVERRIDE: 'DELETE'}))
- self.assertEqual(request.method, 'DELETE')
-
- def test_x_http_method_override_header(self):
- """
- POST requests can also be overloaded to another method by setting
- the X-HTTP-Method-Override header.
- """
- request = Request(factory.post('/', {'foo': 'bar'}, HTTP_X_HTTP_METHOD_OVERRIDE='DELETE'))
- self.assertEqual(request.method, 'DELETE')
-
- request = Request(factory.get('/', {'foo': 'bar'}, HTTP_X_HTTP_METHOD_OVERRIDE='DELETE'))
- self.assertEqual(request.method, 'DELETE')
-
-
-class TestContentParsing(TestCase):
- def test_standard_behaviour_determines_no_content_GET(self):
- """
- Ensure request.DATA returns empty QueryDict for GET request.
- """
- request = Request(factory.get('/'))
- self.assertEqual(request.DATA, {})
-
- def test_standard_behaviour_determines_no_content_HEAD(self):
- """
- Ensure request.DATA returns empty QueryDict for HEAD request.
- """
- request = Request(factory.head('/'))
- self.assertEqual(request.DATA, {})
-
- def test_request_DATA_with_form_content(self):
- """
- Ensure request.DATA returns content for POST request with form content.
- """
- data = {'qwerty': 'uiop'}
- request = Request(factory.post('/', data))
- request.parsers = (FormParser(), MultiPartParser())
- self.assertEqual(list(request.DATA.items()), list(data.items()))
-
- def test_request_DATA_with_text_content(self):
- """
- Ensure request.DATA returns content for POST request with
- non-form content.
- """
- content = six.b('qwerty')
- content_type = 'text/plain'
- request = Request(factory.post('/', content, content_type=content_type))
- request.parsers = (PlainTextParser(),)
- self.assertEqual(request.DATA, content)
-
- def test_request_POST_with_form_content(self):
- """
- Ensure request.POST returns content for POST request with form content.
- """
- data = {'qwerty': 'uiop'}
- request = Request(factory.post('/', data))
- request.parsers = (FormParser(), MultiPartParser())
- self.assertEqual(list(request.POST.items()), list(data.items()))
-
- def test_standard_behaviour_determines_form_content_PUT(self):
- """
- Ensure request.DATA returns content for PUT request with form content.
- """
- data = {'qwerty': 'uiop'}
- request = Request(factory.put('/', data))
- request.parsers = (FormParser(), MultiPartParser())
- self.assertEqual(list(request.DATA.items()), list(data.items()))
-
- def test_standard_behaviour_determines_non_form_content_PUT(self):
- """
- Ensure request.DATA returns content for PUT request with
- non-form content.
- """
- content = six.b('qwerty')
- content_type = 'text/plain'
- request = Request(factory.put('/', content, content_type=content_type))
- request.parsers = (PlainTextParser(), )
- self.assertEqual(request.DATA, content)
-
- def test_overloaded_behaviour_allows_content_tunnelling(self):
- """
- Ensure request.DATA returns content for overloaded POST request.
- """
- json_data = {'foobar': 'qwerty'}
- content = json.dumps(json_data)
- content_type = 'application/json'
- form_data = {
- api_settings.FORM_CONTENT_OVERRIDE: content,
- api_settings.FORM_CONTENTTYPE_OVERRIDE: content_type
- }
- request = Request(factory.post('/', form_data))
- request.parsers = (JSONParser(), )
- self.assertEqual(request.DATA, json_data)
-
- def test_form_POST_unicode(self):
- """
- JSON POST via default web interface with unicode data
- """
- # Note: environ and other variables here have simplified content compared to real Request
- CONTENT = b'_content_type=application%2Fjson&_content=%7B%22request%22%3A+4%2C+%22firm%22%3A+1%2C+%22text%22%3A+%22%D0%9F%D1%80%D0%B8%D0%B2%D0%B5%D1%82%21%22%7D'
- environ = {
- 'REQUEST_METHOD': 'POST',
- 'CONTENT_TYPE': 'application/x-www-form-urlencoded',
- 'CONTENT_LENGTH': len(CONTENT),
- 'wsgi.input': BytesIO(CONTENT),
- }
- wsgi_request = WSGIRequest(environ=environ)
- wsgi_request._load_post_and_files()
- parsers = (JSONParser(), FormParser(), MultiPartParser())
- parser_context = {
- 'encoding': 'utf-8',
- 'kwargs': {},
- 'args': (),
- }
- request = Request(wsgi_request, parsers=parsers, parser_context=parser_context)
- method = request.method
- self.assertEqual(method, 'POST')
- self.assertEqual(request._content_type, 'application/json')
- self.assertEqual(request._stream.getvalue(), b'{"request": 4, "firm": 1, "text": "\xd0\x9f\xd1\x80\xd0\xb8\xd0\xb2\xd0\xb5\xd1\x82!"}')
- self.assertEqual(request._data, Empty)
- self.assertEqual(request._files, Empty)
-
- # def test_accessing_post_after_data_form(self):
- # """
- # Ensures request.POST can be accessed after request.DATA in
- # form request.
- # """
- # data = {'qwerty': 'uiop'}
- # request = factory.post('/', data=data)
- # self.assertEqual(request.DATA.items(), data.items())
- # self.assertEqual(request.POST.items(), data.items())
-
- # def test_accessing_post_after_data_for_json(self):
- # """
- # Ensures request.POST can be accessed after request.DATA in
- # json request.
- # """
- # data = {'qwerty': 'uiop'}
- # content = json.dumps(data)
- # content_type = 'application/json'
- # parsers = (JSONParser, )
-
- # request = factory.post('/', content, content_type=content_type,
- # parsers=parsers)
- # self.assertEqual(request.DATA.items(), data.items())
- # self.assertEqual(request.POST.items(), [])
-
- # def test_accessing_post_after_data_for_overloaded_json(self):
- # """
- # Ensures request.POST can be accessed after request.DATA in overloaded
- # json request.
- # """
- # data = {'qwerty': 'uiop'}
- # content = json.dumps(data)
- # content_type = 'application/json'
- # parsers = (JSONParser, )
- # form_data = {Request._CONTENT_PARAM: content,
- # Request._CONTENTTYPE_PARAM: content_type}
-
- # request = factory.post('/', form_data, parsers=parsers)
- # self.assertEqual(request.DATA.items(), data.items())
- # self.assertEqual(request.POST.items(), form_data.items())
-
- # def test_accessing_data_after_post_form(self):
- # """
- # Ensures request.DATA can be accessed after request.POST in
- # form request.
- # """
- # data = {'qwerty': 'uiop'}
- # parsers = (FormParser, MultiPartParser)
- # request = factory.post('/', data, parsers=parsers)
-
- # self.assertEqual(request.POST.items(), data.items())
- # self.assertEqual(request.DATA.items(), data.items())
-
- # def test_accessing_data_after_post_for_json(self):
- # """
- # Ensures request.DATA can be accessed after request.POST in
- # json request.
- # """
- # data = {'qwerty': 'uiop'}
- # content = json.dumps(data)
- # content_type = 'application/json'
- # parsers = (JSONParser, )
- # request = factory.post('/', content, content_type=content_type,
- # parsers=parsers)
- # self.assertEqual(request.POST.items(), [])
- # self.assertEqual(request.DATA.items(), data.items())
-
- # def test_accessing_data_after_post_for_overloaded_json(self):
- # """
- # Ensures request.DATA can be accessed after request.POST in overloaded
- # json request
- # """
- # data = {'qwerty': 'uiop'}
- # content = json.dumps(data)
- # content_type = 'application/json'
- # parsers = (JSONParser, )
- # form_data = {Request._CONTENT_PARAM: content,
- # Request._CONTENTTYPE_PARAM: content_type}
-
- # request = factory.post('/', form_data, parsers=parsers)
- # self.assertEqual(request.POST.items(), form_data.items())
- # self.assertEqual(request.DATA.items(), data.items())
-
-
-class MockView(APIView):
- authentication_classes = (SessionAuthentication,)
-
- def post(self, request):
- if request.POST.get('example') is not None:
- return Response(status=status.HTTP_200_OK)
-
- return Response(status=status.INTERNAL_SERVER_ERROR)
-
-urlpatterns = patterns('',
- (r'^$', MockView.as_view()),
-)
-
-
-class TestContentParsingWithAuthentication(TestCase):
- urls = 'rest_framework.tests.test_request'
-
- def setUp(self):
- self.csrf_client = APIClient(enforce_csrf_checks=True)
- self.username = 'john'
- self.email = 'lennon@thebeatles.com'
- self.password = 'password'
- self.user = User.objects.create_user(self.username, self.email, self.password)
-
- def test_user_logged_in_authentication_has_POST_when_not_logged_in(self):
- """
- Ensures request.POST exists after SessionAuthentication when user
- doesn't log in.
- """
- content = {'example': 'example'}
-
- response = self.client.post('/', content)
- self.assertEqual(status.HTTP_200_OK, response.status_code)
-
- response = self.csrf_client.post('/', content)
- self.assertEqual(status.HTTP_200_OK, response.status_code)
-
- # def test_user_logged_in_authentication_has_post_when_logged_in(self):
- # """Ensures request.POST exists after UserLoggedInAuthentication when user does log in"""
- # self.client.login(username='john', password='password')
- # self.csrf_client.login(username='john', password='password')
- # content = {'example': 'example'}
-
- # response = self.client.post('/', content)
- # self.assertEqual(status.OK, response.status_code, "POST data is malformed")
-
- # response = self.csrf_client.post('/', content)
- # self.assertEqual(status.OK, response.status_code, "POST data is malformed")
-
-
-class TestUserSetter(TestCase):
-
- def setUp(self):
- # Pass request object through session middleware so session is
- # available to login and logout functions
- self.request = Request(factory.get('/'))
- SessionMiddleware().process_request(self.request)
-
- User.objects.create_user('ringo', 'starr@thebeatles.com', 'yellow')
- self.user = authenticate(username='ringo', password='yellow')
-
- def test_user_can_be_set(self):
- self.request.user = self.user
- self.assertEqual(self.request.user, self.user)
-
- def test_user_can_login(self):
- login(self.request, self.user)
- self.assertEqual(self.request.user, self.user)
-
- def test_user_can_logout(self):
- self.request.user = self.user
- self.assertFalse(self.request.user.is_anonymous())
- logout(self.request)
- self.assertTrue(self.request.user.is_anonymous())
-
-
-class TestAuthSetter(TestCase):
-
- def test_auth_can_be_set(self):
- request = Request(factory.get('/'))
- request.auth = 'DUMMY'
- self.assertEqual(request.auth, 'DUMMY')
diff --git a/rest_framework/tests/test_response.py b/rest_framework/tests/test_response.py
deleted file mode 100644
index eea3c641..00000000
--- a/rest_framework/tests/test_response.py
+++ /dev/null
@@ -1,278 +0,0 @@
-from __future__ import unicode_literals
-from django.test import TestCase
-from rest_framework.tests.models import BasicModel, BasicModelSerializer
-from rest_framework.compat import patterns, url, include
-from rest_framework.response import Response
-from rest_framework.views import APIView
-from rest_framework import generics
-from rest_framework import routers
-from rest_framework import status
-from rest_framework.renderers import (
- BaseRenderer,
- JSONRenderer,
- BrowsableAPIRenderer
-)
-from rest_framework import viewsets
-from rest_framework.settings import api_settings
-from rest_framework.compat import six
-
-
-class MockPickleRenderer(BaseRenderer):
- media_type = 'application/pickle'
-
-
-class MockJsonRenderer(BaseRenderer):
- media_type = 'application/json'
-
-
-class MockTextMediaRenderer(BaseRenderer):
- media_type = 'text/html'
-
-DUMMYSTATUS = status.HTTP_200_OK
-DUMMYCONTENT = 'dummycontent'
-
-RENDERER_A_SERIALIZER = lambda x: ('Renderer A: %s' % x).encode('ascii')
-RENDERER_B_SERIALIZER = lambda x: ('Renderer B: %s' % x).encode('ascii')
-
-
-class RendererA(BaseRenderer):
- media_type = 'mock/renderera'
- format = "formata"
-
- def render(self, data, media_type=None, renderer_context=None):
- return RENDERER_A_SERIALIZER(data)
-
-
-class RendererB(BaseRenderer):
- media_type = 'mock/rendererb'
- format = "formatb"
-
- def render(self, data, media_type=None, renderer_context=None):
- return RENDERER_B_SERIALIZER(data)
-
-
-class RendererC(RendererB):
- media_type = 'mock/rendererc'
- format = 'formatc'
- charset = "rendererc"
-
-
-class MockView(APIView):
- renderer_classes = (RendererA, RendererB, RendererC)
-
- def get(self, request, **kwargs):
- return Response(DUMMYCONTENT, status=DUMMYSTATUS)
-
-
-class MockViewSettingContentType(APIView):
- renderer_classes = (RendererA, RendererB, RendererC)
-
- def get(self, request, **kwargs):
- return Response(DUMMYCONTENT, status=DUMMYSTATUS, content_type='setbyview')
-
-
-class HTMLView(APIView):
- renderer_classes = (BrowsableAPIRenderer, )
-
- def get(self, request, **kwargs):
- return Response('text')
-
-
-class HTMLView1(APIView):
- renderer_classes = (BrowsableAPIRenderer, JSONRenderer)
-
- def get(self, request, **kwargs):
- return Response('text')
-
-
-class HTMLNewModelViewSet(viewsets.ModelViewSet):
- model = BasicModel
-
-
-class HTMLNewModelView(generics.ListCreateAPIView):
- renderer_classes = (BrowsableAPIRenderer,)
- permission_classes = []
- serializer_class = BasicModelSerializer
- model = BasicModel
-
-
-new_model_viewset_router = routers.DefaultRouter()
-new_model_viewset_router.register(r'', HTMLNewModelViewSet)
-
-
-urlpatterns = patterns('',
- url(r'^setbyview$', MockViewSettingContentType.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
- url(r'^.*\.(?P.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
- url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
- url(r'^html$', HTMLView.as_view()),
- url(r'^html1$', HTMLView1.as_view()),
- url(r'^html_new_model$', HTMLNewModelView.as_view()),
- url(r'^html_new_model_viewset', include(new_model_viewset_router.urls)),
- url(r'^restframework', include('rest_framework.urls', namespace='rest_framework'))
-)
-
-
-# TODO: Clean tests bellow - remove duplicates with above, better unit testing, ...
-class RendererIntegrationTests(TestCase):
- """
- End-to-end testing of renderers using an ResponseMixin on a generic view.
- """
-
- urls = 'rest_framework.tests.test_response'
-
- def test_default_renderer_serializes_content(self):
- """If the Accept header is not set the default renderer should serialize the response."""
- resp = self.client.get('/')
- self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
-
- def test_head_method_serializes_no_content(self):
- """No response must be included in HEAD requests."""
- resp = self.client.head('/')
- self.assertEqual(resp.status_code, DUMMYSTATUS)
- self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, six.b(''))
-
- def test_default_renderer_serializes_content_on_accept_any(self):
- """If the Accept header is set to */* the default renderer should serialize the response."""
- resp = self.client.get('/', HTTP_ACCEPT='*/*')
- self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
-
- def test_specified_renderer_serializes_content_default_case(self):
- """If the Accept header is set the specified renderer should serialize the response.
- (In this case we check that works for the default renderer)"""
- resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
- self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
-
- def test_specified_renderer_serializes_content_non_default_case(self):
- """If the Accept header is set the specified renderer should serialize the response.
- (In this case we check that works for a non-default renderer)"""
- resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
-
- def test_specified_renderer_serializes_content_on_accept_query(self):
- """The '_accept' query string should behave in the same way as the Accept header."""
- param = '?%s=%s' % (
- api_settings.URL_ACCEPT_OVERRIDE,
- RendererB.media_type
- )
- resp = self.client.get('/' + param)
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
-
- def test_specified_renderer_serializes_content_on_format_query(self):
- """If a 'format' query is specified, the renderer with the matching
- format attribute should serialize the response."""
- resp = self.client.get('/?format=%s' % RendererB.format)
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
-
- def test_specified_renderer_serializes_content_on_format_kwargs(self):
- """If a 'format' keyword arg is specified, the renderer with the matching
- format attribute should serialize the response."""
- resp = self.client.get('/something.formatb')
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
-
- def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
- """If both a 'format' query and a matching Accept header specified,
- the renderer with the matching format attribute should serialize the response."""
- resp = self.client.get('/?format=%s' % RendererB.format,
- HTTP_ACCEPT=RendererB.media_type)
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
- self.assertEqual(resp.status_code, DUMMYSTATUS)
-
-
-class Issue122Tests(TestCase):
- """
- Tests that covers #122.
- """
- urls = 'rest_framework.tests.test_response'
-
- def test_only_html_renderer(self):
- """
- Test if no infinite recursion occurs.
- """
- self.client.get('/html')
-
- def test_html_renderer_is_first(self):
- """
- Test if no infinite recursion occurs.
- """
- self.client.get('/html1')
-
-
-class Issue467Tests(TestCase):
- """
- Tests for #467
- """
-
- urls = 'rest_framework.tests.test_response'
-
- def test_form_has_label_and_help_text(self):
- resp = self.client.get('/html_new_model')
- self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
- self.assertContains(resp, 'Text comes here')
- self.assertContains(resp, 'Text description.')
-
-
-class Issue807Tests(TestCase):
- """
- Covers #807
- """
-
- urls = 'rest_framework.tests.test_response'
-
- def test_does_not_append_charset_by_default(self):
- """
- Renderers don't include a charset unless set explicitly.
- """
- headers = {"HTTP_ACCEPT": RendererA.media_type}
- resp = self.client.get('/', **headers)
- expected = "{0}; charset={1}".format(RendererA.media_type, 'utf-8')
- self.assertEqual(expected, resp['Content-Type'])
-
- def test_if_there_is_charset_specified_on_renderer_it_gets_appended(self):
- """
- If renderer class has charset attribute declared, it gets appended
- to Response's Content-Type
- """
- headers = {"HTTP_ACCEPT": RendererC.media_type}
- resp = self.client.get('/', **headers)
- expected = "{0}; charset={1}".format(RendererC.media_type, RendererC.charset)
- self.assertEqual(expected, resp['Content-Type'])
-
- def test_content_type_set_explictly_on_response(self):
- """
- The content type may be set explictly on the response.
- """
- headers = {"HTTP_ACCEPT": RendererC.media_type}
- resp = self.client.get('/setbyview', **headers)
- self.assertEqual('setbyview', resp['Content-Type'])
-
- def test_viewset_label_help_text(self):
- param = '?%s=%s' % (
- api_settings.URL_ACCEPT_OVERRIDE,
- 'text/html'
- )
- resp = self.client.get('/html_new_model_viewset/' + param)
- self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
- self.assertContains(resp, 'Text comes here')
- self.assertContains(resp, 'Text description.')
-
- def test_form_has_label_and_help_text(self):
- resp = self.client.get('/html_new_model')
- self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
- self.assertContains(resp, 'Text comes here')
- self.assertContains(resp, 'Text description.')
diff --git a/rest_framework/tests/test_reverse.py b/rest_framework/tests/test_reverse.py
deleted file mode 100644
index 690a30b1..00000000
--- a/rest_framework/tests/test_reverse.py
+++ /dev/null
@@ -1,27 +0,0 @@
-from __future__ import unicode_literals
-from django.test import TestCase
-from rest_framework.compat import patterns, url
-from rest_framework.reverse import reverse
-from rest_framework.test import APIRequestFactory
-
-factory = APIRequestFactory()
-
-
-def null_view(request):
- pass
-
-urlpatterns = patterns('',
- url(r'^view$', null_view, name='view'),
-)
-
-
-class ReverseTests(TestCase):
- """
- Tests for fully qualified URLs when using `reverse`.
- """
- urls = 'rest_framework.tests.test_reverse'
-
- def test_reversed_urls_are_fully_qualified(self):
- request = factory.get('/view')
- url = reverse('view', request=request)
- self.assertEqual(url, 'http://testserver/view')
diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py
deleted file mode 100644
index e723f7d4..00000000
--- a/rest_framework/tests/test_routers.py
+++ /dev/null
@@ -1,216 +0,0 @@
-from __future__ import unicode_literals
-from django.db import models
-from django.test import TestCase
-from django.core.exceptions import ImproperlyConfigured
-from rest_framework import serializers, viewsets, permissions
-from rest_framework.compat import include, patterns, url
-from rest_framework.decorators import link, action
-from rest_framework.response import Response
-from rest_framework.routers import SimpleRouter, DefaultRouter
-from rest_framework.test import APIRequestFactory
-
-factory = APIRequestFactory()
-
-urlpatterns = patterns('',)
-
-
-class BasicViewSet(viewsets.ViewSet):
- def list(self, request, *args, **kwargs):
- return Response({'method': 'list'})
-
- @action()
- def action1(self, request, *args, **kwargs):
- return Response({'method': 'action1'})
-
- @action()
- def action2(self, request, *args, **kwargs):
- return Response({'method': 'action2'})
-
- @action(methods=['post', 'delete'])
- def action3(self, request, *args, **kwargs):
- return Response({'method': 'action2'})
-
- @link()
- def link1(self, request, *args, **kwargs):
- return Response({'method': 'link1'})
-
- @link()
- def link2(self, request, *args, **kwargs):
- return Response({'method': 'link2'})
-
-
-class TestSimpleRouter(TestCase):
- def setUp(self):
- self.router = SimpleRouter()
-
- def test_link_and_action_decorator(self):
- routes = self.router.get_routes(BasicViewSet)
- decorator_routes = routes[2:]
- # Make sure all these endpoints exist and none have been clobbered
- for i, endpoint in enumerate(['action1', 'action2', 'action3', 'link1', 'link2']):
- route = decorator_routes[i]
- # check url listing
- self.assertEqual(route.url,
- '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(endpoint))
- # check method to function mapping
- if endpoint == 'action3':
- methods_map = ['post', 'delete']
- elif endpoint.startswith('action'):
- methods_map = ['post']
- else:
- methods_map = ['get']
- for method in methods_map:
- self.assertEqual(route.mapping[method], endpoint)
-
-
-class RouterTestModel(models.Model):
- uuid = models.CharField(max_length=20)
- text = models.CharField(max_length=200)
-
-
-class TestCustomLookupFields(TestCase):
- """
- Ensure that custom lookup fields are correctly routed.
- """
- urls = 'rest_framework.tests.test_routers'
-
- def setUp(self):
- class NoteSerializer(serializers.HyperlinkedModelSerializer):
- class Meta:
- model = RouterTestModel
- lookup_field = 'uuid'
- fields = ('url', 'uuid', 'text')
-
- class NoteViewSet(viewsets.ModelViewSet):
- queryset = RouterTestModel.objects.all()
- serializer_class = NoteSerializer
- lookup_field = 'uuid'
-
- RouterTestModel.objects.create(uuid='123', text='foo bar')
-
- self.router = SimpleRouter()
- self.router.register(r'notes', NoteViewSet)
-
- from rest_framework.tests import test_routers
- urls = getattr(test_routers, 'urlpatterns')
- urls += patterns('',
- url(r'^', include(self.router.urls)),
- )
-
- def test_custom_lookup_field_route(self):
- detail_route = self.router.urls[-1]
- detail_url_pattern = detail_route.regex.pattern
- self.assertIn('', detail_url_pattern)
-
- def test_retrieve_lookup_field_list_view(self):
- response = self.client.get('/notes/')
- self.assertEqual(response.data,
- [{
- "url": "http://testserver/notes/123/",
- "uuid": "123", "text": "foo bar"
- }]
- )
-
- def test_retrieve_lookup_field_detail_view(self):
- response = self.client.get('/notes/123/')
- self.assertEqual(response.data,
- {
- "url": "http://testserver/notes/123/",
- "uuid": "123", "text": "foo bar"
- }
- )
-
-
-class TestTrailingSlashIncluded(TestCase):
- def setUp(self):
- class NoteViewSet(viewsets.ModelViewSet):
- model = RouterTestModel
-
- self.router = SimpleRouter()
- self.router.register(r'notes', NoteViewSet)
- self.urls = self.router.urls
-
- def test_urls_have_trailing_slash_by_default(self):
- expected = ['^notes/$', '^notes/(?P[^/]+)/$']
- for idx in range(len(expected)):
- self.assertEqual(expected[idx], self.urls[idx].regex.pattern)
-
-
-class TestTrailingSlashRemoved(TestCase):
- def setUp(self):
- class NoteViewSet(viewsets.ModelViewSet):
- model = RouterTestModel
-
- self.router = SimpleRouter(trailing_slash=False)
- self.router.register(r'notes', NoteViewSet)
- self.urls = self.router.urls
-
- def test_urls_can_have_trailing_slash_removed(self):
- expected = ['^notes$', '^notes/(?P[^/.]+)$']
- for idx in range(len(expected)):
- self.assertEqual(expected[idx], self.urls[idx].regex.pattern)
-
-
-class TestNameableRoot(TestCase):
- def setUp(self):
- class NoteViewSet(viewsets.ModelViewSet):
- model = RouterTestModel
- self.router = DefaultRouter()
- self.router.root_view_name = 'nameable-root'
- self.router.register(r'notes', NoteViewSet)
- self.urls = self.router.urls
-
- def test_router_has_custom_name(self):
- expected = 'nameable-root'
- self.assertEqual(expected, self.urls[0].name)
-
-
-class TestActionKeywordArgs(TestCase):
- """
- Ensure keyword arguments passed in the `@action` decorator
- are properly handled. Refs #940.
- """
-
- def setUp(self):
- class TestViewSet(viewsets.ModelViewSet):
- permission_classes = []
-
- @action(permission_classes=[permissions.AllowAny])
- def custom(self, request, *args, **kwargs):
- return Response({
- 'permission_classes': self.permission_classes
- })
-
- self.router = SimpleRouter()
- self.router.register(r'test', TestViewSet, base_name='test')
- self.view = self.router.urls[-1].callback
-
- def test_action_kwargs(self):
- request = factory.post('/test/0/custom/')
- response = self.view(request)
- self.assertEqual(
- response.data,
- {'permission_classes': [permissions.AllowAny]}
- )
-
-
-class TestActionAppliedToExistingRoute(TestCase):
- """
- Ensure `@action` decorator raises an except when applied
- to an existing route
- """
-
- def test_exception_raised_when_action_applied_to_existing_route(self):
- class TestViewSet(viewsets.ModelViewSet):
-
- @action()
- def retrieve(self, request, *args, **kwargs):
- return Response({
- 'hello': 'world'
- })
-
- self.router = SimpleRouter()
- self.router.register(r'test', TestViewSet, base_name='test')
-
- with self.assertRaises(ImproperlyConfigured):
- self.router.urls
diff --git a/rest_framework/tests/test_serializer.py b/rest_framework/tests/test_serializer.py
deleted file mode 100644
index 6b1e333e..00000000
--- a/rest_framework/tests/test_serializer.py
+++ /dev/null
@@ -1,1857 +0,0 @@
-# -*- coding: utf-8 -*-
-from __future__ import unicode_literals
-from django.db import models
-from django.db.models.fields import BLANK_CHOICE_DASH
-from django.test import TestCase
-from django.utils.datastructures import MultiValueDict
-from django.utils.translation import ugettext_lazy as _
-from rest_framework import serializers, fields, relations
-from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel,
- BlankFieldModel, BlogPost, BlogPostComment, Book, CallableDefaultValueModel, DefaultValueModel,
- ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo, RESTFrameworkModel)
-from rest_framework.tests.models import BasicModelSerializer
-import datetime
-import pickle
-
-
-class SubComment(object):
- def __init__(self, sub_comment):
- self.sub_comment = sub_comment
-
-
-class Comment(object):
- def __init__(self, email, content, created):
- self.email = email
- self.content = content
- self.created = created or datetime.datetime.now()
-
- def __eq__(self, other):
- return all([getattr(self, attr) == getattr(other, attr)
- for attr in ('email', 'content', 'created')])
-
- def get_sub_comment(self):
- sub_comment = SubComment('And Merry Christmas!')
- return sub_comment
-
-
-class CommentSerializer(serializers.Serializer):
- email = serializers.EmailField()
- content = serializers.CharField(max_length=1000)
- created = serializers.DateTimeField()
- sub_comment = serializers.Field(source='get_sub_comment.sub_comment')
-
- def restore_object(self, data, instance=None):
- if instance is None:
- return Comment(**data)
- for key, val in data.items():
- setattr(instance, key, val)
- return instance
-
-
-class NamesSerializer(serializers.Serializer):
- first = serializers.CharField()
- last = serializers.CharField(required=False, default='')
- initials = serializers.CharField(required=False, default='')
-
-
-class PersonIdentifierSerializer(serializers.Serializer):
- ssn = serializers.CharField()
- names = NamesSerializer(source='names', required=False)
-
-
-class BookSerializer(serializers.ModelSerializer):
- isbn = serializers.RegexField(regex=r'^[0-9]{13}$', error_messages={'invalid': 'isbn has to be exact 13 numbers'})
-
- class Meta:
- model = Book
-
-
-class ActionItemSerializer(serializers.ModelSerializer):
-
- class Meta:
- model = ActionItem
-
-class ActionItemSerializerOptionalFields(serializers.ModelSerializer):
- """
- Intended to test that fields with `required=False` are excluded from validation.
- """
- title = serializers.CharField(required=False)
-
- class Meta:
- model = ActionItem
- fields = ('title',)
-
-class ActionItemSerializerCustomRestore(serializers.ModelSerializer):
-
- class Meta:
- model = ActionItem
-
- def restore_object(self, data, instance=None):
- if instance is None:
- return ActionItem(**data)
- for key, val in data.items():
- setattr(instance, key, val)
- return instance
-
-
-class PersonSerializer(serializers.ModelSerializer):
- info = serializers.Field(source='info')
-
- class Meta:
- model = Person
- fields = ('name', 'age', 'info')
- read_only_fields = ('age',)
-
-
-class NestedSerializer(serializers.Serializer):
- info = serializers.Field()
-
-
-class ModelSerializerWithNestedSerializer(serializers.ModelSerializer):
- nested = NestedSerializer(source='*')
-
- class Meta:
- model = Person
-
-
-class NestedSerializerWithRenamedField(serializers.Serializer):
- renamed_info = serializers.Field(source='info')
-
-
-class ModelSerializerWithNestedSerializerWithRenamedField(serializers.ModelSerializer):
- nested = NestedSerializerWithRenamedField(source='*')
-
- class Meta:
- model = Person
-
-
-class PersonSerializerInvalidReadOnly(serializers.ModelSerializer):
- """
- Testing for #652.
- """
- info = serializers.Field(source='info')
-
- class Meta:
- model = Person
- fields = ('name', 'age', 'info')
- read_only_fields = ('age', 'info')
-
-
-class AlbumsSerializer(serializers.ModelSerializer):
-
- class Meta:
- model = Album
- fields = ['title'] # lists are also valid options
-
-
-class PositiveIntegerAsChoiceSerializer(serializers.ModelSerializer):
- class Meta:
- model = HasPositiveIntegerAsChoice
- fields = ['some_integer']
-
-
-class BasicTests(TestCase):
- def setUp(self):
- self.comment = Comment(
- 'tom@example.com',
- 'Happy new year!',
- datetime.datetime(2012, 1, 1)
- )
- self.actionitem = ActionItem(title='Some to do item',)
- self.data = {
- 'email': 'tom@example.com',
- 'content': 'Happy new year!',
- 'created': datetime.datetime(2012, 1, 1),
- 'sub_comment': 'This wont change'
- }
- self.expected = {
- 'email': 'tom@example.com',
- 'content': 'Happy new year!',
- 'created': datetime.datetime(2012, 1, 1),
- 'sub_comment': 'And Merry Christmas!'
- }
- self.person_data = {'name': 'dwight', 'age': 35}
- self.person = Person(**self.person_data)
- self.person.save()
-
- def test_empty(self):
- serializer = CommentSerializer()
- expected = {
- 'email': '',
- 'content': '',
- 'created': None
- }
- self.assertEqual(serializer.data, expected)
-
- def test_retrieve(self):
- serializer = CommentSerializer(self.comment)
- self.assertEqual(serializer.data, self.expected)
-
- def test_create(self):
- serializer = CommentSerializer(data=self.data)
- expected = self.comment
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.object, expected)
- self.assertFalse(serializer.object is expected)
- self.assertEqual(serializer.data['sub_comment'], 'And Merry Christmas!')
-
- def test_create_nested(self):
- """Test a serializer with nested data."""
- names = {'first': 'John', 'last': 'Doe', 'initials': 'jd'}
- data = {'ssn': '1234567890', 'names': names}
- serializer = PersonIdentifierSerializer(data=data)
-
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.object, data)
- self.assertFalse(serializer.object is data)
- self.assertEqual(serializer.data['names'], names)
-
- def test_create_partial_nested(self):
- """Test a serializer with nested data which has missing fields."""
- names = {'first': 'John'}
- data = {'ssn': '1234567890', 'names': names}
- serializer = PersonIdentifierSerializer(data=data)
-
- expected_names = {'first': 'John', 'last': '', 'initials': ''}
- data['names'] = expected_names
-
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.object, data)
- self.assertFalse(serializer.object is expected_names)
- self.assertEqual(serializer.data['names'], expected_names)
-
- def test_null_nested(self):
- """Test a serializer with a nonexistent nested field"""
- data = {'ssn': '1234567890'}
- serializer = PersonIdentifierSerializer(data=data)
-
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.object, data)
- self.assertFalse(serializer.object is data)
- expected = {'ssn': '1234567890', 'names': None}
- self.assertEqual(serializer.data, expected)
-
- def test_update(self):
- serializer = CommentSerializer(self.comment, data=self.data)
- expected = self.comment
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.object, expected)
- self.assertTrue(serializer.object is expected)
- self.assertEqual(serializer.data['sub_comment'], 'And Merry Christmas!')
-
- def test_partial_update(self):
- msg = 'Merry New Year!'
- partial_data = {'content': msg}
- serializer = CommentSerializer(self.comment, data=partial_data)
- self.assertEqual(serializer.is_valid(), False)
- serializer = CommentSerializer(self.comment, data=partial_data, partial=True)
- expected = self.comment
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.object, expected)
- self.assertTrue(serializer.object is expected)
- self.assertEqual(serializer.data['content'], msg)
-
- def test_model_fields_as_expected(self):
- """
- Make sure that the fields returned are the same as defined
- in the Meta data
- """
- serializer = PersonSerializer(self.person)
- self.assertEqual(set(serializer.data.keys()),
- set(['name', 'age', 'info']))
-
- def test_field_with_dictionary(self):
- """
- Make sure that dictionaries from fields are left intact
- """
- serializer = PersonSerializer(self.person)
- expected = self.person_data
- self.assertEqual(serializer.data['info'], expected)
-
- def test_read_only_fields(self):
- """
- Attempting to update fields set as read_only should have no effect.
- """
- serializer = PersonSerializer(self.person, data={'name': 'dwight', 'age': 99})
- self.assertEqual(serializer.is_valid(), True)
- instance = serializer.save()
- self.assertEqual(serializer.errors, {})
- # Assert age is unchanged (35)
- self.assertEqual(instance.age, self.person_data['age'])
-
- def test_invalid_read_only_fields(self):
- """
- Regression test for #652.
- """
- self.assertRaises(AssertionError, PersonSerializerInvalidReadOnly, [])
-
- def test_serializer_data_is_cleared_on_save(self):
- """
- Check _data attribute is cleared on `save()`
-
- Regression test for #1116
- — id field is not populated if `data` is accessed prior to `save()`
- """
- serializer = ActionItemSerializer(self.actionitem)
- self.assertIsNone(serializer.data.get('id',None), 'New instance. `id` should not be set.')
- serializer.save()
- self.assertIsNotNone(serializer.data.get('id',None), 'Model is saved. `id` should be set.')
-
- def test_fields_marked_as_not_required_are_excluded_from_validation(self):
- """
- Check that fields with `required=False` are included in list of exclusions.
- """
- serializer = ActionItemSerializerOptionalFields(self.actionitem)
- exclusions = serializer.get_validation_exclusions()
- self.assertTrue('title' in exclusions, '`title` field was marked `required=False` and should be excluded')
-
-
-class DictStyleSerializer(serializers.Serializer):
- """
- Note that we don't have any `restore_object` method, so the default
- case of simply returning a dict will apply.
- """
- email = serializers.EmailField()
-
-
-class DictStyleSerializerTests(TestCase):
- def test_dict_style_deserialize(self):
- """
- Ensure serializers can deserialize into a dict.
- """
- data = {'email': 'foo@example.com'}
- serializer = DictStyleSerializer(data=data)
- self.assertTrue(serializer.is_valid())
- self.assertEqual(serializer.data, data)
-
- def test_dict_style_serialize(self):
- """
- Ensure serializers can serialize dict objects.
- """
- data = {'email': 'foo@example.com'}
- serializer = DictStyleSerializer(data)
- self.assertEqual(serializer.data, data)
-
-
-class ValidationTests(TestCase):
- def setUp(self):
- self.comment = Comment(
- 'tom@example.com',
- 'Happy new year!',
- datetime.datetime(2012, 1, 1)
- )
- self.data = {
- 'email': 'tom@example.com',
- 'content': 'x' * 1001,
- 'created': datetime.datetime(2012, 1, 1)
- }
- self.actionitem = ActionItem(title='Some to do item',)
-
- def test_create(self):
- serializer = CommentSerializer(data=self.data)
- self.assertEqual(serializer.is_valid(), False)
- self.assertEqual(serializer.errors, {'content': ['Ensure this value has at most 1000 characters (it has 1001).']})
-
- def test_update(self):
- serializer = CommentSerializer(self.comment, data=self.data)
- self.assertEqual(serializer.is_valid(), False)
- self.assertEqual(serializer.errors, {'content': ['Ensure this value has at most 1000 characters (it has 1001).']})
-
- def test_update_missing_field(self):
- data = {
- 'content': 'xxx',
- 'created': datetime.datetime(2012, 1, 1)
- }
- serializer = CommentSerializer(self.comment, data=data)
- self.assertEqual(serializer.is_valid(), False)
- self.assertEqual(serializer.errors, {'email': ['This field is required.']})
-
- def test_missing_bool_with_default(self):
- """Make sure that a boolean value with a 'False' value is not
- mistaken for not having a default."""
- data = {
- 'title': 'Some action item',
- #No 'done' value.
- }
- serializer = ActionItemSerializer(self.actionitem, data=data)
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.errors, {})
-
- def test_cross_field_validation(self):
-
- class CommentSerializerWithCrossFieldValidator(CommentSerializer):
-
- def validate(self, attrs):
- if attrs["email"] not in attrs["content"]:
- raise serializers.ValidationError("Email address not in content")
- return attrs
-
- data = {
- 'email': 'tom@example.com',
- 'content': 'A comment from tom@example.com',
- 'created': datetime.datetime(2012, 1, 1)
- }
-
- serializer = CommentSerializerWithCrossFieldValidator(data=data)
- self.assertTrue(serializer.is_valid())
-
- data['content'] = 'A comment from foo@bar.com'
-
- serializer = CommentSerializerWithCrossFieldValidator(data=data)
- self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'non_field_errors': ['Email address not in content']})
-
- def test_null_is_true_fields(self):
- """
- Omitting a value for null-field should validate.
- """
- serializer = PersonSerializer(data={'name': 'marko'})
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.errors, {})
-
- def test_modelserializer_max_length_exceeded(self):
- data = {
- 'title': 'x' * 201,
- }
- serializer = ActionItemSerializer(data=data)
- self.assertEqual(serializer.is_valid(), False)
- self.assertEqual(serializer.errors, {'title': ['Ensure this value has at most 200 characters (it has 201).']})
-
- def test_modelserializer_max_length_exceeded_with_custom_restore(self):
- """
- When overriding ModelSerializer.restore_object, validation tests should still apply.
- Regression test for #623.
-
- https://github.com/tomchristie/django-rest-framework/pull/623
- """
- data = {
- 'title': 'x' * 201,
- }
- serializer = ActionItemSerializerCustomRestore(data=data)
- self.assertEqual(serializer.is_valid(), False)
- self.assertEqual(serializer.errors, {'title': ['Ensure this value has at most 200 characters (it has 201).']})
-
- def test_default_modelfield_max_length_exceeded(self):
- data = {
- 'title': 'Testing "info" field...',
- 'info': 'x' * 13,
- }
- serializer = ActionItemSerializer(data=data)
- self.assertEqual(serializer.is_valid(), False)
- self.assertEqual(serializer.errors, {'info': ['Ensure this value has at most 12 characters (it has 13).']})
-
- def test_datetime_validation_failure(self):
- """
- Test DateTimeField validation errors on non-str values.
- Regression test for #669.
-
- https://github.com/tomchristie/django-rest-framework/issues/669
- """
- data = self.data
- data['created'] = 0
-
- serializer = CommentSerializer(data=data)
- self.assertEqual(serializer.is_valid(), False)
-
- self.assertIn('created', serializer.errors)
-
- def test_missing_model_field_exception_msg(self):
- """
- Assert that a meaningful exception message is outputted when the model
- field is missing (e.g. when mistyping ``model``).
- """
- class BrokenModelSerializer(serializers.ModelSerializer):
- class Meta:
- fields = ['some_field']
-
- try:
- BrokenModelSerializer()
- except AssertionError as e:
- self.assertEqual(e.args[0], "Serializer class 'BrokenModelSerializer' is missing 'model' Meta option")
- except:
- self.fail('Wrong exception type thrown.')
-
- def test_writable_star_source_on_nested_serializer(self):
- """
- Assert that a nested serializer instantiated with source='*' correctly
- expands the data into the outer serializer.
- """
- serializer = ModelSerializerWithNestedSerializer(data={
- 'name': 'marko',
- 'nested': {'info': 'hi'}},
- )
- self.assertEqual(serializer.is_valid(), True)
-
- def test_writable_star_source_with_inner_source_fields(self):
- """
- Tests that a serializer with source="*" correctly expands the
- it's fields into the outer serializer even if they have their
- own 'source' parameters.
- """
-
- serializer = ModelSerializerWithNestedSerializerWithRenamedField(data={
- 'name': 'marko',
- 'nested': {'renamed_info': 'hi'}},
- )
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.errors, {})
-
-
-class CustomValidationTests(TestCase):
- class CommentSerializerWithFieldValidator(CommentSerializer):
-
- def validate_email(self, attrs, source):
- attrs[source]
- return attrs
-
- def validate_content(self, attrs, source):
- value = attrs[source]
- if "test" not in value:
- raise serializers.ValidationError("Test not in value")
- return attrs
-
- def test_field_validation(self):
- data = {
- 'email': 'tom@example.com',
- 'content': 'A test comment',
- 'created': datetime.datetime(2012, 1, 1)
- }
-
- serializer = self.CommentSerializerWithFieldValidator(data=data)
- self.assertTrue(serializer.is_valid())
-
- data['content'] = 'This should not validate'
-
- serializer = self.CommentSerializerWithFieldValidator(data=data)
- self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'content': ['Test not in value']})
-
- def test_missing_data(self):
- """
- Make sure that validate_content isn't called if the field is missing
- """
- incomplete_data = {
- 'email': 'tom@example.com',
- 'created': datetime.datetime(2012, 1, 1)
- }
- serializer = self.CommentSerializerWithFieldValidator(data=incomplete_data)
- self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'content': ['This field is required.']})
-
- def test_wrong_data(self):
- """
- Make sure that validate_content isn't called if the field input is wrong
- """
- wrong_data = {
- 'email': 'not an email',
- 'content': 'A test comment',
- 'created': datetime.datetime(2012, 1, 1)
- }
- serializer = self.CommentSerializerWithFieldValidator(data=wrong_data)
- self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'email': ['Enter a valid email address.']})
-
- def test_partial_update(self):
- """
- Make sure that validate_email isn't called when partial=True and email
- isn't found in data.
- """
- initial_data = {
- 'email': 'tom@example.com',
- 'content': 'A test comment',
- 'created': datetime.datetime(2012, 1, 1)
- }
-
- serializer = self.CommentSerializerWithFieldValidator(data=initial_data)
- self.assertEqual(serializer.is_valid(), True)
- instance = serializer.object
-
- new_content = 'An *updated* test comment'
- partial_data = {
- 'content': new_content
- }
-
- serializer = self.CommentSerializerWithFieldValidator(instance=instance,
- data=partial_data,
- partial=True)
- self.assertEqual(serializer.is_valid(), True)
- instance = serializer.object
- self.assertEqual(instance.content, new_content)
-
-
-class PositiveIntegerAsChoiceTests(TestCase):
- def test_positive_integer_in_json_is_correctly_parsed(self):
- data = {'some_integer': 1}
- serializer = PositiveIntegerAsChoiceSerializer(data=data)
- self.assertEqual(serializer.is_valid(), True)
-
-
-class ModelValidationTests(TestCase):
- def test_validate_unique(self):
- """
- Just check if serializers.ModelSerializer handles unique checks via .full_clean()
- """
- serializer = AlbumsSerializer(data={'title': 'a'})
- serializer.is_valid()
- serializer.save()
- second_serializer = AlbumsSerializer(data={'title': 'a'})
- self.assertFalse(second_serializer.is_valid())
- self.assertEqual(second_serializer.errors, {'title': ['Album with this Title already exists.']})
-
- def test_foreign_key_is_null_with_partial(self):
- """
- Test ModelSerializer validation with partial=True
-
- Specifically test that a null foreign key does not pass validation
- """
- album = Album(title='test')
- album.save()
-
- class PhotoSerializer(serializers.ModelSerializer):
- class Meta:
- model = Photo
-
- photo_serializer = PhotoSerializer(data={'description': 'test', 'album': album.pk})
- self.assertTrue(photo_serializer.is_valid())
- photo = photo_serializer.save()
-
- # Updating only the album (foreign key)
- photo_serializer = PhotoSerializer(instance=photo, data={'album': ''}, partial=True)
- self.assertFalse(photo_serializer.is_valid())
- self.assertTrue('album' in photo_serializer.errors)
- self.assertEqual(photo_serializer.errors['album'], photo_serializer.error_messages['required'])
-
- def test_foreign_key_with_partial(self):
- """
- Test ModelSerializer validation with partial=True
-
- Specifically test foreign key validation.
- """
-
- album = Album(title='test')
- album.save()
-
- class PhotoSerializer(serializers.ModelSerializer):
- class Meta:
- model = Photo
-
- photo_serializer = PhotoSerializer(data={'description': 'test', 'album': album.pk})
- self.assertTrue(photo_serializer.is_valid())
- photo = photo_serializer.save()
-
- # Updating only the album (foreign key)
- photo_serializer = PhotoSerializer(instance=photo, data={'album': album.pk}, partial=True)
- self.assertTrue(photo_serializer.is_valid())
- self.assertTrue(photo_serializer.save())
-
- # Updating only the description
- photo_serializer = PhotoSerializer(instance=photo,
- data={'description': 'new'},
- partial=True)
-
- self.assertTrue(photo_serializer.is_valid())
- self.assertTrue(photo_serializer.save())
-
-
-class RegexValidationTest(TestCase):
- def test_create_failed(self):
- serializer = BookSerializer(data={'isbn': '1234567890'})
- self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']})
-
- serializer = BookSerializer(data={'isbn': '12345678901234'})
- self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']})
-
- serializer = BookSerializer(data={'isbn': 'abcdefghijklm'})
- self.assertFalse(serializer.is_valid())
- self.assertEqual(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']})
-
- def test_create_success(self):
- serializer = BookSerializer(data={'isbn': '1234567890123'})
- self.assertTrue(serializer.is_valid())
-
-
-class MetadataTests(TestCase):
- def test_empty(self):
- serializer = CommentSerializer()
- expected = {
- 'email': serializers.CharField,
- 'content': serializers.CharField,
- 'created': serializers.DateTimeField
- }
- for field_name, field in expected.items():
- self.assertTrue(isinstance(serializer.data.fields[field_name], field))
-
-
-class ManyToManyTests(TestCase):
- def setUp(self):
- class ManyToManySerializer(serializers.ModelSerializer):
- class Meta:
- model = ManyToManyModel
-
- self.serializer_class = ManyToManySerializer
-
- # An anchor instance to use for the relationship
- self.anchor = Anchor()
- self.anchor.save()
-
- # A model instance with a many to many relationship to the anchor
- self.instance = ManyToManyModel()
- self.instance.save()
- self.instance.rel.add(self.anchor)
-
- # A serialized representation of the model instance
- self.data = {'id': 1, 'rel': [self.anchor.id]}
-
- def test_retrieve(self):
- """
- Serialize an instance of a model with a ManyToMany relationship.
- """
- serializer = self.serializer_class(instance=self.instance)
- expected = self.data
- self.assertEqual(serializer.data, expected)
-
- def test_create(self):
- """
- Create an instance of a model with a ManyToMany relationship.
- """
- data = {'rel': [self.anchor.id]}
- serializer = self.serializer_class(data=data)
- self.assertEqual(serializer.is_valid(), True)
- instance = serializer.save()
- self.assertEqual(len(ManyToManyModel.objects.all()), 2)
- self.assertEqual(instance.pk, 2)
- self.assertEqual(list(instance.rel.all()), [self.anchor])
-
- def test_update(self):
- """
- Update an instance of a model with a ManyToMany relationship.
- """
- new_anchor = Anchor()
- new_anchor.save()
- data = {'rel': [self.anchor.id, new_anchor.id]}
- serializer = self.serializer_class(self.instance, data=data)
- self.assertEqual(serializer.is_valid(), True)
- instance = serializer.save()
- self.assertEqual(len(ManyToManyModel.objects.all()), 1)
- self.assertEqual(instance.pk, 1)
- self.assertEqual(list(instance.rel.all()), [self.anchor, new_anchor])
-
- def test_create_empty_relationship(self):
- """
- Create an instance of a model with a ManyToMany relationship,
- containing no items.
- """
- data = {'rel': []}
- serializer = self.serializer_class(data=data)
- self.assertEqual(serializer.is_valid(), True)
- instance = serializer.save()
- self.assertEqual(len(ManyToManyModel.objects.all()), 2)
- self.assertEqual(instance.pk, 2)
- self.assertEqual(list(instance.rel.all()), [])
-
- def test_update_empty_relationship(self):
- """
- Update an instance of a model with a ManyToMany relationship,
- containing no items.
- """
- new_anchor = Anchor()
- new_anchor.save()
- data = {'rel': []}
- serializer = self.serializer_class(self.instance, data=data)
- self.assertEqual(serializer.is_valid(), True)
- instance = serializer.save()
- self.assertEqual(len(ManyToManyModel.objects.all()), 1)
- self.assertEqual(instance.pk, 1)
- self.assertEqual(list(instance.rel.all()), [])
-
- def test_create_empty_relationship_flat_data(self):
- """
- Create an instance of a model with a ManyToMany relationship,
- containing no items, using a representation that does not support
- lists (eg form data).
- """
- data = MultiValueDict()
- data.setlist('rel', [''])
- serializer = self.serializer_class(data=data)
- self.assertEqual(serializer.is_valid(), True)
- instance = serializer.save()
- self.assertEqual(len(ManyToManyModel.objects.all()), 2)
- self.assertEqual(instance.pk, 2)
- self.assertEqual(list(instance.rel.all()), [])
-
-
-class ReadOnlyManyToManyTests(TestCase):
- def setUp(self):
- class ReadOnlyManyToManySerializer(serializers.ModelSerializer):
- rel = serializers.RelatedField(many=True, read_only=True)
-
- class Meta:
- model = ReadOnlyManyToManyModel
-
- self.serializer_class = ReadOnlyManyToManySerializer
-
- # An anchor instance to use for the relationship
- self.anchor = Anchor()
- self.anchor.save()
-
- # A model instance with a many to many relationship to the anchor
- self.instance = ReadOnlyManyToManyModel()
- self.instance.save()
- self.instance.rel.add(self.anchor)
-
- # A serialized representation of the model instance
- self.data = {'rel': [self.anchor.id], 'id': 1, 'text': 'anchor'}
-
- def test_update(self):
- """
- Attempt to update an instance of a model with a ManyToMany
- relationship. Not updated due to read_only=True
- """
- new_anchor = Anchor()
- new_anchor.save()
- data = {'rel': [self.anchor.id, new_anchor.id]}
- serializer = self.serializer_class(self.instance, data=data)
- self.assertEqual(serializer.is_valid(), True)
- instance = serializer.save()
- self.assertEqual(len(ReadOnlyManyToManyModel.objects.all()), 1)
- self.assertEqual(instance.pk, 1)
- # rel is still as original (1 entry)
- self.assertEqual(list(instance.rel.all()), [self.anchor])
-
- def test_update_without_relationship(self):
- """
- Attempt to update an instance of a model where many to ManyToMany
- relationship is not supplied. Not updated due to read_only=True
- """
- new_anchor = Anchor()
- new_anchor.save()
- data = {}
- serializer = self.serializer_class(self.instance, data=data)
- self.assertEqual(serializer.is_valid(), True)
- instance = serializer.save()
- self.assertEqual(len(ReadOnlyManyToManyModel.objects.all()), 1)
- self.assertEqual(instance.pk, 1)
- # rel is still as original (1 entry)
- self.assertEqual(list(instance.rel.all()), [self.anchor])
-
-
-class DefaultValueTests(TestCase):
- def setUp(self):
- class DefaultValueSerializer(serializers.ModelSerializer):
- class Meta:
- model = DefaultValueModel
-
- self.serializer_class = DefaultValueSerializer
- self.objects = DefaultValueModel.objects
-
- def test_create_using_default(self):
- data = {}
- serializer = self.serializer_class(data=data)
- self.assertEqual(serializer.is_valid(), True)
- instance = serializer.save()
- self.assertEqual(len(self.objects.all()), 1)
- self.assertEqual(instance.pk, 1)
- self.assertEqual(instance.text, 'foobar')
-
- def test_create_overriding_default(self):
- data = {'text': 'overridden'}
- serializer = self.serializer_class(data=data)
- self.assertEqual(serializer.is_valid(), True)
- instance = serializer.save()
- self.assertEqual(len(self.objects.all()), 1)
- self.assertEqual(instance.pk, 1)
- self.assertEqual(instance.text, 'overridden')
-
- def test_partial_update_default(self):
- """ Regression test for issue #532 """
- data = {'text': 'overridden'}
- serializer = self.serializer_class(data=data, partial=True)
- self.assertEqual(serializer.is_valid(), True)
- instance = serializer.save()
-
- data = {'extra': 'extra_value'}
- serializer = self.serializer_class(instance=instance, data=data, partial=True)
- self.assertEqual(serializer.is_valid(), True)
- instance = serializer.save()
-
- self.assertEqual(instance.extra, 'extra_value')
- self.assertEqual(instance.text, 'overridden')
-
-
-class CallableDefaultValueTests(TestCase):
- def setUp(self):
- class CallableDefaultValueSerializer(serializers.ModelSerializer):
- class Meta:
- model = CallableDefaultValueModel
-
- self.serializer_class = CallableDefaultValueSerializer
- self.objects = CallableDefaultValueModel.objects
-
- def test_create_using_default(self):
- data = {}
- serializer = self.serializer_class(data=data)
- self.assertEqual(serializer.is_valid(), True)
- instance = serializer.save()
- self.assertEqual(len(self.objects.all()), 1)
- self.assertEqual(instance.pk, 1)
- self.assertEqual(instance.text, 'foobar')
-
- def test_create_overriding_default(self):
- data = {'text': 'overridden'}
- serializer = self.serializer_class(data=data)
- self.assertEqual(serializer.is_valid(), True)
- instance = serializer.save()
- self.assertEqual(len(self.objects.all()), 1)
- self.assertEqual(instance.pk, 1)
- self.assertEqual(instance.text, 'overridden')
-
-
-class ManyRelatedTests(TestCase):
- def test_reverse_relations(self):
- post = BlogPost.objects.create(title="Test blog post")
- post.blogpostcomment_set.create(text="I hate this blog post")
- post.blogpostcomment_set.create(text="I love this blog post")
-
- class BlogPostCommentSerializer(serializers.Serializer):
- text = serializers.CharField()
-
- class BlogPostSerializer(serializers.Serializer):
- title = serializers.CharField()
- comments = BlogPostCommentSerializer(source='blogpostcomment_set')
-
- serializer = BlogPostSerializer(instance=post)
- expected = {
- 'title': 'Test blog post',
- 'comments': [
- {'text': 'I hate this blog post'},
- {'text': 'I love this blog post'}
- ]
- }
-
- self.assertEqual(serializer.data, expected)
-
- def test_include_reverse_relations(self):
- post = BlogPost.objects.create(title="Test blog post")
- post.blogpostcomment_set.create(text="I hate this blog post")
- post.blogpostcomment_set.create(text="I love this blog post")
-
- class BlogPostSerializer(serializers.ModelSerializer):
- class Meta:
- model = BlogPost
- fields = ('id', 'title', 'blogpostcomment_set')
-
- serializer = BlogPostSerializer(instance=post)
- expected = {
- 'id': 1, 'title': 'Test blog post', 'blogpostcomment_set': [1, 2]
- }
- self.assertEqual(serializer.data, expected)
-
- def test_depth_include_reverse_relations(self):
- post = BlogPost.objects.create(title="Test blog post")
- post.blogpostcomment_set.create(text="I hate this blog post")
- post.blogpostcomment_set.create(text="I love this blog post")
-
- class BlogPostSerializer(serializers.ModelSerializer):
- class Meta:
- model = BlogPost
- fields = ('id', 'title', 'blogpostcomment_set')
- depth = 1
-
- serializer = BlogPostSerializer(instance=post)
- expected = {
- 'id': 1, 'title': 'Test blog post',
- 'blogpostcomment_set': [
- {'id': 1, 'text': 'I hate this blog post', 'blog_post': 1},
- {'id': 2, 'text': 'I love this blog post', 'blog_post': 1}
- ]
- }
- self.assertEqual(serializer.data, expected)
-
- def test_callable_source(self):
- post = BlogPost.objects.create(title="Test blog post")
- post.blogpostcomment_set.create(text="I love this blog post")
-
- class BlogPostCommentSerializer(serializers.Serializer):
- text = serializers.CharField()
-
- class BlogPostSerializer(serializers.Serializer):
- title = serializers.CharField()
- first_comment = BlogPostCommentSerializer(source='get_first_comment')
-
- serializer = BlogPostSerializer(post)
-
- expected = {
- 'title': 'Test blog post',
- 'first_comment': {'text': 'I love this blog post'}
- }
- self.assertEqual(serializer.data, expected)
-
-
-class RelatedTraversalTest(TestCase):
- def test_nested_traversal(self):
- """
- Source argument should support dotted.source notation.
- """
- user = Person.objects.create(name="django")
- post = BlogPost.objects.create(title="Test blog post", writer=user)
- post.blogpostcomment_set.create(text="I love this blog post")
-
- class PersonSerializer(serializers.ModelSerializer):
- class Meta:
- model = Person
- fields = ("name", "age")
-
- class BlogPostCommentSerializer(serializers.ModelSerializer):
- class Meta:
- model = BlogPostComment
- fields = ("text", "post_owner")
-
- text = serializers.CharField()
- post_owner = PersonSerializer(source='blog_post.writer')
-
- class BlogPostSerializer(serializers.Serializer):
- title = serializers.CharField()
- comments = BlogPostCommentSerializer(source='blogpostcomment_set')
-
- serializer = BlogPostSerializer(instance=post)
-
- expected = {
- 'title': 'Test blog post',
- 'comments': [{
- 'text': 'I love this blog post',
- 'post_owner': {
- "name": "django",
- "age": None
- }
- }]
- }
-
- self.assertEqual(serializer.data, expected)
-
- def test_nested_traversal_with_none(self):
- """
- If a component of the dotted.source is None, return None for the field.
- """
- from rest_framework.tests.models import NullableForeignKeySource
- instance = NullableForeignKeySource.objects.create(name='Source with null FK')
-
- class NullableSourceSerializer(serializers.Serializer):
- target_name = serializers.Field(source='target.name')
-
- serializer = NullableSourceSerializer(instance=instance)
-
- expected = {
- 'target_name': None,
- }
-
- self.assertEqual(serializer.data, expected)
-
-
-class SerializerMethodFieldTests(TestCase):
- def setUp(self):
-
- class BoopSerializer(serializers.Serializer):
- beep = serializers.SerializerMethodField('get_beep')
- boop = serializers.Field()
- boop_count = serializers.SerializerMethodField('get_boop_count')
-
- def get_beep(self, obj):
- return 'hello!'
-
- def get_boop_count(self, obj):
- return len(obj.boop)
-
- self.serializer_class = BoopSerializer
-
- def test_serializer_method_field(self):
-
- class MyModel(object):
- boop = ['a', 'b', 'c']
-
- source_data = MyModel()
-
- serializer = self.serializer_class(source_data)
-
- expected = {
- 'beep': 'hello!',
- 'boop': ['a', 'b', 'c'],
- 'boop_count': 3,
- }
-
- self.assertEqual(serializer.data, expected)
-
-
-# Test for issue #324
-class BlankFieldTests(TestCase):
- def setUp(self):
-
- class BlankFieldModelSerializer(serializers.ModelSerializer):
- class Meta:
- model = BlankFieldModel
-
- class BlankFieldSerializer(serializers.Serializer):
- title = serializers.CharField(required=False)
-
- class NotBlankFieldModelSerializer(serializers.ModelSerializer):
- class Meta:
- model = BasicModel
-
- class NotBlankFieldSerializer(serializers.Serializer):
- title = serializers.CharField()
-
- self.model_serializer_class = BlankFieldModelSerializer
- self.serializer_class = BlankFieldSerializer
- self.not_blank_model_serializer_class = NotBlankFieldModelSerializer
- self.not_blank_serializer_class = NotBlankFieldSerializer
- self.data = {'title': ''}
-
- def test_create_blank_field(self):
- serializer = self.serializer_class(data=self.data)
- self.assertEqual(serializer.is_valid(), True)
-
- def test_create_model_blank_field(self):
- serializer = self.model_serializer_class(data=self.data)
- self.assertEqual(serializer.is_valid(), True)
-
- def test_create_model_null_field(self):
- serializer = self.model_serializer_class(data={'title': None})
- self.assertEqual(serializer.is_valid(), True)
-
- def test_create_not_blank_field(self):
- """
- Test to ensure blank data in a field not marked as blank=True
- is considered invalid in a non-model serializer
- """
- serializer = self.not_blank_serializer_class(data=self.data)
- self.assertEqual(serializer.is_valid(), False)
-
- def test_create_model_not_blank_field(self):
- """
- Test to ensure blank data in a field not marked as blank=True
- is considered invalid in a model serializer
- """
- serializer = self.not_blank_model_serializer_class(data=self.data)
- self.assertEqual(serializer.is_valid(), False)
-
- def test_create_model_empty_field(self):
- serializer = self.model_serializer_class(data={})
- self.assertEqual(serializer.is_valid(), True)
-
-
-#test for issue #460
-class SerializerPickleTests(TestCase):
- """
- Test pickleability of the output of Serializers
- """
- def test_pickle_simple_model_serializer_data(self):
- """
- Test simple serializer
- """
- pickle.dumps(PersonSerializer(Person(name="Methusela", age=969)).data)
-
- def test_pickle_inner_serializer(self):
- """
- Test pickling a serializer whose resulting .data (a SortedDictWithMetadata) will
- have unpickleable meta data--in order to make sure metadata doesn't get pulled into the pickle.
- See DictWithMetadata.__getstate__
- """
- class InnerPersonSerializer(serializers.ModelSerializer):
- class Meta:
- model = Person
- fields = ('name', 'age')
- pickle.dumps(InnerPersonSerializer(Person(name="Noah", age=950)).data, 0)
-
- def test_getstate_method_should_not_return_none(self):
- """
- Regression test for #645.
- """
- data = serializers.DictWithMetadata({1: 1})
- self.assertEqual(data.__getstate__(), serializers.SortedDict({1: 1}))
-
- def test_serializer_data_is_pickleable(self):
- """
- Another regression test for #645.
- """
- data = serializers.SortedDictWithMetadata({1: 1})
- repr(pickle.loads(pickle.dumps(data, 0)))
-
-
-# test for issue #725
-class SeveralChoicesModel(models.Model):
- color = models.CharField(
- max_length=10,
- choices=[('red', 'Red'), ('green', 'Green'), ('blue', 'Blue')],
- blank=False
- )
- drink = models.CharField(
- max_length=10,
- choices=[('beer', 'Beer'), ('wine', 'Wine'), ('cider', 'Cider')],
- blank=False,
- default='beer'
- )
- os = models.CharField(
- max_length=10,
- choices=[('linux', 'Linux'), ('osx', 'OSX'), ('windows', 'Windows')],
- blank=True
- )
- music_genre = models.CharField(
- max_length=10,
- choices=[('rock', 'Rock'), ('metal', 'Metal'), ('grunge', 'Grunge')],
- blank=True,
- default='metal'
- )
-
-
-class SerializerChoiceFields(TestCase):
-
- def setUp(self):
- super(SerializerChoiceFields, self).setUp()
-
- class SeveralChoicesSerializer(serializers.ModelSerializer):
- class Meta:
- model = SeveralChoicesModel
- fields = ('color', 'drink', 'os', 'music_genre')
-
- self.several_choices_serializer = SeveralChoicesSerializer
-
- def test_choices_blank_false_not_default(self):
- serializer = self.several_choices_serializer()
- self.assertEqual(
- serializer.fields['color'].choices,
- [('red', 'Red'), ('green', 'Green'), ('blue', 'Blue')]
- )
-
- def test_choices_blank_false_with_default(self):
- serializer = self.several_choices_serializer()
- self.assertEqual(
- serializer.fields['drink'].choices,
- [('beer', 'Beer'), ('wine', 'Wine'), ('cider', 'Cider')]
- )
-
- def test_choices_blank_true_not_default(self):
- serializer = self.several_choices_serializer()
- self.assertEqual(
- serializer.fields['os'].choices,
- BLANK_CHOICE_DASH + [('linux', 'Linux'), ('osx', 'OSX'), ('windows', 'Windows')]
- )
-
- def test_choices_blank_true_with_default(self):
- serializer = self.several_choices_serializer()
- self.assertEqual(
- serializer.fields['music_genre'].choices,
- BLANK_CHOICE_DASH + [('rock', 'Rock'), ('metal', 'Metal'), ('grunge', 'Grunge')]
- )
-
-
-# Regression tests for #675
-class Ticket(models.Model):
- assigned = models.ForeignKey(
- Person, related_name='assigned_tickets')
- reviewer = models.ForeignKey(
- Person, blank=True, null=True, related_name='reviewed_tickets')
-
-
-class SerializerRelatedChoicesTest(TestCase):
-
- def setUp(self):
- super(SerializerRelatedChoicesTest, self).setUp()
-
- class RelatedChoicesSerializer(serializers.ModelSerializer):
- class Meta:
- model = Ticket
- fields = ('assigned', 'reviewer')
-
- self.related_fields_serializer = RelatedChoicesSerializer
-
- def test_empty_queryset_required(self):
- serializer = self.related_fields_serializer()
- self.assertEqual(serializer.fields['assigned'].queryset.count(), 0)
- self.assertEqual(
- [x for x in serializer.fields['assigned'].widget.choices],
- []
- )
-
- def test_empty_queryset_not_required(self):
- serializer = self.related_fields_serializer()
- self.assertEqual(serializer.fields['reviewer'].queryset.count(), 0)
- self.assertEqual(
- [x for x in serializer.fields['reviewer'].widget.choices],
- [('', '---------')]
- )
-
- def test_with_some_persons_required(self):
- Person.objects.create(name="Lionel Messi")
- Person.objects.create(name="Xavi Hernandez")
- serializer = self.related_fields_serializer()
- self.assertEqual(serializer.fields['assigned'].queryset.count(), 2)
- self.assertEqual(
- [x for x in serializer.fields['assigned'].widget.choices],
- [(1, 'Person object - 1'), (2, 'Person object - 2')]
- )
-
- def test_with_some_persons_not_required(self):
- Person.objects.create(name="Lionel Messi")
- Person.objects.create(name="Xavi Hernandez")
- serializer = self.related_fields_serializer()
- self.assertEqual(serializer.fields['reviewer'].queryset.count(), 2)
- self.assertEqual(
- [x for x in serializer.fields['reviewer'].widget.choices],
- [('', '---------'), (1, 'Person object - 1'), (2, 'Person object - 2')]
- )
-
-
-class DepthTest(TestCase):
- def test_implicit_nesting(self):
-
- writer = Person.objects.create(name="django", age=1)
- post = BlogPost.objects.create(title="Test blog post", writer=writer)
- comment = BlogPostComment.objects.create(text="Test blog post comment", blog_post=post)
-
- class BlogPostCommentSerializer(serializers.ModelSerializer):
- class Meta:
- model = BlogPostComment
- depth = 2
-
- serializer = BlogPostCommentSerializer(instance=comment)
- expected = {'id': 1, 'text': 'Test blog post comment', 'blog_post': {'id': 1, 'title': 'Test blog post',
- 'writer': {'id': 1, 'name': 'django', 'age': 1}}}
-
- self.assertEqual(serializer.data, expected)
-
- def test_explicit_nesting(self):
- writer = Person.objects.create(name="django", age=1)
- post = BlogPost.objects.create(title="Test blog post", writer=writer)
- comment = BlogPostComment.objects.create(text="Test blog post comment", blog_post=post)
-
- class PersonSerializer(serializers.ModelSerializer):
- class Meta:
- model = Person
-
- class BlogPostSerializer(serializers.ModelSerializer):
- writer = PersonSerializer()
-
- class Meta:
- model = BlogPost
-
- class BlogPostCommentSerializer(serializers.ModelSerializer):
- blog_post = BlogPostSerializer()
-
- class Meta:
- model = BlogPostComment
-
- serializer = BlogPostCommentSerializer(instance=comment)
- expected = {'id': 1, 'text': 'Test blog post comment', 'blog_post': {'id': 1, 'title': 'Test blog post',
- 'writer': {'id': 1, 'name': 'django', 'age': 1}}}
-
- self.assertEqual(serializer.data, expected)
-
-
-class NestedSerializerContextTests(TestCase):
-
- def test_nested_serializer_context(self):
- """
- Regression for #497
-
- https://github.com/tomchristie/django-rest-framework/issues/497
- """
- class PhotoSerializer(serializers.ModelSerializer):
- class Meta:
- model = Photo
- fields = ("description", "callable")
-
- callable = serializers.SerializerMethodField('_callable')
-
- def _callable(self, instance):
- if not 'context_item' in self.context:
- raise RuntimeError("context isn't getting passed into 2nd level nested serializer")
- return "success"
-
- class AlbumSerializer(serializers.ModelSerializer):
- class Meta:
- model = Album
- fields = ("photo_set", "callable")
-
- photo_set = PhotoSerializer(source="photo_set")
- callable = serializers.SerializerMethodField("_callable")
-
- def _callable(self, instance):
- if not 'context_item' in self.context:
- raise RuntimeError("context isn't getting passed into 1st level nested serializer")
- return "success"
-
- class AlbumCollection(object):
- albums = None
-
- class AlbumCollectionSerializer(serializers.Serializer):
- albums = AlbumSerializer(source="albums")
-
- album1 = Album.objects.create(title="album 1")
- album2 = Album.objects.create(title="album 2")
- Photo.objects.create(description="Bigfoot", album=album1)
- Photo.objects.create(description="Unicorn", album=album1)
- Photo.objects.create(description="Yeti", album=album2)
- Photo.objects.create(description="Sasquatch", album=album2)
- album_collection = AlbumCollection()
- album_collection.albums = [album1, album2]
-
- # This will raise RuntimeError if context doesn't get passed correctly to the nested Serializers
- AlbumCollectionSerializer(album_collection, context={'context_item': 'album context'}).data
-
-
-class DeserializeListTestCase(TestCase):
-
- def setUp(self):
- self.data = {
- 'email': 'nobody@nowhere.com',
- 'content': 'This is some test content',
- 'created': datetime.datetime(2013, 3, 7),
- }
-
- def test_no_errors(self):
- data = [self.data.copy() for x in range(0, 3)]
- serializer = CommentSerializer(data=data, many=True)
- self.assertTrue(serializer.is_valid())
- self.assertTrue(isinstance(serializer.object, list))
- self.assertTrue(
- all((isinstance(item, Comment) for item in serializer.object))
- )
-
- def test_errors_return_as_list(self):
- invalid_item = self.data.copy()
- invalid_item['email'] = ''
- data = [self.data.copy(), invalid_item, self.data.copy()]
-
- serializer = CommentSerializer(data=data, many=True)
- self.assertFalse(serializer.is_valid())
- expected = [{}, {'email': ['This field is required.']}, {}]
- self.assertEqual(serializer.errors, expected)
-
-
-# Test for issue 747
-
-class LazyStringModel(object):
- def __init__(self, lazystring):
- self.lazystring = lazystring
-
-
-class LazyStringSerializer(serializers.Serializer):
- lazystring = serializers.Field()
-
- def restore_object(self, attrs, instance=None):
- if instance is not None:
- instance.lazystring = attrs.get('lazystring', instance.lazystring)
- return instance
- return LazyStringModel(**attrs)
-
-
-class LazyStringsTestCase(TestCase):
- def setUp(self):
- self.model = LazyStringModel(lazystring=_('lazystring'))
-
- def test_lazy_strings_are_translated(self):
- serializer = LazyStringSerializer(self.model)
- self.assertEqual(type(serializer.data['lazystring']),
- type('lazystring'))
-
-
-# Test for issue #467
-
-class FieldLabelTest(TestCase):
- def setUp(self):
- self.serializer_class = BasicModelSerializer
-
- def test_label_from_model(self):
- """
- Validates that label and help_text are correctly copied from the model class.
- """
- serializer = self.serializer_class()
- text_field = serializer.fields['text']
-
- self.assertEqual('Text comes here', text_field.label)
- self.assertEqual('Text description.', text_field.help_text)
-
- def test_field_ctor(self):
- """
- This is check that ctor supports both label and help_text.
- """
- self.assertEqual('Label', fields.Field(label='Label', help_text='Help').label)
- self.assertEqual('Help', fields.CharField(label='Label', help_text='Help').help_text)
- self.assertEqual('Label', relations.HyperlinkedRelatedField(view_name='fake', label='Label', help_text='Help', many=True).label)
-
-
-# Test for issue #961
-
-class ManyFieldHelpTextTest(TestCase):
- def test_help_text_no_hold_down_control_msg(self):
- """
- Validate that help_text doesn't contain the 'Hold down "Control" ...'
- message that Django appends to choice fields.
- """
- rel_field = fields.Field(help_text=ManyToManyModel._meta.get_field('rel').help_text)
- self.assertEqual('Some help text.', rel_field.help_text)
-
-
-class AttributeMappingOnAutogeneratedFieldsTests(TestCase):
-
- def setUp(self):
- class AMOAFModel(RESTFrameworkModel):
- char_field = models.CharField(max_length=1024, blank=True)
- comma_separated_integer_field = models.CommaSeparatedIntegerField(max_length=1024, blank=True)
- decimal_field = models.DecimalField(max_digits=64, decimal_places=32, blank=True)
- email_field = models.EmailField(max_length=1024, blank=True)
- file_field = models.FileField(max_length=1024, blank=True)
- image_field = models.ImageField(max_length=1024, blank=True)
- slug_field = models.SlugField(max_length=1024, blank=True)
- url_field = models.URLField(max_length=1024, blank=True)
-
- class AMOAFSerializer(serializers.ModelSerializer):
- class Meta:
- model = AMOAFModel
-
- self.serializer_class = AMOAFSerializer
- self.fields_attributes = {
- 'char_field': [
- ('max_length', 1024),
- ],
- 'comma_separated_integer_field': [
- ('max_length', 1024),
- ],
- 'decimal_field': [
- ('max_digits', 64),
- ('decimal_places', 32),
- ],
- 'email_field': [
- ('max_length', 1024),
- ],
- 'file_field': [
- ('max_length', 1024),
- ],
- 'image_field': [
- ('max_length', 1024),
- ],
- 'slug_field': [
- ('max_length', 1024),
- ],
- 'url_field': [
- ('max_length', 1024),
- ],
- }
-
- def field_test(self, field):
- serializer = self.serializer_class(data={})
- self.assertEqual(serializer.is_valid(), True)
-
- for attribute in self.fields_attributes[field]:
- self.assertEqual(
- getattr(serializer.fields[field], attribute[0]),
- attribute[1]
- )
-
- def test_char_field(self):
- self.field_test('char_field')
-
- def test_comma_separated_integer_field(self):
- self.field_test('comma_separated_integer_field')
-
- def test_decimal_field(self):
- self.field_test('decimal_field')
-
- def test_email_field(self):
- self.field_test('email_field')
-
- def test_file_field(self):
- self.field_test('file_field')
-
- def test_image_field(self):
- self.field_test('image_field')
-
- def test_slug_field(self):
- self.field_test('slug_field')
-
- def test_url_field(self):
- self.field_test('url_field')
-
-
-class DefaultValuesOnAutogeneratedFieldsTests(TestCase):
-
- def setUp(self):
- class DVOAFModel(RESTFrameworkModel):
- positive_integer_field = models.PositiveIntegerField(blank=True)
- positive_small_integer_field = models.PositiveSmallIntegerField(blank=True)
- email_field = models.EmailField(blank=True)
- file_field = models.FileField(blank=True)
- image_field = models.ImageField(blank=True)
- slug_field = models.SlugField(blank=True)
- url_field = models.URLField(blank=True)
-
- class DVOAFSerializer(serializers.ModelSerializer):
- class Meta:
- model = DVOAFModel
-
- self.serializer_class = DVOAFSerializer
- self.fields_attributes = {
- 'positive_integer_field': [
- ('min_value', 0),
- ],
- 'positive_small_integer_field': [
- ('min_value', 0),
- ],
- 'email_field': [
- ('max_length', 75),
- ],
- 'file_field': [
- ('max_length', 100),
- ],
- 'image_field': [
- ('max_length', 100),
- ],
- 'slug_field': [
- ('max_length', 50),
- ],
- 'url_field': [
- ('max_length', 200),
- ],
- }
-
- def field_test(self, field):
- serializer = self.serializer_class(data={})
- self.assertEqual(serializer.is_valid(), True)
-
- for attribute in self.fields_attributes[field]:
- self.assertEqual(
- getattr(serializer.fields[field], attribute[0]),
- attribute[1]
- )
-
- def test_positive_integer_field(self):
- self.field_test('positive_integer_field')
-
- def test_positive_small_integer_field(self):
- self.field_test('positive_small_integer_field')
-
- def test_email_field(self):
- self.field_test('email_field')
-
- def test_file_field(self):
- self.field_test('file_field')
-
- def test_image_field(self):
- self.field_test('image_field')
-
- def test_slug_field(self):
- self.field_test('slug_field')
-
- def test_url_field(self):
- self.field_test('url_field')
-
-
-class MetadataSerializer(serializers.Serializer):
- field1 = serializers.CharField(3, required=True)
- field2 = serializers.CharField(10, required=False)
-
-
-class MetadataSerializerTestCase(TestCase):
- def setUp(self):
- self.serializer = MetadataSerializer()
-
- def test_serializer_metadata(self):
- metadata = self.serializer.metadata()
- expected = {
- 'field1': {
- 'required': True,
- 'max_length': 3,
- 'type': 'string',
- 'read_only': False
- },
- 'field2': {
- 'required': False,
- 'max_length': 10,
- 'type': 'string',
- 'read_only': False
- }
- }
- self.assertEqual(expected, metadata)
-
-
-### Regression test for #840
-
-class SimpleModel(models.Model):
- text = models.CharField(max_length=100)
-
-
-class SimpleModelSerializer(serializers.ModelSerializer):
- text = serializers.CharField()
- other = serializers.CharField()
-
- class Meta:
- model = SimpleModel
-
- def validate_other(self, attrs, source):
- del attrs['other']
- return attrs
-
-
-class FieldValidationRemovingAttr(TestCase):
- def test_removing_non_model_field_in_validation(self):
- """
- Removing an attr during field valiation should ensure that it is not
- passed through when restoring the object.
-
- This allows additional non-model fields to be supported.
-
- Regression test for #840.
- """
- serializer = SimpleModelSerializer(data={'text': 'foo', 'other': 'bar'})
- self.assertTrue(serializer.is_valid())
- serializer.save()
- self.assertEqual(serializer.object.text, 'foo')
-
-
-### Regression test for #878
-
-class SimpleTargetModel(models.Model):
- text = models.CharField(max_length=100)
-
-
-class SimplePKSourceModelSerializer(serializers.Serializer):
- targets = serializers.PrimaryKeyRelatedField(queryset=SimpleTargetModel.objects.all(), many=True)
- text = serializers.CharField()
-
-
-class SimpleSlugSourceModelSerializer(serializers.Serializer):
- targets = serializers.SlugRelatedField(queryset=SimpleTargetModel.objects.all(), many=True, slug_field='pk')
- text = serializers.CharField()
-
-
-class SerializerSupportsManyRelationships(TestCase):
- def setUp(self):
- SimpleTargetModel.objects.create(text='foo')
- SimpleTargetModel.objects.create(text='bar')
-
- def test_serializer_supports_pk_many_relationships(self):
- """
- Regression test for #878.
-
- Note that pk behavior has a different code path to usual cases,
- for performance reasons.
- """
- serializer = SimplePKSourceModelSerializer(data={'text': 'foo', 'targets': [1, 2]})
- self.assertTrue(serializer.is_valid())
- self.assertEqual(serializer.data, {'text': 'foo', 'targets': [1, 2]})
-
- def test_serializer_supports_slug_many_relationships(self):
- """
- Regression test for #878.
- """
- serializer = SimpleSlugSourceModelSerializer(data={'text': 'foo', 'targets': [1, 2]})
- self.assertTrue(serializer.is_valid())
- self.assertEqual(serializer.data, {'text': 'foo', 'targets': [1, 2]})
-
-
-class TransformMethodsSerializer(serializers.Serializer):
- a = serializers.CharField()
- b_renamed = serializers.CharField(source='b')
-
- def transform_a(self, obj, value):
- return value.lower()
-
- def transform_b_renamed(self, obj, value):
- if value is not None:
- return 'and ' + value
-
-
-class TestSerializerTransformMethods(TestCase):
- def setUp(self):
- self.s = TransformMethodsSerializer()
-
- def test_transform_methods(self):
- self.assertEqual(
- self.s.to_native({'a': 'GREEN EGGS', 'b': 'HAM'}),
- {
- 'a': 'green eggs',
- 'b_renamed': 'and HAM',
- }
- )
-
- def test_missing_fields(self):
- self.assertEqual(
- self.s.to_native({'a': 'GREEN EGGS'}),
- {
- 'a': 'green eggs',
- 'b_renamed': None,
- }
- )
-
-
-class DefaultTrueBooleanModel(models.Model):
- cat = models.BooleanField(default=True)
- dog = models.BooleanField(default=False)
-
-
-class SerializerDefaultTrueBoolean(TestCase):
-
- def setUp(self):
- super(SerializerDefaultTrueBoolean, self).setUp()
-
- class DefaultTrueBooleanSerializer(serializers.ModelSerializer):
- class Meta:
- model = DefaultTrueBooleanModel
- fields = ('cat', 'dog')
-
- self.default_true_boolean_serializer = DefaultTrueBooleanSerializer
-
- def test_enabled_as_false(self):
- serializer = self.default_true_boolean_serializer(data={'cat': False,
- 'dog': False})
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.data['cat'], False)
- self.assertEqual(serializer.data['dog'], False)
-
- def test_enabled_as_true(self):
- serializer = self.default_true_boolean_serializer(data={'cat': True,
- 'dog': True})
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.data['cat'], True)
- self.assertEqual(serializer.data['dog'], True)
-
- def test_enabled_partial(self):
- serializer = self.default_true_boolean_serializer(data={'cat': False},
- partial=True)
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.data['cat'], False)
- self.assertEqual(serializer.data['dog'], False)
-
-
-class BoolenFieldTypeTest(TestCase):
- '''
- Ensure the various Boolean based model fields are rendered as the proper
- field type
-
- '''
-
- def setUp(self):
- '''
- Setup an ActionItemSerializer for BooleanTesting
- '''
- data = {
- 'title': 'b' * 201,
- }
- self.serializer = ActionItemSerializer(data=data)
-
- def test_booleanfield_type(self):
- '''
- Test that BooleanField is infered from models.BooleanField
- '''
- bfield = self.serializer.get_fields()['done']
- self.assertEqual(type(bfield), fields.BooleanField)
-
- def test_nullbooleanfield_type(self):
- '''
- Test that BooleanField is infered from models.NullBooleanField
-
- https://groups.google.com/forum/#!topic/django-rest-framework/D9mXEftpuQ8
- '''
- bfield = self.serializer.get_fields()['started']
- self.assertEqual(type(bfield), fields.BooleanField)
diff --git a/rest_framework/tests/test_serializer_bulk_update.py b/rest_framework/tests/test_serializer_bulk_update.py
deleted file mode 100644
index 8b0ded1a..00000000
--- a/rest_framework/tests/test_serializer_bulk_update.py
+++ /dev/null
@@ -1,278 +0,0 @@
-"""
-Tests to cover bulk create and update using serializers.
-"""
-from __future__ import unicode_literals
-from django.test import TestCase
-from rest_framework import serializers
-
-
-class BulkCreateSerializerTests(TestCase):
- """
- Creating multiple instances using serializers.
- """
-
- def setUp(self):
- class BookSerializer(serializers.Serializer):
- id = serializers.IntegerField()
- title = serializers.CharField(max_length=100)
- author = serializers.CharField(max_length=100)
-
- self.BookSerializer = BookSerializer
-
- def test_bulk_create_success(self):
- """
- Correct bulk update serialization should return the input data.
- """
-
- data = [
- {
- 'id': 0,
- 'title': 'The electric kool-aid acid test',
- 'author': 'Tom Wolfe'
- }, {
- 'id': 1,
- 'title': 'If this is a man',
- 'author': 'Primo Levi'
- }, {
- 'id': 2,
- 'title': 'The wind-up bird chronicle',
- 'author': 'Haruki Murakami'
- }
- ]
-
- serializer = self.BookSerializer(data=data, many=True)
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.object, data)
-
- def test_bulk_create_errors(self):
- """
- Correct bulk update serialization should return the input data.
- """
-
- data = [
- {
- 'id': 0,
- 'title': 'The electric kool-aid acid test',
- 'author': 'Tom Wolfe'
- }, {
- 'id': 1,
- 'title': 'If this is a man',
- 'author': 'Primo Levi'
- }, {
- 'id': 'foo',
- 'title': 'The wind-up bird chronicle',
- 'author': 'Haruki Murakami'
- }
- ]
- expected_errors = [
- {},
- {},
- {'id': ['Enter a whole number.']}
- ]
-
- serializer = self.BookSerializer(data=data, many=True)
- self.assertEqual(serializer.is_valid(), False)
- self.assertEqual(serializer.errors, expected_errors)
-
- def test_invalid_list_datatype(self):
- """
- Data containing list of incorrect data type should return errors.
- """
- data = ['foo', 'bar', 'baz']
- serializer = self.BookSerializer(data=data, many=True)
- self.assertEqual(serializer.is_valid(), False)
-
- expected_errors = [
- {'non_field_errors': ['Invalid data']},
- {'non_field_errors': ['Invalid data']},
- {'non_field_errors': ['Invalid data']}
- ]
-
- self.assertEqual(serializer.errors, expected_errors)
-
- def test_invalid_single_datatype(self):
- """
- Data containing a single incorrect data type should return errors.
- """
- data = 123
- serializer = self.BookSerializer(data=data, many=True)
- self.assertEqual(serializer.is_valid(), False)
-
- expected_errors = {'non_field_errors': ['Expected a list of items.']}
-
- self.assertEqual(serializer.errors, expected_errors)
-
- def test_invalid_single_object(self):
- """
- Data containing only a single object, instead of a list of objects
- should return errors.
- """
- data = {
- 'id': 0,
- 'title': 'The electric kool-aid acid test',
- 'author': 'Tom Wolfe'
- }
- serializer = self.BookSerializer(data=data, many=True)
- self.assertEqual(serializer.is_valid(), False)
-
- expected_errors = {'non_field_errors': ['Expected a list of items.']}
-
- self.assertEqual(serializer.errors, expected_errors)
-
-
-class BulkUpdateSerializerTests(TestCase):
- """
- Updating multiple instances using serializers.
- """
-
- def setUp(self):
- class Book(object):
- """
- A data type that can be persisted to a mock storage backend
- with `.save()` and `.delete()`.
- """
- object_map = {}
-
- def __init__(self, id, title, author):
- self.id = id
- self.title = title
- self.author = author
-
- def save(self):
- Book.object_map[self.id] = self
-
- def delete(self):
- del Book.object_map[self.id]
-
- class BookSerializer(serializers.Serializer):
- id = serializers.IntegerField()
- title = serializers.CharField(max_length=100)
- author = serializers.CharField(max_length=100)
-
- def restore_object(self, attrs, instance=None):
- if instance:
- instance.id = attrs['id']
- instance.title = attrs['title']
- instance.author = attrs['author']
- return instance
- return Book(**attrs)
-
- self.Book = Book
- self.BookSerializer = BookSerializer
-
- data = [
- {
- 'id': 0,
- 'title': 'The electric kool-aid acid test',
- 'author': 'Tom Wolfe'
- }, {
- 'id': 1,
- 'title': 'If this is a man',
- 'author': 'Primo Levi'
- }, {
- 'id': 2,
- 'title': 'The wind-up bird chronicle',
- 'author': 'Haruki Murakami'
- }
- ]
-
- for item in data:
- book = Book(item['id'], item['title'], item['author'])
- book.save()
-
- def books(self):
- """
- Return all the objects in the mock storage backend.
- """
- return self.Book.object_map.values()
-
- def test_bulk_update_success(self):
- """
- Correct bulk update serialization should return the input data.
- """
- data = [
- {
- 'id': 0,
- 'title': 'The electric kool-aid acid test',
- 'author': 'Tom Wolfe'
- }, {
- 'id': 2,
- 'title': 'Kafka on the shore',
- 'author': 'Haruki Murakami'
- }
- ]
- serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.data, data)
- serializer.save()
- new_data = self.BookSerializer(self.books(), many=True).data
-
- self.assertEqual(data, new_data)
-
- def test_bulk_update_and_create(self):
- """
- Bulk update serialization may also include created items.
- """
- data = [
- {
- 'id': 0,
- 'title': 'The electric kool-aid acid test',
- 'author': 'Tom Wolfe'
- }, {
- 'id': 3,
- 'title': 'Kafka on the shore',
- 'author': 'Haruki Murakami'
- }
- ]
- serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.data, data)
- serializer.save()
- new_data = self.BookSerializer(self.books(), many=True).data
- self.assertEqual(data, new_data)
-
- def test_bulk_update_invalid_create(self):
- """
- Bulk update serialization without allow_add_remove may not create items.
- """
- data = [
- {
- 'id': 0,
- 'title': 'The electric kool-aid acid test',
- 'author': 'Tom Wolfe'
- }, {
- 'id': 3,
- 'title': 'Kafka on the shore',
- 'author': 'Haruki Murakami'
- }
- ]
- expected_errors = [
- {},
- {'non_field_errors': ['Cannot create a new item, only existing items may be updated.']}
- ]
- serializer = self.BookSerializer(self.books(), data=data, many=True)
- self.assertEqual(serializer.is_valid(), False)
- self.assertEqual(serializer.errors, expected_errors)
-
- def test_bulk_update_error(self):
- """
- Incorrect bulk update serialization should return error data.
- """
- data = [
- {
- 'id': 0,
- 'title': 'The electric kool-aid acid test',
- 'author': 'Tom Wolfe'
- }, {
- 'id': 'foo',
- 'title': 'Kafka on the shore',
- 'author': 'Haruki Murakami'
- }
- ]
- expected_errors = [
- {},
- {'id': ['Enter a whole number.']}
- ]
- serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
- self.assertEqual(serializer.is_valid(), False)
- self.assertEqual(serializer.errors, expected_errors)
diff --git a/rest_framework/tests/test_serializer_empty.py b/rest_framework/tests/test_serializer_empty.py
deleted file mode 100644
index 30cff361..00000000
--- a/rest_framework/tests/test_serializer_empty.py
+++ /dev/null
@@ -1,15 +0,0 @@
-from django.test import TestCase
-from rest_framework import serializers
-
-
-class EmptySerializerTestCase(TestCase):
- def test_empty_serializer(self):
- class FooBarSerializer(serializers.Serializer):
- foo = serializers.IntegerField()
- bar = serializers.SerializerMethodField('get_bar')
-
- def get_bar(self, obj):
- return 'bar'
-
- serializer = FooBarSerializer()
- self.assertEquals(serializer.data, {'foo': 0})
diff --git a/rest_framework/tests/test_serializer_import.py b/rest_framework/tests/test_serializer_import.py
deleted file mode 100644
index 9f30a7ff..00000000
--- a/rest_framework/tests/test_serializer_import.py
+++ /dev/null
@@ -1,19 +0,0 @@
-from django.test import TestCase
-
-from rest_framework import serializers
-from rest_framework.tests.accounts.serializers import AccountSerializer
-
-
-class ImportingModelSerializerTests(TestCase):
- """
- In some situations like, GH #1225, it is possible, especially in
- testing, to import a serializer who's related models have not yet
- been resolved by Django. `AccountSerializer` is an example of such
- a serializer (imported at the top of this file).
- """
- def test_import_model_serializer(self):
- """
- The serializer at the top of this file should have been
- imported successfully, and we should be able to instantiate it.
- """
- self.assertIsInstance(AccountSerializer(), serializers.ModelSerializer)
diff --git a/rest_framework/tests/test_serializer_nested.py b/rest_framework/tests/test_serializer_nested.py
deleted file mode 100644
index 6d69ffbd..00000000
--- a/rest_framework/tests/test_serializer_nested.py
+++ /dev/null
@@ -1,347 +0,0 @@
-"""
-Tests to cover nested serializers.
-
-Doesn't cover model serializers.
-"""
-from __future__ import unicode_literals
-from django.test import TestCase
-from rest_framework import serializers
-from . import models
-
-
-class WritableNestedSerializerBasicTests(TestCase):
- """
- Tests for deserializing nested entities.
- Basic tests that use serializers that simply restore to dicts.
- """
-
- def setUp(self):
- class TrackSerializer(serializers.Serializer):
- order = serializers.IntegerField()
- title = serializers.CharField(max_length=100)
- duration = serializers.IntegerField()
-
- class AlbumSerializer(serializers.Serializer):
- album_name = serializers.CharField(max_length=100)
- artist = serializers.CharField(max_length=100)
- tracks = TrackSerializer(many=True)
-
- self.AlbumSerializer = AlbumSerializer
-
- def test_nested_validation_success(self):
- """
- Correct nested serialization should return the input data.
- """
-
- data = {
- 'album_name': 'Discovery',
- 'artist': 'Daft Punk',
- 'tracks': [
- {'order': 1, 'title': 'One More Time', 'duration': 235},
- {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
- {'order': 3, 'title': 'Digital Love', 'duration': 239}
- ]
- }
-
- serializer = self.AlbumSerializer(data=data)
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.object, data)
-
- def test_nested_validation_error(self):
- """
- Incorrect nested serialization should return appropriate error data.
- """
-
- data = {
- 'album_name': 'Discovery',
- 'artist': 'Daft Punk',
- 'tracks': [
- {'order': 1, 'title': 'One More Time', 'duration': 235},
- {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
- {'order': 3, 'title': 'Digital Love', 'duration': 'foobar'}
- ]
- }
- expected_errors = {
- 'tracks': [
- {},
- {},
- {'duration': ['Enter a whole number.']}
- ]
- }
-
- serializer = self.AlbumSerializer(data=data)
- self.assertEqual(serializer.is_valid(), False)
- self.assertEqual(serializer.errors, expected_errors)
-
- def test_many_nested_validation_error(self):
- """
- Incorrect nested serialization should return appropriate error data
- when multiple entities are being deserialized.
- """
-
- data = [
- {
- 'album_name': 'Russian Red',
- 'artist': 'I Love Your Glasses',
- 'tracks': [
- {'order': 1, 'title': 'Cigarettes', 'duration': 121},
- {'order': 2, 'title': 'No Past Land', 'duration': 198},
- {'order': 3, 'title': 'They Don\'t Believe', 'duration': 191}
- ]
- },
- {
- 'album_name': 'Discovery',
- 'artist': 'Daft Punk',
- 'tracks': [
- {'order': 1, 'title': 'One More Time', 'duration': 235},
- {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
- {'order': 3, 'title': 'Digital Love', 'duration': 'foobar'}
- ]
- }
- ]
- expected_errors = [
- {},
- {
- 'tracks': [
- {},
- {},
- {'duration': ['Enter a whole number.']}
- ]
- }
- ]
-
- serializer = self.AlbumSerializer(data=data, many=True)
- self.assertEqual(serializer.is_valid(), False)
- self.assertEqual(serializer.errors, expected_errors)
-
-
-class WritableNestedSerializerObjectTests(TestCase):
- """
- Tests for deserializing nested entities.
- These tests use serializers that restore to concrete objects.
- """
-
- def setUp(self):
- # Couple of concrete objects that we're going to deserialize into
- class Track(object):
- def __init__(self, order, title, duration):
- self.order, self.title, self.duration = order, title, duration
-
- def __eq__(self, other):
- return (
- self.order == other.order and
- self.title == other.title and
- self.duration == other.duration
- )
-
- class Album(object):
- def __init__(self, album_name, artist, tracks):
- self.album_name, self.artist, self.tracks = album_name, artist, tracks
-
- def __eq__(self, other):
- return (
- self.album_name == other.album_name and
- self.artist == other.artist and
- self.tracks == other.tracks
- )
-
- # And their corresponding serializers
- class TrackSerializer(serializers.Serializer):
- order = serializers.IntegerField()
- title = serializers.CharField(max_length=100)
- duration = serializers.IntegerField()
-
- def restore_object(self, attrs, instance=None):
- return Track(attrs['order'], attrs['title'], attrs['duration'])
-
- class AlbumSerializer(serializers.Serializer):
- album_name = serializers.CharField(max_length=100)
- artist = serializers.CharField(max_length=100)
- tracks = TrackSerializer(many=True)
-
- def restore_object(self, attrs, instance=None):
- return Album(attrs['album_name'], attrs['artist'], attrs['tracks'])
-
- self.Album, self.Track = Album, Track
- self.AlbumSerializer = AlbumSerializer
-
- def test_nested_validation_success(self):
- """
- Correct nested serialization should return a restored object
- that corresponds to the input data.
- """
-
- data = {
- 'album_name': 'Discovery',
- 'artist': 'Daft Punk',
- 'tracks': [
- {'order': 1, 'title': 'One More Time', 'duration': 235},
- {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
- {'order': 3, 'title': 'Digital Love', 'duration': 239}
- ]
- }
- expected_object = self.Album(
- album_name='Discovery',
- artist='Daft Punk',
- tracks=[
- self.Track(order=1, title='One More Time', duration=235),
- self.Track(order=2, title='Aerodynamic', duration=184),
- self.Track(order=3, title='Digital Love', duration=239),
- ]
- )
-
- serializer = self.AlbumSerializer(data=data)
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.object, expected_object)
-
- def test_many_nested_validation_success(self):
- """
- Correct nested serialization should return multiple restored objects
- that corresponds to the input data when multiple objects are
- being deserialized.
- """
-
- data = [
- {
- 'album_name': 'Russian Red',
- 'artist': 'I Love Your Glasses',
- 'tracks': [
- {'order': 1, 'title': 'Cigarettes', 'duration': 121},
- {'order': 2, 'title': 'No Past Land', 'duration': 198},
- {'order': 3, 'title': 'They Don\'t Believe', 'duration': 191}
- ]
- },
- {
- 'album_name': 'Discovery',
- 'artist': 'Daft Punk',
- 'tracks': [
- {'order': 1, 'title': 'One More Time', 'duration': 235},
- {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
- {'order': 3, 'title': 'Digital Love', 'duration': 239}
- ]
- }
- ]
- expected_object = [
- self.Album(
- album_name='Russian Red',
- artist='I Love Your Glasses',
- tracks=[
- self.Track(order=1, title='Cigarettes', duration=121),
- self.Track(order=2, title='No Past Land', duration=198),
- self.Track(order=3, title='They Don\'t Believe', duration=191),
- ]
- ),
- self.Album(
- album_name='Discovery',
- artist='Daft Punk',
- tracks=[
- self.Track(order=1, title='One More Time', duration=235),
- self.Track(order=2, title='Aerodynamic', duration=184),
- self.Track(order=3, title='Digital Love', duration=239),
- ]
- )
- ]
-
- serializer = self.AlbumSerializer(data=data, many=True)
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.object, expected_object)
-
-
-class ForeignKeyNestedSerializerUpdateTests(TestCase):
- def setUp(self):
- class Artist(object):
- def __init__(self, name):
- self.name = name
-
- def __eq__(self, other):
- return self.name == other.name
-
- class Album(object):
- def __init__(self, name, artist):
- self.name, self.artist = name, artist
-
- def __eq__(self, other):
- return self.name == other.name and self.artist == other.artist
-
- class ArtistSerializer(serializers.Serializer):
- name = serializers.CharField()
-
- def restore_object(self, attrs, instance=None):
- if instance:
- instance.name = attrs['name']
- else:
- instance = Artist(attrs['name'])
- return instance
-
- class AlbumSerializer(serializers.Serializer):
- name = serializers.CharField()
- by = ArtistSerializer(source='artist')
-
- def restore_object(self, attrs, instance=None):
- if instance:
- instance.name = attrs['name']
- instance.artist = attrs['artist']
- else:
- instance = Album(attrs['name'], attrs['artist'])
- return instance
-
- self.Artist = Artist
- self.Album = Album
- self.AlbumSerializer = AlbumSerializer
-
- def test_create_via_foreign_key_with_source(self):
- """
- Check that we can both *create* and *update* into objects across
- ForeignKeys that have a `source` specified.
- Regression test for #1170
- """
- data = {
- 'name': 'Discovery',
- 'by': {'name': 'Daft Punk'},
- }
-
- expected = self.Album(artist=self.Artist('Daft Punk'), name='Discovery')
-
- # create
- serializer = self.AlbumSerializer(data=data)
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.object, expected)
-
- # update
- original = self.Album(artist=self.Artist('The Bats'), name='Free All the Monsters')
- serializer = self.AlbumSerializer(instance=original, data=data)
- self.assertEqual(serializer.is_valid(), True)
- self.assertEqual(serializer.object, expected)
-
-
-class NestedModelSerializerUpdateTests(TestCase):
- def test_second_nested_level(self):
- john = models.Person.objects.create(name="john")
-
- post = john.blogpost_set.create(title="Test blog post")
- post.blogpostcomment_set.create(text="I hate this blog post")
- post.blogpostcomment_set.create(text="I love this blog post")
-
- class BlogPostCommentSerializer(serializers.ModelSerializer):
- class Meta:
- model = models.BlogPostComment
-
- class BlogPostSerializer(serializers.ModelSerializer):
- comments = BlogPostCommentSerializer(many=True, source='blogpostcomment_set')
- class Meta:
- model = models.BlogPost
- fields = ('id', 'title', 'comments')
-
- class PersonSerializer(serializers.ModelSerializer):
- posts = BlogPostSerializer(many=True, source='blogpost_set')
- class Meta:
- model = models.Person
- fields = ('id', 'name', 'age', 'posts')
-
- serialize = PersonSerializer(instance=john)
- deserialize = PersonSerializer(data=serialize.data, instance=john)
- self.assertTrue(deserialize.is_valid())
-
- result = deserialize.object
- result.save()
- self.assertEqual(result.id, john.id)
diff --git a/rest_framework/tests/test_serializers.py b/rest_framework/tests/test_serializers.py
deleted file mode 100644
index 082a400c..00000000
--- a/rest_framework/tests/test_serializers.py
+++ /dev/null
@@ -1,28 +0,0 @@
-from django.db import models
-from django.test import TestCase
-
-from rest_framework.serializers import _resolve_model
-from rest_framework.tests.models import BasicModel
-
-
-class ResolveModelTests(TestCase):
- """
- `_resolve_model` should return a Django model class given the
- provided argument is a Django model class itself, or a properly
- formatted string representation of one.
- """
- def test_resolve_django_model(self):
- resolved_model = _resolve_model(BasicModel)
- self.assertEqual(resolved_model, BasicModel)
-
- def test_resolve_string_representation(self):
- resolved_model = _resolve_model('tests.BasicModel')
- self.assertEqual(resolved_model, BasicModel)
-
- def test_resolve_non_django_model(self):
- with self.assertRaises(ValueError):
- _resolve_model(TestCase)
-
- def test_resolve_improper_string_representation(self):
- with self.assertRaises(ValueError):
- _resolve_model('BasicModel')
diff --git a/rest_framework/tests/test_settings.py b/rest_framework/tests/test_settings.py
deleted file mode 100644
index 857375c2..00000000
--- a/rest_framework/tests/test_settings.py
+++ /dev/null
@@ -1,22 +0,0 @@
-"""Tests for the settings module"""
-from __future__ import unicode_literals
-from django.test import TestCase
-
-from rest_framework.settings import APISettings, DEFAULTS, IMPORT_STRINGS
-
-
-class TestSettings(TestCase):
- """Tests relating to the api settings"""
-
- def test_non_import_errors(self):
- """Make sure other errors aren't suppressed."""
- settings = APISettings({'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.tests.extras.bad_import.ModelSerializer'}, DEFAULTS, IMPORT_STRINGS)
- with self.assertRaises(ValueError):
- settings.DEFAULT_MODEL_SERIALIZER_CLASS
-
- def test_import_error_message_maintained(self):
- """Make sure real import errors are captured and raised sensibly."""
- settings = APISettings({'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.tests.extras.not_here.ModelSerializer'}, DEFAULTS, IMPORT_STRINGS)
- with self.assertRaises(ImportError) as cm:
- settings.DEFAULT_MODEL_SERIALIZER_CLASS
- self.assertTrue('ImportError' in str(cm.exception))
diff --git a/rest_framework/tests/test_status.py b/rest_framework/tests/test_status.py
deleted file mode 100644
index 7b1bdae3..00000000
--- a/rest_framework/tests/test_status.py
+++ /dev/null
@@ -1,33 +0,0 @@
-from __future__ import unicode_literals
-from django.test import TestCase
-from rest_framework.status import (
- is_informational, is_success, is_redirect, is_client_error, is_server_error
-)
-
-
-class TestStatus(TestCase):
- def test_status_categories(self):
- self.assertFalse(is_informational(99))
- self.assertTrue(is_informational(100))
- self.assertTrue(is_informational(199))
- self.assertFalse(is_informational(200))
-
- self.assertFalse(is_success(199))
- self.assertTrue(is_success(200))
- self.assertTrue(is_success(299))
- self.assertFalse(is_success(300))
-
- self.assertFalse(is_redirect(299))
- self.assertTrue(is_redirect(300))
- self.assertTrue(is_redirect(399))
- self.assertFalse(is_redirect(400))
-
- self.assertFalse(is_client_error(399))
- self.assertTrue(is_client_error(400))
- self.assertTrue(is_client_error(499))
- self.assertFalse(is_client_error(500))
-
- self.assertFalse(is_server_error(499))
- self.assertTrue(is_server_error(500))
- self.assertTrue(is_server_error(599))
- self.assertFalse(is_server_error(600))
\ No newline at end of file
diff --git a/rest_framework/tests/test_templatetags.py b/rest_framework/tests/test_templatetags.py
deleted file mode 100644
index d4da0c23..00000000
--- a/rest_framework/tests/test_templatetags.py
+++ /dev/null
@@ -1,51 +0,0 @@
-# encoding: utf-8
-from __future__ import unicode_literals
-from django.test import TestCase
-from rest_framework.test import APIRequestFactory
-from rest_framework.templatetags.rest_framework import add_query_param, urlize_quoted_links
-
-factory = APIRequestFactory()
-
-
-class TemplateTagTests(TestCase):
-
- def test_add_query_param_with_non_latin_charactor(self):
- # Ensure we don't double-escape non-latin characters
- # that are present in the querystring.
- # See #1314.
- request = factory.get("/", {'q': '查询'})
- json_url = add_query_param(request, "format", "json")
- self.assertIn("q=%E6%9F%A5%E8%AF%A2", json_url)
- self.assertIn("format=json", json_url)
-
-
-class Issue1386Tests(TestCase):
- """
- Covers #1386
- """
-
- def test_issue_1386(self):
- """
- Test function urlize_quoted_links with different args
- """
- correct_urls = [
- "asdf.com",
- "asdf.net",
- "www.as_df.org",
- "as.d8f.ghj8.gov",
- ]
- for i in correct_urls:
- res = urlize_quoted_links(i)
- self.assertNotEqual(res, i)
- self.assertIn(i, res)
-
- incorrect_urls = [
- "mailto://asdf@fdf.com",
- "asdf.netnet",
- ]
- for i in incorrect_urls:
- res = urlize_quoted_links(i)
- self.assertEqual(i, res)
-
- # example from issue #1386, this shouldn't raise an exception
- _ = urlize_quoted_links("asdf:[/p]zxcv.com")
diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py
deleted file mode 100644
index 71bd8b55..00000000
--- a/rest_framework/tests/test_testing.py
+++ /dev/null
@@ -1,154 +0,0 @@
-# -- coding: utf-8 --
-
-from __future__ import unicode_literals
-from io import BytesIO
-
-from django.contrib.auth.models import User
-from django.test import TestCase
-from rest_framework.compat import patterns, url
-from rest_framework.decorators import api_view
-from rest_framework.response import Response
-from rest_framework.test import APIClient, APIRequestFactory, force_authenticate
-
-
-@api_view(['GET', 'POST'])
-def view(request):
- return Response({
- 'auth': request.META.get('HTTP_AUTHORIZATION', b''),
- 'user': request.user.username
- })
-
-
-@api_view(['GET', 'POST'])
-def session_view(request):
- active_session = request.session.get('active_session', False)
- request.session['active_session'] = True
- return Response({
- 'active_session': active_session
- })
-
-
-urlpatterns = patterns('',
- url(r'^view/$', view),
- url(r'^session-view/$', session_view),
-)
-
-
-class TestAPITestClient(TestCase):
- urls = 'rest_framework.tests.test_testing'
-
- def setUp(self):
- self.client = APIClient()
-
- def test_credentials(self):
- """
- Setting `.credentials()` adds the required headers to each request.
- """
- self.client.credentials(HTTP_AUTHORIZATION='example')
- for _ in range(0, 3):
- response = self.client.get('/view/')
- self.assertEqual(response.data['auth'], 'example')
-
- def test_force_authenticate(self):
- """
- Setting `.force_authenticate()` forcibly authenticates each request.
- """
- user = User.objects.create_user('example', 'example@example.com')
- self.client.force_authenticate(user)
- response = self.client.get('/view/')
- self.assertEqual(response.data['user'], 'example')
-
- def test_force_authenticate_with_sessions(self):
- """
- Setting `.force_authenticate()` forcibly authenticates each request.
- """
- user = User.objects.create_user('example', 'example@example.com')
- self.client.force_authenticate(user)
-
- # First request does not yet have an active session
- response = self.client.get('/session-view/')
- self.assertEqual(response.data['active_session'], False)
-
- # Subsequant requests have an active session
- response = self.client.get('/session-view/')
- self.assertEqual(response.data['active_session'], True)
-
- # Force authenticating as `None` should also logout the user session.
- self.client.force_authenticate(None)
- response = self.client.get('/session-view/')
- self.assertEqual(response.data['active_session'], False)
-
- def test_csrf_exempt_by_default(self):
- """
- By default, the test client is CSRF exempt.
- """
- User.objects.create_user('example', 'example@example.com', 'password')
- self.client.login(username='example', password='password')
- response = self.client.post('/view/')
- self.assertEqual(response.status_code, 200)
-
- def test_explicitly_enforce_csrf_checks(self):
- """
- The test client can enforce CSRF checks.
- """
- client = APIClient(enforce_csrf_checks=True)
- User.objects.create_user('example', 'example@example.com', 'password')
- client.login(username='example', password='password')
- response = client.post('/view/')
- expected = {'detail': 'CSRF Failed: CSRF cookie not set.'}
- self.assertEqual(response.status_code, 403)
- self.assertEqual(response.data, expected)
-
-
-class TestAPIRequestFactory(TestCase):
- def test_csrf_exempt_by_default(self):
- """
- By default, the test client is CSRF exempt.
- """
- user = User.objects.create_user('example', 'example@example.com', 'password')
- factory = APIRequestFactory()
- request = factory.post('/view/')
- request.user = user
- response = view(request)
- self.assertEqual(response.status_code, 200)
-
- def test_explicitly_enforce_csrf_checks(self):
- """
- The test client can enforce CSRF checks.
- """
- user = User.objects.create_user('example', 'example@example.com', 'password')
- factory = APIRequestFactory(enforce_csrf_checks=True)
- request = factory.post('/view/')
- request.user = user
- response = view(request)
- expected = {'detail': 'CSRF Failed: CSRF cookie not set.'}
- self.assertEqual(response.status_code, 403)
- self.assertEqual(response.data, expected)
-
- def test_invalid_format(self):
- """
- Attempting to use a format that is not configured will raise an
- assertion error.
- """
- factory = APIRequestFactory()
- self.assertRaises(AssertionError, factory.post,
- path='/view/', data={'example': 1}, format='xml'
- )
-
- def test_force_authenticate(self):
- """
- Setting `force_authenticate()` forcibly authenticates the request.
- """
- user = User.objects.create_user('example', 'example@example.com')
- factory = APIRequestFactory()
- request = factory.get('/view')
- force_authenticate(request, user=user)
- response = view(request)
- self.assertEqual(response.data['user'], 'example')
-
- def test_upload_file(self):
- # This is a 1x1 black png
- simple_png = BytesIO(b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x06\x00\x00\x00\x1f\x15\xc4\x89\x00\x00\x00\rIDATx\x9cc````\x00\x00\x00\x05\x00\x01\xa5\xf6E@\x00\x00\x00\x00IEND\xaeB`\x82')
- simple_png.name = 'test.png'
- factory = APIRequestFactory()
- factory.post('/', data={'image': simple_png})
diff --git a/rest_framework/tests/test_throttling.py b/rest_framework/tests/test_throttling.py
deleted file mode 100644
index 41bff692..00000000
--- a/rest_framework/tests/test_throttling.py
+++ /dev/null
@@ -1,277 +0,0 @@
-"""
-Tests for the throttling implementations in the permissions module.
-"""
-from __future__ import unicode_literals
-from django.test import TestCase
-from django.contrib.auth.models import User
-from django.core.cache import cache
-from rest_framework.test import APIRequestFactory
-from rest_framework.views import APIView
-from rest_framework.throttling import BaseThrottle, UserRateThrottle, ScopedRateThrottle
-from rest_framework.response import Response
-
-
-class User3SecRateThrottle(UserRateThrottle):
- rate = '3/sec'
- scope = 'seconds'
-
-
-class User3MinRateThrottle(UserRateThrottle):
- rate = '3/min'
- scope = 'minutes'
-
-
-class NonTimeThrottle(BaseThrottle):
- def allow_request(self, request, view):
- if not hasattr(self.__class__, 'called'):
- self.__class__.called = True
- return True
- return False
-
-
-class MockView(APIView):
- throttle_classes = (User3SecRateThrottle,)
-
- def get(self, request):
- return Response('foo')
-
-
-class MockView_MinuteThrottling(APIView):
- throttle_classes = (User3MinRateThrottle,)
-
- def get(self, request):
- return Response('foo')
-
-
-class MockView_NonTimeThrottling(APIView):
- throttle_classes = (NonTimeThrottle,)
-
- def get(self, request):
- return Response('foo')
-
-
-class ThrottlingTests(TestCase):
- def setUp(self):
- """
- Reset the cache so that no throttles will be active
- """
- cache.clear()
- self.factory = APIRequestFactory()
-
- def test_requests_are_throttled(self):
- """
- Ensure request rate is limited
- """
- request = self.factory.get('/')
- for dummy in range(4):
- response = MockView.as_view()(request)
- self.assertEqual(429, response.status_code)
-
- def set_throttle_timer(self, view, value):
- """
- Explicitly set the timer, overriding time.time()
- """
- view.throttle_classes[0].timer = lambda self: value
-
- def test_request_throttling_expires(self):
- """
- Ensure request rate is limited for a limited duration only
- """
- self.set_throttle_timer(MockView, 0)
-
- request = self.factory.get('/')
- for dummy in range(4):
- response = MockView.as_view()(request)
- self.assertEqual(429, response.status_code)
-
- # Advance the timer by one second
- self.set_throttle_timer(MockView, 1)
-
- response = MockView.as_view()(request)
- self.assertEqual(200, response.status_code)
-
- def ensure_is_throttled(self, view, expect):
- request = self.factory.get('/')
- request.user = User.objects.create(username='a')
- for dummy in range(3):
- view.as_view()(request)
- request.user = User.objects.create(username='b')
- response = view.as_view()(request)
- self.assertEqual(expect, response.status_code)
-
- def test_request_throttling_is_per_user(self):
- """
- Ensure request rate is only limited per user, not globally for
- PerUserThrottles
- """
- self.ensure_is_throttled(MockView, 200)
-
- def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
- """
- Ensure the response returns an X-Throttle field with status and next attributes
- set properly.
- """
- request = self.factory.get('/')
- for timer, expect in expected_headers:
- self.set_throttle_timer(view, timer)
- response = view.as_view()(request)
- if expect is not None:
- self.assertEqual(response['X-Throttle-Wait-Seconds'], expect)
- else:
- self.assertFalse('X-Throttle-Wait-Seconds' in response)
-
- def test_seconds_fields(self):
- """
- Ensure for second based throttles.
- """
- self.ensure_response_header_contains_proper_throttle_field(MockView,
- ((0, None),
- (0, None),
- (0, None),
- (0, '1')
- ))
-
- def test_minutes_fields(self):
- """
- Ensure for minute based throttles.
- """
- self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
- ((0, None),
- (0, None),
- (0, None),
- (0, '60')
- ))
-
- def test_next_rate_remains_constant_if_followed(self):
- """
- If a client follows the recommended next request rate,
- the throttling rate should stay constant.
- """
- self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
- ((0, None),
- (20, None),
- (40, None),
- (60, None),
- (80, None)
- ))
-
- def test_non_time_throttle(self):
- """
- Ensure for second based throttles.
- """
- request = self.factory.get('/')
-
- self.assertFalse(hasattr(MockView_NonTimeThrottling.throttle_classes[0], 'called'))
-
- response = MockView_NonTimeThrottling.as_view()(request)
- self.assertFalse('X-Throttle-Wait-Seconds' in response)
-
- self.assertTrue(MockView_NonTimeThrottling.throttle_classes[0].called)
-
- response = MockView_NonTimeThrottling.as_view()(request)
- self.assertFalse('X-Throttle-Wait-Seconds' in response)
-
-
-class ScopedRateThrottleTests(TestCase):
- """
- Tests for ScopedRateThrottle.
- """
-
- def setUp(self):
- class XYScopedRateThrottle(ScopedRateThrottle):
- TIMER_SECONDS = 0
- THROTTLE_RATES = {'x': '3/min', 'y': '1/min'}
- timer = lambda self: self.TIMER_SECONDS
-
- class XView(APIView):
- throttle_classes = (XYScopedRateThrottle,)
- throttle_scope = 'x'
-
- def get(self, request):
- return Response('x')
-
- class YView(APIView):
- throttle_classes = (XYScopedRateThrottle,)
- throttle_scope = 'y'
-
- def get(self, request):
- return Response('y')
-
- class UnscopedView(APIView):
- throttle_classes = (XYScopedRateThrottle,)
-
- def get(self, request):
- return Response('y')
-
- self.throttle_class = XYScopedRateThrottle
- self.factory = APIRequestFactory()
- self.x_view = XView.as_view()
- self.y_view = YView.as_view()
- self.unscoped_view = UnscopedView.as_view()
-
- def increment_timer(self, seconds=1):
- self.throttle_class.TIMER_SECONDS += seconds
-
- def test_scoped_rate_throttle(self):
- request = self.factory.get('/')
-
- # Should be able to hit x view 3 times per minute.
- response = self.x_view(request)
- self.assertEqual(200, response.status_code)
-
- self.increment_timer()
- response = self.x_view(request)
- self.assertEqual(200, response.status_code)
-
- self.increment_timer()
- response = self.x_view(request)
- self.assertEqual(200, response.status_code)
-
- self.increment_timer()
- response = self.x_view(request)
- self.assertEqual(429, response.status_code)
-
- # Should be able to hit y view 1 time per minute.
- self.increment_timer()
- response = self.y_view(request)
- self.assertEqual(200, response.status_code)
-
- self.increment_timer()
- response = self.y_view(request)
- self.assertEqual(429, response.status_code)
-
- # Ensure throttles properly reset by advancing the rest of the minute
- self.increment_timer(55)
-
- # Should still be able to hit x view 3 times per minute.
- response = self.x_view(request)
- self.assertEqual(200, response.status_code)
-
- self.increment_timer()
- response = self.x_view(request)
- self.assertEqual(200, response.status_code)
-
- self.increment_timer()
- response = self.x_view(request)
- self.assertEqual(200, response.status_code)
-
- self.increment_timer()
- response = self.x_view(request)
- self.assertEqual(429, response.status_code)
-
- # Should still be able to hit y view 1 time per minute.
- self.increment_timer()
- response = self.y_view(request)
- self.assertEqual(200, response.status_code)
-
- self.increment_timer()
- response = self.y_view(request)
- self.assertEqual(429, response.status_code)
-
- def test_unscoped_view_not_throttled(self):
- request = self.factory.get('/')
-
- for idx in range(10):
- self.increment_timer()
- response = self.unscoped_view(request)
- self.assertEqual(200, response.status_code)
diff --git a/rest_framework/tests/test_urlpatterns.py b/rest_framework/tests/test_urlpatterns.py
deleted file mode 100644
index 8132ec4c..00000000
--- a/rest_framework/tests/test_urlpatterns.py
+++ /dev/null
@@ -1,76 +0,0 @@
-from __future__ import unicode_literals
-from collections import namedtuple
-from django.core import urlresolvers
-from django.test import TestCase
-from rest_framework.test import APIRequestFactory
-from rest_framework.compat import patterns, url, include
-from rest_framework.urlpatterns import format_suffix_patterns
-
-
-# A container class for test paths for the test case
-URLTestPath = namedtuple('URLTestPath', ['path', 'args', 'kwargs'])
-
-
-def dummy_view(request, *args, **kwargs):
- pass
-
-
-class FormatSuffixTests(TestCase):
- """
- Tests `format_suffix_patterns` against different URLPatterns to ensure the URLs still resolve properly, including any captured parameters.
- """
- def _resolve_urlpatterns(self, urlpatterns, test_paths):
- factory = APIRequestFactory()
- try:
- urlpatterns = format_suffix_patterns(urlpatterns)
- except Exception:
- self.fail("Failed to apply `format_suffix_patterns` on the supplied urlpatterns")
- resolver = urlresolvers.RegexURLResolver(r'^/', urlpatterns)
- for test_path in test_paths:
- request = factory.get(test_path.path)
- try:
- callback, callback_args, callback_kwargs = resolver.resolve(request.path_info)
- except Exception:
- self.fail("Failed to resolve URL: %s" % request.path_info)
- self.assertEqual(callback_args, test_path.args)
- self.assertEqual(callback_kwargs, test_path.kwargs)
-
- def test_format_suffix(self):
- urlpatterns = patterns(
- '',
- url(r'^test$', dummy_view),
- )
- test_paths = [
- URLTestPath('/test', (), {}),
- URLTestPath('/test.api', (), {'format': 'api'}),
- URLTestPath('/test.asdf', (), {'format': 'asdf'}),
- ]
- self._resolve_urlpatterns(urlpatterns, test_paths)
-
- def test_default_args(self):
- urlpatterns = patterns(
- '',
- url(r'^test$', dummy_view, {'foo': 'bar'}),
- )
- test_paths = [
- URLTestPath('/test', (), {'foo': 'bar', }),
- URLTestPath('/test.api', (), {'foo': 'bar', 'format': 'api'}),
- URLTestPath('/test.asdf', (), {'foo': 'bar', 'format': 'asdf'}),
- ]
- self._resolve_urlpatterns(urlpatterns, test_paths)
-
- def test_included_urls(self):
- nested_patterns = patterns(
- '',
- url(r'^path$', dummy_view)
- )
- urlpatterns = patterns(
- '',
- url(r'^test/', include(nested_patterns), {'foo': 'bar'}),
- )
- test_paths = [
- URLTestPath('/test/path', (), {'foo': 'bar', }),
- URLTestPath('/test/path.api', (), {'foo': 'bar', 'format': 'api'}),
- URLTestPath('/test/path.asdf', (), {'foo': 'bar', 'format': 'asdf'}),
- ]
- self._resolve_urlpatterns(urlpatterns, test_paths)
diff --git a/rest_framework/tests/test_validation.py b/rest_framework/tests/test_validation.py
deleted file mode 100644
index 124c874d..00000000
--- a/rest_framework/tests/test_validation.py
+++ /dev/null
@@ -1,104 +0,0 @@
-from __future__ import unicode_literals
-from django.db import models
-from django.test import TestCase
-from rest_framework import generics, serializers, status
-from rest_framework.test import APIRequestFactory
-
-factory = APIRequestFactory()
-
-
-# Regression for #666
-
-class ValidationModel(models.Model):
- blank_validated_field = models.CharField(max_length=255)
-
-
-class ValidationModelSerializer(serializers.ModelSerializer):
- class Meta:
- model = ValidationModel
- fields = ('blank_validated_field',)
- read_only_fields = ('blank_validated_field',)
-
-
-class UpdateValidationModel(generics.RetrieveUpdateDestroyAPIView):
- model = ValidationModel
- serializer_class = ValidationModelSerializer
-
-
-class TestPreSaveValidationExclusions(TestCase):
- def test_pre_save_validation_exclusions(self):
- """
- Somewhat weird test case to ensure that we don't perform model
- validation on read only fields.
- """
- obj = ValidationModel.objects.create(blank_validated_field='')
- request = factory.put('/', {}, format='json')
- view = UpdateValidationModel().as_view()
- response = view(request, pk=obj.pk).render()
- self.assertEqual(response.status_code, status.HTTP_200_OK)
-
-
-# Regression for #653
-
-class ShouldValidateModel(models.Model):
- should_validate_field = models.CharField(max_length=255)
-
-
-class ShouldValidateModelSerializer(serializers.ModelSerializer):
- renamed = serializers.CharField(source='should_validate_field', required=False)
-
- def validate_renamed(self, attrs, source):
- value = attrs[source]
- if len(value) < 3:
- raise serializers.ValidationError('Minimum 3 characters.')
- return attrs
-
- class Meta:
- model = ShouldValidateModel
- fields = ('renamed',)
-
-
-class TestPreSaveValidationExclusionsSerializer(TestCase):
- def test_renamed_fields_are_model_validated(self):
- """
- Ensure fields with 'source' applied do get still get model validation.
- """
- # We've set `required=False` on the serializer, but the model
- # does not have `blank=True`, so this serializer should not validate.
- serializer = ShouldValidateModelSerializer(data={'renamed': ''})
- self.assertEqual(serializer.is_valid(), False)
- self.assertIn('renamed', serializer.errors)
- self.assertNotIn('should_validate_field', serializer.errors)
-
-
-class TestCustomValidationMethods(TestCase):
- def test_custom_validation_method_is_executed(self):
- serializer = ShouldValidateModelSerializer(data={'renamed': 'fo'})
- self.assertFalse(serializer.is_valid())
- self.assertIn('renamed', serializer.errors)
-
- def test_custom_validation_method_passing(self):
- serializer = ShouldValidateModelSerializer(data={'renamed': 'foo'})
- self.assertTrue(serializer.is_valid())
-
-
-class ValidationSerializer(serializers.Serializer):
- foo = serializers.CharField()
-
- def validate_foo(self, attrs, source):
- raise serializers.ValidationError("foo invalid")
-
- def validate(self, attrs):
- raise serializers.ValidationError("serializer invalid")
-
-
-class TestAvoidValidation(TestCase):
- """
- If serializer was initialized with invalid data (None or non dict-like), it
- should avoid validation layer (validate_ and validate methods)
- """
- def test_serializer_errors_has_only_invalid_data_error(self):
- serializer = ValidationSerializer(data='invalid data')
- self.assertFalse(serializer.is_valid())
- self.assertDictEqual(serializer.errors,
- {'non_field_errors': ['Invalid data']})
diff --git a/rest_framework/tests/test_views.py b/rest_framework/tests/test_views.py
deleted file mode 100644
index 65c7e50e..00000000
--- a/rest_framework/tests/test_views.py
+++ /dev/null
@@ -1,142 +0,0 @@
-from __future__ import unicode_literals
-
-import copy
-from django.test import TestCase
-from rest_framework import status
-from rest_framework.decorators import api_view
-from rest_framework.response import Response
-from rest_framework.settings import api_settings
-from rest_framework.test import APIRequestFactory
-from rest_framework.views import APIView
-
-factory = APIRequestFactory()
-
-
-class BasicView(APIView):
- def get(self, request, *args, **kwargs):
- return Response({'method': 'GET'})
-
- def post(self, request, *args, **kwargs):
- return Response({'method': 'POST', 'data': request.DATA})
-
-
-@api_view(['GET', 'POST', 'PUT', 'PATCH'])
-def basic_view(request):
- if request.method == 'GET':
- return {'method': 'GET'}
- elif request.method == 'POST':
- return {'method': 'POST', 'data': request.DATA}
- elif request.method == 'PUT':
- return {'method': 'PUT', 'data': request.DATA}
- elif request.method == 'PATCH':
- return {'method': 'PATCH', 'data': request.DATA}
-
-
-class ErrorView(APIView):
- def get(self, request, *args, **kwargs):
- raise Exception
-
-
-@api_view(['GET'])
-def error_view(request):
- raise Exception
-
-
-def sanitise_json_error(error_dict):
- """
- Exact contents of JSON error messages depend on the installed version
- of json.
- """
- ret = copy.copy(error_dict)
- chop = len('JSON parse error - No JSON object could be decoded')
- ret['detail'] = ret['detail'][:chop]
- return ret
-
-
-class ClassBasedViewIntegrationTests(TestCase):
- def setUp(self):
- self.view = BasicView.as_view()
-
- def test_400_parse_error(self):
- request = factory.post('/', 'f00bar', content_type='application/json')
- response = self.view(request)
- expected = {
- 'detail': 'JSON parse error - No JSON object could be decoded'
- }
- self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
- self.assertEqual(sanitise_json_error(response.data), expected)
-
- def test_400_parse_error_tunneled_content(self):
- content = 'f00bar'
- content_type = 'application/json'
- form_data = {
- api_settings.FORM_CONTENT_OVERRIDE: content,
- api_settings.FORM_CONTENTTYPE_OVERRIDE: content_type
- }
- request = factory.post('/', form_data)
- response = self.view(request)
- expected = {
- 'detail': 'JSON parse error - No JSON object could be decoded'
- }
- self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
- self.assertEqual(sanitise_json_error(response.data), expected)
-
-
-class FunctionBasedViewIntegrationTests(TestCase):
- def setUp(self):
- self.view = basic_view
-
- def test_400_parse_error(self):
- request = factory.post('/', 'f00bar', content_type='application/json')
- response = self.view(request)
- expected = {
- 'detail': 'JSON parse error - No JSON object could be decoded'
- }
- self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
- self.assertEqual(sanitise_json_error(response.data), expected)
-
- def test_400_parse_error_tunneled_content(self):
- content = 'f00bar'
- content_type = 'application/json'
- form_data = {
- api_settings.FORM_CONTENT_OVERRIDE: content,
- api_settings.FORM_CONTENTTYPE_OVERRIDE: content_type
- }
- request = factory.post('/', form_data)
- response = self.view(request)
- expected = {
- 'detail': 'JSON parse error - No JSON object could be decoded'
- }
- self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
- self.assertEqual(sanitise_json_error(response.data), expected)
-
-
-class TestCustomExceptionHandler(TestCase):
- def setUp(self):
- self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER
-
- def exception_handler(exc):
- return Response('Error!', status=status.HTTP_400_BAD_REQUEST)
-
- api_settings.EXCEPTION_HANDLER = exception_handler
-
- def tearDown(self):
- api_settings.EXCEPTION_HANDLER = self.DEFAULT_HANDLER
-
- def test_class_based_view_exception_handler(self):
- view = ErrorView.as_view()
-
- request = factory.get('/', content_type='application/json')
- response = view(request)
- expected = 'Error!'
- self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
- self.assertEqual(response.data, expected)
-
- def test_function_based_view_exception_handler(self):
- view = error_view
-
- request = factory.get('/', content_type='application/json')
- response = view(request)
- expected = 'Error!'
- self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
- self.assertEqual(response.data, expected)
diff --git a/rest_framework/tests/test_write_only_fields.py b/rest_framework/tests/test_write_only_fields.py
deleted file mode 100644
index aabb18d6..00000000
--- a/rest_framework/tests/test_write_only_fields.py
+++ /dev/null
@@ -1,42 +0,0 @@
-from django.db import models
-from django.test import TestCase
-from rest_framework import serializers
-
-
-class ExampleModel(models.Model):
- email = models.EmailField(max_length=100)
- password = models.CharField(max_length=100)
-
-
-class WriteOnlyFieldTests(TestCase):
- def test_write_only_fields(self):
- class ExampleSerializer(serializers.Serializer):
- email = serializers.EmailField()
- password = serializers.CharField(write_only=True)
-
- data = {
- 'email': 'foo@example.com',
- 'password': '123'
- }
- serializer = ExampleSerializer(data=data)
- self.assertTrue(serializer.is_valid())
- self.assertEquals(serializer.object, data)
- self.assertEquals(serializer.data, {'email': 'foo@example.com'})
-
- def test_write_only_fields_meta(self):
- class ExampleSerializer(serializers.ModelSerializer):
- class Meta:
- model = ExampleModel
- fields = ('email', 'password')
- write_only_fields = ('password',)
-
- data = {
- 'email': 'foo@example.com',
- 'password': '123'
- }
- serializer = ExampleSerializer(data=data)
- self.assertTrue(serializer.is_valid())
- self.assertTrue(isinstance(serializer.object, ExampleModel))
- self.assertEquals(serializer.object.email, data['email'])
- self.assertEquals(serializer.object.password, data['password'])
- self.assertEquals(serializer.data, {'email': 'foo@example.com'})
diff --git a/rest_framework/tests/tests.py b/rest_framework/tests/tests.py
deleted file mode 100644
index 554ebd1a..00000000
--- a/rest_framework/tests/tests.py
+++ /dev/null
@@ -1,16 +0,0 @@
-"""
-Force import of all modules in this package in order to get the standard test
-runner to pick up the tests. Yowzers.
-"""
-from __future__ import unicode_literals
-import os
-import django
-
-modules = [filename.rsplit('.', 1)[0]
- for filename in os.listdir(os.path.dirname(__file__))
- if filename.endswith('.py') and not filename.startswith('_')]
-__test__ = dict()
-
-if django.VERSION < (1, 6):
- for module in modules:
- exec("from rest_framework.tests.%s import *" % module)
diff --git a/rest_framework/tests/users/__init__.py b/rest_framework/tests/users/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/rest_framework/tests/users/models.py b/rest_framework/tests/users/models.py
deleted file mode 100644
index 128bac90..00000000
--- a/rest_framework/tests/users/models.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from django.db import models
-
-
-class User(models.Model):
- account = models.ForeignKey('accounts.Account', blank=True, null=True, related_name='users')
- active_record = models.ForeignKey('records.Record', blank=True, null=True)
diff --git a/rest_framework/tests/users/serializers.py b/rest_framework/tests/users/serializers.py
deleted file mode 100644
index da496554..00000000
--- a/rest_framework/tests/users/serializers.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from rest_framework import serializers
-
-from rest_framework.tests.users.models import User
-
-
-class UserSerializer(serializers.ModelSerializer):
- class Meta:
- model = User
diff --git a/rest_framework/tests/views.py b/rest_framework/tests/views.py
deleted file mode 100644
index 3917b74a..00000000
--- a/rest_framework/tests/views.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from rest_framework import generics
-from rest_framework.tests.models import NullableForeignKeySource
-from rest_framework.tests.serializers import NullableFKSourceSerializer
-
-
-class NullableFKSourceDetail(generics.RetrieveUpdateDestroyAPIView):
- model = NullableForeignKeySource
- model_serializer_class = NullableFKSourceSerializer
diff --git a/setup.py b/setup.py
index 78cdb628..2c56cd75 100755
--- a/setup.py
+++ b/setup.py
@@ -2,11 +2,26 @@
# -*- coding: utf-8 -*-
from setuptools import setup
+from setuptools.command.test import test as TestCommand
import re
import os
import sys
+# This command has been borrowed from
+# https://github.com/getsentry/sentry/blob/master/setup.py
+class PyTest(TestCommand):
+ def finalize_options(self):
+ TestCommand.finalize_options(self)
+ self.test_args = ['tests']
+ self.test_suite = True
+
+ def run_tests(self):
+ import pytest
+ errno = pytest.main(self.test_args)
+ sys.exit(errno)
+
+
def get_version(package):
"""
Return package version as listed in `__version__` in `init.py`.
@@ -62,7 +77,7 @@ setup(
author_email='tom@tomchristie.com', # SEE NOTE BELOW (*)
packages=get_packages('rest_framework'),
package_data=get_package_data('rest_framework'),
- test_suite='rest_framework.runtests.runtests.main',
+ cmdclass={'test': PyTest},
install_requires=[],
classifiers=[
'Development Status :: 5 - Production/Stable',
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/accounts/__init__.py b/tests/accounts/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/accounts/models.py b/tests/accounts/models.py
new file mode 100644
index 00000000..3bf4a0c3
--- /dev/null
+++ b/tests/accounts/models.py
@@ -0,0 +1,8 @@
+from django.db import models
+
+from tests.users.models import User
+
+
+class Account(models.Model):
+ owner = models.ForeignKey(User, related_name='accounts_owned')
+ admins = models.ManyToManyField(User, blank=True, null=True, related_name='accounts_administered')
diff --git a/tests/accounts/serializers.py b/tests/accounts/serializers.py
new file mode 100644
index 00000000..57a91b92
--- /dev/null
+++ b/tests/accounts/serializers.py
@@ -0,0 +1,11 @@
+from rest_framework import serializers
+
+from tests.accounts.models import Account
+from tests.users.serializers import UserSerializer
+
+
+class AccountSerializer(serializers.ModelSerializer):
+ admins = UserSerializer(many=True)
+
+ class Meta:
+ model = Account
diff --git a/tests/description.py b/tests/description.py
new file mode 100644
index 00000000..b46d7f54
--- /dev/null
+++ b/tests/description.py
@@ -0,0 +1,26 @@
+# -- coding: utf-8 --
+
+# Apparently there is a python 2.6 issue where docstrings of imported view classes
+# do not retain their encoding information even if a module has a proper
+# encoding declaration at the top of its source file. Therefore for tests
+# to catch unicode related errors, a mock view has to be declared in a separate
+# module.
+
+from rest_framework.views import APIView
+
+
+# test strings snatched from http://www.columbia.edu/~fdc/utf8/,
+# http://winrus.com/utf8-jap.htm and memory
+UTF8_TEST_DOCSTRING = (
+ 'zażółć gęślą jaźń'
+ 'Sîne klâwen durh die wolken sint geslagen'
+ 'Τη γλώσσα μου έδωσαν ελληνική'
+ 'யாமறிந்த மொழிகளிலே தமிழ்மொழி'
+ 'На берегу пустынных волн'
+ 'てすと'
+ 'アイウエオカキクケコサシスセソタチツテ'
+)
+
+
+class ViewWithNonASCIICharactersInDocstring(APIView):
+ __doc__ = UTF8_TEST_DOCSTRING
diff --git a/tests/extras/__init__.py b/tests/extras/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/extras/bad_import.py b/tests/extras/bad_import.py
new file mode 100644
index 00000000..68263d94
--- /dev/null
+++ b/tests/extras/bad_import.py
@@ -0,0 +1 @@
+raise ValueError
diff --git a/tests/models.py b/tests/models.py
new file mode 100644
index 00000000..32a726c0
--- /dev/null
+++ b/tests/models.py
@@ -0,0 +1,170 @@
+from __future__ import unicode_literals
+from django.db import models
+from django.utils.translation import ugettext_lazy as _
+from rest_framework import serializers
+
+
+def foobar():
+ return 'foobar'
+
+
+class CustomField(models.CharField):
+
+ def __init__(self, *args, **kwargs):
+ kwargs['max_length'] = 12
+ super(CustomField, self).__init__(*args, **kwargs)
+
+
+class RESTFrameworkModel(models.Model):
+ """
+ Base for test models that sets app_label, so they play nicely.
+ """
+ class Meta:
+ app_label = 'tests'
+ abstract = True
+
+
+class HasPositiveIntegerAsChoice(RESTFrameworkModel):
+ some_choices = ((1, 'A'), (2, 'B'), (3, 'C'))
+ some_integer = models.PositiveIntegerField(choices=some_choices)
+
+
+class Anchor(RESTFrameworkModel):
+ text = models.CharField(max_length=100, default='anchor')
+
+
+class BasicModel(RESTFrameworkModel):
+ text = models.CharField(max_length=100, verbose_name=_("Text comes here"), help_text=_("Text description."))
+
+
+class SlugBasedModel(RESTFrameworkModel):
+ text = models.CharField(max_length=100)
+ slug = models.SlugField(max_length=32)
+
+
+class DefaultValueModel(RESTFrameworkModel):
+ text = models.CharField(default='foobar', max_length=100)
+ extra = models.CharField(blank=True, null=True, max_length=100)
+
+
+class CallableDefaultValueModel(RESTFrameworkModel):
+ text = models.CharField(default=foobar, max_length=100)
+
+
+class ManyToManyModel(RESTFrameworkModel):
+ rel = models.ManyToManyField(Anchor, help_text='Some help text.')
+
+
+class ReadOnlyManyToManyModel(RESTFrameworkModel):
+ text = models.CharField(max_length=100, default='anchor')
+ rel = models.ManyToManyField(Anchor)
+
+
+# Model for regression test for #285
+
+class Comment(RESTFrameworkModel):
+ email = models.EmailField()
+ content = models.CharField(max_length=200)
+ created = models.DateTimeField(auto_now_add=True)
+
+
+class ActionItem(RESTFrameworkModel):
+ title = models.CharField(max_length=200)
+ started = models.NullBooleanField(default=False)
+ done = models.BooleanField(default=False)
+ info = CustomField(default='---', max_length=12)
+
+
+# Models for reverse relations
+class Person(RESTFrameworkModel):
+ name = models.CharField(max_length=10)
+ age = models.IntegerField(null=True, blank=True)
+
+ @property
+ def info(self):
+ return {
+ 'name': self.name,
+ 'age': self.age,
+ }
+
+
+class BlogPost(RESTFrameworkModel):
+ title = models.CharField(max_length=100)
+ writer = models.ForeignKey(Person, null=True, blank=True)
+
+ def get_first_comment(self):
+ return self.blogpostcomment_set.all()[0]
+
+
+class BlogPostComment(RESTFrameworkModel):
+ text = models.TextField()
+ blog_post = models.ForeignKey(BlogPost)
+
+
+class Album(RESTFrameworkModel):
+ title = models.CharField(max_length=100, unique=True)
+
+
+class Photo(RESTFrameworkModel):
+ description = models.TextField()
+ album = models.ForeignKey(Album)
+
+
+# Model for issue #324
+class BlankFieldModel(RESTFrameworkModel):
+ title = models.CharField(max_length=100, blank=True, null=False)
+
+
+# Model for issue #380
+class OptionalRelationModel(RESTFrameworkModel):
+ other = models.ForeignKey('OptionalRelationModel', blank=True, null=True)
+
+
+# Model for RegexField
+class Book(RESTFrameworkModel):
+ isbn = models.CharField(max_length=13)
+
+
+# Models for relations tests
+# ManyToMany
+class ManyToManyTarget(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+
+
+class ManyToManySource(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+ targets = models.ManyToManyField(ManyToManyTarget, related_name='sources')
+
+
+# ForeignKey
+class ForeignKeyTarget(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+
+
+class ForeignKeySource(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+ target = models.ForeignKey(ForeignKeyTarget, related_name='sources')
+
+
+# Nullable ForeignKey
+class NullableForeignKeySource(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+ target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True,
+ related_name='nullable_sources')
+
+
+# OneToOne
+class OneToOneTarget(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+
+
+class NullableOneToOneSource(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+ target = models.OneToOneField(OneToOneTarget, null=True, blank=True,
+ related_name='nullable_source')
+
+
+# Serializer used to test BasicModel
+class BasicModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BasicModel
diff --git a/tests/records/__init__.py b/tests/records/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/records/models.py b/tests/records/models.py
new file mode 100644
index 00000000..76954807
--- /dev/null
+++ b/tests/records/models.py
@@ -0,0 +1,6 @@
+from django.db import models
+
+
+class Record(models.Model):
+ account = models.ForeignKey('accounts.Account', blank=True, null=True)
+ owner = models.ForeignKey('users.User', blank=True, null=True)
diff --git a/tests/serializers.py b/tests/serializers.py
new file mode 100644
index 00000000..f2f85b6e
--- /dev/null
+++ b/tests/serializers.py
@@ -0,0 +1,8 @@
+from rest_framework import serializers
+
+from tests.models import NullableForeignKeySource
+
+
+class NullableFKSourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = NullableForeignKeySource
diff --git a/tests/settings.py b/tests/settings.py
new file mode 100644
index 00000000..75f7c54b
--- /dev/null
+++ b/tests/settings.py
@@ -0,0 +1,169 @@
+# Django settings for testproject project.
+
+DEBUG = True
+TEMPLATE_DEBUG = DEBUG
+DEBUG_PROPAGATE_EXCEPTIONS = True
+
+ALLOWED_HOSTS = ['*']
+
+ADMINS = (
+ # ('Your Name', 'your_email@domain.com'),
+)
+
+MANAGERS = ADMINS
+
+DATABASES = {
+ 'default': {
+ 'ENGINE': 'django.db.backends.sqlite3', # Add 'postgresql_psycopg2', 'postgresql', 'mysql', 'sqlite3' or 'oracle'.
+ 'NAME': 'sqlite.db', # Or path to database file if using sqlite3.
+ 'USER': '', # Not used with sqlite3.
+ 'PASSWORD': '', # Not used with sqlite3.
+ 'HOST': '', # Set to empty string for localhost. Not used with sqlite3.
+ 'PORT': '', # Set to empty string for default. Not used with sqlite3.
+ }
+}
+
+CACHES = {
+ 'default': {
+ 'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
+ }
+}
+
+# Local time zone for this installation. Choices can be found here:
+# http://en.wikipedia.org/wiki/List_of_tz_zones_by_name
+# although not all choices may be available on all operating systems.
+# On Unix systems, a value of None will cause Django to use the same
+# timezone as the operating system.
+# If running in a Windows environment this must be set to the same as your
+# system time zone.
+TIME_ZONE = 'Europe/London'
+
+# Language code for this installation. All choices can be found here:
+# http://www.i18nguy.com/unicode/language-identifiers.html
+LANGUAGE_CODE = 'en-uk'
+
+SITE_ID = 1
+
+# If you set this to False, Django will make some optimizations so as not
+# to load the internationalization machinery.
+USE_I18N = True
+
+# If you set this to False, Django will not format dates, numbers and
+# calendars according to the current locale
+USE_L10N = True
+
+# Absolute filesystem path to the directory that will hold user-uploaded files.
+# Example: "/home/media/media.lawrence.com/"
+MEDIA_ROOT = ''
+
+# URL that handles the media served from MEDIA_ROOT. Make sure to use a
+# trailing slash if there is a path component (optional in other cases).
+# Examples: "http://media.lawrence.com", "http://example.com/media/"
+MEDIA_URL = ''
+
+# Make this unique, and don't share it with anybody.
+SECRET_KEY = 'u@x-aj9(hoh#rb-^ymf#g2jx_hp0vj7u5#b@ag1n^seu9e!%cy'
+
+# List of callables that know how to import templates from various sources.
+TEMPLATE_LOADERS = (
+ 'django.template.loaders.filesystem.Loader',
+ 'django.template.loaders.app_directories.Loader',
+# 'django.template.loaders.eggs.Loader',
+)
+
+MIDDLEWARE_CLASSES = (
+ 'django.middleware.common.CommonMiddleware',
+ 'django.contrib.sessions.middleware.SessionMiddleware',
+ 'django.middleware.csrf.CsrfViewMiddleware',
+ 'django.contrib.auth.middleware.AuthenticationMiddleware',
+ 'django.contrib.messages.middleware.MessageMiddleware',
+)
+
+ROOT_URLCONF = 'tests.urls'
+
+TEMPLATE_DIRS = (
+ # Put strings here, like "/home/html/django_templates" or "C:/www/django/templates".
+ # Always use forward slashes, even on Windows.
+ # Don't forget to use absolute paths, not relative paths.
+)
+
+INSTALLED_APPS = (
+ 'django.contrib.auth',
+ 'django.contrib.contenttypes',
+ 'django.contrib.sessions',
+ 'django.contrib.sites',
+ 'django.contrib.messages',
+ # Uncomment the next line to enable the admin:
+ # 'django.contrib.admin',
+ # Uncomment the next line to enable admin documentation:
+ # 'django.contrib.admindocs',
+ 'rest_framework',
+ 'rest_framework.authtoken',
+ 'tests',
+ 'tests.accounts',
+ 'tests.records',
+ 'tests.users',
+)
+
+# OAuth is optional and won't work if there is no oauth_provider & oauth2
+try:
+ import oauth_provider
+ import oauth2
+except ImportError:
+ pass
+else:
+ INSTALLED_APPS += (
+ 'oauth_provider',
+ )
+
+try:
+ import provider
+except ImportError:
+ pass
+else:
+ INSTALLED_APPS += (
+ 'provider',
+ 'provider.oauth2',
+ )
+
+# guardian is optional
+try:
+ import guardian
+except ImportError:
+ pass
+else:
+ ANONYMOUS_USER_ID = -1
+ AUTHENTICATION_BACKENDS = (
+ 'django.contrib.auth.backends.ModelBackend', # default
+ 'guardian.backends.ObjectPermissionBackend',
+ )
+ INSTALLED_APPS += (
+ 'guardian',
+ )
+
+STATIC_URL = '/static/'
+
+PASSWORD_HASHERS = (
+ 'django.contrib.auth.hashers.SHA1PasswordHasher',
+ 'django.contrib.auth.hashers.PBKDF2PasswordHasher',
+ 'django.contrib.auth.hashers.PBKDF2SHA1PasswordHasher',
+ 'django.contrib.auth.hashers.BCryptPasswordHasher',
+ 'django.contrib.auth.hashers.MD5PasswordHasher',
+ 'django.contrib.auth.hashers.CryptPasswordHasher',
+)
+
+AUTH_USER_MODEL = 'auth.User'
+
+import django
+
+if django.VERSION < (1, 3):
+ INSTALLED_APPS += ('staticfiles',)
+
+
+# If we're running on the Jenkins server we want to archive the coverage reports as XML.
+import os
+if os.environ.get('HUDSON_URL', None):
+ TEST_RUNNER = 'xmlrunner.extra.djangotestrunner.XMLTestRunner'
+ TEST_OUTPUT_VERBOSE = True
+ TEST_OUTPUT_DESCRIPTIONS = True
+ TEST_OUTPUT_DIR = 'xmlrunner'
diff --git a/tests/test_authentication.py b/tests/test_authentication.py
new file mode 100644
index 00000000..4ecfef44
--- /dev/null
+++ b/tests/test_authentication.py
@@ -0,0 +1,637 @@
+from __future__ import unicode_literals
+from django.contrib.auth.models import User
+from django.http import HttpResponse
+from django.test import TestCase
+from django.utils import unittest
+from rest_framework import HTTP_HEADER_ENCODING
+from rest_framework import exceptions
+from rest_framework import permissions
+from rest_framework import renderers
+from rest_framework.response import Response
+from rest_framework import status
+from rest_framework.authentication import (
+ BaseAuthentication,
+ TokenAuthentication,
+ BasicAuthentication,
+ SessionAuthentication,
+ OAuthAuthentication,
+ OAuth2Authentication
+)
+from rest_framework.authtoken.models import Token
+from rest_framework.compat import patterns, url, include
+from rest_framework.compat import oauth2_provider, oauth2_provider_models, oauth2_provider_scope
+from rest_framework.compat import oauth, oauth_provider
+from rest_framework.test import APIRequestFactory, APIClient
+from rest_framework.views import APIView
+import base64
+import time
+import datetime
+
+factory = APIRequestFactory()
+
+
+class MockView(APIView):
+ permission_classes = (permissions.IsAuthenticated,)
+
+ def get(self, request):
+ return HttpResponse({'a': 1, 'b': 2, 'c': 3})
+
+ def post(self, request):
+ return HttpResponse({'a': 1, 'b': 2, 'c': 3})
+
+ def put(self, request):
+ return HttpResponse({'a': 1, 'b': 2, 'c': 3})
+
+
+urlpatterns = patterns('',
+ (r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])),
+ (r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])),
+ (r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),
+ (r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),
+ (r'^oauth/$', MockView.as_view(authentication_classes=[OAuthAuthentication])),
+ (r'^oauth-with-scope/$', MockView.as_view(authentication_classes=[OAuthAuthentication],
+ permission_classes=[permissions.TokenHasReadWriteScope]))
+)
+
+if oauth2_provider is not None:
+ urlpatterns += patterns('',
+ url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')),
+ url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])),
+ url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication],
+ permission_classes=[permissions.TokenHasReadWriteScope])),
+ )
+
+
+class BasicAuthTests(TestCase):
+ """Basic authentication"""
+ urls = 'tests.test_authentication'
+
+ def setUp(self):
+ self.csrf_client = APIClient(enforce_csrf_checks=True)
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ def test_post_form_passing_basic_auth(self):
+ """Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF"""
+ credentials = ('%s:%s' % (self.username, self.password))
+ base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING)
+ auth = 'Basic %s' % base64_credentials
+ response = self.csrf_client.post('/basic/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ def test_post_json_passing_basic_auth(self):
+ """Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF"""
+ credentials = ('%s:%s' % (self.username, self.password))
+ base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING)
+ auth = 'Basic %s' % base64_credentials
+ response = self.csrf_client.post('/basic/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ def test_post_form_failing_basic_auth(self):
+ """Ensure POSTing form over basic auth without correct credentials fails"""
+ response = self.csrf_client.post('/basic/', {'example': 'example'})
+ self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
+
+ def test_post_json_failing_basic_auth(self):
+ """Ensure POSTing json over basic auth without correct credentials fails"""
+ response = self.csrf_client.post('/basic/', {'example': 'example'}, format='json')
+ self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
+ self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"')
+
+
+class SessionAuthTests(TestCase):
+ """User session authentication"""
+ urls = 'tests.test_authentication'
+
+ def setUp(self):
+ self.csrf_client = APIClient(enforce_csrf_checks=True)
+ self.non_csrf_client = APIClient(enforce_csrf_checks=False)
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ def tearDown(self):
+ self.csrf_client.logout()
+
+ def test_post_form_session_auth_failing_csrf(self):
+ """
+ Ensure POSTing form over session authentication without CSRF token fails.
+ """
+ self.csrf_client.login(username=self.username, password=self.password)
+ response = self.csrf_client.post('/session/', {'example': 'example'})
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+ def test_post_form_session_auth_passing(self):
+ """
+ Ensure POSTing form over session authentication with logged in user and CSRF token passes.
+ """
+ self.non_csrf_client.login(username=self.username, password=self.password)
+ response = self.non_csrf_client.post('/session/', {'example': 'example'})
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ def test_put_form_session_auth_passing(self):
+ """
+ Ensure PUTting form over session authentication with logged in user and CSRF token passes.
+ """
+ self.non_csrf_client.login(username=self.username, password=self.password)
+ response = self.non_csrf_client.put('/session/', {'example': 'example'})
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ def test_post_form_session_auth_failing(self):
+ """
+ Ensure POSTing form over session authentication without logged in user fails.
+ """
+ response = self.csrf_client.post('/session/', {'example': 'example'})
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+
+class TokenAuthTests(TestCase):
+ """Token authentication"""
+ urls = 'tests.test_authentication'
+
+ def setUp(self):
+ self.csrf_client = APIClient(enforce_csrf_checks=True)
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ self.key = 'abcd1234'
+ self.token = Token.objects.create(key=self.key, user=self.user)
+
+ def test_post_form_passing_token_auth(self):
+ """Ensure POSTing json over token auth with correct credentials passes and does not require CSRF"""
+ auth = 'Token ' + self.key
+ response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ def test_post_json_passing_token_auth(self):
+ """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF"""
+ auth = "Token " + self.key
+ response = self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ def test_post_form_failing_token_auth(self):
+ """Ensure POSTing form over token auth without correct credentials fails"""
+ response = self.csrf_client.post('/token/', {'example': 'example'})
+ self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
+
+ def test_post_json_failing_token_auth(self):
+ """Ensure POSTing json over token auth without correct credentials fails"""
+ response = self.csrf_client.post('/token/', {'example': 'example'}, format='json')
+ self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
+
+ def test_token_has_auto_assigned_key_if_none_provided(self):
+ """Ensure creating a token with no key will auto-assign a key"""
+ self.token.delete()
+ token = Token.objects.create(user=self.user)
+ self.assertTrue(bool(token.key))
+
+ def test_token_login_json(self):
+ """Ensure token login view using JSON POST works."""
+ client = APIClient(enforce_csrf_checks=True)
+ response = client.post('/auth-token/',
+ {'username': self.username, 'password': self.password}, format='json')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['token'], self.key)
+
+ def test_token_login_json_bad_creds(self):
+ """Ensure token login view using JSON POST fails if bad credentials are used."""
+ client = APIClient(enforce_csrf_checks=True)
+ response = client.post('/auth-token/',
+ {'username': self.username, 'password': "badpass"}, format='json')
+ self.assertEqual(response.status_code, 400)
+
+ def test_token_login_json_missing_fields(self):
+ """Ensure token login view using JSON POST fails if missing fields."""
+ client = APIClient(enforce_csrf_checks=True)
+ response = client.post('/auth-token/',
+ {'username': self.username}, format='json')
+ self.assertEqual(response.status_code, 400)
+
+ def test_token_login_form(self):
+ """Ensure token login view using form POST works."""
+ client = APIClient(enforce_csrf_checks=True)
+ response = client.post('/auth-token/',
+ {'username': self.username, 'password': self.password})
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['token'], self.key)
+
+
+class IncorrectCredentialsTests(TestCase):
+ def test_incorrect_credentials(self):
+ """
+ If a request contains bad authentication credentials, then
+ authentication should run and error, even if no permissions
+ are set on the view.
+ """
+ class IncorrectCredentialsAuth(BaseAuthentication):
+ def authenticate(self, request):
+ raise exceptions.AuthenticationFailed('Bad credentials')
+
+ request = factory.get('/')
+ view = MockView.as_view(
+ authentication_classes=(IncorrectCredentialsAuth,),
+ permission_classes=()
+ )
+ response = view(request)
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+ self.assertEqual(response.data, {'detail': 'Bad credentials'})
+
+
+class OAuthTests(TestCase):
+ """OAuth 1.0a authentication"""
+ urls = 'tests.test_authentication'
+
+ def setUp(self):
+ # these imports are here because oauth is optional and hiding them in try..except block or compat
+ # could obscure problems if something breaks
+ from oauth_provider.models import Consumer, Scope
+ from oauth_provider.models import Token as OAuthToken
+ from oauth_provider import consts
+
+ self.consts = consts
+
+ self.csrf_client = APIClient(enforce_csrf_checks=True)
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ self.CONSUMER_KEY = 'consumer_key'
+ self.CONSUMER_SECRET = 'consumer_secret'
+ self.TOKEN_KEY = "token_key"
+ self.TOKEN_SECRET = "token_secret"
+
+ self.consumer = Consumer.objects.create(key=self.CONSUMER_KEY, secret=self.CONSUMER_SECRET,
+ name='example', user=self.user, status=self.consts.ACCEPTED)
+
+ self.scope = Scope.objects.create(name="resource name", url="api/")
+ self.token = OAuthToken.objects.create(user=self.user, consumer=self.consumer, scope=self.scope,
+ token_type=OAuthToken.ACCESS, key=self.TOKEN_KEY, secret=self.TOKEN_SECRET, is_approved=True
+ )
+
+ def _create_authorization_header(self):
+ params = {
+ 'oauth_version': "1.0",
+ 'oauth_nonce': oauth.generate_nonce(),
+ 'oauth_timestamp': int(time.time()),
+ 'oauth_token': self.token.key,
+ 'oauth_consumer_key': self.consumer.key
+ }
+
+ req = oauth.Request(method="GET", url="http://example.com", parameters=params)
+
+ signature_method = oauth.SignatureMethod_PLAINTEXT()
+ req.sign_request(signature_method, self.consumer, self.token)
+
+ return req.to_header()["Authorization"]
+
+ def _create_authorization_url_parameters(self):
+ params = {
+ 'oauth_version': "1.0",
+ 'oauth_nonce': oauth.generate_nonce(),
+ 'oauth_timestamp': int(time.time()),
+ 'oauth_token': self.token.key,
+ 'oauth_consumer_key': self.consumer.key
+ }
+
+ req = oauth.Request(method="GET", url="http://example.com", parameters=params)
+
+ signature_method = oauth.SignatureMethod_PLAINTEXT()
+ req.sign_request(signature_method, self.consumer, self.token)
+ return dict(req)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_passing_oauth(self):
+ """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_repeated_nonce_failing_oauth(self):
+ """Ensure POSTing form over OAuth with repeated auth (same nonces and timestamp) credentials fails"""
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ # simulate reply attack auth header containes already used (nonce, timestamp) pair
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_token_removed_failing_oauth(self):
+ """Ensure POSTing when there is no OAuth access token in db fails"""
+ self.token.delete()
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_consumer_status_not_accepted_failing_oauth(self):
+ """Ensure POSTing when consumer status is anything other than ACCEPTED fails"""
+ for consumer_status in (self.consts.CANCELED, self.consts.PENDING, self.consts.REJECTED):
+ self.consumer.status = consumer_status
+ self.consumer.save()
+
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_with_request_token_failing_oauth(self):
+ """Ensure POSTing with unauthorized request token instead of access token fails"""
+ self.token.token_type = self.token.REQUEST
+ self.token.save()
+
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_with_urlencoded_parameters(self):
+ """Ensure POSTing with x-www-form-urlencoded auth parameters passes"""
+ params = self._create_authorization_url_parameters()
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth/', params, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_get_form_with_url_parameters(self):
+ """Ensure GETing with auth in url parameters passes"""
+ params = self._create_authorization_url_parameters()
+ response = self.csrf_client.get('/oauth/', params)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_hmac_sha1_signature_passes(self):
+ """Ensure POSTing using HMAC_SHA1 signature method passes"""
+ params = {
+ 'oauth_version': "1.0",
+ 'oauth_nonce': oauth.generate_nonce(),
+ 'oauth_timestamp': int(time.time()),
+ 'oauth_token': self.token.key,
+ 'oauth_consumer_key': self.consumer.key
+ }
+
+ req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params)
+
+ signature_method = oauth.SignatureMethod_HMAC_SHA1()
+ req.sign_request(signature_method, self.consumer, self.token)
+ auth = req.to_header()["Authorization"]
+
+ response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_get_form_with_readonly_resource_passing_auth(self):
+ """Ensure POSTing with a readonly scope instead of a write scope fails"""
+ read_only_access_token = self.token
+ read_only_access_token.scope.is_readonly = True
+ read_only_access_token.scope.save()
+ params = self._create_authorization_url_parameters()
+ response = self.csrf_client.get('/oauth-with-scope/', params)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_with_readonly_resource_failing_auth(self):
+ """Ensure POSTing with a readonly resource instead of a write scope fails"""
+ read_only_access_token = self.token
+ read_only_access_token.scope.is_readonly = True
+ read_only_access_token.scope.save()
+ params = self._create_authorization_url_parameters()
+ response = self.csrf_client.post('/oauth-with-scope/', params)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_post_form_with_write_resource_passing_auth(self):
+ """Ensure POSTing with a write resource succeed"""
+ read_write_access_token = self.token
+ read_write_access_token.scope.is_readonly = False
+ read_write_access_token.scope.save()
+ params = self._create_authorization_url_parameters()
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth-with-scope/', params, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_bad_consumer_key(self):
+ """Ensure POSTing using HMAC_SHA1 signature method passes"""
+ params = {
+ 'oauth_version': "1.0",
+ 'oauth_nonce': oauth.generate_nonce(),
+ 'oauth_timestamp': int(time.time()),
+ 'oauth_token': self.token.key,
+ 'oauth_consumer_key': 'badconsumerkey'
+ }
+
+ req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params)
+
+ signature_method = oauth.SignatureMethod_HMAC_SHA1()
+ req.sign_request(signature_method, self.consumer, self.token)
+ auth = req.to_header()["Authorization"]
+
+ response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+
+ @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
+ @unittest.skipUnless(oauth, 'oauth2 not installed')
+ def test_bad_token_key(self):
+ """Ensure POSTing using HMAC_SHA1 signature method passes"""
+ params = {
+ 'oauth_version': "1.0",
+ 'oauth_nonce': oauth.generate_nonce(),
+ 'oauth_timestamp': int(time.time()),
+ 'oauth_token': 'badtokenkey',
+ 'oauth_consumer_key': self.consumer.key
+ }
+
+ req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params)
+
+ signature_method = oauth.SignatureMethod_HMAC_SHA1()
+ req.sign_request(signature_method, self.consumer, self.token)
+ auth = req.to_header()["Authorization"]
+
+ response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+
+
+class OAuth2Tests(TestCase):
+ """OAuth 2.0 authentication"""
+ urls = 'tests.test_authentication'
+
+ def setUp(self):
+ self.csrf_client = APIClient(enforce_csrf_checks=True)
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ self.CLIENT_ID = 'client_key'
+ self.CLIENT_SECRET = 'client_secret'
+ self.ACCESS_TOKEN = "access_token"
+ self.REFRESH_TOKEN = "refresh_token"
+
+ self.oauth2_client = oauth2_provider_models.Client.objects.create(
+ client_id=self.CLIENT_ID,
+ client_secret=self.CLIENT_SECRET,
+ redirect_uri='',
+ client_type=0,
+ name='example',
+ user=None,
+ )
+
+ self.access_token = oauth2_provider_models.AccessToken.objects.create(
+ token=self.ACCESS_TOKEN,
+ client=self.oauth2_client,
+ user=self.user,
+ )
+ self.refresh_token = oauth2_provider_models.RefreshToken.objects.create(
+ user=self.user,
+ access_token=self.access_token,
+ client=self.oauth2_client
+ )
+
+ def _create_authorization_header(self, token=None):
+ return "Bearer {0}".format(token or self.access_token.token)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_with_wrong_authorization_header_token_type_failing(self):
+ """Ensure that a wrong token type lead to the correct HTTP error status code"""
+ auth = "Wrong token-type-obsviously"
+ response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+ response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_with_wrong_authorization_header_token_format_failing(self):
+ """Ensure that a wrong token format lead to the correct HTTP error status code"""
+ auth = "Bearer wrong token format"
+ response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+ response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_with_wrong_authorization_header_token_failing(self):
+ """Ensure that a wrong token lead to the correct HTTP error status code"""
+ auth = "Bearer wrong-token"
+ response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+ response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 401)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_passing_auth(self):
+ """Ensure GETing form over OAuth with correct client credentials succeed"""
+ auth = self._create_authorization_header()
+ response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_passing_auth(self):
+ """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_token_removed_failing_auth(self):
+ """Ensure POSTing when there is no OAuth access token in db fails"""
+ self.access_token.delete()
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_with_refresh_token_failing_auth(self):
+ """Ensure POSTing with refresh token instead of access token fails"""
+ auth = self._create_authorization_header(token=self.refresh_token.token)
+ response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_with_expired_access_token_failing_auth(self):
+ """Ensure POSTing with expired access token fails with an 'Invalid token' error"""
+ self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10) # 10 seconds late
+ self.access_token.save()
+ auth = self._create_authorization_header()
+ response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+ self.assertIn('Invalid token', response.content)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_with_invalid_scope_failing_auth(self):
+ """Ensure POSTing with a readonly scope instead of a write scope fails"""
+ read_only_access_token = self.access_token
+ read_only_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['read']
+ read_only_access_token.save()
+ auth = self._create_authorization_header(token=read_only_access_token.token)
+ response = self.csrf_client.get('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+ response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_with_valid_scope_passing_auth(self):
+ """Ensure POSTing with a write scope succeed"""
+ read_write_access_token = self.access_token
+ read_write_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['write']
+ read_write_access_token.save()
+ auth = self._create_authorization_header(token=read_write_access_token.token)
+ response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+
+class FailingAuthAccessedInRenderer(TestCase):
+ def setUp(self):
+ class AuthAccessingRenderer(renderers.BaseRenderer):
+ media_type = 'text/plain'
+ format = 'txt'
+
+ def render(self, data, media_type=None, renderer_context=None):
+ request = renderer_context['request']
+ if request.user.is_authenticated():
+ return b'authenticated'
+ return b'not authenticated'
+
+ class FailingAuth(BaseAuthentication):
+ def authenticate(self, request):
+ raise exceptions.AuthenticationFailed('authentication failed')
+
+ class ExampleView(APIView):
+ authentication_classes = (FailingAuth,)
+ renderer_classes = (AuthAccessingRenderer,)
+
+ def get(self, request):
+ return Response({'foo': 'bar'})
+
+ self.view = ExampleView.as_view()
+
+ def test_failing_auth_accessed_in_renderer(self):
+ """
+ When authentication fails the renderer should still be able to access
+ `request.user` without raising an exception. Particularly relevant
+ to HTML responses that might reasonably access `request.user`.
+ """
+ request = factory.get('/')
+ response = self.view(request)
+ content = response.render().content
+ self.assertEqual(content, b'not authenticated')
diff --git a/tests/test_breadcrumbs.py b/tests/test_breadcrumbs.py
new file mode 100644
index 00000000..78edc603
--- /dev/null
+++ b/tests/test_breadcrumbs.py
@@ -0,0 +1,73 @@
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework.compat import patterns, url
+from rest_framework.utils.breadcrumbs import get_breadcrumbs
+from rest_framework.views import APIView
+
+
+class Root(APIView):
+ pass
+
+
+class ResourceRoot(APIView):
+ pass
+
+
+class ResourceInstance(APIView):
+ pass
+
+
+class NestedResourceRoot(APIView):
+ pass
+
+
+class NestedResourceInstance(APIView):
+ pass
+
+urlpatterns = patterns('',
+ url(r'^$', Root.as_view()),
+ url(r'^resource/$', ResourceRoot.as_view()),
+ url(r'^resource/(?P[0-9]+)$', ResourceInstance.as_view()),
+ url(r'^resource/(?P[0-9]+)/$', NestedResourceRoot.as_view()),
+ url(r'^resource/(?P[0-9]+)/(?P[A-Za-z]+)$', NestedResourceInstance.as_view()),
+)
+
+
+class BreadcrumbTests(TestCase):
+ """Tests the breadcrumb functionality used by the HTML renderer."""
+
+ urls = 'tests.test_breadcrumbs'
+
+ def test_root_breadcrumbs(self):
+ url = '/'
+ self.assertEqual(get_breadcrumbs(url), [('Root', '/')])
+
+ def test_resource_root_breadcrumbs(self):
+ url = '/resource/'
+ self.assertEqual(get_breadcrumbs(url), [('Root', '/'),
+ ('Resource Root', '/resource/')])
+
+ def test_resource_instance_breadcrumbs(self):
+ url = '/resource/123'
+ self.assertEqual(get_breadcrumbs(url), [('Root', '/'),
+ ('Resource Root', '/resource/'),
+ ('Resource Instance', '/resource/123')])
+
+ def test_nested_resource_breadcrumbs(self):
+ url = '/resource/123/'
+ self.assertEqual(get_breadcrumbs(url), [('Root', '/'),
+ ('Resource Root', '/resource/'),
+ ('Resource Instance', '/resource/123'),
+ ('Nested Resource Root', '/resource/123/')])
+
+ def test_nested_resource_instance_breadcrumbs(self):
+ url = '/resource/123/abc'
+ self.assertEqual(get_breadcrumbs(url), [('Root', '/'),
+ ('Resource Root', '/resource/'),
+ ('Resource Instance', '/resource/123'),
+ ('Nested Resource Root', '/resource/123/'),
+ ('Nested Resource Instance', '/resource/123/abc')])
+
+ def test_broken_url_breadcrumbs_handled_gracefully(self):
+ url = '/foobar'
+ self.assertEqual(get_breadcrumbs(url), [('Root', '/')])
diff --git a/tests/test_decorators.py b/tests/test_decorators.py
new file mode 100644
index 00000000..195f0ba3
--- /dev/null
+++ b/tests/test_decorators.py
@@ -0,0 +1,157 @@
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework import status
+from rest_framework.authentication import BasicAuthentication
+from rest_framework.parsers import JSONParser
+from rest_framework.permissions import IsAuthenticated
+from rest_framework.response import Response
+from rest_framework.renderers import JSONRenderer
+from rest_framework.test import APIRequestFactory
+from rest_framework.throttling import UserRateThrottle
+from rest_framework.views import APIView
+from rest_framework.decorators import (
+ api_view,
+ renderer_classes,
+ parser_classes,
+ authentication_classes,
+ throttle_classes,
+ permission_classes,
+)
+
+
+class DecoratorTestCase(TestCase):
+
+ def setUp(self):
+ self.factory = APIRequestFactory()
+
+ def _finalize_response(self, request, response, *args, **kwargs):
+ response.request = request
+ return APIView.finalize_response(self, request, response, *args, **kwargs)
+
+ def test_api_view_incorrect(self):
+ """
+ If @api_view is not applied correct, we should raise an assertion.
+ """
+
+ @api_view
+ def view(request):
+ return Response()
+
+ request = self.factory.get('/')
+ self.assertRaises(AssertionError, view, request)
+
+ def test_api_view_incorrect_arguments(self):
+ """
+ If @api_view is missing arguments, we should raise an assertion.
+ """
+
+ with self.assertRaises(AssertionError):
+ @api_view('GET')
+ def view(request):
+ return Response()
+
+ def test_calling_method(self):
+
+ @api_view(['GET'])
+ def view(request):
+ return Response({})
+
+ request = self.factory.get('/')
+ response = view(request)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ request = self.factory.post('/')
+ response = view(request)
+ self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
+
+ def test_calling_put_method(self):
+
+ @api_view(['GET', 'PUT'])
+ def view(request):
+ return Response({})
+
+ request = self.factory.put('/')
+ response = view(request)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ request = self.factory.post('/')
+ response = view(request)
+ self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
+
+ def test_calling_patch_method(self):
+
+ @api_view(['GET', 'PATCH'])
+ def view(request):
+ return Response({})
+
+ request = self.factory.patch('/')
+ response = view(request)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ request = self.factory.post('/')
+ response = view(request)
+ self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
+
+ def test_renderer_classes(self):
+
+ @api_view(['GET'])
+ @renderer_classes([JSONRenderer])
+ def view(request):
+ return Response({})
+
+ request = self.factory.get('/')
+ response = view(request)
+ self.assertTrue(isinstance(response.accepted_renderer, JSONRenderer))
+
+ def test_parser_classes(self):
+
+ @api_view(['GET'])
+ @parser_classes([JSONParser])
+ def view(request):
+ self.assertEqual(len(request.parsers), 1)
+ self.assertTrue(isinstance(request.parsers[0],
+ JSONParser))
+ return Response({})
+
+ request = self.factory.get('/')
+ view(request)
+
+ def test_authentication_classes(self):
+
+ @api_view(['GET'])
+ @authentication_classes([BasicAuthentication])
+ def view(request):
+ self.assertEqual(len(request.authenticators), 1)
+ self.assertTrue(isinstance(request.authenticators[0],
+ BasicAuthentication))
+ return Response({})
+
+ request = self.factory.get('/')
+ view(request)
+
+ def test_permission_classes(self):
+
+ @api_view(['GET'])
+ @permission_classes([IsAuthenticated])
+ def view(request):
+ return Response({})
+
+ request = self.factory.get('/')
+ response = view(request)
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+ def test_throttle_classes(self):
+ class OncePerDayUserThrottle(UserRateThrottle):
+ rate = '1/day'
+
+ @api_view(['GET'])
+ @throttle_classes([OncePerDayUserThrottle])
+ def view(request):
+ return Response({})
+
+ request = self.factory.get('/')
+ response = view(request)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ response = view(request)
+ self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS)
diff --git a/tests/test_description.py b/tests/test_description.py
new file mode 100644
index 00000000..1e481f06
--- /dev/null
+++ b/tests/test_description.py
@@ -0,0 +1,108 @@
+# -- coding: utf-8 --
+
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework.compat import apply_markdown, smart_text
+from rest_framework.views import APIView
+from .description import ViewWithNonASCIICharactersInDocstring
+from .description import UTF8_TEST_DOCSTRING
+
+# We check that docstrings get nicely un-indented.
+DESCRIPTION = """an example docstring
+====================
+
+* list
+* list
+
+another header
+--------------
+
+ code block
+
+indented
+
+# hash style header #"""
+
+# If markdown is installed we also test it's working
+# (and that our wrapped forces '=' to h2 and '-' to h3)
+
+# We support markdown < 2.1 and markdown >= 2.1
+MARKED_DOWN_lt_21 = """
an example docstring
+
+
list
+
list
+
+
another header
+
code block
+
+
indented
+
hash style header
"""
+
+MARKED_DOWN_gte_21 = """
an example docstring
+
+
list
+
list
+
+
another header
+
code block
+
+
indented
+
hash style header
"""
+
+
+class TestViewNamesAndDescriptions(TestCase):
+ def test_view_name_uses_class_name(self):
+ """
+ Ensure view names are based on the class name.
+ """
+ class MockView(APIView):
+ pass
+ self.assertEqual(MockView().get_view_name(), 'Mock')
+
+ def test_view_description_uses_docstring(self):
+ """Ensure view descriptions are based on the docstring."""
+ class MockView(APIView):
+ """an example docstring
+ ====================
+
+ * list
+ * list
+
+ another header
+ --------------
+
+ code block
+
+ indented
+
+ # hash style header #"""
+
+ self.assertEqual(MockView().get_view_description(), DESCRIPTION)
+
+ def test_view_description_supports_unicode(self):
+ """
+ Unicode in docstrings should be respected.
+ """
+
+ self.assertEqual(
+ ViewWithNonASCIICharactersInDocstring().get_view_description(),
+ smart_text(UTF8_TEST_DOCSTRING)
+ )
+
+ def test_view_description_can_be_empty(self):
+ """
+ Ensure that if a view has no docstring,
+ then it's description is the empty string.
+ """
+ class MockView(APIView):
+ pass
+ self.assertEqual(MockView().get_view_description(), '')
+
+ def test_markdown(self):
+ """
+ Ensure markdown to HTML works as expected.
+ """
+ if apply_markdown:
+ gte_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_gte_21
+ lt_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_lt_21
+ self.assertTrue(gte_21_match or lt_21_match)
diff --git a/tests/test_fields.py b/tests/test_fields.py
new file mode 100644
index 00000000..e65a2fb3
--- /dev/null
+++ b/tests/test_fields.py
@@ -0,0 +1,984 @@
+"""
+General serializer field tests.
+"""
+from __future__ import unicode_literals
+
+import datetime
+from decimal import Decimal
+from uuid import uuid4
+from django.core import validators
+from django.db import models
+from django.test import TestCase
+from django.utils.datastructures import SortedDict
+from rest_framework import serializers
+from tests.models import RESTFrameworkModel
+
+
+class TimestampedModel(models.Model):
+ added = models.DateTimeField(auto_now_add=True)
+ updated = models.DateTimeField(auto_now=True)
+
+
+class CharPrimaryKeyModel(models.Model):
+ id = models.CharField(max_length=20, primary_key=True)
+
+
+class TimestampedModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = TimestampedModel
+
+
+class CharPrimaryKeyModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = CharPrimaryKeyModel
+
+
+class TimeFieldModel(models.Model):
+ clock = models.TimeField()
+
+
+class TimeFieldModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = TimeFieldModel
+
+
+SAMPLE_CHOICES = [
+ ('red', 'Red'),
+ ('green', 'Green'),
+ ('blue', 'Blue'),
+]
+
+
+class ChoiceFieldModel(models.Model):
+ choice = models.CharField(choices=SAMPLE_CHOICES, blank=True, max_length=255)
+
+
+class ChoiceFieldModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ChoiceFieldModel
+
+
+class ChoiceFieldModelWithNull(models.Model):
+ choice = models.CharField(choices=SAMPLE_CHOICES, blank=True, null=True, max_length=255)
+
+
+class ChoiceFieldModelWithNullSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ChoiceFieldModelWithNull
+
+
+class BasicFieldTests(TestCase):
+ def test_auto_now_fields_read_only(self):
+ """
+ auto_now and auto_now_add fields should be read_only by default.
+ """
+ serializer = TimestampedModelSerializer()
+ self.assertEqual(serializer.fields['added'].read_only, True)
+
+ def test_auto_pk_fields_read_only(self):
+ """
+ AutoField fields should be read_only by default.
+ """
+ serializer = TimestampedModelSerializer()
+ self.assertEqual(serializer.fields['id'].read_only, True)
+
+ def test_non_auto_pk_fields_not_read_only(self):
+ """
+ PK fields other than AutoField fields should not be read_only by default.
+ """
+ serializer = CharPrimaryKeyModelSerializer()
+ self.assertEqual(serializer.fields['id'].read_only, False)
+
+ def test_dict_field_ordering(self):
+ """
+ Field should preserve dictionary ordering, if it exists.
+ See: https://github.com/tomchristie/django-rest-framework/issues/832
+ """
+ ret = SortedDict()
+ ret['c'] = 1
+ ret['b'] = 1
+ ret['a'] = 1
+ ret['z'] = 1
+ field = serializers.Field()
+ keys = list(field.to_native(ret).keys())
+ self.assertEqual(keys, ['c', 'b', 'a', 'z'])
+
+
+class DateFieldTest(TestCase):
+ """
+ Tests for the DateFieldTest from_native() and to_native() behavior
+ """
+
+ def test_from_native_string(self):
+ """
+ Make sure from_native() accepts default iso input formats.
+ """
+ f = serializers.DateField()
+ result_1 = f.from_native('1984-07-31')
+
+ self.assertEqual(datetime.date(1984, 7, 31), result_1)
+
+ def test_from_native_datetime_date(self):
+ """
+ Make sure from_native() accepts a datetime.date instance.
+ """
+ f = serializers.DateField()
+ result_1 = f.from_native(datetime.date(1984, 7, 31))
+
+ self.assertEqual(result_1, datetime.date(1984, 7, 31))
+
+ def test_from_native_custom_format(self):
+ """
+ Make sure from_native() accepts custom input formats.
+ """
+ f = serializers.DateField(input_formats=['%Y -- %d'])
+ result = f.from_native('1984 -- 31')
+
+ self.assertEqual(datetime.date(1984, 1, 31), result)
+
+ def test_from_native_invalid_default_on_custom_format(self):
+ """
+ Make sure from_native() don't accept default formats if custom format is preset
+ """
+ f = serializers.DateField(input_formats=['%Y -- %d'])
+
+ try:
+ f.from_native('1984-07-31')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY -- DD"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns None on empty param.
+ """
+ f = serializers.DateField()
+ result = f.from_native('')
+
+ self.assertEqual(result, None)
+
+ def test_from_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DateField()
+ result = f.from_native(None)
+
+ self.assertEqual(result, None)
+
+ def test_from_native_invalid_date(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid date.
+ """
+ f = serializers.DateField()
+
+ try:
+ f.from_native('1984-13-31')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_invalid_format(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid format.
+ """
+ f = serializers.DateField()
+
+ try:
+ f.from_native('1984 -- 31')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_to_native(self):
+ """
+ Make sure to_native() returns datetime as default.
+ """
+ f = serializers.DateField()
+
+ result_1 = f.to_native(datetime.date(1984, 7, 31))
+
+ self.assertEqual(datetime.date(1984, 7, 31), result_1)
+
+ def test_to_native_iso(self):
+ """
+ Make sure to_native() with 'iso-8601' returns iso formated date.
+ """
+ f = serializers.DateField(format='iso-8601')
+
+ result_1 = f.to_native(datetime.date(1984, 7, 31))
+
+ self.assertEqual('1984-07-31', result_1)
+
+ def test_to_native_custom_format(self):
+ """
+ Make sure to_native() returns correct custom format.
+ """
+ f = serializers.DateField(format="%Y - %m.%d")
+
+ result_1 = f.to_native(datetime.date(1984, 7, 31))
+
+ self.assertEqual('1984 - 07.31', result_1)
+
+ def test_to_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DateField(required=False)
+ self.assertEqual(None, f.to_native(None))
+
+
+class DateTimeFieldTest(TestCase):
+ """
+ Tests for the DateTimeField from_native() and to_native() behavior
+ """
+
+ def test_from_native_string(self):
+ """
+ Make sure from_native() accepts default iso input formats.
+ """
+ f = serializers.DateTimeField()
+ result_1 = f.from_native('1984-07-31 04:31')
+ result_2 = f.from_native('1984-07-31 04:31:59')
+ result_3 = f.from_native('1984-07-31 04:31:59.000200')
+
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_1)
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_2)
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_3)
+
+ def test_from_native_datetime_datetime(self):
+ """
+ Make sure from_native() accepts a datetime.datetime instance.
+ """
+ f = serializers.DateTimeField()
+ result_1 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31))
+ result_2 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
+ result_3 = f.from_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+
+ self.assertEqual(result_1, datetime.datetime(1984, 7, 31, 4, 31))
+ self.assertEqual(result_2, datetime.datetime(1984, 7, 31, 4, 31, 59))
+ self.assertEqual(result_3, datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+
+ def test_from_native_custom_format(self):
+ """
+ Make sure from_native() accepts custom input formats.
+ """
+ f = serializers.DateTimeField(input_formats=['%Y -- %H:%M'])
+ result = f.from_native('1984 -- 04:59')
+
+ self.assertEqual(datetime.datetime(1984, 1, 1, 4, 59), result)
+
+ def test_from_native_invalid_default_on_custom_format(self):
+ """
+ Make sure from_native() don't accept default formats if custom format is preset
+ """
+ f = serializers.DateTimeField(input_formats=['%Y -- %H:%M'])
+
+ try:
+ f.from_native('1984-07-31 04:31:59')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: YYYY -- hh:mm"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns None on empty param.
+ """
+ f = serializers.DateTimeField()
+ result = f.from_native('')
+
+ self.assertEqual(result, None)
+
+ def test_from_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DateTimeField()
+ result = f.from_native(None)
+
+ self.assertEqual(result, None)
+
+ def test_from_native_invalid_datetime(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid datetime.
+ """
+ f = serializers.DateTimeField()
+
+ try:
+ f.from_native('04:61:59')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: "
+ "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_invalid_format(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid format.
+ """
+ f = serializers.DateTimeField()
+
+ try:
+ f.from_native('04 -- 31')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Datetime has wrong format. Use one of these formats instead: "
+ "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_to_native(self):
+ """
+ Make sure to_native() returns isoformat as default.
+ """
+ f = serializers.DateTimeField()
+
+ result_1 = f.to_native(datetime.datetime(1984, 7, 31))
+ result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31))
+ result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
+ result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+
+ self.assertEqual(datetime.datetime(1984, 7, 31), result_1)
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_2)
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_3)
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_4)
+
+ def test_to_native_iso(self):
+ """
+ Make sure to_native() with format=iso-8601 returns iso formatted datetime.
+ """
+ f = serializers.DateTimeField(format='iso-8601')
+
+ result_1 = f.to_native(datetime.datetime(1984, 7, 31))
+ result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31))
+ result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
+ result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+
+ self.assertEqual('1984-07-31T00:00:00', result_1)
+ self.assertEqual('1984-07-31T04:31:00', result_2)
+ self.assertEqual('1984-07-31T04:31:59', result_3)
+ self.assertEqual('1984-07-31T04:31:59.000200', result_4)
+
+ def test_to_native_custom_format(self):
+ """
+ Make sure to_native() returns correct custom format.
+ """
+ f = serializers.DateTimeField(format="%Y - %H:%M")
+
+ result_1 = f.to_native(datetime.datetime(1984, 7, 31))
+ result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31))
+ result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
+ result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+
+ self.assertEqual('1984 - 00:00', result_1)
+ self.assertEqual('1984 - 04:31', result_2)
+ self.assertEqual('1984 - 04:31', result_3)
+ self.assertEqual('1984 - 04:31', result_4)
+
+ def test_to_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DateTimeField(required=False)
+ self.assertEqual(None, f.to_native(None))
+
+
+class TimeFieldTest(TestCase):
+ """
+ Tests for the TimeField from_native() and to_native() behavior
+ """
+
+ def test_from_native_string(self):
+ """
+ Make sure from_native() accepts default iso input formats.
+ """
+ f = serializers.TimeField()
+ result_1 = f.from_native('04:31')
+ result_2 = f.from_native('04:31:59')
+ result_3 = f.from_native('04:31:59.000200')
+
+ self.assertEqual(datetime.time(4, 31), result_1)
+ self.assertEqual(datetime.time(4, 31, 59), result_2)
+ self.assertEqual(datetime.time(4, 31, 59, 200), result_3)
+
+ def test_from_native_datetime_time(self):
+ """
+ Make sure from_native() accepts a datetime.time instance.
+ """
+ f = serializers.TimeField()
+ result_1 = f.from_native(datetime.time(4, 31))
+ result_2 = f.from_native(datetime.time(4, 31, 59))
+ result_3 = f.from_native(datetime.time(4, 31, 59, 200))
+
+ self.assertEqual(result_1, datetime.time(4, 31))
+ self.assertEqual(result_2, datetime.time(4, 31, 59))
+ self.assertEqual(result_3, datetime.time(4, 31, 59, 200))
+
+ def test_from_native_custom_format(self):
+ """
+ Make sure from_native() accepts custom input formats.
+ """
+ f = serializers.TimeField(input_formats=['%H -- %M'])
+ result = f.from_native('04 -- 31')
+
+ self.assertEqual(datetime.time(4, 31), result)
+
+ def test_from_native_invalid_default_on_custom_format(self):
+ """
+ Make sure from_native() don't accept default formats if custom format is preset
+ """
+ f = serializers.TimeField(input_formats=['%H -- %M'])
+
+ try:
+ f.from_native('04:31:59')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: hh -- mm"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns None on empty param.
+ """
+ f = serializers.TimeField()
+ result = f.from_native('')
+
+ self.assertEqual(result, None)
+
+ def test_from_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.TimeField()
+ result = f.from_native(None)
+
+ self.assertEqual(result, None)
+
+ def test_from_native_invalid_time(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid time.
+ """
+ f = serializers.TimeField()
+
+ try:
+ f.from_native('04:61:59')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: "
+ "hh:mm[:ss[.uuuuuu]]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_invalid_format(self):
+ """
+ Make sure from_native() raises a ValidationError on passing an invalid format.
+ """
+ f = serializers.TimeField()
+
+ try:
+ f.from_native('04 -- 31')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Time has wrong format. Use one of these formats instead: "
+ "hh:mm[:ss[.uuuuuu]]"])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_to_native(self):
+ """
+ Make sure to_native() returns time object as default.
+ """
+ f = serializers.TimeField()
+ result_1 = f.to_native(datetime.time(4, 31))
+ result_2 = f.to_native(datetime.time(4, 31, 59))
+ result_3 = f.to_native(datetime.time(4, 31, 59, 200))
+
+ self.assertEqual(datetime.time(4, 31), result_1)
+ self.assertEqual(datetime.time(4, 31, 59), result_2)
+ self.assertEqual(datetime.time(4, 31, 59, 200), result_3)
+
+ def test_to_native_iso(self):
+ """
+ Make sure to_native() with format='iso-8601' returns iso formatted time.
+ """
+ f = serializers.TimeField(format='iso-8601')
+ result_1 = f.to_native(datetime.time(4, 31))
+ result_2 = f.to_native(datetime.time(4, 31, 59))
+ result_3 = f.to_native(datetime.time(4, 31, 59, 200))
+
+ self.assertEqual('04:31:00', result_1)
+ self.assertEqual('04:31:59', result_2)
+ self.assertEqual('04:31:59.000200', result_3)
+
+ def test_to_native_custom_format(self):
+ """
+ Make sure to_native() returns correct custom format.
+ """
+ f = serializers.TimeField(format="%H - %S [%f]")
+ result_1 = f.to_native(datetime.time(4, 31))
+ result_2 = f.to_native(datetime.time(4, 31, 59))
+ result_3 = f.to_native(datetime.time(4, 31, 59, 200))
+
+ self.assertEqual('04 - 00 [000000]', result_1)
+ self.assertEqual('04 - 59 [000000]', result_2)
+ self.assertEqual('04 - 59 [000200]', result_3)
+
+
+class DecimalFieldTest(TestCase):
+ """
+ Tests for the DecimalField from_native() and to_native() behavior
+ """
+
+ def test_from_native_string(self):
+ """
+ Make sure from_native() accepts string values
+ """
+ f = serializers.DecimalField()
+ result_1 = f.from_native('9000')
+ result_2 = f.from_native('1.00000001')
+
+ self.assertEqual(Decimal('9000'), result_1)
+ self.assertEqual(Decimal('1.00000001'), result_2)
+
+ def test_from_native_invalid_string(self):
+ """
+ Make sure from_native() raises ValidationError on passing invalid string
+ """
+ f = serializers.DecimalField()
+
+ try:
+ f.from_native('123.45.6')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Enter a number."])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_integer(self):
+ """
+ Make sure from_native() accepts integer values
+ """
+ f = serializers.DecimalField()
+ result = f.from_native(9000)
+
+ self.assertEqual(Decimal('9000'), result)
+
+ def test_from_native_float(self):
+ """
+ Make sure from_native() accepts float values
+ """
+ f = serializers.DecimalField()
+ result = f.from_native(1.00000001)
+
+ self.assertEqual(Decimal('1.00000001'), result)
+
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns None on empty param.
+ """
+ f = serializers.DecimalField()
+ result = f.from_native('')
+
+ self.assertEqual(result, None)
+
+ def test_from_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DecimalField()
+ result = f.from_native(None)
+
+ self.assertEqual(result, None)
+
+ def test_to_native(self):
+ """
+ Make sure to_native() returns Decimal as string.
+ """
+ f = serializers.DecimalField()
+
+ result_1 = f.to_native(Decimal('9000'))
+ result_2 = f.to_native(Decimal('1.00000001'))
+
+ self.assertEqual(Decimal('9000'), result_1)
+ self.assertEqual(Decimal('1.00000001'), result_2)
+
+ def test_to_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DecimalField(required=False)
+ self.assertEqual(None, f.to_native(None))
+
+ def test_valid_serialization(self):
+ """
+ Make sure the serializer works correctly
+ """
+ class DecimalSerializer(serializers.Serializer):
+ decimal_field = serializers.DecimalField(max_value=9010,
+ min_value=9000,
+ max_digits=6,
+ decimal_places=2)
+
+ self.assertTrue(DecimalSerializer(data={'decimal_field': '9001'}).is_valid())
+ self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.2'}).is_valid())
+ self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.23'}).is_valid())
+
+ self.assertFalse(DecimalSerializer(data={'decimal_field': '8000'}).is_valid())
+ self.assertFalse(DecimalSerializer(data={'decimal_field': '9900'}).is_valid())
+ self.assertFalse(DecimalSerializer(data={'decimal_field': '9001.234'}).is_valid())
+
+ def test_raise_max_value(self):
+ """
+ Make sure max_value violations raises ValidationError
+ """
+ class DecimalSerializer(serializers.Serializer):
+ decimal_field = serializers.DecimalField(max_value=100)
+
+ s = DecimalSerializer(data={'decimal_field': '123'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is less than or equal to 100.']})
+
+ def test_raise_min_value(self):
+ """
+ Make sure min_value violations raises ValidationError
+ """
+ class DecimalSerializer(serializers.Serializer):
+ decimal_field = serializers.DecimalField(min_value=100)
+
+ s = DecimalSerializer(data={'decimal_field': '99'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is greater than or equal to 100.']})
+
+ def test_raise_max_digits(self):
+ """
+ Make sure max_digits violations raises ValidationError
+ """
+ class DecimalSerializer(serializers.Serializer):
+ decimal_field = serializers.DecimalField(max_digits=5)
+
+ s = DecimalSerializer(data={'decimal_field': '123.456'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 5 digits in total.']})
+
+ def test_raise_max_decimal_places(self):
+ """
+ Make sure max_decimal_places violations raises ValidationError
+ """
+ class DecimalSerializer(serializers.Serializer):
+ decimal_field = serializers.DecimalField(decimal_places=3)
+
+ s = DecimalSerializer(data={'decimal_field': '123.4567'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 3 decimal places.']})
+
+ def test_raise_max_whole_digits(self):
+ """
+ Make sure max_whole_digits violations raises ValidationError
+ """
+ class DecimalSerializer(serializers.Serializer):
+ decimal_field = serializers.DecimalField(max_digits=4, decimal_places=3)
+
+ s = DecimalSerializer(data={'decimal_field': '12345.6'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']})
+
+
+class ChoiceFieldTests(TestCase):
+ """
+ Tests for the ChoiceField options generator
+ """
+ def test_choices_required(self):
+ """
+ Make sure proper choices are rendered if field is required
+ """
+ f = serializers.ChoiceField(required=True, choices=SAMPLE_CHOICES)
+ self.assertEqual(f.choices, SAMPLE_CHOICES)
+
+ def test_choices_not_required(self):
+ """
+ Make sure proper choices (plus blank) are rendered if the field isn't required
+ """
+ f = serializers.ChoiceField(required=False, choices=SAMPLE_CHOICES)
+ self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + SAMPLE_CHOICES)
+
+ def test_invalid_choice_model(self):
+ s = ChoiceFieldModelSerializer(data={'choice': 'wrong_value'})
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'choice': ['Select a valid choice. wrong_value is not one of the available choices.']})
+ self.assertEqual(s.data['choice'], '')
+
+ def test_empty_choice_model(self):
+ """
+ Test that the 'empty' value is correctly passed and used depending on
+ the 'null' property on the model field.
+ """
+ s = ChoiceFieldModelSerializer(data={'choice': ''})
+ self.assertTrue(s.is_valid())
+ self.assertEqual(s.data['choice'], '')
+
+ s = ChoiceFieldModelWithNullSerializer(data={'choice': ''})
+ self.assertTrue(s.is_valid())
+ self.assertEqual(s.data['choice'], None)
+
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns an empty string on empty param by default.
+ """
+ f = serializers.ChoiceField(choices=SAMPLE_CHOICES)
+ self.assertEqual(f.from_native(''), '')
+ self.assertEqual(f.from_native(None), '')
+
+ def test_from_native_empty_override(self):
+ """
+ Make sure you can override from_native() behavior regarding empty values.
+ """
+ f = serializers.ChoiceField(choices=SAMPLE_CHOICES, empty=None)
+ self.assertEqual(f.from_native(''), None)
+ self.assertEqual(f.from_native(None), None)
+
+ def test_metadata_choices(self):
+ """
+ Make sure proper choices are included in the field's metadata.
+ """
+ choices = [{'value': v, 'display_name': n} for v, n in SAMPLE_CHOICES]
+ f = serializers.ChoiceField(choices=SAMPLE_CHOICES)
+ self.assertEqual(f.metadata()['choices'], choices)
+
+ def test_metadata_choices_not_required(self):
+ """
+ Make sure proper choices are included in the field's metadata.
+ """
+ choices = [{'value': v, 'display_name': n}
+ for v, n in models.fields.BLANK_CHOICE_DASH + SAMPLE_CHOICES]
+ f = serializers.ChoiceField(required=False, choices=SAMPLE_CHOICES)
+ self.assertEqual(f.metadata()['choices'], choices)
+
+
+class EmailFieldTests(TestCase):
+ """
+ Tests for EmailField attribute values
+ """
+
+ class EmailFieldModel(RESTFrameworkModel):
+ email_field = models.EmailField(blank=True)
+
+ class EmailFieldWithGivenMaxLengthModel(RESTFrameworkModel):
+ email_field = models.EmailField(max_length=150, blank=True)
+
+ def test_default_model_value(self):
+ class EmailFieldSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = self.EmailFieldModel
+
+ serializer = EmailFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 75)
+
+ def test_given_model_value(self):
+ class EmailFieldSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = self.EmailFieldWithGivenMaxLengthModel
+
+ serializer = EmailFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 150)
+
+ def test_given_serializer_value(self):
+ class EmailFieldSerializer(serializers.ModelSerializer):
+ email_field = serializers.EmailField(source='email_field', max_length=20, required=False)
+
+ class Meta:
+ model = self.EmailFieldModel
+
+ serializer = EmailFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['email_field'], 'max_length'), 20)
+
+
+class SlugFieldTests(TestCase):
+ """
+ Tests for SlugField attribute values
+ """
+
+ class SlugFieldModel(RESTFrameworkModel):
+ slug_field = models.SlugField(blank=True)
+
+ class SlugFieldWithGivenMaxLengthModel(RESTFrameworkModel):
+ slug_field = models.SlugField(max_length=84, blank=True)
+
+ def test_default_model_value(self):
+ class SlugFieldSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = self.SlugFieldModel
+
+ serializer = SlugFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['slug_field'], 'max_length'), 50)
+
+ def test_given_model_value(self):
+ class SlugFieldSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = self.SlugFieldWithGivenMaxLengthModel
+
+ serializer = SlugFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['slug_field'], 'max_length'), 84)
+
+ def test_given_serializer_value(self):
+ class SlugFieldSerializer(serializers.ModelSerializer):
+ slug_field = serializers.SlugField(source='slug_field',
+ max_length=20, required=False)
+
+ class Meta:
+ model = self.SlugFieldModel
+
+ serializer = SlugFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['slug_field'],
+ 'max_length'), 20)
+
+ def test_invalid_slug(self):
+ """
+ Make sure an invalid slug raises ValidationError
+ """
+ class SlugFieldSerializer(serializers.ModelSerializer):
+ slug_field = serializers.SlugField(source='slug_field', max_length=20, required=True)
+
+ class Meta:
+ model = self.SlugFieldModel
+
+ s = SlugFieldSerializer(data={'slug_field': 'a b'})
+
+ self.assertEqual(s.is_valid(), False)
+ self.assertEqual(s.errors, {'slug_field': ["Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens."]})
+
+
+class URLFieldTests(TestCase):
+ """
+ Tests for URLField attribute values.
+
+ (Includes test for #1210, checking that validators can be overridden.)
+ """
+
+ class URLFieldModel(RESTFrameworkModel):
+ url_field = models.URLField(blank=True)
+
+ class URLFieldWithGivenMaxLengthModel(RESTFrameworkModel):
+ url_field = models.URLField(max_length=128, blank=True)
+
+ def test_default_model_value(self):
+ class URLFieldSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = self.URLFieldModel
+
+ serializer = URLFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['url_field'],
+ 'max_length'), 200)
+
+ def test_given_model_value(self):
+ class URLFieldSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = self.URLFieldWithGivenMaxLengthModel
+
+ serializer = URLFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['url_field'],
+ 'max_length'), 128)
+
+ def test_given_serializer_value(self):
+ class URLFieldSerializer(serializers.ModelSerializer):
+ url_field = serializers.URLField(source='url_field',
+ max_length=20, required=False)
+
+ class Meta:
+ model = self.URLFieldWithGivenMaxLengthModel
+
+ serializer = URLFieldSerializer(data={})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(getattr(serializer.fields['url_field'],
+ 'max_length'), 20)
+
+ def test_validators_can_be_overridden(self):
+ url_field = serializers.URLField(validators=[])
+ validators = url_field.validators
+ self.assertEqual([], validators, 'Passing `validators` kwarg should have overridden default validators')
+
+
+class FieldMetadata(TestCase):
+ def setUp(self):
+ self.required_field = serializers.Field()
+ self.required_field.label = uuid4().hex
+ self.required_field.required = True
+
+ self.optional_field = serializers.Field()
+ self.optional_field.label = uuid4().hex
+ self.optional_field.required = False
+
+ def test_required(self):
+ self.assertEqual(self.required_field.metadata()['required'], True)
+
+ def test_optional(self):
+ self.assertEqual(self.optional_field.metadata()['required'], False)
+
+ def test_label(self):
+ for field in (self.required_field, self.optional_field):
+ self.assertEqual(field.metadata()['label'], field.label)
+
+
+class FieldCallableDefault(TestCase):
+ def setUp(self):
+ self.simple_callable = lambda: 'foo bar'
+
+ def test_default_can_be_simple_callable(self):
+ """
+ Ensure that the 'default' argument can also be a simple callable.
+ """
+ field = serializers.WritableField(default=self.simple_callable)
+ into = {}
+ field.field_from_native({}, {}, 'field', into)
+ self.assertEqual(into, {'field': 'foo bar'})
+
+
+class CustomIntegerField(TestCase):
+ """
+ Test that custom fields apply min_value and max_value constraints
+ """
+ def test_custom_fields_can_be_validated_for_value(self):
+
+ class MoneyField(models.PositiveIntegerField):
+ pass
+
+ class EntryModel(models.Model):
+ bank = MoneyField(validators=[validators.MaxValueValidator(100)])
+
+ class EntrySerializer(serializers.ModelSerializer):
+ class Meta:
+ model = EntryModel
+
+ entry = EntryModel(bank=1)
+
+ serializer = EntrySerializer(entry, data={"bank": 11})
+ self.assertTrue(serializer.is_valid())
+
+ serializer = EntrySerializer(entry, data={"bank": -1})
+ self.assertFalse(serializer.is_valid())
+
+ serializer = EntrySerializer(entry, data={"bank": 101})
+ self.assertFalse(serializer.is_valid())
+
+
+class BooleanField(TestCase):
+ """
+ Tests for BooleanField
+ """
+ def test_boolean_required(self):
+ class BooleanRequiredSerializer(serializers.Serializer):
+ bool_field = serializers.BooleanField(required=True)
+
+ self.assertFalse(BooleanRequiredSerializer(data={}).is_valid())
diff --git a/tests/test_files.py b/tests/test_files.py
new file mode 100644
index 00000000..78f4cf42
--- /dev/null
+++ b/tests/test_files.py
@@ -0,0 +1,95 @@
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework import serializers
+from rest_framework.compat import BytesIO
+from rest_framework.compat import six
+import datetime
+
+
+class UploadedFile(object):
+ def __init__(self, file=None, created=None):
+ self.file = file
+ self.created = created or datetime.datetime.now()
+
+
+class UploadedFileSerializer(serializers.Serializer):
+ file = serializers.FileField(required=False)
+ created = serializers.DateTimeField()
+
+ def restore_object(self, attrs, instance=None):
+ if instance:
+ instance.file = attrs['file']
+ instance.created = attrs['created']
+ return instance
+ return UploadedFile(**attrs)
+
+
+class FileSerializerTests(TestCase):
+ def test_create(self):
+ now = datetime.datetime.now()
+ file = BytesIO(six.b('stuff'))
+ file.name = 'stuff.txt'
+ file.size = len(file.getvalue())
+ serializer = UploadedFileSerializer(data={'created': now}, files={'file': file})
+ uploaded_file = UploadedFile(file=file, created=now)
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.object.created, uploaded_file.created)
+ self.assertEqual(serializer.object.file, uploaded_file.file)
+ self.assertFalse(serializer.object is uploaded_file)
+
+ def test_creation_failure(self):
+ """
+ Passing files=None should result in an ValidationError
+
+ Regression test for:
+ https://github.com/tomchristie/django-rest-framework/issues/542
+ """
+ now = datetime.datetime.now()
+
+ serializer = UploadedFileSerializer(data={'created': now})
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.object.created, now)
+ self.assertIsNone(serializer.object.file)
+
+ def test_remove_with_empty_string(self):
+ """
+ Passing empty string as data should cause file to be removed
+
+ Test for:
+ https://github.com/tomchristie/django-rest-framework/issues/937
+ """
+ now = datetime.datetime.now()
+ file = BytesIO(six.b('stuff'))
+ file.name = 'stuff.txt'
+ file.size = len(file.getvalue())
+
+ uploaded_file = UploadedFile(file=file, created=now)
+
+ serializer = UploadedFileSerializer(instance=uploaded_file, data={'created': now, 'file': ''})
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.object.created, uploaded_file.created)
+ self.assertIsNone(serializer.object.file)
+
+ def test_validation_error_with_non_file(self):
+ """
+ Passing non-files should raise a validation error.
+ """
+ now = datetime.datetime.now()
+ errmsg = 'No file was submitted. Check the encoding type on the form.'
+
+ serializer = UploadedFileSerializer(data={'created': now, 'file': 'abc'})
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'file': [errmsg]})
+
+ def test_validation_with_no_data(self):
+ """
+ Validation should still function when no data dictionary is provided.
+ """
+ now = datetime.datetime.now()
+ file = BytesIO(six.b('stuff'))
+ file.name = 'stuff.txt'
+ file.size = len(file.getvalue())
+ uploaded_file = UploadedFile(file=file, created=now)
+
+ serializer = UploadedFileSerializer(files={'file': file})
+ self.assertFalse(serializer.is_valid())
diff --git a/tests/test_filters.py b/tests/test_filters.py
new file mode 100644
index 00000000..d9d8042e
--- /dev/null
+++ b/tests/test_filters.py
@@ -0,0 +1,615 @@
+from __future__ import unicode_literals
+import datetime
+from decimal import Decimal
+from django.db import models
+from django.core.urlresolvers import reverse
+from django.test import TestCase
+from django.utils import unittest
+from rest_framework import generics, serializers, status, filters
+from rest_framework.compat import django_filters, patterns, url
+from rest_framework.test import APIRequestFactory
+from tests.models import BasicModel
+
+factory = APIRequestFactory()
+
+
+class FilterableItem(models.Model):
+ text = models.CharField(max_length=100)
+ decimal = models.DecimalField(max_digits=4, decimal_places=2)
+ date = models.DateField()
+
+
+if django_filters:
+ # Basic filter on a list view.
+ class FilterFieldsRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_fields = ['decimal', 'date']
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ # These class are used to test a filter class.
+ class SeveralFieldsFilter(django_filters.FilterSet):
+ text = django_filters.CharFilter(lookup_type='icontains')
+ decimal = django_filters.NumberFilter(lookup_type='lt')
+ date = django_filters.DateFilter(lookup_type='gt')
+
+ class Meta:
+ model = FilterableItem
+ fields = ['text', 'decimal', 'date']
+
+ class FilterClassRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_class = SeveralFieldsFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ # These classes are used to test a misconfigured filter class.
+ class MisconfiguredFilter(django_filters.FilterSet):
+ text = django_filters.CharFilter(lookup_type='icontains')
+
+ class Meta:
+ model = BasicModel
+ fields = ['text']
+
+ class IncorrectlyConfiguredRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_class = MisconfiguredFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ class FilterClassDetailView(generics.RetrieveAPIView):
+ model = FilterableItem
+ filter_class = SeveralFieldsFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ # Regression test for #814
+ class FilterableItemSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = FilterableItem
+
+ class FilterFieldsQuerysetView(generics.ListCreateAPIView):
+ queryset = FilterableItem.objects.all()
+ serializer_class = FilterableItemSerializer
+ filter_fields = ['decimal', 'date']
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ class GetQuerysetView(generics.ListCreateAPIView):
+ serializer_class = FilterableItemSerializer
+ filter_class = SeveralFieldsFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ def get_queryset(self):
+ return FilterableItem.objects.all()
+
+ urlpatterns = patterns('',
+ url(r'^(?P\d+)/$', FilterClassDetailView.as_view(), name='detail-view'),
+ url(r'^$', FilterClassRootView.as_view(), name='root-view'),
+ url(r'^get-queryset/$', GetQuerysetView.as_view(),
+ name='get-queryset-view'),
+ )
+
+
+class CommonFilteringTestCase(TestCase):
+ def _serialize_object(self, obj):
+ return {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
+
+ def setUp(self):
+ """
+ Create 10 FilterableItem instances.
+ """
+ base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
+ for i in range(10):
+ text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
+ decimal = base_data[1] + i
+ date = base_data[2] - datetime.timedelta(days=i * 2)
+ FilterableItem(text=text, decimal=decimal, date=date).save()
+
+ self.objects = FilterableItem.objects
+ self.data = [
+ self._serialize_object(obj)
+ for obj in self.objects.all()
+ ]
+
+
+class IntegrationTestFiltering(CommonFilteringTestCase):
+ """
+ Integration tests for filtered list views.
+ """
+
+ @unittest.skipUnless(django_filters, 'django-filter not installed')
+ def test_get_filtered_fields_root_view(self):
+ """
+ GET requests to paginated ListCreateAPIView should return paginated results.
+ """
+ view = FilterFieldsRootView.as_view()
+
+ # Basic test with no filter.
+ request = factory.get('/')
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
+
+ # Tests that the decimal filter works.
+ search_decimal = Decimal('2.25')
+ request = factory.get('/?decimal=%s' % search_decimal)
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['decimal'] == search_decimal]
+ self.assertEqual(response.data, expected_data)
+
+ # Tests that the date filter works.
+ search_date = datetime.date(2012, 9, 22)
+ request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22'
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['date'] == search_date]
+ self.assertEqual(response.data, expected_data)
+
+ @unittest.skipUnless(django_filters, 'django-filter not installed')
+ def test_filter_with_queryset(self):
+ """
+ Regression test for #814.
+ """
+ view = FilterFieldsQuerysetView.as_view()
+
+ # Tests that the decimal filter works.
+ search_decimal = Decimal('2.25')
+ request = factory.get('/?decimal=%s' % search_decimal)
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['decimal'] == search_decimal]
+ self.assertEqual(response.data, expected_data)
+
+ @unittest.skipUnless(django_filters, 'django-filter not installed')
+ def test_filter_with_get_queryset_only(self):
+ """
+ Regression test for #834.
+ """
+ view = GetQuerysetView.as_view()
+ request = factory.get('/get-queryset/')
+ view(request).render()
+ # Used to raise "issubclass() arg 2 must be a class or tuple of classes"
+ # here when neither `model' nor `queryset' was specified.
+
+ @unittest.skipUnless(django_filters, 'django-filter not installed')
+ def test_get_filtered_class_root_view(self):
+ """
+ GET requests to filtered ListCreateAPIView that have a filter_class set
+ should return filtered results.
+ """
+ view = FilterClassRootView.as_view()
+
+ # Basic test with no filter.
+ request = factory.get('/')
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
+
+ # Tests that the decimal filter set with 'lt' in the filter class works.
+ search_decimal = Decimal('4.25')
+ request = factory.get('/?decimal=%s' % search_decimal)
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['decimal'] < search_decimal]
+ self.assertEqual(response.data, expected_data)
+
+ # Tests that the date filter set with 'gt' in the filter class works.
+ search_date = datetime.date(2012, 10, 2)
+ request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02'
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['date'] > search_date]
+ self.assertEqual(response.data, expected_data)
+
+ # Tests that the text filter set with 'icontains' in the filter class works.
+ search_text = 'ff'
+ request = factory.get('/?text=%s' % search_text)
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if search_text in f['text'].lower()]
+ self.assertEqual(response.data, expected_data)
+
+ # Tests that multiple filters works.
+ search_decimal = Decimal('5.25')
+ search_date = datetime.date(2012, 10, 2)
+ request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date))
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['date'] > search_date and
+ f['decimal'] < search_decimal]
+ self.assertEqual(response.data, expected_data)
+
+ @unittest.skipUnless(django_filters, 'django-filter not installed')
+ def test_incorrectly_configured_filter(self):
+ """
+ An error should be displayed when the filter class is misconfigured.
+ """
+ view = IncorrectlyConfiguredRootView.as_view()
+
+ request = factory.get('/')
+ self.assertRaises(AssertionError, view, request)
+
+ @unittest.skipUnless(django_filters, 'django-filter not installed')
+ def test_unknown_filter(self):
+ """
+ GET requests with filters that aren't configured should return 200.
+ """
+ view = FilterFieldsRootView.as_view()
+
+ search_integer = 10
+ request = factory.get('/?integer=%s' % search_integer)
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+
+class IntegrationTestDetailFiltering(CommonFilteringTestCase):
+ """
+ Integration tests for filtered detail views.
+ """
+ urls = 'tests.test_filters'
+
+ def _get_url(self, item):
+ return reverse('detail-view', kwargs=dict(pk=item.pk))
+
+ @unittest.skipUnless(django_filters, 'django-filter not installed')
+ def test_get_filtered_detail_view(self):
+ """
+ GET requests to filtered RetrieveAPIView that have a filter_class set
+ should return filtered results.
+ """
+ item = self.objects.all()[0]
+ data = self._serialize_object(item)
+
+ # Basic test with no filter.
+ response = self.client.get(self._get_url(item))
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, data)
+
+ # Tests that the decimal filter set that should fail.
+ search_decimal = Decimal('4.25')
+ high_item = self.objects.filter(decimal__gt=search_decimal)[0]
+ response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(high_item), param=search_decimal))
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
+ # Tests that the decimal filter set that should succeed.
+ search_decimal = Decimal('4.25')
+ low_item = self.objects.filter(decimal__lt=search_decimal)[0]
+ low_item_data = self._serialize_object(low_item)
+ response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal))
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, low_item_data)
+
+ # Tests that multiple filters works.
+ search_decimal = Decimal('5.25')
+ search_date = datetime.date(2012, 10, 2)
+ valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0]
+ valid_item_data = self._serialize_object(valid_item)
+ response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date))
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, valid_item_data)
+
+
+class SearchFilterModel(models.Model):
+ title = models.CharField(max_length=20)
+ text = models.CharField(max_length=100)
+
+
+class SearchFilterTests(TestCase):
+ def setUp(self):
+ # Sequence of title/text is:
+ #
+ # z abc
+ # zz bcd
+ # zzz cde
+ # ...
+ for idx in range(10):
+ title = 'z' * (idx + 1)
+ text = (
+ chr(idx + ord('a')) +
+ chr(idx + ord('b')) +
+ chr(idx + ord('c'))
+ )
+ SearchFilterModel(title=title, text=text).save()
+
+ def test_search(self):
+ class SearchListView(generics.ListAPIView):
+ model = SearchFilterModel
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('title', 'text')
+
+ view = SearchListView.as_view()
+ request = factory.get('?search=b')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 1, 'title': 'z', 'text': 'abc'},
+ {'id': 2, 'title': 'zz', 'text': 'bcd'}
+ ]
+ )
+
+ def test_exact_search(self):
+ class SearchListView(generics.ListAPIView):
+ model = SearchFilterModel
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('=title', 'text')
+
+ view = SearchListView.as_view()
+ request = factory.get('?search=zzz')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'zzz', 'text': 'cde'}
+ ]
+ )
+
+ def test_startswith_search(self):
+ class SearchListView(generics.ListAPIView):
+ model = SearchFilterModel
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('title', '^text')
+
+ view = SearchListView.as_view()
+ request = factory.get('?search=b')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 2, 'title': 'zz', 'text': 'bcd'}
+ ]
+ )
+
+
+class OrdringFilterModel(models.Model):
+ title = models.CharField(max_length=20)
+ text = models.CharField(max_length=100)
+
+
+class OrderingFilterRelatedModel(models.Model):
+ related_object = models.ForeignKey(OrdringFilterModel,
+ related_name="relateds")
+
+
+class OrderingFilterTests(TestCase):
+ def setUp(self):
+ # Sequence of title/text is:
+ #
+ # zyx abc
+ # yxw bcd
+ # xwv cde
+ for idx in range(3):
+ title = (
+ chr(ord('z') - idx) +
+ chr(ord('y') - idx) +
+ chr(ord('x') - idx)
+ )
+ text = (
+ chr(idx + ord('a')) +
+ chr(idx + ord('b')) +
+ chr(idx + ord('c'))
+ )
+ OrdringFilterModel(title=title, text=text).save()
+
+ def test_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+ ordering_fields = ('text',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=text')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ ]
+ )
+
+ def test_reverse_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+ ordering_fields = ('text',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=-text')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
+
+ def test_incorrectfield_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+ ordering_fields = ('text',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=foobar')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
+
+ def test_default_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+ oredering_fields = ('text',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
+
+ def test_default_ordering_using_string(self):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = 'title'
+ ordering_fields = ('text',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
+
+ def test_ordering_by_aggregate_field(self):
+ # create some related models to aggregate order by
+ num_objs = [2, 5, 3]
+ for obj, num_relateds in zip(OrdringFilterModel.objects.all(),
+ num_objs):
+ for _ in range(num_relateds):
+ new_related = OrderingFilterRelatedModel(
+ related_object=obj
+ )
+ new_related.save()
+
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = 'title'
+ ordering_fields = '__all__'
+ queryset = OrdringFilterModel.objects.all().annotate(
+ models.Count("relateds"))
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=relateds__count')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ ]
+ )
+
+
+class SensitiveOrderingFilterModel(models.Model):
+ username = models.CharField(max_length=20)
+ password = models.CharField(max_length=100)
+
+
+# Three different styles of serializer.
+# All should allow ordering by username, but not by password.
+class SensitiveDataSerializer1(serializers.ModelSerializer):
+ username = serializers.CharField()
+
+ class Meta:
+ model = SensitiveOrderingFilterModel
+ fields = ('id', 'username')
+
+
+class SensitiveDataSerializer2(serializers.ModelSerializer):
+ username = serializers.CharField()
+ password = serializers.CharField(write_only=True)
+
+ class Meta:
+ model = SensitiveOrderingFilterModel
+ fields = ('id', 'username', 'password')
+
+
+class SensitiveDataSerializer3(serializers.ModelSerializer):
+ user = serializers.CharField(source='username')
+
+ class Meta:
+ model = SensitiveOrderingFilterModel
+ fields = ('id', 'user')
+
+
+class SensitiveOrderingFilterTests(TestCase):
+ def setUp(self):
+ for idx in range(3):
+ username = {0: 'userA', 1: 'userB', 2: 'userC'}[idx]
+ password = {0: 'passA', 1: 'passC', 2: 'passB'}[idx]
+ SensitiveOrderingFilterModel(username=username, password=password).save()
+
+ def test_order_by_serializer_fields(self):
+ for serializer_cls in [
+ SensitiveDataSerializer1,
+ SensitiveDataSerializer2,
+ SensitiveDataSerializer3
+ ]:
+ class OrderingListView(generics.ListAPIView):
+ queryset = SensitiveOrderingFilterModel.objects.all().order_by('username')
+ filter_backends = (filters.OrderingFilter,)
+ serializer_class = serializer_cls
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=-username')
+ response = view(request)
+
+ if serializer_cls == SensitiveDataSerializer3:
+ username_field = 'user'
+ else:
+ username_field = 'username'
+
+ # Note: Inverse username ordering correctly applied.
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, username_field: 'userC'},
+ {'id': 2, username_field: 'userB'},
+ {'id': 1, username_field: 'userA'},
+ ]
+ )
+
+ def test_cannot_order_by_non_serializer_fields(self):
+ for serializer_cls in [
+ SensitiveDataSerializer1,
+ SensitiveDataSerializer2,
+ SensitiveDataSerializer3
+ ]:
+ class OrderingListView(generics.ListAPIView):
+ queryset = SensitiveOrderingFilterModel.objects.all().order_by('username')
+ filter_backends = (filters.OrderingFilter,)
+ serializer_class = serializer_cls
+
+ view = OrderingListView.as_view()
+ request = factory.get('?ordering=password')
+ response = view(request)
+
+ if serializer_cls == SensitiveDataSerializer3:
+ username_field = 'user'
+ else:
+ username_field = 'username'
+
+ # Note: The passwords are not in order. Default ordering is used.
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 1, username_field: 'userA'}, # PassB
+ {'id': 2, username_field: 'userB'}, # PassC
+ {'id': 3, username_field: 'userC'}, # PassA
+ ]
+ )
diff --git a/tests/test_genericrelations.py b/tests/test_genericrelations.py
new file mode 100644
index 00000000..2d341344
--- /dev/null
+++ b/tests/test_genericrelations.py
@@ -0,0 +1,129 @@
+from __future__ import unicode_literals
+from django.contrib.contenttypes.models import ContentType
+from django.contrib.contenttypes.generic import GenericRelation, GenericForeignKey
+from django.db import models
+from django.test import TestCase
+from rest_framework import serializers
+
+
+class Tag(models.Model):
+ """
+ Tags have a descriptive slug, and are attached to an arbitrary object.
+ """
+ tag = models.SlugField()
+ content_type = models.ForeignKey(ContentType)
+ object_id = models.PositiveIntegerField()
+ tagged_item = GenericForeignKey('content_type', 'object_id')
+
+ def __unicode__(self):
+ return self.tag
+
+
+class Bookmark(models.Model):
+ """
+ A URL bookmark that may have multiple tags attached.
+ """
+ url = models.URLField()
+ tags = GenericRelation(Tag)
+
+ def __unicode__(self):
+ return 'Bookmark: %s' % self.url
+
+
+class Note(models.Model):
+ """
+ A textual note that may have multiple tags attached.
+ """
+ text = models.TextField()
+ tags = GenericRelation(Tag)
+
+ def __unicode__(self):
+ return 'Note: %s' % self.text
+
+
+class TestGenericRelations(TestCase):
+ def setUp(self):
+ self.bookmark = Bookmark.objects.create(url='https://www.djangoproject.com/')
+ Tag.objects.create(tagged_item=self.bookmark, tag='django')
+ Tag.objects.create(tagged_item=self.bookmark, tag='python')
+ self.note = Note.objects.create(text='Remember the milk')
+ Tag.objects.create(tagged_item=self.note, tag='reminder')
+
+ def test_generic_relation(self):
+ """
+ Test a relationship that spans a GenericRelation field.
+ IE. A reverse generic relationship.
+ """
+
+ class BookmarkSerializer(serializers.ModelSerializer):
+ tags = serializers.RelatedField(many=True)
+
+ class Meta:
+ model = Bookmark
+ exclude = ('id',)
+
+ serializer = BookmarkSerializer(self.bookmark)
+ expected = {
+ 'tags': ['django', 'python'],
+ 'url': 'https://www.djangoproject.com/'
+ }
+ self.assertEqual(serializer.data, expected)
+
+ def test_generic_nested_relation(self):
+ """
+ Test saving a GenericRelation field via a nested serializer.
+ """
+
+ class TagSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Tag
+ exclude = ('content_type', 'object_id')
+
+ class BookmarkSerializer(serializers.ModelSerializer):
+ tags = TagSerializer()
+
+ class Meta:
+ model = Bookmark
+ exclude = ('id',)
+
+ data = {
+ 'url': 'https://docs.djangoproject.com/',
+ 'tags': [
+ {'tag': 'contenttypes'},
+ {'tag': 'genericrelations'},
+ ]
+ }
+ serializer = BookmarkSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEqual(serializer.object.tags.count(), 2)
+
+ def test_generic_fk(self):
+ """
+ Test a relationship that spans a GenericForeignKey field.
+ IE. A forward generic relationship.
+ """
+
+ class TagSerializer(serializers.ModelSerializer):
+ tagged_item = serializers.RelatedField()
+
+ class Meta:
+ model = Tag
+ exclude = ('id', 'content_type', 'object_id')
+
+ serializer = TagSerializer(Tag.objects.all(), many=True)
+ expected = [
+ {
+ 'tag': 'django',
+ 'tagged_item': 'Bookmark: https://www.djangoproject.com/'
+ },
+ {
+ 'tag': 'python',
+ 'tagged_item': 'Bookmark: https://www.djangoproject.com/'
+ },
+ {
+ 'tag': 'reminder',
+ 'tagged_item': 'Note: Remember the milk'
+ }
+ ]
+ self.assertEqual(serializer.data, expected)
diff --git a/tests/test_generics.py b/tests/test_generics.py
new file mode 100644
index 00000000..4389994a
--- /dev/null
+++ b/tests/test_generics.py
@@ -0,0 +1,609 @@
+from __future__ import unicode_literals
+from django.db import models
+from django.shortcuts import get_object_or_404
+from django.test import TestCase
+from rest_framework import generics, renderers, serializers, status
+from rest_framework.test import APIRequestFactory
+from tests.models import BasicModel, Comment, SlugBasedModel
+from rest_framework.compat import six
+
+factory = APIRequestFactory()
+
+
+class RootView(generics.ListCreateAPIView):
+ """
+ Example description for OPTIONS.
+ """
+ model = BasicModel
+
+
+class InstanceView(generics.RetrieveUpdateDestroyAPIView):
+ """
+ Example description for OPTIONS.
+ """
+ model = BasicModel
+
+ def get_queryset(self):
+ queryset = super(InstanceView, self).get_queryset()
+ return queryset.exclude(text='filtered out')
+
+
+class SlugSerializer(serializers.ModelSerializer):
+ slug = serializers.Field() # read only
+
+ class Meta:
+ model = SlugBasedModel
+ exclude = ('id',)
+
+
+class SlugBasedInstanceView(InstanceView):
+ """
+ A model with a slug-field.
+ """
+ model = SlugBasedModel
+ serializer_class = SlugSerializer
+ lookup_field = 'slug'
+
+
+class TestRootView(TestCase):
+ def setUp(self):
+ """
+ Create 3 BasicModel instances.
+ """
+ items = ['foo', 'bar', 'baz']
+ for item in items:
+ BasicModel(text=item).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.view = RootView.as_view()
+
+ def test_get_root_view(self):
+ """
+ GET requests to ListCreateAPIView should return list of objects.
+ """
+ request = factory.get('/')
+ with self.assertNumQueries(1):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
+
+ def test_post_root_view(self):
+ """
+ POST requests to ListCreateAPIView should create a new object.
+ """
+ data = {'text': 'foobar'}
+ request = factory.post('/', data, format='json')
+ with self.assertNumQueries(1):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ self.assertEqual(response.data, {'id': 4, 'text': 'foobar'})
+ created = self.objects.get(id=4)
+ self.assertEqual(created.text, 'foobar')
+
+ def test_put_root_view(self):
+ """
+ PUT requests to ListCreateAPIView should not be allowed
+ """
+ data = {'text': 'foobar'}
+ request = factory.put('/', data, format='json')
+ with self.assertNumQueries(0):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
+ self.assertEqual(response.data, {"detail": "Method 'PUT' not allowed."})
+
+ def test_delete_root_view(self):
+ """
+ DELETE requests to ListCreateAPIView should not be allowed
+ """
+ request = factory.delete('/')
+ with self.assertNumQueries(0):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
+ self.assertEqual(response.data, {"detail": "Method 'DELETE' not allowed."})
+
+ def test_options_root_view(self):
+ """
+ OPTIONS requests to ListCreateAPIView should return metadata
+ """
+ request = factory.options('/')
+ with self.assertNumQueries(0):
+ response = self.view(request).render()
+ expected = {
+ 'parses': [
+ 'application/json',
+ 'application/x-www-form-urlencoded',
+ 'multipart/form-data'
+ ],
+ 'renders': [
+ 'application/json',
+ 'text/html'
+ ],
+ 'name': 'Root',
+ 'description': 'Example description for OPTIONS.',
+ 'actions': {
+ 'POST': {
+ 'text': {
+ 'max_length': 100,
+ 'read_only': False,
+ 'required': True,
+ 'type': 'string',
+ "label": "Text comes here",
+ "help_text": "Text description."
+ },
+ 'id': {
+ 'read_only': True,
+ 'required': False,
+ 'type': 'integer',
+ 'label': 'ID',
+ },
+ }
+ }
+ }
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, expected)
+
+ def test_post_cannot_set_id(self):
+ """
+ POST requests to create a new object should not be able to set the id.
+ """
+ data = {'id': 999, 'text': 'foobar'}
+ request = factory.post('/', data, format='json')
+ with self.assertNumQueries(1):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ self.assertEqual(response.data, {'id': 4, 'text': 'foobar'})
+ created = self.objects.get(id=4)
+ self.assertEqual(created.text, 'foobar')
+
+
+class TestInstanceView(TestCase):
+ def setUp(self):
+ """
+ Create 3 BasicModel intances.
+ """
+ items = ['foo', 'bar', 'baz', 'filtered out']
+ for item in items:
+ BasicModel(text=item).save()
+ self.objects = BasicModel.objects.exclude(text='filtered out')
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.view = InstanceView.as_view()
+ self.slug_based_view = SlugBasedInstanceView.as_view()
+
+ def test_get_instance_view(self):
+ """
+ GET requests to RetrieveUpdateDestroyAPIView should return a single object.
+ """
+ request = factory.get('/1')
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data[0])
+
+ def test_post_instance_view(self):
+ """
+ POST requests to RetrieveUpdateDestroyAPIView should not be allowed
+ """
+ data = {'text': 'foobar'}
+ request = factory.post('/', data, format='json')
+ with self.assertNumQueries(0):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
+ self.assertEqual(response.data, {"detail": "Method 'POST' not allowed."})
+
+ def test_put_instance_view(self):
+ """
+ PUT requests to RetrieveUpdateDestroyAPIView should update an object.
+ """
+ data = {'text': 'foobar'}
+ request = factory.put('/1', data, format='json')
+ with self.assertNumQueries(2):
+ response = self.view(request, pk='1').render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
+ updated = self.objects.get(id=1)
+ self.assertEqual(updated.text, 'foobar')
+
+ def test_patch_instance_view(self):
+ """
+ PATCH requests to RetrieveUpdateDestroyAPIView should update an object.
+ """
+ data = {'text': 'foobar'}
+ request = factory.patch('/1', data, format='json')
+
+ with self.assertNumQueries(2):
+ response = self.view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
+ updated = self.objects.get(id=1)
+ self.assertEqual(updated.text, 'foobar')
+
+ def test_delete_instance_view(self):
+ """
+ DELETE requests to RetrieveUpdateDestroyAPIView should delete an object.
+ """
+ request = factory.delete('/1')
+ with self.assertNumQueries(2):
+ response = self.view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
+ self.assertEqual(response.content, six.b(''))
+ ids = [obj.id for obj in self.objects.all()]
+ self.assertEqual(ids, [2, 3])
+
+ def test_options_instance_view(self):
+ """
+ OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata
+ """
+ request = factory.options('/1')
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=1).render()
+ expected = {
+ 'parses': [
+ 'application/json',
+ 'application/x-www-form-urlencoded',
+ 'multipart/form-data'
+ ],
+ 'renders': [
+ 'application/json',
+ 'text/html'
+ ],
+ 'name': 'Instance',
+ 'description': 'Example description for OPTIONS.',
+ 'actions': {
+ 'PUT': {
+ 'text': {
+ 'max_length': 100,
+ 'read_only': False,
+ 'required': True,
+ 'type': 'string',
+ 'label': 'Text comes here',
+ 'help_text': 'Text description.'
+ },
+ 'id': {
+ 'read_only': True,
+ 'required': False,
+ 'type': 'integer',
+ 'label': 'ID',
+ },
+ }
+ }
+ }
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, expected)
+
+ def test_options_before_instance_create(self):
+ """
+ OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata
+ before the instance has been created
+ """
+ request = factory.options('/999')
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=999).render()
+ expected = {
+ 'parses': [
+ 'application/json',
+ 'application/x-www-form-urlencoded',
+ 'multipart/form-data'
+ ],
+ 'renders': [
+ 'application/json',
+ 'text/html'
+ ],
+ 'name': 'Instance',
+ 'description': 'Example description for OPTIONS.',
+ 'actions': {
+ 'PUT': {
+ 'text': {
+ 'max_length': 100,
+ 'read_only': False,
+ 'required': True,
+ 'type': 'string',
+ 'label': 'Text comes here',
+ 'help_text': 'Text description.'
+ },
+ 'id': {
+ 'read_only': True,
+ 'required': False,
+ 'type': 'integer',
+ 'label': 'ID',
+ },
+ }
+ }
+ }
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, expected)
+
+ def test_get_instance_view_incorrect_arg(self):
+ """
+ GET requests with an incorrect pk type, should raise 404, not 500.
+ Regression test for #890.
+ """
+ request = factory.get('/a')
+ with self.assertNumQueries(0):
+ response = self.view(request, pk='a').render()
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
+ def test_put_cannot_set_id(self):
+ """
+ PUT requests to create a new object should not be able to set the id.
+ """
+ data = {'id': 999, 'text': 'foobar'}
+ request = factory.put('/1', data, format='json')
+ with self.assertNumQueries(2):
+ response = self.view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
+ updated = self.objects.get(id=1)
+ self.assertEqual(updated.text, 'foobar')
+
+ def test_put_to_deleted_instance(self):
+ """
+ PUT requests to RetrieveUpdateDestroyAPIView should create an object
+ if it does not currently exist.
+ """
+ self.objects.get(id=1).delete()
+ data = {'text': 'foobar'}
+ request = factory.put('/1', data, format='json')
+ with self.assertNumQueries(3):
+ response = self.view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
+ updated = self.objects.get(id=1)
+ self.assertEqual(updated.text, 'foobar')
+
+ def test_put_to_filtered_out_instance(self):
+ """
+ PUT requests to an URL of instance which is filtered out should not be
+ able to create new objects.
+ """
+ data = {'text': 'foo'}
+ filtered_out_pk = BasicModel.objects.filter(text='filtered out')[0].pk
+ request = factory.put('/{0}'.format(filtered_out_pk), data, format='json')
+ response = self.view(request, pk=filtered_out_pk).render()
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+
+ def test_put_as_create_on_id_based_url(self):
+ """
+ PUT requests to RetrieveUpdateDestroyAPIView should create an object
+ at the requested url if it doesn't exist.
+ """
+ data = {'text': 'foobar'}
+ # pk fields can not be created on demand, only the database can set the pk for a new object
+ request = factory.put('/5', data, format='json')
+ with self.assertNumQueries(3):
+ response = self.view(request, pk=5).render()
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ new_obj = self.objects.get(pk=5)
+ self.assertEqual(new_obj.text, 'foobar')
+
+ def test_put_as_create_on_slug_based_url(self):
+ """
+ PUT requests to RetrieveUpdateDestroyAPIView should create an object
+ at the requested url if possible, else return HTTP_403_FORBIDDEN error-response.
+ """
+ data = {'text': 'foobar'}
+ request = factory.put('/test_slug', data, format='json')
+ with self.assertNumQueries(2):
+ response = self.slug_based_view(request, slug='test_slug').render()
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ self.assertEqual(response.data, {'slug': 'test_slug', 'text': 'foobar'})
+ new_obj = SlugBasedModel.objects.get(slug='test_slug')
+ self.assertEqual(new_obj.text, 'foobar')
+
+ def test_patch_cannot_create_an_object(self):
+ """
+ PATCH requests should not be able to create objects.
+ """
+ data = {'text': 'foobar'}
+ request = factory.patch('/999', data, format='json')
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=999).render()
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+ self.assertFalse(self.objects.filter(id=999).exists())
+
+
+class TestOverriddenGetObject(TestCase):
+ """
+ Test cases for a RetrieveUpdateDestroyAPIView that does NOT use the
+ queryset/model mechanism but instead overrides get_object()
+ """
+ def setUp(self):
+ """
+ Create 3 BasicModel intances.
+ """
+ items = ['foo', 'bar', 'baz']
+ for item in items:
+ BasicModel(text=item).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+
+ class OverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView):
+ """
+ Example detail view for override of get_object().
+ """
+ model = BasicModel
+
+ def get_object(self):
+ pk = int(self.kwargs['pk'])
+ return get_object_or_404(BasicModel.objects.all(), id=pk)
+
+ self.view = OverriddenGetObjectView.as_view()
+
+ def test_overridden_get_object_view(self):
+ """
+ GET requests to RetrieveUpdateDestroyAPIView should return a single object.
+ """
+ request = factory.get('/1')
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data[0])
+
+
+# Regression test for #285
+
+class CommentSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Comment
+ exclude = ('created',)
+
+
+class CommentView(generics.ListCreateAPIView):
+ serializer_class = CommentSerializer
+ model = Comment
+
+
+class TestCreateModelWithAutoNowAddField(TestCase):
+ def setUp(self):
+ self.objects = Comment.objects
+ self.view = CommentView.as_view()
+
+ def test_create_model_with_auto_now_add_field(self):
+ """
+ Regression test for #285
+
+ https://github.com/tomchristie/django-rest-framework/issues/285
+ """
+ data = {'email': 'foobar@example.com', 'content': 'foobar'}
+ request = factory.post('/', data, format='json')
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ created = self.objects.get(id=1)
+ self.assertEqual(created.content, 'foobar')
+
+
+# Test for particularly ugly regression with m2m in browsable API
+class ClassB(models.Model):
+ name = models.CharField(max_length=255)
+
+
+class ClassA(models.Model):
+ name = models.CharField(max_length=255)
+ childs = models.ManyToManyField(ClassB, blank=True, null=True)
+
+
+class ClassASerializer(serializers.ModelSerializer):
+ childs = serializers.PrimaryKeyRelatedField(many=True, source='childs')
+
+ class Meta:
+ model = ClassA
+
+
+class ExampleView(generics.ListCreateAPIView):
+ serializer_class = ClassASerializer
+ model = ClassA
+
+
+class TestM2MBrowseableAPI(TestCase):
+ def test_m2m_in_browseable_api(self):
+ """
+ Test for particularly ugly regression with m2m in browsable API
+ """
+ request = factory.get('/', HTTP_ACCEPT='text/html')
+ view = ExampleView().as_view()
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+
+class InclusiveFilterBackend(object):
+ def filter_queryset(self, request, queryset, view):
+ return queryset.filter(text='foo')
+
+
+class ExclusiveFilterBackend(object):
+ def filter_queryset(self, request, queryset, view):
+ return queryset.filter(text='other')
+
+
+class TwoFieldModel(models.Model):
+ field_a = models.CharField(max_length=100)
+ field_b = models.CharField(max_length=100)
+
+
+class DynamicSerializerView(generics.ListCreateAPIView):
+ model = TwoFieldModel
+ renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer)
+
+ def get_serializer_class(self):
+ if self.request.method == 'POST':
+ class DynamicSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = TwoFieldModel
+ fields = ('field_b',)
+ return DynamicSerializer
+ return super(DynamicSerializerView, self).get_serializer_class()
+
+
+class TestFilterBackendAppliedToViews(TestCase):
+
+ def setUp(self):
+ """
+ Create 3 BasicModel instances to filter on.
+ """
+ items = ['foo', 'bar', 'baz']
+ for item in items:
+ BasicModel(text=item).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+
+ def test_get_root_view_filters_by_name_with_filter_backend(self):
+ """
+ GET requests to ListCreateAPIView should return filtered list.
+ """
+ root_view = RootView.as_view(filter_backends=(InclusiveFilterBackend,))
+ request = factory.get('/')
+ response = root_view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(len(response.data), 1)
+ self.assertEqual(response.data, [{'id': 1, 'text': 'foo'}])
+
+ def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(self):
+ """
+ GET requests to ListCreateAPIView should return empty list when all models are filtered out.
+ """
+ root_view = RootView.as_view(filter_backends=(ExclusiveFilterBackend,))
+ request = factory.get('/')
+ response = root_view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, [])
+
+ def test_get_instance_view_filters_out_name_with_filter_backend(self):
+ """
+ GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out.
+ """
+ instance_view = InstanceView.as_view(filter_backends=(ExclusiveFilterBackend,))
+ request = factory.get('/1')
+ response = instance_view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+ self.assertEqual(response.data, {'detail': 'Not found'})
+
+ def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self):
+ """
+ GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded
+ """
+ instance_view = InstanceView.as_view(filter_backends=(InclusiveFilterBackend,))
+ request = factory.get('/1')
+ response = instance_view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, {'id': 1, 'text': 'foo'})
+
+ def test_dynamic_serializer_form_in_browsable_api(self):
+ """
+ GET requests to ListCreateAPIView should return filtered list.
+ """
+ view = DynamicSerializerView.as_view()
+ request = factory.get('/')
+ response = view(request).render()
+ self.assertContains(response, 'field_b')
+ self.assertNotContains(response, 'field_a')
diff --git a/tests/test_htmlrenderer.py b/tests/test_htmlrenderer.py
new file mode 100644
index 00000000..c748fbdb
--- /dev/null
+++ b/tests/test_htmlrenderer.py
@@ -0,0 +1,118 @@
+from __future__ import unicode_literals
+from django.core.exceptions import PermissionDenied
+from django.http import Http404
+from django.test import TestCase
+from django.template import TemplateDoesNotExist, Template
+import django.template.loader
+from rest_framework import status
+from rest_framework.compat import patterns, url
+from rest_framework.decorators import api_view, renderer_classes
+from rest_framework.renderers import TemplateHTMLRenderer
+from rest_framework.response import Response
+from rest_framework.compat import six
+
+
+@api_view(('GET',))
+@renderer_classes((TemplateHTMLRenderer,))
+def example(request):
+ """
+ A view that can returns an HTML representation.
+ """
+ data = {'object': 'foobar'}
+ return Response(data, template_name='example.html')
+
+
+@api_view(('GET',))
+@renderer_classes((TemplateHTMLRenderer,))
+def permission_denied(request):
+ raise PermissionDenied()
+
+
+@api_view(('GET',))
+@renderer_classes((TemplateHTMLRenderer,))
+def not_found(request):
+ raise Http404()
+
+
+urlpatterns = patterns('',
+ url(r'^$', example),
+ url(r'^permission_denied$', permission_denied),
+ url(r'^not_found$', not_found),
+)
+
+
+class TemplateHTMLRendererTests(TestCase):
+ urls = 'tests.test_htmlrenderer'
+
+ def setUp(self):
+ """
+ Monkeypatch get_template
+ """
+ self.get_template = django.template.loader.get_template
+
+ def get_template(template_name):
+ if template_name == 'example.html':
+ return Template("example: {{ object }}")
+ raise TemplateDoesNotExist(template_name)
+
+ django.template.loader.get_template = get_template
+
+ def tearDown(self):
+ """
+ Revert monkeypatching
+ """
+ django.template.loader.get_template = self.get_template
+
+ def test_simple_html_view(self):
+ response = self.client.get('/')
+ self.assertContains(response, "example: foobar")
+ self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
+
+ def test_not_found_html_view(self):
+ response = self.client.get('/not_found')
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+ self.assertEqual(response.content, six.b("404 Not Found"))
+ self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
+
+ def test_permission_denied_html_view(self):
+ response = self.client.get('/permission_denied')
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+ self.assertEqual(response.content, six.b("403 Forbidden"))
+ self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
+
+
+class TemplateHTMLRendererExceptionTests(TestCase):
+ urls = 'tests.test_htmlrenderer'
+
+ def setUp(self):
+ """
+ Monkeypatch get_template
+ """
+ self.get_template = django.template.loader.get_template
+
+ def get_template(template_name):
+ if template_name == '404.html':
+ return Template("404: {{ detail }}")
+ if template_name == '403.html':
+ return Template("403: {{ detail }}")
+ raise TemplateDoesNotExist(template_name)
+
+ django.template.loader.get_template = get_template
+
+ def tearDown(self):
+ """
+ Revert monkeypatching
+ """
+ django.template.loader.get_template = self.get_template
+
+ def test_not_found_html_view_with_template(self):
+ response = self.client.get('/not_found')
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+ self.assertEqual(response.content, six.b("404: Not found"))
+ self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
+
+ def test_permission_denied_html_view_with_template(self):
+ response = self.client.get('/permission_denied')
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+ self.assertEqual(response.content, six.b("403: Permission denied"))
+ self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
diff --git a/tests/test_hyperlinkedserializers.py b/tests/test_hyperlinkedserializers.py
new file mode 100644
index 00000000..eee179ca
--- /dev/null
+++ b/tests/test_hyperlinkedserializers.py
@@ -0,0 +1,379 @@
+from __future__ import unicode_literals
+import json
+from django.test import TestCase
+from rest_framework import generics, status, serializers
+from rest_framework.compat import patterns, url
+from rest_framework.settings import api_settings
+from rest_framework.test import APIRequestFactory
+from tests.models import (
+ Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment,
+ Album, Photo, OptionalRelationModel
+)
+
+factory = APIRequestFactory()
+
+
+class BlogPostCommentSerializer(serializers.ModelSerializer):
+ url = serializers.HyperlinkedIdentityField(view_name='blogpostcomment-detail')
+ text = serializers.CharField()
+ blog_post_url = serializers.HyperlinkedRelatedField(source='blog_post', view_name='blogpost-detail')
+
+ class Meta:
+ model = BlogPostComment
+ fields = ('text', 'blog_post_url', 'url')
+
+
+class PhotoSerializer(serializers.Serializer):
+ description = serializers.CharField()
+ album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', queryset=Album.objects.all(), lookup_field='title', slug_url_kwarg='title')
+
+ def restore_object(self, attrs, instance=None):
+ return Photo(**attrs)
+
+
+class AlbumSerializer(serializers.ModelSerializer):
+ url = serializers.HyperlinkedIdentityField(view_name='album-detail', lookup_field='title')
+
+ class Meta:
+ model = Album
+ fields = ('title', 'url')
+
+
+class BasicList(generics.ListCreateAPIView):
+ model = BasicModel
+ model_serializer_class = serializers.HyperlinkedModelSerializer
+
+
+class BasicDetail(generics.RetrieveUpdateDestroyAPIView):
+ model = BasicModel
+ model_serializer_class = serializers.HyperlinkedModelSerializer
+
+
+class AnchorDetail(generics.RetrieveAPIView):
+ model = Anchor
+ model_serializer_class = serializers.HyperlinkedModelSerializer
+
+
+class ManyToManyList(generics.ListAPIView):
+ model = ManyToManyModel
+ model_serializer_class = serializers.HyperlinkedModelSerializer
+
+
+class ManyToManyDetail(generics.RetrieveAPIView):
+ model = ManyToManyModel
+ model_serializer_class = serializers.HyperlinkedModelSerializer
+
+
+class BlogPostCommentListCreate(generics.ListCreateAPIView):
+ model = BlogPostComment
+ serializer_class = BlogPostCommentSerializer
+
+
+class BlogPostCommentDetail(generics.RetrieveAPIView):
+ model = BlogPostComment
+ serializer_class = BlogPostCommentSerializer
+
+
+class BlogPostDetail(generics.RetrieveAPIView):
+ model = BlogPost
+
+
+class PhotoListCreate(generics.ListCreateAPIView):
+ model = Photo
+ model_serializer_class = PhotoSerializer
+
+
+class AlbumDetail(generics.RetrieveAPIView):
+ model = Album
+ serializer_class = AlbumSerializer
+ lookup_field = 'title'
+
+
+class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView):
+ model = OptionalRelationModel
+ model_serializer_class = serializers.HyperlinkedModelSerializer
+
+
+urlpatterns = patterns('',
+ url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'),
+ url(r'^basic/(?P\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'),
+ url(r'^anchor/(?P\d+)/$', AnchorDetail.as_view(), name='anchor-detail'),
+ url(r'^manytomany/$', ManyToManyList.as_view(), name='manytomanymodel-list'),
+ url(r'^manytomany/(?P\d+)/$', ManyToManyDetail.as_view(), name='manytomanymodel-detail'),
+ url(r'^posts/(?P\d+)/$', BlogPostDetail.as_view(), name='blogpost-detail'),
+ url(r'^comments/$', BlogPostCommentListCreate.as_view(), name='blogpostcomment-list'),
+ url(r'^comments/(?P\d+)/$', BlogPostCommentDetail.as_view(), name='blogpostcomment-detail'),
+ url(r'^albums/(?P\w[\w-]*)/$', AlbumDetail.as_view(), name='album-detail'),
+ url(r'^photos/$', PhotoListCreate.as_view(), name='photo-list'),
+ url(r'^optionalrelation/(?P\d+)/$', OptionalRelationDetail.as_view(), name='optionalrelationmodel-detail'),
+)
+
+
+class TestBasicHyperlinkedView(TestCase):
+ urls = 'tests.test_hyperlinkedserializers'
+
+ def setUp(self):
+ """
+ Create 3 BasicModel instances.
+ """
+ items = ['foo', 'bar', 'baz']
+ for item in items:
+ BasicModel(text=item).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'url': 'http://testserver/basic/%d/' % obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.list_view = BasicList.as_view()
+ self.detail_view = BasicDetail.as_view()
+
+ def test_get_list_view(self):
+ """
+ GET requests to ListCreateAPIView should return list of objects.
+ """
+ request = factory.get('/basic/')
+ response = self.list_view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
+
+ def test_get_detail_view(self):
+ """
+ GET requests to ListCreateAPIView should return list of objects.
+ """
+ request = factory.get('/basic/1')
+ response = self.detail_view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data[0])
+
+
+class TestManyToManyHyperlinkedView(TestCase):
+ urls = 'tests.test_hyperlinkedserializers'
+
+ def setUp(self):
+ """
+ Create 3 BasicModel instances.
+ """
+ items = ['foo', 'bar', 'baz']
+ anchors = []
+ for item in items:
+ anchor = Anchor(text=item)
+ anchor.save()
+ anchors.append(anchor)
+
+ manytomany = ManyToManyModel()
+ manytomany.save()
+ manytomany.rel.add(*anchors)
+
+ self.data = [{
+ 'url': 'http://testserver/manytomany/1/',
+ 'rel': [
+ 'http://testserver/anchor/1/',
+ 'http://testserver/anchor/2/',
+ 'http://testserver/anchor/3/',
+ ]
+ }]
+ self.list_view = ManyToManyList.as_view()
+ self.detail_view = ManyToManyDetail.as_view()
+
+ def test_get_list_view(self):
+ """
+ GET requests to ListCreateAPIView should return list of objects.
+ """
+ request = factory.get('/manytomany/')
+ response = self.list_view(request)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
+
+ def test_get_detail_view(self):
+ """
+ GET requests to ListCreateAPIView should return list of objects.
+ """
+ request = factory.get('/manytomany/1/')
+ response = self.detail_view(request, pk=1)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data[0])
+
+
+class TestHyperlinkedIdentityFieldLookup(TestCase):
+ urls = 'tests.test_hyperlinkedserializers'
+
+ def setUp(self):
+ """
+ Create 3 Album instances.
+ """
+ titles = ['foo', 'bar', 'baz']
+ for title in titles:
+ album = Album(title=title)
+ album.save()
+ self.detail_view = AlbumDetail.as_view()
+ self.data = {
+ 'foo': {'title': 'foo', 'url': 'http://testserver/albums/foo/'},
+ 'bar': {'title': 'bar', 'url': 'http://testserver/albums/bar/'},
+ 'baz': {'title': 'baz', 'url': 'http://testserver/albums/baz/'}
+ }
+
+ def test_lookup_field(self):
+ """
+ GET requests to AlbumDetail view should return serialized Albums
+ with a url field keyed by `title`.
+ """
+ for album in Album.objects.all():
+ request = factory.get('/albums/{0}/'.format(album.title))
+ response = self.detail_view(request, title=album.title)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data[album.title])
+
+
+class TestCreateWithForeignKeys(TestCase):
+ urls = 'tests.test_hyperlinkedserializers'
+
+ def setUp(self):
+ """
+ Create a blog post
+ """
+ self.post = BlogPost.objects.create(title="Test post")
+ self.create_view = BlogPostCommentListCreate.as_view()
+
+ def test_create_comment(self):
+
+ data = {
+ 'text': 'A test comment',
+ 'blog_post_url': 'http://testserver/posts/1/'
+ }
+
+ request = factory.post('/comments/', data=data)
+ response = self.create_view(request)
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ self.assertEqual(response['Location'], 'http://testserver/comments/1/')
+ self.assertEqual(self.post.blogpostcomment_set.count(), 1)
+ self.assertEqual(self.post.blogpostcomment_set.all()[0].text, 'A test comment')
+
+
+class TestCreateWithForeignKeysAndCustomSlug(TestCase):
+ urls = 'tests.test_hyperlinkedserializers'
+
+ def setUp(self):
+ """
+ Create an Album
+ """
+ self.post = Album.objects.create(title='test-album')
+ self.list_create_view = PhotoListCreate.as_view()
+
+ def test_create_photo(self):
+
+ data = {
+ 'description': 'A test photo',
+ 'album_url': 'http://testserver/albums/test-album/'
+ }
+
+ request = factory.post('/photos/', data=data)
+ response = self.list_create_view(request)
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ self.assertNotIn('Location', response, msg='Location should only be included if there is a "url" field on the serializer')
+ self.assertEqual(self.post.photo_set.count(), 1)
+ self.assertEqual(self.post.photo_set.all()[0].description, 'A test photo')
+
+
+class TestOptionalRelationHyperlinkedView(TestCase):
+ urls = 'tests.test_hyperlinkedserializers'
+
+ def setUp(self):
+ """
+ Create 1 OptionalRelationModel instances.
+ """
+ OptionalRelationModel().save()
+ self.objects = OptionalRelationModel.objects
+ self.detail_view = OptionalRelationDetail.as_view()
+ self.data = {"url": "http://testserver/optionalrelation/1/", "other": None}
+
+ def test_get_detail_view(self):
+ """
+ GET requests to RetrieveAPIView with optional relations should return None
+ for non existing relations.
+ """
+ request = factory.get('/optionalrelationmodel-detail/1')
+ response = self.detail_view(request, pk=1)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
+
+ def test_put_detail_view(self):
+ """
+ PUT requests to RetrieveUpdateDestroyAPIView with optional relations
+ should accept None for non existing relations.
+ """
+ response = self.client.put('/optionalrelation/1/',
+ data=json.dumps(self.data),
+ content_type='application/json')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+
+class TestOverriddenURLField(TestCase):
+ def setUp(self):
+ class OverriddenURLSerializer(serializers.HyperlinkedModelSerializer):
+ url = serializers.SerializerMethodField('get_url')
+
+ class Meta:
+ model = BlogPost
+ fields = ('title', 'url')
+
+ def get_url(self, obj):
+ return 'foo bar'
+
+ self.Serializer = OverriddenURLSerializer
+ self.obj = BlogPost.objects.create(title='New blog post')
+
+ def test_overridden_url_field(self):
+ """
+ The 'url' field should respect overriding.
+ Regression test for #936.
+ """
+ serializer = self.Serializer(self.obj)
+ self.assertEqual(
+ serializer.data,
+ {'title': 'New blog post', 'url': 'foo bar'}
+ )
+
+
+class TestURLFieldNameBySettings(TestCase):
+ urls = 'tests.test_hyperlinkedserializers'
+
+ def setUp(self):
+ self.saved_url_field_name = api_settings.URL_FIELD_NAME
+ api_settings.URL_FIELD_NAME = 'global_url_field'
+
+ class Serializer(serializers.HyperlinkedModelSerializer):
+
+ class Meta:
+ model = BlogPost
+ fields = ('title', api_settings.URL_FIELD_NAME)
+
+ self.Serializer = Serializer
+ self.obj = BlogPost.objects.create(title="New blog post")
+
+ def tearDown(self):
+ api_settings.URL_FIELD_NAME = self.saved_url_field_name
+
+ def test_overridden_url_field_name(self):
+ request = factory.get('/posts/')
+ serializer = self.Serializer(self.obj, context={'request': request})
+ self.assertIn(api_settings.URL_FIELD_NAME, serializer.data)
+
+
+class TestURLFieldNameByOptions(TestCase):
+ urls = 'tests.test_hyperlinkedserializers'
+
+ def setUp(self):
+ class Serializer(serializers.HyperlinkedModelSerializer):
+
+ class Meta:
+ model = BlogPost
+ fields = ('title', 'serializer_url_field')
+ url_field_name = 'serializer_url_field'
+
+ self.Serializer = Serializer
+ self.obj = BlogPost.objects.create(title="New blog post")
+
+ def test_overridden_url_field_name(self):
+ request = factory.get('/posts/')
+ serializer = self.Serializer(self.obj, context={'request': request})
+ self.assertIn(self.Serializer.Meta.url_field_name, serializer.data)
diff --git a/tests/test_multitable_inheritance.py b/tests/test_multitable_inheritance.py
new file mode 100644
index 00000000..ce1bf3ea
--- /dev/null
+++ b/tests/test_multitable_inheritance.py
@@ -0,0 +1,67 @@
+from __future__ import unicode_literals
+from django.db import models
+from django.test import TestCase
+from rest_framework import serializers
+from tests.models import RESTFrameworkModel
+
+
+# Models
+class ParentModel(RESTFrameworkModel):
+ name1 = models.CharField(max_length=100)
+
+
+class ChildModel(ParentModel):
+ name2 = models.CharField(max_length=100)
+
+
+class AssociatedModel(RESTFrameworkModel):
+ ref = models.OneToOneField(ParentModel, primary_key=True)
+ name = models.CharField(max_length=100)
+
+
+# Serializers
+class DerivedModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ChildModel
+
+
+class AssociatedModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = AssociatedModel
+
+
+# Tests
+class IneritedModelSerializationTests(TestCase):
+
+ def test_multitable_inherited_model_fields_as_expected(self):
+ """
+ Assert that the parent pointer field is not included in the fields
+ serialized fields
+ """
+ child = ChildModel(name1='parent name', name2='child name')
+ serializer = DerivedModelSerializer(child)
+ self.assertEqual(set(serializer.data.keys()),
+ set(['name1', 'name2', 'id']))
+
+ def test_onetoone_primary_key_model_fields_as_expected(self):
+ """
+ Assert that a model with a onetoone field that is the primary key is
+ not treated like a derived model
+ """
+ parent = ParentModel(name1='parent name')
+ associate = AssociatedModel(name='hello', ref=parent)
+ serializer = AssociatedModelSerializer(associate)
+ self.assertEqual(set(serializer.data.keys()),
+ set(['name', 'ref']))
+
+ def test_data_is_valid_without_parent_ptr(self):
+ """
+ Assert that the pointer to the parent table is not a required field
+ for input data
+ """
+ data = {
+ 'name1': 'parent name',
+ 'name2': 'child name',
+ }
+ serializer = DerivedModelSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), True)
diff --git a/tests/test_negotiation.py b/tests/test_negotiation.py
new file mode 100644
index 00000000..04b89eb6
--- /dev/null
+++ b/tests/test_negotiation.py
@@ -0,0 +1,45 @@
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework.negotiation import DefaultContentNegotiation
+from rest_framework.request import Request
+from rest_framework.renderers import BaseRenderer
+from rest_framework.test import APIRequestFactory
+
+
+factory = APIRequestFactory()
+
+
+class MockJSONRenderer(BaseRenderer):
+ media_type = 'application/json'
+
+
+class MockHTMLRenderer(BaseRenderer):
+ media_type = 'text/html'
+
+
+class NoCharsetSpecifiedRenderer(BaseRenderer):
+ media_type = 'my/media'
+
+
+class TestAcceptedMediaType(TestCase):
+ def setUp(self):
+ self.renderers = [MockJSONRenderer(), MockHTMLRenderer()]
+ self.negotiator = DefaultContentNegotiation()
+
+ def select_renderer(self, request):
+ return self.negotiator.select_renderer(request, self.renderers)
+
+ def test_client_without_accept_use_renderer(self):
+ request = Request(factory.get('/'))
+ accepted_renderer, accepted_media_type = self.select_renderer(request)
+ self.assertEqual(accepted_media_type, 'application/json')
+
+ def test_client_underspecifies_accept_use_renderer(self):
+ request = Request(factory.get('/', HTTP_ACCEPT='*/*'))
+ accepted_renderer, accepted_media_type = self.select_renderer(request)
+ self.assertEqual(accepted_media_type, 'application/json')
+
+ def test_client_overspecifies_accept_use_client(self):
+ request = Request(factory.get('/', HTTP_ACCEPT='application/json; indent=8'))
+ accepted_renderer, accepted_media_type = self.select_renderer(request)
+ self.assertEqual(accepted_media_type, 'application/json; indent=8')
diff --git a/tests/test_nullable_fields.py b/tests/test_nullable_fields.py
new file mode 100644
index 00000000..33a9685f
--- /dev/null
+++ b/tests/test_nullable_fields.py
@@ -0,0 +1,30 @@
+from django.core.urlresolvers import reverse
+
+from rest_framework.compat import patterns, url
+from rest_framework.test import APITestCase
+from tests.models import NullableForeignKeySource
+from tests.serializers import NullableFKSourceSerializer
+from tests.views import NullableFKSourceDetail
+
+
+urlpatterns = patterns(
+ '',
+ url(r'^objects/(?P\d+)/$', NullableFKSourceDetail.as_view(), name='object-detail'),
+)
+
+
+class NullableForeignKeyTests(APITestCase):
+ """
+ DRF should be able to handle nullable foreign keys when a test
+ Client POST/PUT request is made with its own serialized object.
+ """
+ urls = 'tests.test_nullable_fields'
+
+ def test_updating_object_with_null_fk(self):
+ obj = NullableForeignKeySource(name='example', target=None)
+ obj.save()
+ serialized_data = NullableFKSourceSerializer(obj).data
+
+ response = self.client.put(reverse('object-detail', args=[obj.pk]), serialized_data)
+
+ self.assertEqual(response.data, serialized_data)
diff --git a/tests/test_pagination.py b/tests/test_pagination.py
new file mode 100644
index 00000000..65fa9dcd
--- /dev/null
+++ b/tests/test_pagination.py
@@ -0,0 +1,517 @@
+from __future__ import unicode_literals
+import datetime
+from decimal import Decimal
+from django.db import models
+from django.core.paginator import Paginator
+from django.test import TestCase
+from django.utils import unittest
+from rest_framework import generics, status, pagination, filters, serializers
+from rest_framework.compat import django_filters
+from rest_framework.test import APIRequestFactory
+from tests.models import BasicModel
+
+factory = APIRequestFactory()
+
+
+class FilterableItem(models.Model):
+ text = models.CharField(max_length=100)
+ decimal = models.DecimalField(max_digits=4, decimal_places=2)
+ date = models.DateField()
+
+
+class RootView(generics.ListCreateAPIView):
+ """
+ Example description for OPTIONS.
+ """
+ model = BasicModel
+ paginate_by = 10
+
+
+class DefaultPageSizeKwargView(generics.ListAPIView):
+ """
+ View for testing default paginate_by_param usage
+ """
+ model = BasicModel
+
+
+class PaginateByParamView(generics.ListAPIView):
+ """
+ View for testing custom paginate_by_param usage
+ """
+ model = BasicModel
+ paginate_by_param = 'page_size'
+
+
+class MaxPaginateByView(generics.ListAPIView):
+ """
+ View for testing custom max_paginate_by usage
+ """
+ model = BasicModel
+ paginate_by = 3
+ max_paginate_by = 5
+ paginate_by_param = 'page_size'
+
+
+class IntegrationTestPagination(TestCase):
+ """
+ Integration tests for paginated list views.
+ """
+
+ def setUp(self):
+ """
+ Create 26 BasicModel instances.
+ """
+ for char in 'abcdefghijklmnopqrstuvwxyz':
+ BasicModel(text=char * 3).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.view = RootView.as_view()
+
+ def test_get_paginated_root_view(self):
+ """
+ GET requests to paginated ListCreateAPIView should return paginated results.
+ """
+ request = factory.get('/')
+ # Note: Database queries are a `SELECT COUNT`, and `SELECT `
+ with self.assertNumQueries(2):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 26)
+ self.assertEqual(response.data['results'], self.data[:10])
+ self.assertNotEqual(response.data['next'], None)
+ self.assertEqual(response.data['previous'], None)
+
+ request = factory.get(response.data['next'])
+ with self.assertNumQueries(2):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 26)
+ self.assertEqual(response.data['results'], self.data[10:20])
+ self.assertNotEqual(response.data['next'], None)
+ self.assertNotEqual(response.data['previous'], None)
+
+ request = factory.get(response.data['next'])
+ with self.assertNumQueries(2):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 26)
+ self.assertEqual(response.data['results'], self.data[20:])
+ self.assertEqual(response.data['next'], None)
+ self.assertNotEqual(response.data['previous'], None)
+
+
+class IntegrationTestPaginationAndFiltering(TestCase):
+
+ def setUp(self):
+ """
+ Create 50 FilterableItem instances.
+ """
+ base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
+ for i in range(26):
+ text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
+ decimal = base_data[1] + i
+ date = base_data[2] - datetime.timedelta(days=i * 2)
+ FilterableItem(text=text, decimal=decimal, date=date).save()
+
+ self.objects = FilterableItem.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
+ for obj in self.objects.all()
+ ]
+
+ @unittest.skipUnless(django_filters, 'django-filter not installed')
+ def test_get_django_filter_paginated_filtered_root_view(self):
+ """
+ GET requests to paginated filtered ListCreateAPIView should return
+ paginated results. The next and previous links should preserve the
+ filtered parameters.
+ """
+ class DecimalFilter(django_filters.FilterSet):
+ decimal = django_filters.NumberFilter(lookup_type='lt')
+
+ class Meta:
+ model = FilterableItem
+ fields = ['text', 'decimal', 'date']
+
+ class FilterFieldsRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ paginate_by = 10
+ filter_class = DecimalFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ view = FilterFieldsRootView.as_view()
+
+ EXPECTED_NUM_QUERIES = 2
+
+ request = factory.get('/?decimal=15.20')
+ with self.assertNumQueries(EXPECTED_NUM_QUERIES):
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 15)
+ self.assertEqual(response.data['results'], self.data[:10])
+ self.assertNotEqual(response.data['next'], None)
+ self.assertEqual(response.data['previous'], None)
+
+ request = factory.get(response.data['next'])
+ with self.assertNumQueries(EXPECTED_NUM_QUERIES):
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 15)
+ self.assertEqual(response.data['results'], self.data[10:15])
+ self.assertEqual(response.data['next'], None)
+ self.assertNotEqual(response.data['previous'], None)
+
+ request = factory.get(response.data['previous'])
+ with self.assertNumQueries(EXPECTED_NUM_QUERIES):
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 15)
+ self.assertEqual(response.data['results'], self.data[:10])
+ self.assertNotEqual(response.data['next'], None)
+ self.assertEqual(response.data['previous'], None)
+
+ def test_get_basic_paginated_filtered_root_view(self):
+ """
+ Same as `test_get_django_filter_paginated_filtered_root_view`,
+ except using a custom filter backend instead of the django-filter
+ backend,
+ """
+
+ class DecimalFilterBackend(filters.BaseFilterBackend):
+ def filter_queryset(self, request, queryset, view):
+ return queryset.filter(decimal__lt=Decimal(request.GET['decimal']))
+
+ class BasicFilterFieldsRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ paginate_by = 10
+ filter_backends = (DecimalFilterBackend,)
+
+ view = BasicFilterFieldsRootView.as_view()
+
+ request = factory.get('/?decimal=15.20')
+ with self.assertNumQueries(2):
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 15)
+ self.assertEqual(response.data['results'], self.data[:10])
+ self.assertNotEqual(response.data['next'], None)
+ self.assertEqual(response.data['previous'], None)
+
+ request = factory.get(response.data['next'])
+ with self.assertNumQueries(2):
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 15)
+ self.assertEqual(response.data['results'], self.data[10:15])
+ self.assertEqual(response.data['next'], None)
+ self.assertNotEqual(response.data['previous'], None)
+
+ request = factory.get(response.data['previous'])
+ with self.assertNumQueries(2):
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['count'], 15)
+ self.assertEqual(response.data['results'], self.data[:10])
+ self.assertNotEqual(response.data['next'], None)
+ self.assertEqual(response.data['previous'], None)
+
+
+class PassOnContextPaginationSerializer(pagination.PaginationSerializer):
+ class Meta:
+ object_serializer_class = serializers.Serializer
+
+
+class UnitTestPagination(TestCase):
+ """
+ Unit tests for pagination of primitive objects.
+ """
+
+ def setUp(self):
+ self.objects = [char * 3 for char in 'abcdefghijklmnopqrstuvwxyz']
+ paginator = Paginator(self.objects, 10)
+ self.first_page = paginator.page(1)
+ self.last_page = paginator.page(3)
+
+ def test_native_pagination(self):
+ serializer = pagination.PaginationSerializer(self.first_page)
+ self.assertEqual(serializer.data['count'], 26)
+ self.assertEqual(serializer.data['next'], '?page=2')
+ self.assertEqual(serializer.data['previous'], None)
+ self.assertEqual(serializer.data['results'], self.objects[:10])
+
+ serializer = pagination.PaginationSerializer(self.last_page)
+ self.assertEqual(serializer.data['count'], 26)
+ self.assertEqual(serializer.data['next'], None)
+ self.assertEqual(serializer.data['previous'], '?page=2')
+ self.assertEqual(serializer.data['results'], self.objects[20:])
+
+ def test_context_available_in_result(self):
+ """
+ Ensure context gets passed through to the object serializer.
+ """
+ serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'})
+ serializer.data
+ results = serializer.fields[serializer.results_field]
+ self.assertEqual(serializer.context, results.context)
+
+
+class TestUnpaginated(TestCase):
+ """
+ Tests for list views without pagination.
+ """
+
+ def setUp(self):
+ """
+ Create 13 BasicModel instances.
+ """
+ for i in range(13):
+ BasicModel(text=i).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.view = DefaultPageSizeKwargView.as_view()
+
+ def test_unpaginated(self):
+ """
+ Tests the default page size for this view.
+ no page size --> no limit --> no meta data
+ """
+ request = factory.get('/')
+ response = self.view(request)
+ self.assertEqual(response.data, self.data)
+
+
+class TestCustomPaginateByParam(TestCase):
+ """
+ Tests for list views with default page size kwarg
+ """
+
+ def setUp(self):
+ """
+ Create 13 BasicModel instances.
+ """
+ for i in range(13):
+ BasicModel(text=i).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.view = PaginateByParamView.as_view()
+
+ def test_default_page_size(self):
+ """
+ Tests the default page size for this view.
+ no page size --> no limit --> no meta data
+ """
+ request = factory.get('/')
+ response = self.view(request).render()
+ self.assertEqual(response.data, self.data)
+
+ def test_paginate_by_param(self):
+ """
+ If paginate_by_param is set, the new kwarg should limit per view requests.
+ """
+ request = factory.get('/?page_size=5')
+ response = self.view(request).render()
+ self.assertEqual(response.data['count'], 13)
+ self.assertEqual(response.data['results'], self.data[:5])
+
+
+class TestMaxPaginateByParam(TestCase):
+ """
+ Tests for list views with max_paginate_by kwarg
+ """
+
+ def setUp(self):
+ """
+ Create 13 BasicModel instances.
+ """
+ for i in range(13):
+ BasicModel(text=i).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.view = MaxPaginateByView.as_view()
+
+ def test_max_paginate_by(self):
+ """
+ If max_paginate_by is set, it should limit page size for the view.
+ """
+ request = factory.get('/?page_size=10')
+ response = self.view(request).render()
+ self.assertEqual(response.data['count'], 13)
+ self.assertEqual(response.data['results'], self.data[:5])
+
+ def test_max_paginate_by_without_page_size_param(self):
+ """
+ If max_paginate_by is set, but client does not specifiy page_size,
+ standard `paginate_by` behavior should be used.
+ """
+ request = factory.get('/')
+ response = self.view(request).render()
+ self.assertEqual(response.data['results'], self.data[:3])
+
+
+### Tests for context in pagination serializers
+
+class CustomField(serializers.Field):
+ def to_native(self, value):
+ if not 'view' in self.context:
+ raise RuntimeError("context isn't getting passed into custom field")
+ return "value"
+
+
+class BasicModelSerializer(serializers.Serializer):
+ text = CustomField()
+
+ def __init__(self, *args, **kwargs):
+ super(BasicModelSerializer, self).__init__(*args, **kwargs)
+ if not 'view' in self.context:
+ raise RuntimeError("context isn't getting passed into serializer init")
+
+
+class TestContextPassedToCustomField(TestCase):
+ def setUp(self):
+ BasicModel.objects.create(text='ala ma kota')
+
+ def test_with_pagination(self):
+ class ListView(generics.ListCreateAPIView):
+ model = BasicModel
+ serializer_class = BasicModelSerializer
+ paginate_by = 1
+
+ self.view = ListView.as_view()
+ request = factory.get('/')
+ response = self.view(request).render()
+
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+
+### Tests for custom pagination serializers
+
+class LinksSerializer(serializers.Serializer):
+ next = pagination.NextPageField(source='*')
+ prev = pagination.PreviousPageField(source='*')
+
+
+class CustomPaginationSerializer(pagination.BasePaginationSerializer):
+ links = LinksSerializer(source='*') # Takes the page object as the source
+ total_results = serializers.Field(source='paginator.count')
+
+ results_field = 'objects'
+
+
+class TestCustomPaginationSerializer(TestCase):
+ def setUp(self):
+ objects = ['john', 'paul', 'george', 'ringo']
+ paginator = Paginator(objects, 2)
+ self.page = paginator.page(1)
+
+ def test_custom_pagination_serializer(self):
+ request = APIRequestFactory().get('/foobar')
+ serializer = CustomPaginationSerializer(
+ instance=self.page,
+ context={'request': request}
+ )
+ expected = {
+ 'links': {
+ 'next': 'http://testserver/foobar?page=2',
+ 'prev': None
+ },
+ 'total_results': 4,
+ 'objects': ['john', 'paul']
+ }
+ self.assertEqual(serializer.data, expected)
+
+
+class NonIntegerPage(object):
+
+ def __init__(self, paginator, object_list, prev_token, token, next_token):
+ self.paginator = paginator
+ self.object_list = object_list
+ self.prev_token = prev_token
+ self.token = token
+ self.next_token = next_token
+
+ def has_next(self):
+ return not not self.next_token
+
+ def next_page_number(self):
+ return self.next_token
+
+ def has_previous(self):
+ return not not self.prev_token
+
+ def previous_page_number(self):
+ return self.prev_token
+
+
+class NonIntegerPaginator(object):
+
+ def __init__(self, object_list, per_page):
+ self.object_list = object_list
+ self.per_page = per_page
+
+ def count(self):
+ # pretend like we don't know how many pages we have
+ return None
+
+ def page(self, token=None):
+ if token:
+ try:
+ first = self.object_list.index(token)
+ except ValueError:
+ first = 0
+ else:
+ first = 0
+ n = len(self.object_list)
+ last = min(first + self.per_page, n)
+ prev_token = self.object_list[last - (2 * self.per_page)] if first else None
+ next_token = self.object_list[last] if last < n else None
+ return NonIntegerPage(self, self.object_list[first:last], prev_token, token, next_token)
+
+
+class TestNonIntegerPagination(TestCase):
+
+
+ def test_custom_pagination_serializer(self):
+ objects = ['john', 'paul', 'george', 'ringo']
+ paginator = NonIntegerPaginator(objects, 2)
+
+ request = APIRequestFactory().get('/foobar')
+ serializer = CustomPaginationSerializer(
+ instance=paginator.page(),
+ context={'request': request}
+ )
+ expected = {
+ 'links': {
+ 'next': 'http://testserver/foobar?page={0}'.format(objects[2]),
+ 'prev': None
+ },
+ 'total_results': None,
+ 'objects': objects[:2]
+ }
+ self.assertEqual(serializer.data, expected)
+
+ request = APIRequestFactory().get('/foobar')
+ serializer = CustomPaginationSerializer(
+ instance=paginator.page('george'),
+ context={'request': request}
+ )
+ expected = {
+ 'links': {
+ 'next': None,
+ 'prev': 'http://testserver/foobar?page={0}'.format(objects[0]),
+ },
+ 'total_results': None,
+ 'objects': objects[2:]
+ }
+ self.assertEqual(serializer.data, expected)
diff --git a/tests/test_parsers.py b/tests/test_parsers.py
new file mode 100644
index 00000000..7699e10c
--- /dev/null
+++ b/tests/test_parsers.py
@@ -0,0 +1,115 @@
+from __future__ import unicode_literals
+from rest_framework.compat import StringIO
+from django import forms
+from django.core.files.uploadhandler import MemoryFileUploadHandler
+from django.test import TestCase
+from django.utils import unittest
+from rest_framework.compat import etree
+from rest_framework.parsers import FormParser, FileUploadParser
+from rest_framework.parsers import XMLParser
+import datetime
+
+
+class Form(forms.Form):
+ field1 = forms.CharField(max_length=3)
+ field2 = forms.CharField()
+
+
+class TestFormParser(TestCase):
+ def setUp(self):
+ self.string = "field1=abc&field2=defghijk"
+
+ def test_parse(self):
+ """ Make sure the `QueryDict` works OK """
+ parser = FormParser()
+
+ stream = StringIO(self.string)
+ data = parser.parse(stream)
+
+ self.assertEqual(Form(data).is_valid(), True)
+
+
+class TestXMLParser(TestCase):
+ def setUp(self):
+ self._input = StringIO(
+ ''
+ ''
+ '121.0'
+ 'dasd'
+ ''
+ '2011-12-25 12:45:00'
+ ''
+ )
+ self._data = {
+ 'field_a': 121,
+ 'field_b': 'dasd',
+ 'field_c': None,
+ 'field_d': datetime.datetime(2011, 12, 25, 12, 45, 00)
+ }
+ self._complex_data_input = StringIO(
+ ''
+ ''
+ '2011-12-25 12:45:00'
+ ''
+ '1first'
+ '2second'
+ ''
+ 'name'
+ ''
+ )
+ self._complex_data = {
+ "creation_date": datetime.datetime(2011, 12, 25, 12, 45, 00),
+ "name": "name",
+ "sub_data_list": [
+ {
+ "sub_id": 1,
+ "sub_name": "first"
+ },
+ {
+ "sub_id": 2,
+ "sub_name": "second"
+ }
+ ]
+ }
+
+ @unittest.skipUnless(etree, 'defusedxml not installed')
+ def test_parse(self):
+ parser = XMLParser()
+ data = parser.parse(self._input)
+ self.assertEqual(data, self._data)
+
+ @unittest.skipUnless(etree, 'defusedxml not installed')
+ def test_complex_data_parse(self):
+ parser = XMLParser()
+ data = parser.parse(self._complex_data_input)
+ self.assertEqual(data, self._complex_data)
+
+
+class TestFileUploadParser(TestCase):
+ def setUp(self):
+ class MockRequest(object):
+ pass
+ from io import BytesIO
+ self.stream = BytesIO(
+ "Test text file".encode('utf-8')
+ )
+ request = MockRequest()
+ request.upload_handlers = (MemoryFileUploadHandler(),)
+ request.META = {
+ 'HTTP_CONTENT_DISPOSITION': 'Content-Disposition: inline; filename=file.txt'.encode('utf-8'),
+ 'HTTP_CONTENT_LENGTH': 14,
+ }
+ self.parser_context = {'request': request, 'kwargs': {}}
+
+ def test_parse(self):
+ """ Make sure the `QueryDict` works OK """
+ parser = FileUploadParser()
+ self.stream.seek(0)
+ data_and_files = parser.parse(self.stream, None, self.parser_context)
+ file_obj = data_and_files.files['file']
+ self.assertEqual(file_obj._size, 14)
+
+ def test_get_filename(self):
+ parser = FileUploadParser()
+ filename = parser.get_filename(self.stream, None, self.parser_context)
+ self.assertEqual(filename, 'file.txt'.encode('utf-8'))
diff --git a/tests/test_permissions.py b/tests/test_permissions.py
new file mode 100644
index 00000000..a2cb0c36
--- /dev/null
+++ b/tests/test_permissions.py
@@ -0,0 +1,291 @@
+from __future__ import unicode_literals
+from django.contrib.auth.models import User, Permission, Group
+from django.db import models
+from django.test import TestCase
+from django.utils import unittest
+from rest_framework import generics, status, permissions, authentication, HTTP_HEADER_ENCODING
+from rest_framework.compat import guardian, get_model_name
+from rest_framework.filters import DjangoObjectPermissionsFilter
+from rest_framework.test import APIRequestFactory
+from tests.models import BasicModel
+import base64
+
+factory = APIRequestFactory()
+
+class RootView(generics.ListCreateAPIView):
+ model = BasicModel
+ authentication_classes = [authentication.BasicAuthentication]
+ permission_classes = [permissions.DjangoModelPermissions]
+
+
+class InstanceView(generics.RetrieveUpdateDestroyAPIView):
+ model = BasicModel
+ authentication_classes = [authentication.BasicAuthentication]
+ permission_classes = [permissions.DjangoModelPermissions]
+
+root_view = RootView.as_view()
+instance_view = InstanceView.as_view()
+
+
+def basic_auth_header(username, password):
+ credentials = ('%s:%s' % (username, password))
+ base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING)
+ return 'Basic %s' % base64_credentials
+
+
+class ModelPermissionsIntegrationTests(TestCase):
+ def setUp(self):
+ User.objects.create_user('disallowed', 'disallowed@example.com', 'password')
+ user = User.objects.create_user('permitted', 'permitted@example.com', 'password')
+ user.user_permissions = [
+ Permission.objects.get(codename='add_basicmodel'),
+ Permission.objects.get(codename='change_basicmodel'),
+ Permission.objects.get(codename='delete_basicmodel')
+ ]
+ user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password')
+ user.user_permissions = [
+ Permission.objects.get(codename='change_basicmodel'),
+ ]
+
+ self.permitted_credentials = basic_auth_header('permitted', 'password')
+ self.disallowed_credentials = basic_auth_header('disallowed', 'password')
+ self.updateonly_credentials = basic_auth_header('updateonly', 'password')
+
+ BasicModel(text='foo').save()
+
+ def test_has_create_permissions(self):
+ request = factory.post('/', {'text': 'foobar'}, format='json',
+ HTTP_AUTHORIZATION=self.permitted_credentials)
+ response = root_view(request, pk=1)
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+
+ def test_has_put_permissions(self):
+ request = factory.put('/1', {'text': 'foobar'}, format='json',
+ HTTP_AUTHORIZATION=self.permitted_credentials)
+ response = instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ def test_has_delete_permissions(self):
+ request = factory.delete('/1', HTTP_AUTHORIZATION=self.permitted_credentials)
+ response = instance_view(request, pk=1)
+ self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
+
+ def test_does_not_have_create_permissions(self):
+ request = factory.post('/', {'text': 'foobar'}, format='json',
+ HTTP_AUTHORIZATION=self.disallowed_credentials)
+ response = root_view(request, pk=1)
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+ def test_does_not_have_put_permissions(self):
+ request = factory.put('/1', {'text': 'foobar'}, format='json',
+ HTTP_AUTHORIZATION=self.disallowed_credentials)
+ response = instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+ def test_does_not_have_delete_permissions(self):
+ request = factory.delete('/1', HTTP_AUTHORIZATION=self.disallowed_credentials)
+ response = instance_view(request, pk=1)
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+ def test_has_put_as_create_permissions(self):
+ # User only has update permissions - should be able to update an entity.
+ request = factory.put('/1', {'text': 'foobar'}, format='json',
+ HTTP_AUTHORIZATION=self.updateonly_credentials)
+ response = instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ # But if PUTing to a new entity, permission should be denied.
+ request = factory.put('/2', {'text': 'foobar'}, format='json',
+ HTTP_AUTHORIZATION=self.updateonly_credentials)
+ response = instance_view(request, pk='2')
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+ def test_options_permitted(self):
+ request = factory.options('/',
+ HTTP_AUTHORIZATION=self.permitted_credentials)
+ response = root_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertIn('actions', response.data)
+ self.assertEqual(list(response.data['actions'].keys()), ['POST'])
+
+ request = factory.options('/1',
+ HTTP_AUTHORIZATION=self.permitted_credentials)
+ response = instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertIn('actions', response.data)
+ self.assertEqual(list(response.data['actions'].keys()), ['PUT'])
+
+ def test_options_disallowed(self):
+ request = factory.options('/',
+ HTTP_AUTHORIZATION=self.disallowed_credentials)
+ response = root_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertNotIn('actions', response.data)
+
+ request = factory.options('/1',
+ HTTP_AUTHORIZATION=self.disallowed_credentials)
+ response = instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertNotIn('actions', response.data)
+
+ def test_options_updateonly(self):
+ request = factory.options('/',
+ HTTP_AUTHORIZATION=self.updateonly_credentials)
+ response = root_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertNotIn('actions', response.data)
+
+ request = factory.options('/1',
+ HTTP_AUTHORIZATION=self.updateonly_credentials)
+ response = instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertIn('actions', response.data)
+ self.assertEqual(list(response.data['actions'].keys()), ['PUT'])
+
+
+class BasicPermModel(models.Model):
+ text = models.CharField(max_length=100)
+
+ class Meta:
+ app_label = 'tests'
+ permissions = (
+ ('view_basicpermmodel', 'Can view basic perm model'),
+ # add, change, delete built in to django
+ )
+
+# Custom object-level permission, that includes 'view' permissions
+class ViewObjectPermissions(permissions.DjangoObjectPermissions):
+ perms_map = {
+ 'GET': ['%(app_label)s.view_%(model_name)s'],
+ 'OPTIONS': ['%(app_label)s.view_%(model_name)s'],
+ 'HEAD': ['%(app_label)s.view_%(model_name)s'],
+ 'POST': ['%(app_label)s.add_%(model_name)s'],
+ 'PUT': ['%(app_label)s.change_%(model_name)s'],
+ 'PATCH': ['%(app_label)s.change_%(model_name)s'],
+ 'DELETE': ['%(app_label)s.delete_%(model_name)s'],
+ }
+
+
+class ObjectPermissionInstanceView(generics.RetrieveUpdateDestroyAPIView):
+ model = BasicPermModel
+ authentication_classes = [authentication.BasicAuthentication]
+ permission_classes = [ViewObjectPermissions]
+
+object_permissions_view = ObjectPermissionInstanceView.as_view()
+
+
+class ObjectPermissionListView(generics.ListAPIView):
+ model = BasicPermModel
+ authentication_classes = [authentication.BasicAuthentication]
+ permission_classes = [ViewObjectPermissions]
+
+object_permissions_list_view = ObjectPermissionListView.as_view()
+
+
+@unittest.skipUnless(guardian, 'django-guardian not installed')
+class ObjectPermissionsIntegrationTests(TestCase):
+ """
+ Integration tests for the object level permissions API.
+ """
+ def setUp(self):
+ from guardian.shortcuts import assign_perm
+
+ # create users
+ create = User.objects.create_user
+ users = {
+ 'fullaccess': create('fullaccess', 'fullaccess@example.com', 'password'),
+ 'readonly': create('readonly', 'readonly@example.com', 'password'),
+ 'writeonly': create('writeonly', 'writeonly@example.com', 'password'),
+ 'deleteonly': create('deleteonly', 'deleteonly@example.com', 'password'),
+ }
+
+ # give everyone model level permissions, as we are not testing those
+ everyone = Group.objects.create(name='everyone')
+ model_name = get_model_name(BasicPermModel)
+ app_label = BasicPermModel._meta.app_label
+ f = '{0}_{1}'.format
+ perms = {
+ 'view': f('view', model_name),
+ 'change': f('change', model_name),
+ 'delete': f('delete', model_name)
+ }
+ for perm in perms.values():
+ perm = '{0}.{1}'.format(app_label, perm)
+ assign_perm(perm, everyone)
+ everyone.user_set.add(*users.values())
+
+ # appropriate object level permissions
+ readers = Group.objects.create(name='readers')
+ writers = Group.objects.create(name='writers')
+ deleters = Group.objects.create(name='deleters')
+
+ model = BasicPermModel.objects.create(text='foo')
+
+ assign_perm(perms['view'], readers, model)
+ assign_perm(perms['change'], writers, model)
+ assign_perm(perms['delete'], deleters, model)
+
+ readers.user_set.add(users['fullaccess'], users['readonly'])
+ writers.user_set.add(users['fullaccess'], users['writeonly'])
+ deleters.user_set.add(users['fullaccess'], users['deleteonly'])
+
+ self.credentials = {}
+ for user in users.values():
+ self.credentials[user.username] = basic_auth_header(user.username, 'password')
+
+ # Delete
+ def test_can_delete_permissions(self):
+ request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['deleteonly'])
+ response = object_permissions_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
+
+ def test_cannot_delete_permissions(self):
+ request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
+ response = object_permissions_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+ # Update
+ def test_can_update_permissions(self):
+ request = factory.patch('/1', {'text': 'foobar'}, format='json',
+ HTTP_AUTHORIZATION=self.credentials['writeonly'])
+ response = object_permissions_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data.get('text'), 'foobar')
+
+ def test_cannot_update_permissions(self):
+ request = factory.patch('/1', {'text': 'foobar'}, format='json',
+ HTTP_AUTHORIZATION=self.credentials['deleteonly'])
+ response = object_permissions_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
+ def test_cannot_update_permissions_non_existing(self):
+ request = factory.patch('/999', {'text': 'foobar'}, format='json',
+ HTTP_AUTHORIZATION=self.credentials['deleteonly'])
+ response = object_permissions_view(request, pk='999')
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
+ # Read
+ def test_can_read_permissions(self):
+ request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
+ response = object_permissions_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ def test_cannot_read_permissions(self):
+ request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['writeonly'])
+ response = object_permissions_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
+ # Read list
+ def test_can_read_list_permissions(self):
+ request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['readonly'])
+ object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,)
+ response = object_permissions_list_view(request)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data[0].get('id'), 1)
+
+ def test_cannot_read_list_permissions(self):
+ request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['writeonly'])
+ object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,)
+ response = object_permissions_list_view(request)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertListEqual(response.data, [])
diff --git a/tests/test_relations.py b/tests/test_relations.py
new file mode 100644
index 00000000..bfc8d487
--- /dev/null
+++ b/tests/test_relations.py
@@ -0,0 +1,120 @@
+"""
+General tests for relational fields.
+"""
+from __future__ import unicode_literals
+from django.db import models
+from django.test import TestCase
+from rest_framework import serializers
+from tests.models import BlogPost
+
+
+class NullModel(models.Model):
+ pass
+
+
+class FieldTests(TestCase):
+ def test_pk_related_field_with_empty_string(self):
+ """
+ Regression test for #446
+
+ https://github.com/tomchristie/django-rest-framework/issues/446
+ """
+ field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all())
+ self.assertRaises(serializers.ValidationError, field.from_native, '')
+ self.assertRaises(serializers.ValidationError, field.from_native, [])
+
+ def test_hyperlinked_related_field_with_empty_string(self):
+ field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='')
+ self.assertRaises(serializers.ValidationError, field.from_native, '')
+ self.assertRaises(serializers.ValidationError, field.from_native, [])
+
+ def test_slug_related_field_with_empty_string(self):
+ field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk')
+ self.assertRaises(serializers.ValidationError, field.from_native, '')
+ self.assertRaises(serializers.ValidationError, field.from_native, [])
+
+
+class TestManyRelatedMixin(TestCase):
+ def test_missing_many_to_many_related_field(self):
+ '''
+ Regression test for #632
+
+ https://github.com/tomchristie/django-rest-framework/pull/632
+ '''
+ field = serializers.RelatedField(many=True, read_only=False)
+
+ into = {}
+ field.field_from_native({}, None, 'field_name', into)
+ self.assertEqual(into['field_name'], [])
+
+
+# Regression tests for #694 (`source` attribute on related fields)
+
+class RelatedFieldSourceTests(TestCase):
+ def test_related_manager_source(self):
+ """
+ Relational fields should be able to use manager-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.RelatedField(many=True, source='get_blogposts_manager')
+
+ class ClassWithManagerMethod(object):
+ def get_blogposts_manager(self):
+ return BlogPost.objects
+
+ obj = ClassWithManagerMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['BlogPost object'])
+
+ def test_related_queryset_source(self):
+ """
+ Relational fields should be able to use queryset-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.RelatedField(many=True, source='get_blogposts_queryset')
+
+ class ClassWithQuerysetMethod(object):
+ def get_blogposts_queryset(self):
+ return BlogPost.objects.all()
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['BlogPost object'])
+
+ def test_dotted_source(self):
+ """
+ Source argument should support dotted.source notation.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.RelatedField(many=True, source='a.b.c')
+
+ class ClassWithQuerysetMethod(object):
+ a = {
+ 'b': {
+ 'c': BlogPost.objects.all()
+ }
+ }
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['BlogPost object'])
+
+ # Regression for #1129
+ def test_exception_for_incorect_fk(self):
+ """
+ Check that the exception message are correct if the source field
+ doesn't exist.
+ """
+ from tests.models import ManyToManySource
+ class Meta:
+ model = ManyToManySource
+ attrs = {
+ 'name': serializers.SlugRelatedField(
+ slug_field='name', source='banzai'),
+ 'Meta': Meta,
+ }
+
+ TestSerializer = type(str('TestSerializer'),
+ (serializers.ModelSerializer,), attrs)
+ with self.assertRaises(AttributeError):
+ TestSerializer(data={'name': 'foo'})
diff --git a/tests/test_relations_hyperlink.py b/tests/test_relations_hyperlink.py
new file mode 100644
index 00000000..98f68d29
--- /dev/null
+++ b/tests/test_relations_hyperlink.py
@@ -0,0 +1,524 @@
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework import serializers
+from rest_framework.compat import patterns, url
+from rest_framework.test import APIRequestFactory
+from tests.models import (
+ BlogPost,
+ ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource,
+ NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
+)
+
+factory = APIRequestFactory()
+request = factory.get('/') # Just to ensure we have a request in the serializer context
+
+
+def dummy_view(request, pk):
+ pass
+
+urlpatterns = patterns('',
+ url(r'^dummyurl/(?P[0-9]+)/$', dummy_view, name='dummy-url'),
+ url(r'^manytomanysource/(?P[0-9]+)/$', dummy_view, name='manytomanysource-detail'),
+ url(r'^manytomanytarget/(?P[0-9]+)/$', dummy_view, name='manytomanytarget-detail'),
+ url(r'^foreignkeysource/(?P[0-9]+)/$', dummy_view, name='foreignkeysource-detail'),
+ url(r'^foreignkeytarget/(?P[0-9]+)/$', dummy_view, name='foreignkeytarget-detail'),
+ url(r'^nullableforeignkeysource/(?P[0-9]+)/$', dummy_view, name='nullableforeignkeysource-detail'),
+ url(r'^onetoonetarget/(?P[0-9]+)/$', dummy_view, name='onetoonetarget-detail'),
+ url(r'^nullableonetoonesource/(?P[0-9]+)/$', dummy_view, name='nullableonetoonesource-detail'),
+)
+
+
+# ManyToMany
+class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = ManyToManyTarget
+ fields = ('url', 'name', 'sources')
+
+
+class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = ManyToManySource
+ fields = ('url', 'name', 'targets')
+
+
+# ForeignKey
+class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = ForeignKeyTarget
+ fields = ('url', 'name', 'sources')
+
+
+class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = ForeignKeySource
+ fields = ('url', 'name', 'target')
+
+
+# Nullable ForeignKey
+class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = NullableForeignKeySource
+ fields = ('url', 'name', 'target')
+
+
+# Nullable OneToOne
+class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = OneToOneTarget
+ fields = ('url', 'name', 'nullable_source')
+
+
+# TODO: Add test that .data cannot be accessed prior to .is_valid
+
+class HyperlinkedManyToManyTests(TestCase):
+ urls = 'tests.test_relations_hyperlink'
+
+ def setUp(self):
+ for idx in range(1, 4):
+ target = ManyToManyTarget(name='target-%d' % idx)
+ target.save()
+ source = ManyToManySource(name='source-%d' % idx)
+ source.save()
+ for target in ManyToManyTarget.objects.all():
+ source.targets.add(target)
+
+ def test_many_to_many_retrieve(self):
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']},
+ {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
+ {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_many_to_many_retrieve(self):
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
+ {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
+ {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_many_to_many_update(self):
+ data = {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
+ instance = ManyToManySource.objects.get(pk=1)
+ serializer = ManyToManySourceSerializer(instance, data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEqual(serializer.data, data)
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']},
+ {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
+ {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_many_to_many_update(self):
+ data = {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']}
+ instance = ManyToManyTarget.objects.get(pk=1)
+ serializer = ManyToManyTargetSerializer(instance, data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEqual(serializer.data, data)
+
+ # Ensure target 1 is updated, and everything else is as expected
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']},
+ {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
+ {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}
+
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_many_to_many_create(self):
+ data = {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']}
+ serializer = ManyToManySourceSerializer(data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is added, and everything else is as expected
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']},
+ {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
+ {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']},
+ {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_many_to_many_create(self):
+ data = {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']}
+ serializer = ManyToManyTargetSerializer(data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-4')
+
+ # Ensure target 4 is added, and everything else is as expected
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
+ {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
+ {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']},
+ {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+
+class HyperlinkedForeignKeyTests(TestCase):
+ urls = 'tests.test_relations_hyperlink'
+
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ new_target = ForeignKeyTarget(name='target-2')
+ new_target.save()
+ for idx in range(1, 4):
+ source = ForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve(self):
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_foreign_key_retrieve(self):
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},
+ {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update(self):
+ data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'},
+ {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_incorrect_type(self):
+ data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 2}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected url string, received int.']})
+
+ def test_reverse_foreign_key_update(self):
+ data = {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}
+ instance = ForeignKeyTarget.objects.get(pk=2)
+ serializer = ForeignKeyTargetSerializer(instance, data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ # We shouldn't have saved anything to the db yet since save
+ # hasn't been called.
+ queryset = ForeignKeyTarget.objects.all()
+ new_serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},
+ {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
+ ]
+ self.assertEqual(new_serializer.data, expected)
+
+ serializer.save()
+ self.assertEqual(serializer.data, data)
+
+ # Ensure target 2 is update, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']},
+ {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_create(self):
+ data = {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'}
+ serializer = ForeignKeySourceSerializer(data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_foreign_key_create(self):
+ data = {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}
+ serializer = ForeignKeyTargetSerializer(data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-3')
+
+ # Ensure target 4 is added, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']},
+ {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
+ {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_with_invalid_null(self):
+ data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': None}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'target': ['This field is required.']})
+
+
+class HyperlinkedNullableForeignKeyTests(TestCase):
+ urls = 'tests.test_relations_hyperlink'
+
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ for idx in range(1, 4):
+ if idx == 3:
+ target = None
+ source = NullableForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve_with_null(self):
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_create_with_valid_null(self):
+ data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is created, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
+ {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_create_with_valid_emptystring(self):
+ """
+ The emptystring should be interpreted as null in the context
+ of relationships.
+ """
+ data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': ''}
+ expected_data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, expected_data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is created, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
+ {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_with_valid_null(self):
+ data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None},
+ {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_with_valid_emptystring(self):
+ """
+ The emptystring should be interpreted as null in the context
+ of relationships.
+ """
+ data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': ''}
+ expected_data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, expected_data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None},
+ {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ # reverse foreign keys MUST be read_only
+ # In the general case they do not provide .remove() or .clear()
+ # and cannot be arbitrarily set.
+
+ # def test_reverse_foreign_key_update(self):
+ # data = {'id': 1, 'name': 'target-1', 'sources': [1]}
+ # instance = ForeignKeyTarget.objects.get(pk=1)
+ # serializer = ForeignKeyTargetSerializer(instance, data=data)
+ # self.assertTrue(serializer.is_valid())
+ # self.assertEqual(serializer.data, data)
+ # serializer.save()
+
+ # # Ensure target 1 is updated, and everything else is as expected
+ # queryset = ForeignKeyTarget.objects.all()
+ # serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ # expected = [
+ # {'id': 1, 'name': 'target-1', 'sources': [1]},
+ # {'id': 2, 'name': 'target-2', 'sources': []},
+ # ]
+ # self.assertEqual(serializer.data, expected)
+
+
+class HyperlinkedNullableOneToOneTests(TestCase):
+ urls = 'tests.test_relations_hyperlink'
+
+ def setUp(self):
+ target = OneToOneTarget(name='target-1')
+ target.save()
+ new_target = OneToOneTarget(name='target-2')
+ new_target.save()
+ source = NullableOneToOneSource(name='source-1', target=target)
+ source.save()
+
+ def test_reverse_foreign_key_retrieve_with_null(self):
+ queryset = OneToOneTarget.objects.all()
+ serializer = NullableOneToOneTargetSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/onetoonetarget/1/', 'name': 'target-1', 'nullable_source': 'http://testserver/nullableonetoonesource/1/'},
+ {'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+
+# Regression tests for #694 (`source` attribute on related fields)
+
+class HyperlinkedRelatedFieldSourceTests(TestCase):
+ urls = 'tests.test_relations_hyperlink'
+
+ def test_related_manager_source(self):
+ """
+ Relational fields should be able to use manager-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.HyperlinkedRelatedField(
+ many=True,
+ source='get_blogposts_manager',
+ view_name='dummy-url',
+ )
+ field.context = {'request': request}
+
+ class ClassWithManagerMethod(object):
+ def get_blogposts_manager(self):
+ return BlogPost.objects
+
+ obj = ClassWithManagerMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['http://testserver/dummyurl/1/'])
+
+ def test_related_queryset_source(self):
+ """
+ Relational fields should be able to use queryset-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.HyperlinkedRelatedField(
+ many=True,
+ source='get_blogposts_queryset',
+ view_name='dummy-url',
+ )
+ field.context = {'request': request}
+
+ class ClassWithQuerysetMethod(object):
+ def get_blogposts_queryset(self):
+ return BlogPost.objects.all()
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['http://testserver/dummyurl/1/'])
+
+ def test_dotted_source(self):
+ """
+ Source argument should support dotted.source notation.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.HyperlinkedRelatedField(
+ many=True,
+ source='a.b.c',
+ view_name='dummy-url',
+ )
+ field.context = {'request': request}
+
+ class ClassWithQuerysetMethod(object):
+ a = {
+ 'b': {
+ 'c': BlogPost.objects.all()
+ }
+ }
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, ['http://testserver/dummyurl/1/'])
diff --git a/tests/test_relations_nested.py b/tests/test_relations_nested.py
new file mode 100644
index 00000000..d393b0c3
--- /dev/null
+++ b/tests/test_relations_nested.py
@@ -0,0 +1,328 @@
+from __future__ import unicode_literals
+from django.db import models
+from django.test import TestCase
+from rest_framework import serializers
+
+
+class OneToOneTarget(models.Model):
+ name = models.CharField(max_length=100)
+
+
+class OneToOneSource(models.Model):
+ name = models.CharField(max_length=100)
+ target = models.OneToOneField(OneToOneTarget, related_name='source',
+ null=True, blank=True)
+
+
+class OneToManyTarget(models.Model):
+ name = models.CharField(max_length=100)
+
+
+class OneToManySource(models.Model):
+ name = models.CharField(max_length=100)
+ target = models.ForeignKey(OneToManyTarget, related_name='sources')
+
+
+class ReverseNestedOneToOneTests(TestCase):
+ def setUp(self):
+ class OneToOneSourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = OneToOneSource
+ fields = ('id', 'name')
+
+ class OneToOneTargetSerializer(serializers.ModelSerializer):
+ source = OneToOneSourceSerializer()
+
+ class Meta:
+ model = OneToOneTarget
+ fields = ('id', 'name', 'source')
+
+ self.Serializer = OneToOneTargetSerializer
+
+ for idx in range(1, 4):
+ target = OneToOneTarget(name='target-%d' % idx)
+ target.save()
+ source = OneToOneSource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_one_to_one_retrieve(self):
+ queryset = OneToOneTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
+ {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
+ {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_one_create(self):
+ data = {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}}
+ serializer = self.Serializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-4')
+
+ # Ensure (target 4, target_source 4, source 4) are added, and
+ # everything else is as expected.
+ queryset = OneToOneTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
+ {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
+ {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}},
+ {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_one_create_with_invalid_data(self):
+ data = {'id': 4, 'name': 'target-4', 'source': {'id': 4}}
+ serializer = self.Serializer(data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'source': [{'name': ['This field is required.']}]})
+
+ def test_one_to_one_update(self):
+ data = {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}
+ instance = OneToOneTarget.objects.get(pk=3)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-3-updated')
+
+ # Ensure (target 3, target_source 3, source 3) are updated,
+ # and everything else is as expected.
+ queryset = OneToOneTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
+ {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
+ {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+
+class ForwardNestedOneToOneTests(TestCase):
+ def setUp(self):
+ class OneToOneTargetSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = OneToOneTarget
+ fields = ('id', 'name')
+
+ class OneToOneSourceSerializer(serializers.ModelSerializer):
+ target = OneToOneTargetSerializer()
+
+ class Meta:
+ model = OneToOneSource
+ fields = ('id', 'name', 'target')
+
+ self.Serializer = OneToOneSourceSerializer
+
+ for idx in range(1, 4):
+ target = OneToOneTarget(name='target-%d' % idx)
+ target.save()
+ source = OneToOneSource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_one_to_one_retrieve(self):
+ queryset = OneToOneSource.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
+ {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
+ {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_one_create(self):
+ data = {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}}
+ serializer = self.Serializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure (target 4, target_source 4, source 4) are added, and
+ # everything else is as expected.
+ queryset = OneToOneSource.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
+ {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
+ {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}},
+ {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_one_create_with_invalid_data(self):
+ data = {'id': 4, 'name': 'source-4', 'target': {'id': 4}}
+ serializer = self.Serializer(data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'target': [{'name': ['This field is required.']}]})
+
+ def test_one_to_one_update(self):
+ data = {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}}
+ instance = OneToOneSource.objects.get(pk=3)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-3-updated')
+
+ # Ensure (target 3, target_source 3, source 3) are updated,
+ # and everything else is as expected.
+ queryset = OneToOneSource.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
+ {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
+ {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_one_update_to_null(self):
+ data = {'id': 3, 'name': 'source-3-updated', 'target': None}
+ instance = OneToOneSource.objects.get(pk=3)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-3-updated')
+ self.assertEqual(obj.target, None)
+
+ queryset = OneToOneSource.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}},
+ {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}},
+ {'id': 3, 'name': 'source-3-updated', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ # TODO: Nullable 1-1 tests
+ # def test_one_to_one_delete(self):
+ # data = {'id': 3, 'name': 'target-3', 'target_source': None}
+ # instance = OneToOneTarget.objects.get(pk=3)
+ # serializer = self.Serializer(instance, data=data)
+ # self.assertTrue(serializer.is_valid())
+ # serializer.save()
+
+ # # Ensure (target_source 3, source 3) are deleted,
+ # # and everything else is as expected.
+ # queryset = OneToOneTarget.objects.all()
+ # serializer = self.Serializer(queryset)
+ # expected = [
+ # {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}},
+ # {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}},
+ # {'id': 3, 'name': 'target-3', 'source': None}
+ # ]
+ # self.assertEqual(serializer.data, expected)
+
+
+class ReverseNestedOneToManyTests(TestCase):
+ def setUp(self):
+ class OneToManySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = OneToManySource
+ fields = ('id', 'name')
+
+ class OneToManyTargetSerializer(serializers.ModelSerializer):
+ sources = OneToManySourceSerializer(many=True, allow_add_remove=True)
+
+ class Meta:
+ model = OneToManyTarget
+ fields = ('id', 'name', 'sources')
+
+ self.Serializer = OneToManyTargetSerializer
+
+ target = OneToManyTarget(name='target-1')
+ target.save()
+ for idx in range(1, 4):
+ source = OneToManySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_one_to_many_retrieve(self):
+ queryset = OneToManyTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'}]},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_many_create(self):
+ data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'},
+ {'id': 4, 'name': 'source-4'}]}
+ instance = OneToManyTarget.objects.get(pk=1)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-1')
+
+ # Ensure source 4 is added, and everything else is as
+ # expected.
+ queryset = OneToManyTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'},
+ {'id': 4, 'name': 'source-4'}]}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_many_create_with_invalid_data(self):
+ data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'},
+ {'id': 4}]}
+ serializer = self.Serializer(data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'sources': [{}, {}, {}, {'name': ['This field is required.']}]})
+
+ def test_one_to_many_update(self):
+ data = {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'}]}
+ instance = OneToManyTarget.objects.get(pk=1)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-1-updated')
+
+ # Ensure (target 1, source 1) are updated,
+ # and everything else is as expected.
+ queryset = OneToManyTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'}]}
+
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_many_delete(self):
+ data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 3, 'name': 'source-3'}]}
+ instance = OneToManyTarget.objects.get(pk=1)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+
+ # Ensure source 2 is deleted, and everything else is as
+ # expected.
+ queryset = OneToManyTarget.objects.all()
+ serializer = self.Serializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 3, 'name': 'source-3'}]}
+
+ ]
+ self.assertEqual(serializer.data, expected)
diff --git a/tests/test_relations_pk.py b/tests/test_relations_pk.py
new file mode 100644
index 00000000..ff59b250
--- /dev/null
+++ b/tests/test_relations_pk.py
@@ -0,0 +1,551 @@
+from __future__ import unicode_literals
+from django.db import models
+from django.test import TestCase
+from rest_framework import serializers
+from tests.models import (
+ BlogPost, ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource,
+ NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource,
+)
+from rest_framework.compat import six
+
+
+# ManyToMany
+class ManyToManyTargetSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ManyToManyTarget
+ fields = ('id', 'name', 'sources')
+
+
+class ManyToManySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ManyToManySource
+ fields = ('id', 'name', 'targets')
+
+
+# ForeignKey
+class ForeignKeyTargetSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ForeignKeyTarget
+ fields = ('id', 'name', 'sources')
+
+
+class ForeignKeySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ForeignKeySource
+ fields = ('id', 'name', 'target')
+
+
+# Nullable ForeignKey
+class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = NullableForeignKeySource
+ fields = ('id', 'name', 'target')
+
+
+# Nullable OneToOne
+class NullableOneToOneTargetSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = OneToOneTarget
+ fields = ('id', 'name', 'nullable_source')
+
+
+# TODO: Add test that .data cannot be accessed prior to .is_valid
+
+class PKManyToManyTests(TestCase):
+ def setUp(self):
+ for idx in range(1, 4):
+ target = ManyToManyTarget(name='target-%d' % idx)
+ target.save()
+ source = ManyToManySource(name='source-%d' % idx)
+ source.save()
+ for target in ManyToManyTarget.objects.all():
+ source.targets.add(target)
+
+ def test_many_to_many_retrieve(self):
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'targets': [1]},
+ {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
+ {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_many_to_many_retrieve(self):
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
+ {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
+ {'id': 3, 'name': 'target-3', 'sources': [3]}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_many_to_many_update(self):
+ data = {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]}
+ instance = ManyToManySource.objects.get(pk=1)
+ serializer = ManyToManySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEqual(serializer.data, data)
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]},
+ {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
+ {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_many_to_many_update(self):
+ data = {'id': 1, 'name': 'target-1', 'sources': [1]}
+ instance = ManyToManyTarget.objects.get(pk=1)
+ serializer = ManyToManyTargetSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEqual(serializer.data, data)
+
+ # Ensure target 1 is updated, and everything else is as expected
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [1]},
+ {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
+ {'id': 3, 'name': 'target-3', 'sources': [3]}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_many_to_many_create(self):
+ data = {'id': 4, 'name': 'source-4', 'targets': [1, 3]}
+ serializer = ManyToManySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is added, and everything else is as expected
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset, many=True)
+ self.assertFalse(serializer.fields['targets'].read_only)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'targets': [1]},
+ {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
+ {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]},
+ {'id': 4, 'name': 'source-4', 'targets': [1, 3]},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_many_to_many_create(self):
+ data = {'id': 4, 'name': 'target-4', 'sources': [1, 3]}
+ serializer = ManyToManyTargetSerializer(data=data)
+ self.assertFalse(serializer.fields['sources'].read_only)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-4')
+
+ # Ensure target 4 is added, and everything else is as expected
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
+ {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
+ {'id': 3, 'name': 'target-3', 'sources': [3]},
+ {'id': 4, 'name': 'target-4', 'sources': [1, 3]}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+
+class PKForeignKeyTests(TestCase):
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ new_target = ForeignKeyTarget(name='target-2')
+ new_target.save()
+ for idx in range(1, 4):
+ source = ForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve(self):
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 1},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': 1}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_foreign_key_retrieve(self):
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
+ {'id': 2, 'name': 'target-2', 'sources': []},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update(self):
+ data = {'id': 1, 'name': 'source-1', 'target': 2}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 2},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': 1}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_incorrect_type(self):
+ data = {'id': 1, 'name': 'source-1', 'target': 'foo'}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected pk value, received %s.' % six.text_type.__name__]})
+
+ def test_reverse_foreign_key_update(self):
+ data = {'id': 2, 'name': 'target-2', 'sources': [1, 3]}
+ instance = ForeignKeyTarget.objects.get(pk=2)
+ serializer = ForeignKeyTargetSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ # We shouldn't have saved anything to the db yet since save
+ # hasn't been called.
+ queryset = ForeignKeyTarget.objects.all()
+ new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
+ {'id': 2, 'name': 'target-2', 'sources': []},
+ ]
+ self.assertEqual(new_serializer.data, expected)
+
+ serializer.save()
+ self.assertEqual(serializer.data, data)
+
+ # Ensure target 2 is update, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [2]},
+ {'id': 2, 'name': 'target-2', 'sources': [1, 3]},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_create(self):
+ data = {'id': 4, 'name': 'source-4', 'target': 2}
+ serializer = ForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is added, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 1},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': 1},
+ {'id': 4, 'name': 'source-4', 'target': 2},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_foreign_key_create(self):
+ data = {'id': 3, 'name': 'target-3', 'sources': [1, 3]}
+ serializer = ForeignKeyTargetSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-3')
+
+ # Ensure target 3 is added, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [2]},
+ {'id': 2, 'name': 'target-2', 'sources': []},
+ {'id': 3, 'name': 'target-3', 'sources': [1, 3]},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_with_invalid_null(self):
+ data = {'id': 1, 'name': 'source-1', 'target': None}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'target': ['This field is required.']})
+
+ def test_foreign_key_with_empty(self):
+ """
+ Regression test for #1072
+
+ https://github.com/tomchristie/django-rest-framework/issues/1072
+ """
+ serializer = NullableForeignKeySourceSerializer()
+ self.assertEqual(serializer.data['target'], None)
+
+
+class PKNullableForeignKeyTests(TestCase):
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ for idx in range(1, 4):
+ if idx == 3:
+ target = None
+ source = NullableForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve_with_null(self):
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 1},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': None},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_create_with_valid_null(self):
+ data = {'id': 4, 'name': 'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is created, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 1},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': None},
+ {'id': 4, 'name': 'source-4', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_create_with_valid_emptystring(self):
+ """
+ The emptystring should be interpreted as null in the context
+ of relationships.
+ """
+ data = {'id': 4, 'name': 'source-4', 'target': ''}
+ expected_data = {'id': 4, 'name': 'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, expected_data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is created, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 1},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': None},
+ {'id': 4, 'name': 'source-4', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_with_valid_null(self):
+ data = {'id': 1, 'name': 'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': None},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_with_valid_emptystring(self):
+ """
+ The emptystring should be interpreted as null in the context
+ of relationships.
+ """
+ data = {'id': 1, 'name': 'source-1', 'target': ''}
+ expected_data = {'id': 1, 'name': 'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, expected_data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': None},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ # reverse foreign keys MUST be read_only
+ # In the general case they do not provide .remove() or .clear()
+ # and cannot be arbitrarily set.
+
+ # def test_reverse_foreign_key_update(self):
+ # data = {'id': 1, 'name': 'target-1', 'sources': [1]}
+ # instance = ForeignKeyTarget.objects.get(pk=1)
+ # serializer = ForeignKeyTargetSerializer(instance, data=data)
+ # self.assertTrue(serializer.is_valid())
+ # self.assertEqual(serializer.data, data)
+ # serializer.save()
+
+ # # Ensure target 1 is updated, and everything else is as expected
+ # queryset = ForeignKeyTarget.objects.all()
+ # serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ # expected = [
+ # {'id': 1, 'name': 'target-1', 'sources': [1]},
+ # {'id': 2, 'name': 'target-2', 'sources': []},
+ # ]
+ # self.assertEqual(serializer.data, expected)
+
+
+class PKNullableOneToOneTests(TestCase):
+ def setUp(self):
+ target = OneToOneTarget(name='target-1')
+ target.save()
+ new_target = OneToOneTarget(name='target-2')
+ new_target.save()
+ source = NullableOneToOneSource(name='source-1', target=new_target)
+ source.save()
+
+ def test_reverse_foreign_key_retrieve_with_null(self):
+ queryset = OneToOneTarget.objects.all()
+ serializer = NullableOneToOneTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'nullable_source': None},
+ {'id': 2, 'name': 'target-2', 'nullable_source': 1},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+
+# The below models and tests ensure that serializer fields corresponding
+# to a ManyToManyField field with a user-specified ``through`` model are
+# set to read only
+
+
+class ManyToManyThroughTarget(models.Model):
+ name = models.CharField(max_length=100)
+
+
+class ManyToManyThrough(models.Model):
+ source = models.ForeignKey('ManyToManyThroughSource')
+ target = models.ForeignKey(ManyToManyThroughTarget)
+
+
+class ManyToManyThroughSource(models.Model):
+ name = models.CharField(max_length=100)
+ targets = models.ManyToManyField(ManyToManyThroughTarget,
+ related_name='sources',
+ through='ManyToManyThrough')
+
+
+class ManyToManyThroughTargetSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ManyToManyThroughTarget
+ fields = ('id', 'name', 'sources')
+
+
+class ManyToManyThroughSourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ManyToManyThroughSource
+ fields = ('id', 'name', 'targets')
+
+
+class PKManyToManyThroughTests(TestCase):
+ def setUp(self):
+ self.source = ManyToManyThroughSource.objects.create(
+ name='through-source-1')
+ self.target = ManyToManyThroughTarget.objects.create(
+ name='through-target-1')
+
+ def test_many_to_many_create(self):
+ data = {'id': 2, 'name': 'source-2', 'targets': [self.target.pk]}
+ serializer = ManyToManyThroughSourceSerializer(data=data)
+ self.assertTrue(serializer.fields['targets'].read_only)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(obj.name, 'source-2')
+ self.assertEqual(obj.targets.count(), 0)
+
+ def test_many_to_many_reverse_create(self):
+ data = {'id': 2, 'name': 'target-2', 'sources': [self.source.pk]}
+ serializer = ManyToManyThroughTargetSerializer(data=data)
+ self.assertTrue(serializer.fields['sources'].read_only)
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ obj = serializer.save()
+ self.assertEqual(obj.name, 'target-2')
+ self.assertEqual(obj.sources.count(), 0)
+
+
+# Regression tests for #694 (`source` attribute on related fields)
+
+
+class PrimaryKeyRelatedFieldSourceTests(TestCase):
+ def test_related_manager_source(self):
+ """
+ Relational fields should be able to use manager-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.PrimaryKeyRelatedField(many=True, source='get_blogposts_manager')
+
+ class ClassWithManagerMethod(object):
+ def get_blogposts_manager(self):
+ return BlogPost.objects
+
+ obj = ClassWithManagerMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, [1])
+
+ def test_related_queryset_source(self):
+ """
+ Relational fields should be able to use queryset-returning methods as their source.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.PrimaryKeyRelatedField(many=True, source='get_blogposts_queryset')
+
+ class ClassWithQuerysetMethod(object):
+ def get_blogposts_queryset(self):
+ return BlogPost.objects.all()
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, [1])
+
+ def test_dotted_source(self):
+ """
+ Source argument should support dotted.source notation.
+ """
+ BlogPost.objects.create(title='blah')
+ field = serializers.PrimaryKeyRelatedField(many=True, source='a.b.c')
+
+ class ClassWithQuerysetMethod(object):
+ a = {
+ 'b': {
+ 'c': BlogPost.objects.all()
+ }
+ }
+
+ obj = ClassWithQuerysetMethod()
+ value = field.field_to_native(obj, 'field_name')
+ self.assertEqual(value, [1])
diff --git a/tests/test_relations_slug.py b/tests/test_relations_slug.py
new file mode 100644
index 00000000..97ebf23a
--- /dev/null
+++ b/tests/test_relations_slug.py
@@ -0,0 +1,257 @@
+from django.test import TestCase
+from rest_framework import serializers
+from tests.models import NullableForeignKeySource, ForeignKeySource, ForeignKeyTarget
+
+
+class ForeignKeyTargetSerializer(serializers.ModelSerializer):
+ sources = serializers.SlugRelatedField(many=True, slug_field='name')
+
+ class Meta:
+ model = ForeignKeyTarget
+
+
+class ForeignKeySourceSerializer(serializers.ModelSerializer):
+ target = serializers.SlugRelatedField(slug_field='name')
+
+ class Meta:
+ model = ForeignKeySource
+
+
+class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
+ target = serializers.SlugRelatedField(slug_field='name', required=False)
+
+ class Meta:
+ model = NullableForeignKeySource
+
+
+# TODO: M2M Tests, FKTests (Non-nullable), One2One
+class SlugForeignKeyTests(TestCase):
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ new_target = ForeignKeyTarget(name='target-2')
+ new_target.save()
+ for idx in range(1, 4):
+ source = ForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve(self):
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 'target-1'},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': 'target-1'}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_foreign_key_retrieve(self):
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
+ {'id': 2, 'name': 'target-2', 'sources': []},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update(self):
+ data = {'id': 1, 'name': 'source-1', 'target': 'target-2'}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 'target-2'},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': 'target-1'}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_incorrect_type(self):
+ data = {'id': 1, 'name': 'source-1', 'target': 123}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'target': ['Object with name=123 does not exist.']})
+
+ def test_reverse_foreign_key_update(self):
+ data = {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']}
+ instance = ForeignKeyTarget.objects.get(pk=2)
+ serializer = ForeignKeyTargetSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ # We shouldn't have saved anything to the db yet since save
+ # hasn't been called.
+ queryset = ForeignKeyTarget.objects.all()
+ new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
+ {'id': 2, 'name': 'target-2', 'sources': []},
+ ]
+ self.assertEqual(new_serializer.data, expected)
+
+ serializer.save()
+ self.assertEqual(serializer.data, data)
+
+ # Ensure target 2 is update, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': ['source-2']},
+ {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_create(self):
+ data = {'id': 4, 'name': 'source-4', 'target': 'target-2'}
+ serializer = ForeignKeySourceSerializer(data=data)
+ serializer.is_valid()
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is added, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 'target-1'},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': 'target-1'},
+ {'id': 4, 'name': 'source-4', 'target': 'target-2'},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_foreign_key_create(self):
+ data = {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']}
+ serializer = ForeignKeyTargetSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-3')
+
+ # Ensure target 3 is added, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': ['source-2']},
+ {'id': 2, 'name': 'target-2', 'sources': []},
+ {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_with_invalid_null(self):
+ data = {'id': 1, 'name': 'source-1', 'target': None}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'target': ['This field is required.']})
+
+
+class SlugNullableForeignKeyTests(TestCase):
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ for idx in range(1, 4):
+ if idx == 3:
+ target = None
+ source = NullableForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve_with_null(self):
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 'target-1'},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': None},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_create_with_valid_null(self):
+ data = {'id': 4, 'name': 'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is created, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 'target-1'},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': None},
+ {'id': 4, 'name': 'source-4', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_create_with_valid_emptystring(self):
+ """
+ The emptystring should be interpreted as null in the context
+ of relationships.
+ """
+ data = {'id': 4, 'name': 'source-4', 'target': ''}
+ expected_data = {'id': 4, 'name': 'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, expected_data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is created, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 'target-1'},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': None},
+ {'id': 4, 'name': 'source-4', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_with_valid_null(self):
+ data = {'id': 1, 'name': 'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': None},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_with_valid_emptystring(self):
+ """
+ The emptystring should be interpreted as null in the context
+ of relationships.
+ """
+ data = {'id': 1, 'name': 'source-1', 'target': ''}
+ expected_data = {'id': 1, 'name': 'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, expected_data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': None},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
diff --git a/tests/test_renderers.py b/tests/test_renderers.py
new file mode 100644
index 00000000..b41cff39
--- /dev/null
+++ b/tests/test_renderers.py
@@ -0,0 +1,651 @@
+# -*- coding: utf-8 -*-
+from __future__ import unicode_literals
+
+from decimal import Decimal
+from django.core.cache import cache
+from django.db import models
+from django.test import TestCase
+from django.utils import unittest
+from django.utils.translation import ugettext_lazy as _
+from rest_framework import status, permissions
+from rest_framework.compat import yaml, etree, patterns, url, include, six, StringIO
+from rest_framework.response import Response
+from rest_framework.views import APIView
+from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \
+ XMLRenderer, JSONPRenderer, BrowsableAPIRenderer, UnicodeJSONRenderer
+from rest_framework.parsers import YAMLParser, XMLParser
+from rest_framework.settings import api_settings
+from rest_framework.test import APIRequestFactory
+from collections import MutableMapping
+import datetime
+import json
+import pickle
+import re
+
+
+DUMMYSTATUS = status.HTTP_200_OK
+DUMMYCONTENT = 'dummycontent'
+
+RENDERER_A_SERIALIZER = lambda x: ('Renderer A: %s' % x).encode('ascii')
+RENDERER_B_SERIALIZER = lambda x: ('Renderer B: %s' % x).encode('ascii')
+
+
+expected_results = [
+ ((elem for elem in [1, 2, 3]), JSONRenderer, b'[1, 2, 3]') # Generator
+]
+
+
+class DummyTestModel(models.Model):
+ name = models.CharField(max_length=42, default='')
+
+
+class BasicRendererTests(TestCase):
+ def test_expected_results(self):
+ for value, renderer_cls, expected in expected_results:
+ output = renderer_cls().render(value)
+ self.assertEqual(output, expected)
+
+
+class RendererA(BaseRenderer):
+ media_type = 'mock/renderera'
+ format = "formata"
+
+ def render(self, data, media_type=None, renderer_context=None):
+ return RENDERER_A_SERIALIZER(data)
+
+
+class RendererB(BaseRenderer):
+ media_type = 'mock/rendererb'
+ format = "formatb"
+
+ def render(self, data, media_type=None, renderer_context=None):
+ return RENDERER_B_SERIALIZER(data)
+
+
+class MockView(APIView):
+ renderer_classes = (RendererA, RendererB)
+
+ def get(self, request, **kwargs):
+ response = Response(DUMMYCONTENT, status=DUMMYSTATUS)
+ return response
+
+
+class MockGETView(APIView):
+ def get(self, request, **kwargs):
+ return Response({'foo': ['bar', 'baz']})
+
+
+
+class MockPOSTView(APIView):
+ def post(self, request, **kwargs):
+ return Response({'foo': request.DATA})
+
+
+class EmptyGETView(APIView):
+ renderer_classes = (JSONRenderer,)
+
+ def get(self, request, **kwargs):
+ return Response(status=status.HTTP_204_NO_CONTENT)
+
+
+class HTMLView(APIView):
+ renderer_classes = (BrowsableAPIRenderer, )
+
+ def get(self, request, **kwargs):
+ return Response('text')
+
+
+class HTMLView1(APIView):
+ renderer_classes = (BrowsableAPIRenderer, JSONRenderer)
+
+ def get(self, request, **kwargs):
+ return Response('text')
+
+urlpatterns = patterns('',
+ url(r'^.*\.(?P.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
+ url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
+ url(r'^cache$', MockGETView.as_view()),
+ url(r'^jsonp/jsonrenderer$', MockGETView.as_view(renderer_classes=[JSONRenderer, JSONPRenderer])),
+ url(r'^jsonp/nojsonrenderer$', MockGETView.as_view(renderer_classes=[JSONPRenderer])),
+ url(r'^parseerror$', MockPOSTView.as_view(renderer_classes=[JSONRenderer, BrowsableAPIRenderer])),
+ url(r'^html$', HTMLView.as_view()),
+ url(r'^html1$', HTMLView1.as_view()),
+ url(r'^empty$', EmptyGETView.as_view()),
+ url(r'^api', include('rest_framework.urls', namespace='rest_framework'))
+)
+
+
+class POSTDeniedPermission(permissions.BasePermission):
+ def has_permission(self, request, view):
+ return request.method != 'POST'
+
+
+class POSTDeniedView(APIView):
+ renderer_classes = (BrowsableAPIRenderer,)
+ permission_classes = (POSTDeniedPermission,)
+
+ def get(self, request):
+ return Response()
+
+ def post(self, request):
+ return Response()
+
+ def put(self, request):
+ return Response()
+
+ def patch(self, request):
+ return Response()
+
+
+class DocumentingRendererTests(TestCase):
+ def test_only_permitted_forms_are_displayed(self):
+ view = POSTDeniedView.as_view()
+ request = APIRequestFactory().get('/')
+ response = view(request).render()
+ self.assertNotContains(response, '>POST<')
+ self.assertContains(response, '>PUT<')
+ self.assertContains(response, '>PATCH<')
+
+
+class RendererEndToEndTests(TestCase):
+ """
+ End-to-end testing of renderers using an RendererMixin on a generic view.
+ """
+
+ urls = 'tests.test_renderers'
+
+ def test_default_renderer_serializes_content(self):
+ """If the Accept header is not set the default renderer should serialize the response."""
+ resp = self.client.get('/')
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_head_method_serializes_no_content(self):
+ """No response must be included in HEAD requests."""
+ resp = self.client.head('/')
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, six.b(''))
+
+ def test_default_renderer_serializes_content_on_accept_any(self):
+ """If the Accept header is set to */* the default renderer should serialize the response."""
+ resp = self.client.get('/', HTTP_ACCEPT='*/*')
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_default_case(self):
+ """If the Accept header is set the specified renderer should serialize the response.
+ (In this case we check that works for the default renderer)"""
+ resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_non_default_case(self):
+ """If the Accept header is set the specified renderer should serialize the response.
+ (In this case we check that works for a non-default renderer)"""
+ resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_on_accept_query(self):
+ """The '_accept' query string should behave in the same way as the Accept header."""
+ param = '?%s=%s' % (
+ api_settings.URL_ACCEPT_OVERRIDE,
+ RendererB.media_type
+ )
+ resp = self.client.get('/' + param)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_unsatisfiable_accept_header_on_request_returns_406_status(self):
+ """If the Accept header is unsatisfiable we should return a 406 Not Acceptable response."""
+ resp = self.client.get('/', HTTP_ACCEPT='foo/bar')
+ self.assertEqual(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE)
+
+ def test_specified_renderer_serializes_content_on_format_query(self):
+ """If a 'format' query is specified, the renderer with the matching
+ format attribute should serialize the response."""
+ param = '?%s=%s' % (
+ api_settings.URL_FORMAT_OVERRIDE,
+ RendererB.format
+ )
+ resp = self.client.get('/' + param)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_on_format_kwargs(self):
+ """If a 'format' keyword arg is specified, the renderer with the matching
+ format attribute should serialize the response."""
+ resp = self.client.get('/something.formatb')
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
+ """If both a 'format' query and a matching Accept header specified,
+ the renderer with the matching format attribute should serialize the response."""
+ param = '?%s=%s' % (
+ api_settings.URL_FORMAT_OVERRIDE,
+ RendererB.format
+ )
+ resp = self.client.get('/' + param,
+ HTTP_ACCEPT=RendererB.media_type)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_parse_error_renderers_browsable_api(self):
+ """Invalid data should still render the browsable API correctly."""
+ resp = self.client.post('/parseerror', data='foobar', content_type='application/json', HTTP_ACCEPT='text/html')
+ self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
+ self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST)
+
+ def test_204_no_content_responses_have_no_content_type_set(self):
+ """
+ Regression test for #1196
+
+ https://github.com/tomchristie/django-rest-framework/issues/1196
+ """
+ resp = self.client.get('/empty')
+ self.assertEqual(resp.get('Content-Type', None), None)
+ self.assertEqual(resp.status_code, status.HTTP_204_NO_CONTENT)
+
+ def test_contains_headers_of_api_response(self):
+ """
+ Issue #1437
+
+ Test we display the headers of the API response and not those from the
+ HTML response
+ """
+ resp = self.client.get('/html1')
+ self.assertContains(resp, '>GET, HEAD, OPTIONS<')
+ self.assertContains(resp, '>application/json<')
+ self.assertNotContains(resp, '>text/html; charset=utf-8<')
+
+
+_flat_repr = '{"foo": ["bar", "baz"]}'
+_indented_repr = '{\n "foo": [\n "bar",\n "baz"\n ]\n}'
+
+
+def strip_trailing_whitespace(content):
+ """
+ Seems to be some inconsistencies re. trailing whitespace with
+ different versions of the json lib.
+ """
+ return re.sub(' +\n', '\n', content)
+
+
+class JSONRendererTests(TestCase):
+ """
+ Tests specific to the JSON Renderer
+ """
+
+ def test_render_lazy_strings(self):
+ """
+ JSONRenderer should deal with lazy translated strings.
+ """
+ ret = JSONRenderer().render(_('test'))
+ self.assertEqual(ret, b'"test"')
+
+ def test_render_queryset_values(self):
+ o = DummyTestModel.objects.create(name='dummy')
+ qs = DummyTestModel.objects.values('id', 'name')
+ ret = JSONRenderer().render(qs)
+ data = json.loads(ret.decode('utf-8'))
+ self.assertEquals(data, [{'id': o.id, 'name': o.name}])
+
+ def test_render_queryset_values_list(self):
+ o = DummyTestModel.objects.create(name='dummy')
+ qs = DummyTestModel.objects.values_list('id', 'name')
+ ret = JSONRenderer().render(qs)
+ data = json.loads(ret.decode('utf-8'))
+ self.assertEquals(data, [[o.id, o.name]])
+
+ def test_render_dict_abc_obj(self):
+ class Dict(MutableMapping):
+ def __init__(self):
+ self._dict = dict()
+ def __getitem__(self, key):
+ return self._dict.__getitem__(key)
+ def __setitem__(self, key, value):
+ return self._dict.__setitem__(key, value)
+ def __delitem__(self, key):
+ return self._dict.__delitem__(key)
+ def __iter__(self):
+ return self._dict.__iter__()
+ def __len__(self):
+ return self._dict.__len__()
+ def keys(self):
+ return self._dict.keys()
+
+ x = Dict()
+ x['key'] = 'string value'
+ x[2] = 3
+ ret = JSONRenderer().render(x)
+ data = json.loads(ret.decode('utf-8'))
+ self.assertEquals(data, {'key': 'string value', '2': 3})
+
+ def test_render_obj_with_getitem(self):
+ class DictLike(object):
+ def __init__(self):
+ self._dict = {}
+ def set(self, value):
+ self._dict = dict(value)
+ def __getitem__(self, key):
+ return self._dict[key]
+
+ x = DictLike()
+ x.set({'a': 1, 'b': 'string'})
+ with self.assertRaises(TypeError):
+ JSONRenderer().render(x)
+
+ def test_without_content_type_args(self):
+ """
+ Test basic JSON rendering.
+ """
+ obj = {'foo': ['bar', 'baz']}
+ renderer = JSONRenderer()
+ content = renderer.render(obj, 'application/json')
+ # Fix failing test case which depends on version of JSON library.
+ self.assertEqual(content.decode('utf-8'), _flat_repr)
+
+ def test_with_content_type_args(self):
+ """
+ Test JSON rendering with additional content type arguments supplied.
+ """
+ obj = {'foo': ['bar', 'baz']}
+ renderer = JSONRenderer()
+ content = renderer.render(obj, 'application/json; indent=2')
+ self.assertEqual(strip_trailing_whitespace(content.decode('utf-8')), _indented_repr)
+
+ def test_check_ascii(self):
+ obj = {'countries': ['United Kingdom', 'France', 'España']}
+ renderer = JSONRenderer()
+ content = renderer.render(obj, 'application/json')
+ self.assertEqual(content, '{"countries": ["United Kingdom", "France", "Espa\\u00f1a"]}'.encode('utf-8'))
+
+
+class UnicodeJSONRendererTests(TestCase):
+ """
+ Tests specific for the Unicode JSON Renderer
+ """
+ def test_proper_encoding(self):
+ obj = {'countries': ['United Kingdom', 'France', 'España']}
+ renderer = UnicodeJSONRenderer()
+ content = renderer.render(obj, 'application/json')
+ self.assertEqual(content, '{"countries": ["United Kingdom", "France", "España"]}'.encode('utf-8'))
+
+
+class JSONPRendererTests(TestCase):
+ """
+ Tests specific to the JSONP Renderer
+ """
+
+ urls = 'tests.test_renderers'
+
+ def test_without_callback_with_json_renderer(self):
+ """
+ Test JSONP rendering with View JSON Renderer.
+ """
+ resp = self.client.get('/jsonp/jsonrenderer',
+ HTTP_ACCEPT='application/javascript')
+ self.assertEqual(resp.status_code, status.HTTP_200_OK)
+ self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8')
+ self.assertEqual(resp.content,
+ ('callback(%s);' % _flat_repr).encode('ascii'))
+
+ def test_without_callback_without_json_renderer(self):
+ """
+ Test JSONP rendering without View JSON Renderer.
+ """
+ resp = self.client.get('/jsonp/nojsonrenderer',
+ HTTP_ACCEPT='application/javascript')
+ self.assertEqual(resp.status_code, status.HTTP_200_OK)
+ self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8')
+ self.assertEqual(resp.content,
+ ('callback(%s);' % _flat_repr).encode('ascii'))
+
+ def test_with_callback(self):
+ """
+ Test JSONP rendering with callback function name.
+ """
+ callback_func = 'myjsonpcallback'
+ resp = self.client.get('/jsonp/nojsonrenderer?callback=' + callback_func,
+ HTTP_ACCEPT='application/javascript')
+ self.assertEqual(resp.status_code, status.HTTP_200_OK)
+ self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8')
+ self.assertEqual(resp.content,
+ ('%s(%s);' % (callback_func, _flat_repr)).encode('ascii'))
+
+
+if yaml:
+ _yaml_repr = 'foo: [bar, baz]\n'
+
+ class YAMLRendererTests(TestCase):
+ """
+ Tests specific to the YAML Renderer
+ """
+
+ def test_render(self):
+ """
+ Test basic YAML rendering.
+ """
+ obj = {'foo': ['bar', 'baz']}
+ renderer = YAMLRenderer()
+ content = renderer.render(obj, 'application/yaml')
+ self.assertEqual(content, _yaml_repr)
+
+ def test_render_and_parse(self):
+ """
+ Test rendering and then parsing returns the original object.
+ IE obj -> render -> parse -> obj.
+ """
+ obj = {'foo': ['bar', 'baz']}
+
+ renderer = YAMLRenderer()
+ parser = YAMLParser()
+
+ content = renderer.render(obj, 'application/yaml')
+ data = parser.parse(StringIO(content))
+ self.assertEqual(obj, data)
+
+ def test_render_decimal(self):
+ """
+ Test YAML decimal rendering.
+ """
+ renderer = YAMLRenderer()
+ content = renderer.render({'field': Decimal('111.2')}, 'application/yaml')
+ self.assertYAMLContains(content, "field: '111.2'")
+
+ def assertYAMLContains(self, content, string):
+ self.assertTrue(string in content, '%r not in %r' % (string, content))
+
+
+class XMLRendererTestCase(TestCase):
+ """
+ Tests specific to the XML Renderer
+ """
+
+ _complex_data = {
+ "creation_date": datetime.datetime(2011, 12, 25, 12, 45, 00),
+ "name": "name",
+ "sub_data_list": [
+ {
+ "sub_id": 1,
+ "sub_name": "first"
+ },
+ {
+ "sub_id": 2,
+ "sub_name": "second"
+ }
+ ]
+ }
+
+ def test_render_string(self):
+ """
+ Test XML rendering.
+ """
+ renderer = XMLRenderer()
+ content = renderer.render({'field': 'astring'}, 'application/xml')
+ self.assertXMLContains(content, 'astring')
+
+ def test_render_integer(self):
+ """
+ Test XML rendering.
+ """
+ renderer = XMLRenderer()
+ content = renderer.render({'field': 111}, 'application/xml')
+ self.assertXMLContains(content, '111')
+
+ def test_render_datetime(self):
+ """
+ Test XML rendering.
+ """
+ renderer = XMLRenderer()
+ content = renderer.render({
+ 'field': datetime.datetime(2011, 12, 25, 12, 45, 00)
+ }, 'application/xml')
+ self.assertXMLContains(content, '2011-12-25 12:45:00')
+
+ def test_render_float(self):
+ """
+ Test XML rendering.
+ """
+ renderer = XMLRenderer()
+ content = renderer.render({'field': 123.4}, 'application/xml')
+ self.assertXMLContains(content, '123.4')
+
+ def test_render_decimal(self):
+ """
+ Test XML rendering.
+ """
+ renderer = XMLRenderer()
+ content = renderer.render({'field': Decimal('111.2')}, 'application/xml')
+ self.assertXMLContains(content, '111.2')
+
+ def test_render_none(self):
+ """
+ Test XML rendering.
+ """
+ renderer = XMLRenderer()
+ content = renderer.render({'field': None}, 'application/xml')
+ self.assertXMLContains(content, '')
+
+ def test_render_complex_data(self):
+ """
+ Test XML rendering.
+ """
+ renderer = XMLRenderer()
+ content = renderer.render(self._complex_data, 'application/xml')
+ self.assertXMLContains(content, 'first')
+ self.assertXMLContains(content, 'second')
+
+ @unittest.skipUnless(etree, 'defusedxml not installed')
+ def test_render_and_parse_complex_data(self):
+ """
+ Test XML rendering.
+ """
+ renderer = XMLRenderer()
+ content = StringIO(renderer.render(self._complex_data, 'application/xml'))
+
+ parser = XMLParser()
+ complex_data_out = parser.parse(content)
+ error_msg = "complex data differs!IN:\n %s \n\n OUT:\n %s" % (repr(self._complex_data), repr(complex_data_out))
+ self.assertEqual(self._complex_data, complex_data_out, error_msg)
+
+ def assertXMLContains(self, xml, string):
+ self.assertTrue(xml.startswith('\n'))
+ self.assertTrue(xml.endswith(''))
+ self.assertTrue(string in xml, '%r not in %r' % (string, xml))
+
+
+# Tests for caching issue, #346
+class CacheRenderTest(TestCase):
+ """
+ Tests specific to caching responses
+ """
+
+ urls = 'tests.test_renderers'
+
+ cache_key = 'just_a_cache_key'
+
+ @classmethod
+ def _get_pickling_errors(cls, obj, seen=None):
+ """ Return any errors that would be raised if `obj' is pickled
+ Courtesy of koffie @ http://stackoverflow.com/a/7218986/109897
+ """
+ if seen == None:
+ seen = []
+ try:
+ state = obj.__getstate__()
+ except AttributeError:
+ return
+ if state == None:
+ return
+ if isinstance(state, tuple):
+ if not isinstance(state[0], dict):
+ state = state[1]
+ else:
+ state = state[0].update(state[1])
+ result = {}
+ for i in state:
+ try:
+ pickle.dumps(state[i], protocol=2)
+ except pickle.PicklingError:
+ if not state[i] in seen:
+ seen.append(state[i])
+ result[i] = cls._get_pickling_errors(state[i], seen)
+ return result
+
+ def http_resp(self, http_method, url):
+ """
+ Simple wrapper for Client http requests
+ Removes the `client' and `request' attributes from as they are
+ added by django.test.client.Client and not part of caching
+ responses outside of tests.
+ """
+ method = getattr(self.client, http_method)
+ resp = method(url)
+ del resp.client, resp.request
+ return resp
+
+ def test_obj_pickling(self):
+ """
+ Test that responses are properly pickled
+ """
+ resp = self.http_resp('get', '/cache')
+
+ # Make sure that no pickling errors occurred
+ self.assertEqual(self._get_pickling_errors(resp), {})
+
+ # Unfortunately LocMem backend doesn't raise PickleErrors but returns
+ # None instead.
+ cache.set(self.cache_key, resp)
+ self.assertTrue(cache.get(self.cache_key) is not None)
+
+ def test_head_caching(self):
+ """
+ Test caching of HEAD requests
+ """
+ resp = self.http_resp('head', '/cache')
+ cache.set(self.cache_key, resp)
+
+ cached_resp = cache.get(self.cache_key)
+ self.assertIsInstance(cached_resp, Response)
+
+ def test_get_caching(self):
+ """
+ Test caching of GET requests
+ """
+ resp = self.http_resp('get', '/cache')
+ cache.set(self.cache_key, resp)
+
+ cached_resp = cache.get(self.cache_key)
+ self.assertIsInstance(cached_resp, Response)
+ self.assertEqual(cached_resp.content, resp.content)
diff --git a/tests/test_request.py b/tests/test_request.py
new file mode 100644
index 00000000..0a9355f0
--- /dev/null
+++ b/tests/test_request.py
@@ -0,0 +1,347 @@
+"""
+Tests for content parsing, and form-overloaded content parsing.
+"""
+from __future__ import unicode_literals
+from django.contrib.auth.models import User
+from django.contrib.auth import authenticate, login, logout
+from django.contrib.sessions.middleware import SessionMiddleware
+from django.core.handlers.wsgi import WSGIRequest
+from django.test import TestCase
+from rest_framework import status
+from rest_framework.authentication import SessionAuthentication
+from rest_framework.compat import patterns
+from rest_framework.parsers import (
+ BaseParser,
+ FormParser,
+ MultiPartParser,
+ JSONParser
+)
+from rest_framework.request import Request, Empty
+from rest_framework.response import Response
+from rest_framework.settings import api_settings
+from rest_framework.test import APIRequestFactory, APIClient
+from rest_framework.views import APIView
+from rest_framework.compat import six
+from io import BytesIO
+import json
+
+
+factory = APIRequestFactory()
+
+
+class PlainTextParser(BaseParser):
+ media_type = 'text/plain'
+
+ def parse(self, stream, media_type=None, parser_context=None):
+ """
+ Returns a 2-tuple of `(data, files)`.
+
+ `data` will simply be a string representing the body of the request.
+ `files` will always be `None`.
+ """
+ return stream.read()
+
+
+class TestMethodOverloading(TestCase):
+ def test_method(self):
+ """
+ Request methods should be same as underlying request.
+ """
+ request = Request(factory.get('/'))
+ self.assertEqual(request.method, 'GET')
+ request = Request(factory.post('/'))
+ self.assertEqual(request.method, 'POST')
+
+ def test_overloaded_method(self):
+ """
+ POST requests can be overloaded to another method by setting a
+ reserved form field
+ """
+ request = Request(factory.post('/', {api_settings.FORM_METHOD_OVERRIDE: 'DELETE'}))
+ self.assertEqual(request.method, 'DELETE')
+
+ def test_x_http_method_override_header(self):
+ """
+ POST requests can also be overloaded to another method by setting
+ the X-HTTP-Method-Override header.
+ """
+ request = Request(factory.post('/', {'foo': 'bar'}, HTTP_X_HTTP_METHOD_OVERRIDE='DELETE'))
+ self.assertEqual(request.method, 'DELETE')
+
+ request = Request(factory.get('/', {'foo': 'bar'}, HTTP_X_HTTP_METHOD_OVERRIDE='DELETE'))
+ self.assertEqual(request.method, 'DELETE')
+
+
+class TestContentParsing(TestCase):
+ def test_standard_behaviour_determines_no_content_GET(self):
+ """
+ Ensure request.DATA returns empty QueryDict for GET request.
+ """
+ request = Request(factory.get('/'))
+ self.assertEqual(request.DATA, {})
+
+ def test_standard_behaviour_determines_no_content_HEAD(self):
+ """
+ Ensure request.DATA returns empty QueryDict for HEAD request.
+ """
+ request = Request(factory.head('/'))
+ self.assertEqual(request.DATA, {})
+
+ def test_request_DATA_with_form_content(self):
+ """
+ Ensure request.DATA returns content for POST request with form content.
+ """
+ data = {'qwerty': 'uiop'}
+ request = Request(factory.post('/', data))
+ request.parsers = (FormParser(), MultiPartParser())
+ self.assertEqual(list(request.DATA.items()), list(data.items()))
+
+ def test_request_DATA_with_text_content(self):
+ """
+ Ensure request.DATA returns content for POST request with
+ non-form content.
+ """
+ content = six.b('qwerty')
+ content_type = 'text/plain'
+ request = Request(factory.post('/', content, content_type=content_type))
+ request.parsers = (PlainTextParser(),)
+ self.assertEqual(request.DATA, content)
+
+ def test_request_POST_with_form_content(self):
+ """
+ Ensure request.POST returns content for POST request with form content.
+ """
+ data = {'qwerty': 'uiop'}
+ request = Request(factory.post('/', data))
+ request.parsers = (FormParser(), MultiPartParser())
+ self.assertEqual(list(request.POST.items()), list(data.items()))
+
+ def test_standard_behaviour_determines_form_content_PUT(self):
+ """
+ Ensure request.DATA returns content for PUT request with form content.
+ """
+ data = {'qwerty': 'uiop'}
+ request = Request(factory.put('/', data))
+ request.parsers = (FormParser(), MultiPartParser())
+ self.assertEqual(list(request.DATA.items()), list(data.items()))
+
+ def test_standard_behaviour_determines_non_form_content_PUT(self):
+ """
+ Ensure request.DATA returns content for PUT request with
+ non-form content.
+ """
+ content = six.b('qwerty')
+ content_type = 'text/plain'
+ request = Request(factory.put('/', content, content_type=content_type))
+ request.parsers = (PlainTextParser(), )
+ self.assertEqual(request.DATA, content)
+
+ def test_overloaded_behaviour_allows_content_tunnelling(self):
+ """
+ Ensure request.DATA returns content for overloaded POST request.
+ """
+ json_data = {'foobar': 'qwerty'}
+ content = json.dumps(json_data)
+ content_type = 'application/json'
+ form_data = {
+ api_settings.FORM_CONTENT_OVERRIDE: content,
+ api_settings.FORM_CONTENTTYPE_OVERRIDE: content_type
+ }
+ request = Request(factory.post('/', form_data))
+ request.parsers = (JSONParser(), )
+ self.assertEqual(request.DATA, json_data)
+
+ def test_form_POST_unicode(self):
+ """
+ JSON POST via default web interface with unicode data
+ """
+ # Note: environ and other variables here have simplified content compared to real Request
+ CONTENT = b'_content_type=application%2Fjson&_content=%7B%22request%22%3A+4%2C+%22firm%22%3A+1%2C+%22text%22%3A+%22%D0%9F%D1%80%D0%B8%D0%B2%D0%B5%D1%82%21%22%7D'
+ environ = {
+ 'REQUEST_METHOD': 'POST',
+ 'CONTENT_TYPE': 'application/x-www-form-urlencoded',
+ 'CONTENT_LENGTH': len(CONTENT),
+ 'wsgi.input': BytesIO(CONTENT),
+ }
+ wsgi_request = WSGIRequest(environ=environ)
+ wsgi_request._load_post_and_files()
+ parsers = (JSONParser(), FormParser(), MultiPartParser())
+ parser_context = {
+ 'encoding': 'utf-8',
+ 'kwargs': {},
+ 'args': (),
+ }
+ request = Request(wsgi_request, parsers=parsers, parser_context=parser_context)
+ method = request.method
+ self.assertEqual(method, 'POST')
+ self.assertEqual(request._content_type, 'application/json')
+ self.assertEqual(request._stream.getvalue(), b'{"request": 4, "firm": 1, "text": "\xd0\x9f\xd1\x80\xd0\xb8\xd0\xb2\xd0\xb5\xd1\x82!"}')
+ self.assertEqual(request._data, Empty)
+ self.assertEqual(request._files, Empty)
+
+ # def test_accessing_post_after_data_form(self):
+ # """
+ # Ensures request.POST can be accessed after request.DATA in
+ # form request.
+ # """
+ # data = {'qwerty': 'uiop'}
+ # request = factory.post('/', data=data)
+ # self.assertEqual(request.DATA.items(), data.items())
+ # self.assertEqual(request.POST.items(), data.items())
+
+ # def test_accessing_post_after_data_for_json(self):
+ # """
+ # Ensures request.POST can be accessed after request.DATA in
+ # json request.
+ # """
+ # data = {'qwerty': 'uiop'}
+ # content = json.dumps(data)
+ # content_type = 'application/json'
+ # parsers = (JSONParser, )
+
+ # request = factory.post('/', content, content_type=content_type,
+ # parsers=parsers)
+ # self.assertEqual(request.DATA.items(), data.items())
+ # self.assertEqual(request.POST.items(), [])
+
+ # def test_accessing_post_after_data_for_overloaded_json(self):
+ # """
+ # Ensures request.POST can be accessed after request.DATA in overloaded
+ # json request.
+ # """
+ # data = {'qwerty': 'uiop'}
+ # content = json.dumps(data)
+ # content_type = 'application/json'
+ # parsers = (JSONParser, )
+ # form_data = {Request._CONTENT_PARAM: content,
+ # Request._CONTENTTYPE_PARAM: content_type}
+
+ # request = factory.post('/', form_data, parsers=parsers)
+ # self.assertEqual(request.DATA.items(), data.items())
+ # self.assertEqual(request.POST.items(), form_data.items())
+
+ # def test_accessing_data_after_post_form(self):
+ # """
+ # Ensures request.DATA can be accessed after request.POST in
+ # form request.
+ # """
+ # data = {'qwerty': 'uiop'}
+ # parsers = (FormParser, MultiPartParser)
+ # request = factory.post('/', data, parsers=parsers)
+
+ # self.assertEqual(request.POST.items(), data.items())
+ # self.assertEqual(request.DATA.items(), data.items())
+
+ # def test_accessing_data_after_post_for_json(self):
+ # """
+ # Ensures request.DATA can be accessed after request.POST in
+ # json request.
+ # """
+ # data = {'qwerty': 'uiop'}
+ # content = json.dumps(data)
+ # content_type = 'application/json'
+ # parsers = (JSONParser, )
+ # request = factory.post('/', content, content_type=content_type,
+ # parsers=parsers)
+ # self.assertEqual(request.POST.items(), [])
+ # self.assertEqual(request.DATA.items(), data.items())
+
+ # def test_accessing_data_after_post_for_overloaded_json(self):
+ # """
+ # Ensures request.DATA can be accessed after request.POST in overloaded
+ # json request
+ # """
+ # data = {'qwerty': 'uiop'}
+ # content = json.dumps(data)
+ # content_type = 'application/json'
+ # parsers = (JSONParser, )
+ # form_data = {Request._CONTENT_PARAM: content,
+ # Request._CONTENTTYPE_PARAM: content_type}
+
+ # request = factory.post('/', form_data, parsers=parsers)
+ # self.assertEqual(request.POST.items(), form_data.items())
+ # self.assertEqual(request.DATA.items(), data.items())
+
+
+class MockView(APIView):
+ authentication_classes = (SessionAuthentication,)
+
+ def post(self, request):
+ if request.POST.get('example') is not None:
+ return Response(status=status.HTTP_200_OK)
+
+ return Response(status=status.INTERNAL_SERVER_ERROR)
+
+urlpatterns = patterns('',
+ (r'^$', MockView.as_view()),
+)
+
+
+class TestContentParsingWithAuthentication(TestCase):
+ urls = 'tests.test_request'
+
+ def setUp(self):
+ self.csrf_client = APIClient(enforce_csrf_checks=True)
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ def test_user_logged_in_authentication_has_POST_when_not_logged_in(self):
+ """
+ Ensures request.POST exists after SessionAuthentication when user
+ doesn't log in.
+ """
+ content = {'example': 'example'}
+
+ response = self.client.post('/', content)
+ self.assertEqual(status.HTTP_200_OK, response.status_code)
+
+ response = self.csrf_client.post('/', content)
+ self.assertEqual(status.HTTP_200_OK, response.status_code)
+
+ # def test_user_logged_in_authentication_has_post_when_logged_in(self):
+ # """Ensures request.POST exists after UserLoggedInAuthentication when user does log in"""
+ # self.client.login(username='john', password='password')
+ # self.csrf_client.login(username='john', password='password')
+ # content = {'example': 'example'}
+
+ # response = self.client.post('/', content)
+ # self.assertEqual(status.OK, response.status_code, "POST data is malformed")
+
+ # response = self.csrf_client.post('/', content)
+ # self.assertEqual(status.OK, response.status_code, "POST data is malformed")
+
+
+class TestUserSetter(TestCase):
+
+ def setUp(self):
+ # Pass request object through session middleware so session is
+ # available to login and logout functions
+ self.request = Request(factory.get('/'))
+ SessionMiddleware().process_request(self.request)
+
+ User.objects.create_user('ringo', 'starr@thebeatles.com', 'yellow')
+ self.user = authenticate(username='ringo', password='yellow')
+
+ def test_user_can_be_set(self):
+ self.request.user = self.user
+ self.assertEqual(self.request.user, self.user)
+
+ def test_user_can_login(self):
+ login(self.request, self.user)
+ self.assertEqual(self.request.user, self.user)
+
+ def test_user_can_logout(self):
+ self.request.user = self.user
+ self.assertFalse(self.request.user.is_anonymous())
+ logout(self.request)
+ self.assertTrue(self.request.user.is_anonymous())
+
+
+class TestAuthSetter(TestCase):
+
+ def test_auth_can_be_set(self):
+ request = Request(factory.get('/'))
+ request.auth = 'DUMMY'
+ self.assertEqual(request.auth, 'DUMMY')
diff --git a/tests/test_response.py b/tests/test_response.py
new file mode 100644
index 00000000..41c0f49d
--- /dev/null
+++ b/tests/test_response.py
@@ -0,0 +1,278 @@
+from __future__ import unicode_literals
+from django.test import TestCase
+from tests.models import BasicModel, BasicModelSerializer
+from rest_framework.compat import patterns, url, include
+from rest_framework.response import Response
+from rest_framework.views import APIView
+from rest_framework import generics
+from rest_framework import routers
+from rest_framework import status
+from rest_framework.renderers import (
+ BaseRenderer,
+ JSONRenderer,
+ BrowsableAPIRenderer
+)
+from rest_framework import viewsets
+from rest_framework.settings import api_settings
+from rest_framework.compat import six
+
+
+class MockPickleRenderer(BaseRenderer):
+ media_type = 'application/pickle'
+
+
+class MockJsonRenderer(BaseRenderer):
+ media_type = 'application/json'
+
+
+class MockTextMediaRenderer(BaseRenderer):
+ media_type = 'text/html'
+
+DUMMYSTATUS = status.HTTP_200_OK
+DUMMYCONTENT = 'dummycontent'
+
+RENDERER_A_SERIALIZER = lambda x: ('Renderer A: %s' % x).encode('ascii')
+RENDERER_B_SERIALIZER = lambda x: ('Renderer B: %s' % x).encode('ascii')
+
+
+class RendererA(BaseRenderer):
+ media_type = 'mock/renderera'
+ format = "formata"
+
+ def render(self, data, media_type=None, renderer_context=None):
+ return RENDERER_A_SERIALIZER(data)
+
+
+class RendererB(BaseRenderer):
+ media_type = 'mock/rendererb'
+ format = "formatb"
+
+ def render(self, data, media_type=None, renderer_context=None):
+ return RENDERER_B_SERIALIZER(data)
+
+
+class RendererC(RendererB):
+ media_type = 'mock/rendererc'
+ format = 'formatc'
+ charset = "rendererc"
+
+
+class MockView(APIView):
+ renderer_classes = (RendererA, RendererB, RendererC)
+
+ def get(self, request, **kwargs):
+ return Response(DUMMYCONTENT, status=DUMMYSTATUS)
+
+
+class MockViewSettingContentType(APIView):
+ renderer_classes = (RendererA, RendererB, RendererC)
+
+ def get(self, request, **kwargs):
+ return Response(DUMMYCONTENT, status=DUMMYSTATUS, content_type='setbyview')
+
+
+class HTMLView(APIView):
+ renderer_classes = (BrowsableAPIRenderer, )
+
+ def get(self, request, **kwargs):
+ return Response('text')
+
+
+class HTMLView1(APIView):
+ renderer_classes = (BrowsableAPIRenderer, JSONRenderer)
+
+ def get(self, request, **kwargs):
+ return Response('text')
+
+
+class HTMLNewModelViewSet(viewsets.ModelViewSet):
+ model = BasicModel
+
+
+class HTMLNewModelView(generics.ListCreateAPIView):
+ renderer_classes = (BrowsableAPIRenderer,)
+ permission_classes = []
+ serializer_class = BasicModelSerializer
+ model = BasicModel
+
+
+new_model_viewset_router = routers.DefaultRouter()
+new_model_viewset_router.register(r'', HTMLNewModelViewSet)
+
+
+urlpatterns = patterns('',
+ url(r'^setbyview$', MockViewSettingContentType.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
+ url(r'^.*\.(?P.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
+ url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
+ url(r'^html$', HTMLView.as_view()),
+ url(r'^html1$', HTMLView1.as_view()),
+ url(r'^html_new_model$', HTMLNewModelView.as_view()),
+ url(r'^html_new_model_viewset', include(new_model_viewset_router.urls)),
+ url(r'^restframework', include('rest_framework.urls', namespace='rest_framework'))
+)
+
+
+# TODO: Clean tests bellow - remove duplicates with above, better unit testing, ...
+class RendererIntegrationTests(TestCase):
+ """
+ End-to-end testing of renderers using an ResponseMixin on a generic view.
+ """
+
+ urls = 'tests.test_response'
+
+ def test_default_renderer_serializes_content(self):
+ """If the Accept header is not set the default renderer should serialize the response."""
+ resp = self.client.get('/')
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_head_method_serializes_no_content(self):
+ """No response must be included in HEAD requests."""
+ resp = self.client.head('/')
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, six.b(''))
+
+ def test_default_renderer_serializes_content_on_accept_any(self):
+ """If the Accept header is set to */* the default renderer should serialize the response."""
+ resp = self.client.get('/', HTTP_ACCEPT='*/*')
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_default_case(self):
+ """If the Accept header is set the specified renderer should serialize the response.
+ (In this case we check that works for the default renderer)"""
+ resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_non_default_case(self):
+ """If the Accept header is set the specified renderer should serialize the response.
+ (In this case we check that works for a non-default renderer)"""
+ resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_on_accept_query(self):
+ """The '_accept' query string should behave in the same way as the Accept header."""
+ param = '?%s=%s' % (
+ api_settings.URL_ACCEPT_OVERRIDE,
+ RendererB.media_type
+ )
+ resp = self.client.get('/' + param)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_on_format_query(self):
+ """If a 'format' query is specified, the renderer with the matching
+ format attribute should serialize the response."""
+ resp = self.client.get('/?format=%s' % RendererB.format)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_on_format_kwargs(self):
+ """If a 'format' keyword arg is specified, the renderer with the matching
+ format attribute should serialize the response."""
+ resp = self.client.get('/something.formatb')
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
+ """If both a 'format' query and a matching Accept header specified,
+ the renderer with the matching format attribute should serialize the response."""
+ resp = self.client.get('/?format=%s' % RendererB.format,
+ HTTP_ACCEPT=RendererB.media_type)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+
+class Issue122Tests(TestCase):
+ """
+ Tests that covers #122.
+ """
+ urls = 'tests.test_response'
+
+ def test_only_html_renderer(self):
+ """
+ Test if no infinite recursion occurs.
+ """
+ self.client.get('/html')
+
+ def test_html_renderer_is_first(self):
+ """
+ Test if no infinite recursion occurs.
+ """
+ self.client.get('/html1')
+
+
+class Issue467Tests(TestCase):
+ """
+ Tests for #467
+ """
+
+ urls = 'tests.test_response'
+
+ def test_form_has_label_and_help_text(self):
+ resp = self.client.get('/html_new_model')
+ self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
+ self.assertContains(resp, 'Text comes here')
+ self.assertContains(resp, 'Text description.')
+
+
+class Issue807Tests(TestCase):
+ """
+ Covers #807
+ """
+
+ urls = 'tests.test_response'
+
+ def test_does_not_append_charset_by_default(self):
+ """
+ Renderers don't include a charset unless set explicitly.
+ """
+ headers = {"HTTP_ACCEPT": RendererA.media_type}
+ resp = self.client.get('/', **headers)
+ expected = "{0}; charset={1}".format(RendererA.media_type, 'utf-8')
+ self.assertEqual(expected, resp['Content-Type'])
+
+ def test_if_there_is_charset_specified_on_renderer_it_gets_appended(self):
+ """
+ If renderer class has charset attribute declared, it gets appended
+ to Response's Content-Type
+ """
+ headers = {"HTTP_ACCEPT": RendererC.media_type}
+ resp = self.client.get('/', **headers)
+ expected = "{0}; charset={1}".format(RendererC.media_type, RendererC.charset)
+ self.assertEqual(expected, resp['Content-Type'])
+
+ def test_content_type_set_explictly_on_response(self):
+ """
+ The content type may be set explictly on the response.
+ """
+ headers = {"HTTP_ACCEPT": RendererC.media_type}
+ resp = self.client.get('/setbyview', **headers)
+ self.assertEqual('setbyview', resp['Content-Type'])
+
+ def test_viewset_label_help_text(self):
+ param = '?%s=%s' % (
+ api_settings.URL_ACCEPT_OVERRIDE,
+ 'text/html'
+ )
+ resp = self.client.get('/html_new_model_viewset/' + param)
+ self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
+ self.assertContains(resp, 'Text comes here')
+ self.assertContains(resp, 'Text description.')
+
+ def test_form_has_label_and_help_text(self):
+ resp = self.client.get('/html_new_model')
+ self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
+ self.assertContains(resp, 'Text comes here')
+ self.assertContains(resp, 'Text description.')
diff --git a/tests/test_reverse.py b/tests/test_reverse.py
new file mode 100644
index 00000000..3d14a28f
--- /dev/null
+++ b/tests/test_reverse.py
@@ -0,0 +1,27 @@
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework.compat import patterns, url
+from rest_framework.reverse import reverse
+from rest_framework.test import APIRequestFactory
+
+factory = APIRequestFactory()
+
+
+def null_view(request):
+ pass
+
+urlpatterns = patterns('',
+ url(r'^view$', null_view, name='view'),
+)
+
+
+class ReverseTests(TestCase):
+ """
+ Tests for fully qualified URLs when using `reverse`.
+ """
+ urls = 'tests.test_reverse'
+
+ def test_reversed_urls_are_fully_qualified(self):
+ request = factory.get('/view')
+ url = reverse('view', request=request)
+ self.assertEqual(url, 'http://testserver/view')
diff --git a/tests/test_routers.py b/tests/test_routers.py
new file mode 100644
index 00000000..084c0e27
--- /dev/null
+++ b/tests/test_routers.py
@@ -0,0 +1,216 @@
+from __future__ import unicode_literals
+from django.db import models
+from django.test import TestCase
+from django.core.exceptions import ImproperlyConfigured
+from rest_framework import serializers, viewsets, permissions
+from rest_framework.compat import include, patterns, url
+from rest_framework.decorators import link, action
+from rest_framework.response import Response
+from rest_framework.routers import SimpleRouter, DefaultRouter
+from rest_framework.test import APIRequestFactory
+
+factory = APIRequestFactory()
+
+urlpatterns = patterns('',)
+
+
+class BasicViewSet(viewsets.ViewSet):
+ def list(self, request, *args, **kwargs):
+ return Response({'method': 'list'})
+
+ @action()
+ def action1(self, request, *args, **kwargs):
+ return Response({'method': 'action1'})
+
+ @action()
+ def action2(self, request, *args, **kwargs):
+ return Response({'method': 'action2'})
+
+ @action(methods=['post', 'delete'])
+ def action3(self, request, *args, **kwargs):
+ return Response({'method': 'action2'})
+
+ @link()
+ def link1(self, request, *args, **kwargs):
+ return Response({'method': 'link1'})
+
+ @link()
+ def link2(self, request, *args, **kwargs):
+ return Response({'method': 'link2'})
+
+
+class TestSimpleRouter(TestCase):
+ def setUp(self):
+ self.router = SimpleRouter()
+
+ def test_link_and_action_decorator(self):
+ routes = self.router.get_routes(BasicViewSet)
+ decorator_routes = routes[2:]
+ # Make sure all these endpoints exist and none have been clobbered
+ for i, endpoint in enumerate(['action1', 'action2', 'action3', 'link1', 'link2']):
+ route = decorator_routes[i]
+ # check url listing
+ self.assertEqual(route.url,
+ '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(endpoint))
+ # check method to function mapping
+ if endpoint == 'action3':
+ methods_map = ['post', 'delete']
+ elif endpoint.startswith('action'):
+ methods_map = ['post']
+ else:
+ methods_map = ['get']
+ for method in methods_map:
+ self.assertEqual(route.mapping[method], endpoint)
+
+
+class RouterTestModel(models.Model):
+ uuid = models.CharField(max_length=20)
+ text = models.CharField(max_length=200)
+
+
+class TestCustomLookupFields(TestCase):
+ """
+ Ensure that custom lookup fields are correctly routed.
+ """
+ urls = 'tests.test_routers'
+
+ def setUp(self):
+ class NoteSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = RouterTestModel
+ lookup_field = 'uuid'
+ fields = ('url', 'uuid', 'text')
+
+ class NoteViewSet(viewsets.ModelViewSet):
+ queryset = RouterTestModel.objects.all()
+ serializer_class = NoteSerializer
+ lookup_field = 'uuid'
+
+ RouterTestModel.objects.create(uuid='123', text='foo bar')
+
+ self.router = SimpleRouter()
+ self.router.register(r'notes', NoteViewSet)
+
+ from tests import test_routers
+ urls = getattr(test_routers, 'urlpatterns')
+ urls += patterns('',
+ url(r'^', include(self.router.urls)),
+ )
+
+ def test_custom_lookup_field_route(self):
+ detail_route = self.router.urls[-1]
+ detail_url_pattern = detail_route.regex.pattern
+ self.assertIn('', detail_url_pattern)
+
+ def test_retrieve_lookup_field_list_view(self):
+ response = self.client.get('/notes/')
+ self.assertEqual(response.data,
+ [{
+ "url": "http://testserver/notes/123/",
+ "uuid": "123", "text": "foo bar"
+ }]
+ )
+
+ def test_retrieve_lookup_field_detail_view(self):
+ response = self.client.get('/notes/123/')
+ self.assertEqual(response.data,
+ {
+ "url": "http://testserver/notes/123/",
+ "uuid": "123", "text": "foo bar"
+ }
+ )
+
+
+class TestTrailingSlashIncluded(TestCase):
+ def setUp(self):
+ class NoteViewSet(viewsets.ModelViewSet):
+ model = RouterTestModel
+
+ self.router = SimpleRouter()
+ self.router.register(r'notes', NoteViewSet)
+ self.urls = self.router.urls
+
+ def test_urls_have_trailing_slash_by_default(self):
+ expected = ['^notes/$', '^notes/(?P[^/]+)/$']
+ for idx in range(len(expected)):
+ self.assertEqual(expected[idx], self.urls[idx].regex.pattern)
+
+
+class TestTrailingSlashRemoved(TestCase):
+ def setUp(self):
+ class NoteViewSet(viewsets.ModelViewSet):
+ model = RouterTestModel
+
+ self.router = SimpleRouter(trailing_slash=False)
+ self.router.register(r'notes', NoteViewSet)
+ self.urls = self.router.urls
+
+ def test_urls_can_have_trailing_slash_removed(self):
+ expected = ['^notes$', '^notes/(?P[^/.]+)$']
+ for idx in range(len(expected)):
+ self.assertEqual(expected[idx], self.urls[idx].regex.pattern)
+
+
+class TestNameableRoot(TestCase):
+ def setUp(self):
+ class NoteViewSet(viewsets.ModelViewSet):
+ model = RouterTestModel
+ self.router = DefaultRouter()
+ self.router.root_view_name = 'nameable-root'
+ self.router.register(r'notes', NoteViewSet)
+ self.urls = self.router.urls
+
+ def test_router_has_custom_name(self):
+ expected = 'nameable-root'
+ self.assertEqual(expected, self.urls[0].name)
+
+
+class TestActionKeywordArgs(TestCase):
+ """
+ Ensure keyword arguments passed in the `@action` decorator
+ are properly handled. Refs #940.
+ """
+
+ def setUp(self):
+ class TestViewSet(viewsets.ModelViewSet):
+ permission_classes = []
+
+ @action(permission_classes=[permissions.AllowAny])
+ def custom(self, request, *args, **kwargs):
+ return Response({
+ 'permission_classes': self.permission_classes
+ })
+
+ self.router = SimpleRouter()
+ self.router.register(r'test', TestViewSet, base_name='test')
+ self.view = self.router.urls[-1].callback
+
+ def test_action_kwargs(self):
+ request = factory.post('/test/0/custom/')
+ response = self.view(request)
+ self.assertEqual(
+ response.data,
+ {'permission_classes': [permissions.AllowAny]}
+ )
+
+
+class TestActionAppliedToExistingRoute(TestCase):
+ """
+ Ensure `@action` decorator raises an except when applied
+ to an existing route
+ """
+
+ def test_exception_raised_when_action_applied_to_existing_route(self):
+ class TestViewSet(viewsets.ModelViewSet):
+
+ @action()
+ def retrieve(self, request, *args, **kwargs):
+ return Response({
+ 'hello': 'world'
+ })
+
+ self.router = SimpleRouter()
+ self.router.register(r'test', TestViewSet, base_name='test')
+
+ with self.assertRaises(ImproperlyConfigured):
+ self.router.urls
diff --git a/tests/test_serializer.py b/tests/test_serializer.py
new file mode 100644
index 00000000..18484afe
--- /dev/null
+++ b/tests/test_serializer.py
@@ -0,0 +1,1857 @@
+# -*- coding: utf-8 -*-
+from __future__ import unicode_literals
+from django.db import models
+from django.db.models.fields import BLANK_CHOICE_DASH
+from django.test import TestCase
+from django.utils.datastructures import MultiValueDict
+from django.utils.translation import ugettext_lazy as _
+from rest_framework import serializers, fields, relations
+from tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel,
+ BlankFieldModel, BlogPost, BlogPostComment, Book, CallableDefaultValueModel, DefaultValueModel,
+ ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo, RESTFrameworkModel)
+from tests.models import BasicModelSerializer
+import datetime
+import pickle
+
+
+class SubComment(object):
+ def __init__(self, sub_comment):
+ self.sub_comment = sub_comment
+
+
+class Comment(object):
+ def __init__(self, email, content, created):
+ self.email = email
+ self.content = content
+ self.created = created or datetime.datetime.now()
+
+ def __eq__(self, other):
+ return all([getattr(self, attr) == getattr(other, attr)
+ for attr in ('email', 'content', 'created')])
+
+ def get_sub_comment(self):
+ sub_comment = SubComment('And Merry Christmas!')
+ return sub_comment
+
+
+class CommentSerializer(serializers.Serializer):
+ email = serializers.EmailField()
+ content = serializers.CharField(max_length=1000)
+ created = serializers.DateTimeField()
+ sub_comment = serializers.Field(source='get_sub_comment.sub_comment')
+
+ def restore_object(self, data, instance=None):
+ if instance is None:
+ return Comment(**data)
+ for key, val in data.items():
+ setattr(instance, key, val)
+ return instance
+
+
+class NamesSerializer(serializers.Serializer):
+ first = serializers.CharField()
+ last = serializers.CharField(required=False, default='')
+ initials = serializers.CharField(required=False, default='')
+
+
+class PersonIdentifierSerializer(serializers.Serializer):
+ ssn = serializers.CharField()
+ names = NamesSerializer(source='names', required=False)
+
+
+class BookSerializer(serializers.ModelSerializer):
+ isbn = serializers.RegexField(regex=r'^[0-9]{13}$', error_messages={'invalid': 'isbn has to be exact 13 numbers'})
+
+ class Meta:
+ model = Book
+
+
+class ActionItemSerializer(serializers.ModelSerializer):
+
+ class Meta:
+ model = ActionItem
+
+class ActionItemSerializerOptionalFields(serializers.ModelSerializer):
+ """
+ Intended to test that fields with `required=False` are excluded from validation.
+ """
+ title = serializers.CharField(required=False)
+
+ class Meta:
+ model = ActionItem
+ fields = ('title',)
+
+class ActionItemSerializerCustomRestore(serializers.ModelSerializer):
+
+ class Meta:
+ model = ActionItem
+
+ def restore_object(self, data, instance=None):
+ if instance is None:
+ return ActionItem(**data)
+ for key, val in data.items():
+ setattr(instance, key, val)
+ return instance
+
+
+class PersonSerializer(serializers.ModelSerializer):
+ info = serializers.Field(source='info')
+
+ class Meta:
+ model = Person
+ fields = ('name', 'age', 'info')
+ read_only_fields = ('age',)
+
+
+class NestedSerializer(serializers.Serializer):
+ info = serializers.Field()
+
+
+class ModelSerializerWithNestedSerializer(serializers.ModelSerializer):
+ nested = NestedSerializer(source='*')
+
+ class Meta:
+ model = Person
+
+
+class NestedSerializerWithRenamedField(serializers.Serializer):
+ renamed_info = serializers.Field(source='info')
+
+
+class ModelSerializerWithNestedSerializerWithRenamedField(serializers.ModelSerializer):
+ nested = NestedSerializerWithRenamedField(source='*')
+
+ class Meta:
+ model = Person
+
+
+class PersonSerializerInvalidReadOnly(serializers.ModelSerializer):
+ """
+ Testing for #652.
+ """
+ info = serializers.Field(source='info')
+
+ class Meta:
+ model = Person
+ fields = ('name', 'age', 'info')
+ read_only_fields = ('age', 'info')
+
+
+class AlbumsSerializer(serializers.ModelSerializer):
+
+ class Meta:
+ model = Album
+ fields = ['title'] # lists are also valid options
+
+
+class PositiveIntegerAsChoiceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = HasPositiveIntegerAsChoice
+ fields = ['some_integer']
+
+
+class BasicTests(TestCase):
+ def setUp(self):
+ self.comment = Comment(
+ 'tom@example.com',
+ 'Happy new year!',
+ datetime.datetime(2012, 1, 1)
+ )
+ self.actionitem = ActionItem(title='Some to do item',)
+ self.data = {
+ 'email': 'tom@example.com',
+ 'content': 'Happy new year!',
+ 'created': datetime.datetime(2012, 1, 1),
+ 'sub_comment': 'This wont change'
+ }
+ self.expected = {
+ 'email': 'tom@example.com',
+ 'content': 'Happy new year!',
+ 'created': datetime.datetime(2012, 1, 1),
+ 'sub_comment': 'And Merry Christmas!'
+ }
+ self.person_data = {'name': 'dwight', 'age': 35}
+ self.person = Person(**self.person_data)
+ self.person.save()
+
+ def test_empty(self):
+ serializer = CommentSerializer()
+ expected = {
+ 'email': '',
+ 'content': '',
+ 'created': None
+ }
+ self.assertEqual(serializer.data, expected)
+
+ def test_retrieve(self):
+ serializer = CommentSerializer(self.comment)
+ self.assertEqual(serializer.data, self.expected)
+
+ def test_create(self):
+ serializer = CommentSerializer(data=self.data)
+ expected = self.comment
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, expected)
+ self.assertFalse(serializer.object is expected)
+ self.assertEqual(serializer.data['sub_comment'], 'And Merry Christmas!')
+
+ def test_create_nested(self):
+ """Test a serializer with nested data."""
+ names = {'first': 'John', 'last': 'Doe', 'initials': 'jd'}
+ data = {'ssn': '1234567890', 'names': names}
+ serializer = PersonIdentifierSerializer(data=data)
+
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, data)
+ self.assertFalse(serializer.object is data)
+ self.assertEqual(serializer.data['names'], names)
+
+ def test_create_partial_nested(self):
+ """Test a serializer with nested data which has missing fields."""
+ names = {'first': 'John'}
+ data = {'ssn': '1234567890', 'names': names}
+ serializer = PersonIdentifierSerializer(data=data)
+
+ expected_names = {'first': 'John', 'last': '', 'initials': ''}
+ data['names'] = expected_names
+
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, data)
+ self.assertFalse(serializer.object is expected_names)
+ self.assertEqual(serializer.data['names'], expected_names)
+
+ def test_null_nested(self):
+ """Test a serializer with a nonexistent nested field"""
+ data = {'ssn': '1234567890'}
+ serializer = PersonIdentifierSerializer(data=data)
+
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, data)
+ self.assertFalse(serializer.object is data)
+ expected = {'ssn': '1234567890', 'names': None}
+ self.assertEqual(serializer.data, expected)
+
+ def test_update(self):
+ serializer = CommentSerializer(self.comment, data=self.data)
+ expected = self.comment
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, expected)
+ self.assertTrue(serializer.object is expected)
+ self.assertEqual(serializer.data['sub_comment'], 'And Merry Christmas!')
+
+ def test_partial_update(self):
+ msg = 'Merry New Year!'
+ partial_data = {'content': msg}
+ serializer = CommentSerializer(self.comment, data=partial_data)
+ self.assertEqual(serializer.is_valid(), False)
+ serializer = CommentSerializer(self.comment, data=partial_data, partial=True)
+ expected = self.comment
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, expected)
+ self.assertTrue(serializer.object is expected)
+ self.assertEqual(serializer.data['content'], msg)
+
+ def test_model_fields_as_expected(self):
+ """
+ Make sure that the fields returned are the same as defined
+ in the Meta data
+ """
+ serializer = PersonSerializer(self.person)
+ self.assertEqual(set(serializer.data.keys()),
+ set(['name', 'age', 'info']))
+
+ def test_field_with_dictionary(self):
+ """
+ Make sure that dictionaries from fields are left intact
+ """
+ serializer = PersonSerializer(self.person)
+ expected = self.person_data
+ self.assertEqual(serializer.data['info'], expected)
+
+ def test_read_only_fields(self):
+ """
+ Attempting to update fields set as read_only should have no effect.
+ """
+ serializer = PersonSerializer(self.person, data={'name': 'dwight', 'age': 99})
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEqual(serializer.errors, {})
+ # Assert age is unchanged (35)
+ self.assertEqual(instance.age, self.person_data['age'])
+
+ def test_invalid_read_only_fields(self):
+ """
+ Regression test for #652.
+ """
+ self.assertRaises(AssertionError, PersonSerializerInvalidReadOnly, [])
+
+ def test_serializer_data_is_cleared_on_save(self):
+ """
+ Check _data attribute is cleared on `save()`
+
+ Regression test for #1116
+ — id field is not populated if `data` is accessed prior to `save()`
+ """
+ serializer = ActionItemSerializer(self.actionitem)
+ self.assertIsNone(serializer.data.get('id',None), 'New instance. `id` should not be set.')
+ serializer.save()
+ self.assertIsNotNone(serializer.data.get('id',None), 'Model is saved. `id` should be set.')
+
+ def test_fields_marked_as_not_required_are_excluded_from_validation(self):
+ """
+ Check that fields with `required=False` are included in list of exclusions.
+ """
+ serializer = ActionItemSerializerOptionalFields(self.actionitem)
+ exclusions = serializer.get_validation_exclusions()
+ self.assertTrue('title' in exclusions, '`title` field was marked `required=False` and should be excluded')
+
+
+class DictStyleSerializer(serializers.Serializer):
+ """
+ Note that we don't have any `restore_object` method, so the default
+ case of simply returning a dict will apply.
+ """
+ email = serializers.EmailField()
+
+
+class DictStyleSerializerTests(TestCase):
+ def test_dict_style_deserialize(self):
+ """
+ Ensure serializers can deserialize into a dict.
+ """
+ data = {'email': 'foo@example.com'}
+ serializer = DictStyleSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, data)
+
+ def test_dict_style_serialize(self):
+ """
+ Ensure serializers can serialize dict objects.
+ """
+ data = {'email': 'foo@example.com'}
+ serializer = DictStyleSerializer(data)
+ self.assertEqual(serializer.data, data)
+
+
+class ValidationTests(TestCase):
+ def setUp(self):
+ self.comment = Comment(
+ 'tom@example.com',
+ 'Happy new year!',
+ datetime.datetime(2012, 1, 1)
+ )
+ self.data = {
+ 'email': 'tom@example.com',
+ 'content': 'x' * 1001,
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+ self.actionitem = ActionItem(title='Some to do item',)
+
+ def test_create(self):
+ serializer = CommentSerializer(data=self.data)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, {'content': ['Ensure this value has at most 1000 characters (it has 1001).']})
+
+ def test_update(self):
+ serializer = CommentSerializer(self.comment, data=self.data)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, {'content': ['Ensure this value has at most 1000 characters (it has 1001).']})
+
+ def test_update_missing_field(self):
+ data = {
+ 'content': 'xxx',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+ serializer = CommentSerializer(self.comment, data=data)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, {'email': ['This field is required.']})
+
+ def test_missing_bool_with_default(self):
+ """Make sure that a boolean value with a 'False' value is not
+ mistaken for not having a default."""
+ data = {
+ 'title': 'Some action item',
+ #No 'done' value.
+ }
+ serializer = ActionItemSerializer(self.actionitem, data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.errors, {})
+
+ def test_cross_field_validation(self):
+
+ class CommentSerializerWithCrossFieldValidator(CommentSerializer):
+
+ def validate(self, attrs):
+ if attrs["email"] not in attrs["content"]:
+ raise serializers.ValidationError("Email address not in content")
+ return attrs
+
+ data = {
+ 'email': 'tom@example.com',
+ 'content': 'A comment from tom@example.com',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+
+ serializer = CommentSerializerWithCrossFieldValidator(data=data)
+ self.assertTrue(serializer.is_valid())
+
+ data['content'] = 'A comment from foo@bar.com'
+
+ serializer = CommentSerializerWithCrossFieldValidator(data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'non_field_errors': ['Email address not in content']})
+
+ def test_null_is_true_fields(self):
+ """
+ Omitting a value for null-field should validate.
+ """
+ serializer = PersonSerializer(data={'name': 'marko'})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.errors, {})
+
+ def test_modelserializer_max_length_exceeded(self):
+ data = {
+ 'title': 'x' * 201,
+ }
+ serializer = ActionItemSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, {'title': ['Ensure this value has at most 200 characters (it has 201).']})
+
+ def test_modelserializer_max_length_exceeded_with_custom_restore(self):
+ """
+ When overriding ModelSerializer.restore_object, validation tests should still apply.
+ Regression test for #623.
+
+ https://github.com/tomchristie/django-rest-framework/pull/623
+ """
+ data = {
+ 'title': 'x' * 201,
+ }
+ serializer = ActionItemSerializerCustomRestore(data=data)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, {'title': ['Ensure this value has at most 200 characters (it has 201).']})
+
+ def test_default_modelfield_max_length_exceeded(self):
+ data = {
+ 'title': 'Testing "info" field...',
+ 'info': 'x' * 13,
+ }
+ serializer = ActionItemSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, {'info': ['Ensure this value has at most 12 characters (it has 13).']})
+
+ def test_datetime_validation_failure(self):
+ """
+ Test DateTimeField validation errors on non-str values.
+ Regression test for #669.
+
+ https://github.com/tomchristie/django-rest-framework/issues/669
+ """
+ data = self.data
+ data['created'] = 0
+
+ serializer = CommentSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), False)
+
+ self.assertIn('created', serializer.errors)
+
+ def test_missing_model_field_exception_msg(self):
+ """
+ Assert that a meaningful exception message is outputted when the model
+ field is missing (e.g. when mistyping ``model``).
+ """
+ class BrokenModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ fields = ['some_field']
+
+ try:
+ BrokenModelSerializer()
+ except AssertionError as e:
+ self.assertEqual(e.args[0], "Serializer class 'BrokenModelSerializer' is missing 'model' Meta option")
+ except:
+ self.fail('Wrong exception type thrown.')
+
+ def test_writable_star_source_on_nested_serializer(self):
+ """
+ Assert that a nested serializer instantiated with source='*' correctly
+ expands the data into the outer serializer.
+ """
+ serializer = ModelSerializerWithNestedSerializer(data={
+ 'name': 'marko',
+ 'nested': {'info': 'hi'}},
+ )
+ self.assertEqual(serializer.is_valid(), True)
+
+ def test_writable_star_source_with_inner_source_fields(self):
+ """
+ Tests that a serializer with source="*" correctly expands the
+ it's fields into the outer serializer even if they have their
+ own 'source' parameters.
+ """
+
+ serializer = ModelSerializerWithNestedSerializerWithRenamedField(data={
+ 'name': 'marko',
+ 'nested': {'renamed_info': 'hi'}},
+ )
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.errors, {})
+
+
+class CustomValidationTests(TestCase):
+ class CommentSerializerWithFieldValidator(CommentSerializer):
+
+ def validate_email(self, attrs, source):
+ attrs[source]
+ return attrs
+
+ def validate_content(self, attrs, source):
+ value = attrs[source]
+ if "test" not in value:
+ raise serializers.ValidationError("Test not in value")
+ return attrs
+
+ def test_field_validation(self):
+ data = {
+ 'email': 'tom@example.com',
+ 'content': 'A test comment',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+
+ serializer = self.CommentSerializerWithFieldValidator(data=data)
+ self.assertTrue(serializer.is_valid())
+
+ data['content'] = 'This should not validate'
+
+ serializer = self.CommentSerializerWithFieldValidator(data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'content': ['Test not in value']})
+
+ def test_missing_data(self):
+ """
+ Make sure that validate_content isn't called if the field is missing
+ """
+ incomplete_data = {
+ 'email': 'tom@example.com',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+ serializer = self.CommentSerializerWithFieldValidator(data=incomplete_data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'content': ['This field is required.']})
+
+ def test_wrong_data(self):
+ """
+ Make sure that validate_content isn't called if the field input is wrong
+ """
+ wrong_data = {
+ 'email': 'not an email',
+ 'content': 'A test comment',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+ serializer = self.CommentSerializerWithFieldValidator(data=wrong_data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'email': ['Enter a valid email address.']})
+
+ def test_partial_update(self):
+ """
+ Make sure that validate_email isn't called when partial=True and email
+ isn't found in data.
+ """
+ initial_data = {
+ 'email': 'tom@example.com',
+ 'content': 'A test comment',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+
+ serializer = self.CommentSerializerWithFieldValidator(data=initial_data)
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.object
+
+ new_content = 'An *updated* test comment'
+ partial_data = {
+ 'content': new_content
+ }
+
+ serializer = self.CommentSerializerWithFieldValidator(instance=instance,
+ data=partial_data,
+ partial=True)
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.object
+ self.assertEqual(instance.content, new_content)
+
+
+class PositiveIntegerAsChoiceTests(TestCase):
+ def test_positive_integer_in_json_is_correctly_parsed(self):
+ data = {'some_integer': 1}
+ serializer = PositiveIntegerAsChoiceSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), True)
+
+
+class ModelValidationTests(TestCase):
+ def test_validate_unique(self):
+ """
+ Just check if serializers.ModelSerializer handles unique checks via .full_clean()
+ """
+ serializer = AlbumsSerializer(data={'title': 'a'})
+ serializer.is_valid()
+ serializer.save()
+ second_serializer = AlbumsSerializer(data={'title': 'a'})
+ self.assertFalse(second_serializer.is_valid())
+ self.assertEqual(second_serializer.errors, {'title': ['Album with this Title already exists.']})
+
+ def test_foreign_key_is_null_with_partial(self):
+ """
+ Test ModelSerializer validation with partial=True
+
+ Specifically test that a null foreign key does not pass validation
+ """
+ album = Album(title='test')
+ album.save()
+
+ class PhotoSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Photo
+
+ photo_serializer = PhotoSerializer(data={'description': 'test', 'album': album.pk})
+ self.assertTrue(photo_serializer.is_valid())
+ photo = photo_serializer.save()
+
+ # Updating only the album (foreign key)
+ photo_serializer = PhotoSerializer(instance=photo, data={'album': ''}, partial=True)
+ self.assertFalse(photo_serializer.is_valid())
+ self.assertTrue('album' in photo_serializer.errors)
+ self.assertEqual(photo_serializer.errors['album'], photo_serializer.error_messages['required'])
+
+ def test_foreign_key_with_partial(self):
+ """
+ Test ModelSerializer validation with partial=True
+
+ Specifically test foreign key validation.
+ """
+
+ album = Album(title='test')
+ album.save()
+
+ class PhotoSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Photo
+
+ photo_serializer = PhotoSerializer(data={'description': 'test', 'album': album.pk})
+ self.assertTrue(photo_serializer.is_valid())
+ photo = photo_serializer.save()
+
+ # Updating only the album (foreign key)
+ photo_serializer = PhotoSerializer(instance=photo, data={'album': album.pk}, partial=True)
+ self.assertTrue(photo_serializer.is_valid())
+ self.assertTrue(photo_serializer.save())
+
+ # Updating only the description
+ photo_serializer = PhotoSerializer(instance=photo,
+ data={'description': 'new'},
+ partial=True)
+
+ self.assertTrue(photo_serializer.is_valid())
+ self.assertTrue(photo_serializer.save())
+
+
+class RegexValidationTest(TestCase):
+ def test_create_failed(self):
+ serializer = BookSerializer(data={'isbn': '1234567890'})
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']})
+
+ serializer = BookSerializer(data={'isbn': '12345678901234'})
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']})
+
+ serializer = BookSerializer(data={'isbn': 'abcdefghijklm'})
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']})
+
+ def test_create_success(self):
+ serializer = BookSerializer(data={'isbn': '1234567890123'})
+ self.assertTrue(serializer.is_valid())
+
+
+class MetadataTests(TestCase):
+ def test_empty(self):
+ serializer = CommentSerializer()
+ expected = {
+ 'email': serializers.CharField,
+ 'content': serializers.CharField,
+ 'created': serializers.DateTimeField
+ }
+ for field_name, field in expected.items():
+ self.assertTrue(isinstance(serializer.data.fields[field_name], field))
+
+
+class ManyToManyTests(TestCase):
+ def setUp(self):
+ class ManyToManySerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ManyToManyModel
+
+ self.serializer_class = ManyToManySerializer
+
+ # An anchor instance to use for the relationship
+ self.anchor = Anchor()
+ self.anchor.save()
+
+ # A model instance with a many to many relationship to the anchor
+ self.instance = ManyToManyModel()
+ self.instance.save()
+ self.instance.rel.add(self.anchor)
+
+ # A serialized representation of the model instance
+ self.data = {'id': 1, 'rel': [self.anchor.id]}
+
+ def test_retrieve(self):
+ """
+ Serialize an instance of a model with a ManyToMany relationship.
+ """
+ serializer = self.serializer_class(instance=self.instance)
+ expected = self.data
+ self.assertEqual(serializer.data, expected)
+
+ def test_create(self):
+ """
+ Create an instance of a model with a ManyToMany relationship.
+ """
+ data = {'rel': [self.anchor.id]}
+ serializer = self.serializer_class(data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEqual(len(ManyToManyModel.objects.all()), 2)
+ self.assertEqual(instance.pk, 2)
+ self.assertEqual(list(instance.rel.all()), [self.anchor])
+
+ def test_update(self):
+ """
+ Update an instance of a model with a ManyToMany relationship.
+ """
+ new_anchor = Anchor()
+ new_anchor.save()
+ data = {'rel': [self.anchor.id, new_anchor.id]}
+ serializer = self.serializer_class(self.instance, data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEqual(len(ManyToManyModel.objects.all()), 1)
+ self.assertEqual(instance.pk, 1)
+ self.assertEqual(list(instance.rel.all()), [self.anchor, new_anchor])
+
+ def test_create_empty_relationship(self):
+ """
+ Create an instance of a model with a ManyToMany relationship,
+ containing no items.
+ """
+ data = {'rel': []}
+ serializer = self.serializer_class(data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEqual(len(ManyToManyModel.objects.all()), 2)
+ self.assertEqual(instance.pk, 2)
+ self.assertEqual(list(instance.rel.all()), [])
+
+ def test_update_empty_relationship(self):
+ """
+ Update an instance of a model with a ManyToMany relationship,
+ containing no items.
+ """
+ new_anchor = Anchor()
+ new_anchor.save()
+ data = {'rel': []}
+ serializer = self.serializer_class(self.instance, data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEqual(len(ManyToManyModel.objects.all()), 1)
+ self.assertEqual(instance.pk, 1)
+ self.assertEqual(list(instance.rel.all()), [])
+
+ def test_create_empty_relationship_flat_data(self):
+ """
+ Create an instance of a model with a ManyToMany relationship,
+ containing no items, using a representation that does not support
+ lists (eg form data).
+ """
+ data = MultiValueDict()
+ data.setlist('rel', [''])
+ serializer = self.serializer_class(data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEqual(len(ManyToManyModel.objects.all()), 2)
+ self.assertEqual(instance.pk, 2)
+ self.assertEqual(list(instance.rel.all()), [])
+
+
+class ReadOnlyManyToManyTests(TestCase):
+ def setUp(self):
+ class ReadOnlyManyToManySerializer(serializers.ModelSerializer):
+ rel = serializers.RelatedField(many=True, read_only=True)
+
+ class Meta:
+ model = ReadOnlyManyToManyModel
+
+ self.serializer_class = ReadOnlyManyToManySerializer
+
+ # An anchor instance to use for the relationship
+ self.anchor = Anchor()
+ self.anchor.save()
+
+ # A model instance with a many to many relationship to the anchor
+ self.instance = ReadOnlyManyToManyModel()
+ self.instance.save()
+ self.instance.rel.add(self.anchor)
+
+ # A serialized representation of the model instance
+ self.data = {'rel': [self.anchor.id], 'id': 1, 'text': 'anchor'}
+
+ def test_update(self):
+ """
+ Attempt to update an instance of a model with a ManyToMany
+ relationship. Not updated due to read_only=True
+ """
+ new_anchor = Anchor()
+ new_anchor.save()
+ data = {'rel': [self.anchor.id, new_anchor.id]}
+ serializer = self.serializer_class(self.instance, data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEqual(len(ReadOnlyManyToManyModel.objects.all()), 1)
+ self.assertEqual(instance.pk, 1)
+ # rel is still as original (1 entry)
+ self.assertEqual(list(instance.rel.all()), [self.anchor])
+
+ def test_update_without_relationship(self):
+ """
+ Attempt to update an instance of a model where many to ManyToMany
+ relationship is not supplied. Not updated due to read_only=True
+ """
+ new_anchor = Anchor()
+ new_anchor.save()
+ data = {}
+ serializer = self.serializer_class(self.instance, data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEqual(len(ReadOnlyManyToManyModel.objects.all()), 1)
+ self.assertEqual(instance.pk, 1)
+ # rel is still as original (1 entry)
+ self.assertEqual(list(instance.rel.all()), [self.anchor])
+
+
+class DefaultValueTests(TestCase):
+ def setUp(self):
+ class DefaultValueSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = DefaultValueModel
+
+ self.serializer_class = DefaultValueSerializer
+ self.objects = DefaultValueModel.objects
+
+ def test_create_using_default(self):
+ data = {}
+ serializer = self.serializer_class(data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEqual(len(self.objects.all()), 1)
+ self.assertEqual(instance.pk, 1)
+ self.assertEqual(instance.text, 'foobar')
+
+ def test_create_overriding_default(self):
+ data = {'text': 'overridden'}
+ serializer = self.serializer_class(data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEqual(len(self.objects.all()), 1)
+ self.assertEqual(instance.pk, 1)
+ self.assertEqual(instance.text, 'overridden')
+
+ def test_partial_update_default(self):
+ """ Regression test for issue #532 """
+ data = {'text': 'overridden'}
+ serializer = self.serializer_class(data=data, partial=True)
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.save()
+
+ data = {'extra': 'extra_value'}
+ serializer = self.serializer_class(instance=instance, data=data, partial=True)
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.save()
+
+ self.assertEqual(instance.extra, 'extra_value')
+ self.assertEqual(instance.text, 'overridden')
+
+
+class CallableDefaultValueTests(TestCase):
+ def setUp(self):
+ class CallableDefaultValueSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = CallableDefaultValueModel
+
+ self.serializer_class = CallableDefaultValueSerializer
+ self.objects = CallableDefaultValueModel.objects
+
+ def test_create_using_default(self):
+ data = {}
+ serializer = self.serializer_class(data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEqual(len(self.objects.all()), 1)
+ self.assertEqual(instance.pk, 1)
+ self.assertEqual(instance.text, 'foobar')
+
+ def test_create_overriding_default(self):
+ data = {'text': 'overridden'}
+ serializer = self.serializer_class(data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEqual(len(self.objects.all()), 1)
+ self.assertEqual(instance.pk, 1)
+ self.assertEqual(instance.text, 'overridden')
+
+
+class ManyRelatedTests(TestCase):
+ def test_reverse_relations(self):
+ post = BlogPost.objects.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I hate this blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ class BlogPostCommentSerializer(serializers.Serializer):
+ text = serializers.CharField()
+
+ class BlogPostSerializer(serializers.Serializer):
+ title = serializers.CharField()
+ comments = BlogPostCommentSerializer(source='blogpostcomment_set')
+
+ serializer = BlogPostSerializer(instance=post)
+ expected = {
+ 'title': 'Test blog post',
+ 'comments': [
+ {'text': 'I hate this blog post'},
+ {'text': 'I love this blog post'}
+ ]
+ }
+
+ self.assertEqual(serializer.data, expected)
+
+ def test_include_reverse_relations(self):
+ post = BlogPost.objects.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I hate this blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ class BlogPostSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlogPost
+ fields = ('id', 'title', 'blogpostcomment_set')
+
+ serializer = BlogPostSerializer(instance=post)
+ expected = {
+ 'id': 1, 'title': 'Test blog post', 'blogpostcomment_set': [1, 2]
+ }
+ self.assertEqual(serializer.data, expected)
+
+ def test_depth_include_reverse_relations(self):
+ post = BlogPost.objects.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I hate this blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ class BlogPostSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlogPost
+ fields = ('id', 'title', 'blogpostcomment_set')
+ depth = 1
+
+ serializer = BlogPostSerializer(instance=post)
+ expected = {
+ 'id': 1, 'title': 'Test blog post',
+ 'blogpostcomment_set': [
+ {'id': 1, 'text': 'I hate this blog post', 'blog_post': 1},
+ {'id': 2, 'text': 'I love this blog post', 'blog_post': 1}
+ ]
+ }
+ self.assertEqual(serializer.data, expected)
+
+ def test_callable_source(self):
+ post = BlogPost.objects.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ class BlogPostCommentSerializer(serializers.Serializer):
+ text = serializers.CharField()
+
+ class BlogPostSerializer(serializers.Serializer):
+ title = serializers.CharField()
+ first_comment = BlogPostCommentSerializer(source='get_first_comment')
+
+ serializer = BlogPostSerializer(post)
+
+ expected = {
+ 'title': 'Test blog post',
+ 'first_comment': {'text': 'I love this blog post'}
+ }
+ self.assertEqual(serializer.data, expected)
+
+
+class RelatedTraversalTest(TestCase):
+ def test_nested_traversal(self):
+ """
+ Source argument should support dotted.source notation.
+ """
+ user = Person.objects.create(name="django")
+ post = BlogPost.objects.create(title="Test blog post", writer=user)
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ class PersonSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Person
+ fields = ("name", "age")
+
+ class BlogPostCommentSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlogPostComment
+ fields = ("text", "post_owner")
+
+ text = serializers.CharField()
+ post_owner = PersonSerializer(source='blog_post.writer')
+
+ class BlogPostSerializer(serializers.Serializer):
+ title = serializers.CharField()
+ comments = BlogPostCommentSerializer(source='blogpostcomment_set')
+
+ serializer = BlogPostSerializer(instance=post)
+
+ expected = {
+ 'title': 'Test blog post',
+ 'comments': [{
+ 'text': 'I love this blog post',
+ 'post_owner': {
+ "name": "django",
+ "age": None
+ }
+ }]
+ }
+
+ self.assertEqual(serializer.data, expected)
+
+ def test_nested_traversal_with_none(self):
+ """
+ If a component of the dotted.source is None, return None for the field.
+ """
+ from tests.models import NullableForeignKeySource
+ instance = NullableForeignKeySource.objects.create(name='Source with null FK')
+
+ class NullableSourceSerializer(serializers.Serializer):
+ target_name = serializers.Field(source='target.name')
+
+ serializer = NullableSourceSerializer(instance=instance)
+
+ expected = {
+ 'target_name': None,
+ }
+
+ self.assertEqual(serializer.data, expected)
+
+
+class SerializerMethodFieldTests(TestCase):
+ def setUp(self):
+
+ class BoopSerializer(serializers.Serializer):
+ beep = serializers.SerializerMethodField('get_beep')
+ boop = serializers.Field()
+ boop_count = serializers.SerializerMethodField('get_boop_count')
+
+ def get_beep(self, obj):
+ return 'hello!'
+
+ def get_boop_count(self, obj):
+ return len(obj.boop)
+
+ self.serializer_class = BoopSerializer
+
+ def test_serializer_method_field(self):
+
+ class MyModel(object):
+ boop = ['a', 'b', 'c']
+
+ source_data = MyModel()
+
+ serializer = self.serializer_class(source_data)
+
+ expected = {
+ 'beep': 'hello!',
+ 'boop': ['a', 'b', 'c'],
+ 'boop_count': 3,
+ }
+
+ self.assertEqual(serializer.data, expected)
+
+
+# Test for issue #324
+class BlankFieldTests(TestCase):
+ def setUp(self):
+
+ class BlankFieldModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlankFieldModel
+
+ class BlankFieldSerializer(serializers.Serializer):
+ title = serializers.CharField(required=False)
+
+ class NotBlankFieldModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BasicModel
+
+ class NotBlankFieldSerializer(serializers.Serializer):
+ title = serializers.CharField()
+
+ self.model_serializer_class = BlankFieldModelSerializer
+ self.serializer_class = BlankFieldSerializer
+ self.not_blank_model_serializer_class = NotBlankFieldModelSerializer
+ self.not_blank_serializer_class = NotBlankFieldSerializer
+ self.data = {'title': ''}
+
+ def test_create_blank_field(self):
+ serializer = self.serializer_class(data=self.data)
+ self.assertEqual(serializer.is_valid(), True)
+
+ def test_create_model_blank_field(self):
+ serializer = self.model_serializer_class(data=self.data)
+ self.assertEqual(serializer.is_valid(), True)
+
+ def test_create_model_null_field(self):
+ serializer = self.model_serializer_class(data={'title': None})
+ self.assertEqual(serializer.is_valid(), True)
+
+ def test_create_not_blank_field(self):
+ """
+ Test to ensure blank data in a field not marked as blank=True
+ is considered invalid in a non-model serializer
+ """
+ serializer = self.not_blank_serializer_class(data=self.data)
+ self.assertEqual(serializer.is_valid(), False)
+
+ def test_create_model_not_blank_field(self):
+ """
+ Test to ensure blank data in a field not marked as blank=True
+ is considered invalid in a model serializer
+ """
+ serializer = self.not_blank_model_serializer_class(data=self.data)
+ self.assertEqual(serializer.is_valid(), False)
+
+ def test_create_model_empty_field(self):
+ serializer = self.model_serializer_class(data={})
+ self.assertEqual(serializer.is_valid(), True)
+
+
+#test for issue #460
+class SerializerPickleTests(TestCase):
+ """
+ Test pickleability of the output of Serializers
+ """
+ def test_pickle_simple_model_serializer_data(self):
+ """
+ Test simple serializer
+ """
+ pickle.dumps(PersonSerializer(Person(name="Methusela", age=969)).data)
+
+ def test_pickle_inner_serializer(self):
+ """
+ Test pickling a serializer whose resulting .data (a SortedDictWithMetadata) will
+ have unpickleable meta data--in order to make sure metadata doesn't get pulled into the pickle.
+ See DictWithMetadata.__getstate__
+ """
+ class InnerPersonSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Person
+ fields = ('name', 'age')
+ pickle.dumps(InnerPersonSerializer(Person(name="Noah", age=950)).data, 0)
+
+ def test_getstate_method_should_not_return_none(self):
+ """
+ Regression test for #645.
+ """
+ data = serializers.DictWithMetadata({1: 1})
+ self.assertEqual(data.__getstate__(), serializers.SortedDict({1: 1}))
+
+ def test_serializer_data_is_pickleable(self):
+ """
+ Another regression test for #645.
+ """
+ data = serializers.SortedDictWithMetadata({1: 1})
+ repr(pickle.loads(pickle.dumps(data, 0)))
+
+
+# test for issue #725
+class SeveralChoicesModel(models.Model):
+ color = models.CharField(
+ max_length=10,
+ choices=[('red', 'Red'), ('green', 'Green'), ('blue', 'Blue')],
+ blank=False
+ )
+ drink = models.CharField(
+ max_length=10,
+ choices=[('beer', 'Beer'), ('wine', 'Wine'), ('cider', 'Cider')],
+ blank=False,
+ default='beer'
+ )
+ os = models.CharField(
+ max_length=10,
+ choices=[('linux', 'Linux'), ('osx', 'OSX'), ('windows', 'Windows')],
+ blank=True
+ )
+ music_genre = models.CharField(
+ max_length=10,
+ choices=[('rock', 'Rock'), ('metal', 'Metal'), ('grunge', 'Grunge')],
+ blank=True,
+ default='metal'
+ )
+
+
+class SerializerChoiceFields(TestCase):
+
+ def setUp(self):
+ super(SerializerChoiceFields, self).setUp()
+
+ class SeveralChoicesSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = SeveralChoicesModel
+ fields = ('color', 'drink', 'os', 'music_genre')
+
+ self.several_choices_serializer = SeveralChoicesSerializer
+
+ def test_choices_blank_false_not_default(self):
+ serializer = self.several_choices_serializer()
+ self.assertEqual(
+ serializer.fields['color'].choices,
+ [('red', 'Red'), ('green', 'Green'), ('blue', 'Blue')]
+ )
+
+ def test_choices_blank_false_with_default(self):
+ serializer = self.several_choices_serializer()
+ self.assertEqual(
+ serializer.fields['drink'].choices,
+ [('beer', 'Beer'), ('wine', 'Wine'), ('cider', 'Cider')]
+ )
+
+ def test_choices_blank_true_not_default(self):
+ serializer = self.several_choices_serializer()
+ self.assertEqual(
+ serializer.fields['os'].choices,
+ BLANK_CHOICE_DASH + [('linux', 'Linux'), ('osx', 'OSX'), ('windows', 'Windows')]
+ )
+
+ def test_choices_blank_true_with_default(self):
+ serializer = self.several_choices_serializer()
+ self.assertEqual(
+ serializer.fields['music_genre'].choices,
+ BLANK_CHOICE_DASH + [('rock', 'Rock'), ('metal', 'Metal'), ('grunge', 'Grunge')]
+ )
+
+
+# Regression tests for #675
+class Ticket(models.Model):
+ assigned = models.ForeignKey(
+ Person, related_name='assigned_tickets')
+ reviewer = models.ForeignKey(
+ Person, blank=True, null=True, related_name='reviewed_tickets')
+
+
+class SerializerRelatedChoicesTest(TestCase):
+
+ def setUp(self):
+ super(SerializerRelatedChoicesTest, self).setUp()
+
+ class RelatedChoicesSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Ticket
+ fields = ('assigned', 'reviewer')
+
+ self.related_fields_serializer = RelatedChoicesSerializer
+
+ def test_empty_queryset_required(self):
+ serializer = self.related_fields_serializer()
+ self.assertEqual(serializer.fields['assigned'].queryset.count(), 0)
+ self.assertEqual(
+ [x for x in serializer.fields['assigned'].widget.choices],
+ []
+ )
+
+ def test_empty_queryset_not_required(self):
+ serializer = self.related_fields_serializer()
+ self.assertEqual(serializer.fields['reviewer'].queryset.count(), 0)
+ self.assertEqual(
+ [x for x in serializer.fields['reviewer'].widget.choices],
+ [('', '---------')]
+ )
+
+ def test_with_some_persons_required(self):
+ Person.objects.create(name="Lionel Messi")
+ Person.objects.create(name="Xavi Hernandez")
+ serializer = self.related_fields_serializer()
+ self.assertEqual(serializer.fields['assigned'].queryset.count(), 2)
+ self.assertEqual(
+ [x for x in serializer.fields['assigned'].widget.choices],
+ [(1, 'Person object - 1'), (2, 'Person object - 2')]
+ )
+
+ def test_with_some_persons_not_required(self):
+ Person.objects.create(name="Lionel Messi")
+ Person.objects.create(name="Xavi Hernandez")
+ serializer = self.related_fields_serializer()
+ self.assertEqual(serializer.fields['reviewer'].queryset.count(), 2)
+ self.assertEqual(
+ [x for x in serializer.fields['reviewer'].widget.choices],
+ [('', '---------'), (1, 'Person object - 1'), (2, 'Person object - 2')]
+ )
+
+
+class DepthTest(TestCase):
+ def test_implicit_nesting(self):
+
+ writer = Person.objects.create(name="django", age=1)
+ post = BlogPost.objects.create(title="Test blog post", writer=writer)
+ comment = BlogPostComment.objects.create(text="Test blog post comment", blog_post=post)
+
+ class BlogPostCommentSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlogPostComment
+ depth = 2
+
+ serializer = BlogPostCommentSerializer(instance=comment)
+ expected = {'id': 1, 'text': 'Test blog post comment', 'blog_post': {'id': 1, 'title': 'Test blog post',
+ 'writer': {'id': 1, 'name': 'django', 'age': 1}}}
+
+ self.assertEqual(serializer.data, expected)
+
+ def test_explicit_nesting(self):
+ writer = Person.objects.create(name="django", age=1)
+ post = BlogPost.objects.create(title="Test blog post", writer=writer)
+ comment = BlogPostComment.objects.create(text="Test blog post comment", blog_post=post)
+
+ class PersonSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Person
+
+ class BlogPostSerializer(serializers.ModelSerializer):
+ writer = PersonSerializer()
+
+ class Meta:
+ model = BlogPost
+
+ class BlogPostCommentSerializer(serializers.ModelSerializer):
+ blog_post = BlogPostSerializer()
+
+ class Meta:
+ model = BlogPostComment
+
+ serializer = BlogPostCommentSerializer(instance=comment)
+ expected = {'id': 1, 'text': 'Test blog post comment', 'blog_post': {'id': 1, 'title': 'Test blog post',
+ 'writer': {'id': 1, 'name': 'django', 'age': 1}}}
+
+ self.assertEqual(serializer.data, expected)
+
+
+class NestedSerializerContextTests(TestCase):
+
+ def test_nested_serializer_context(self):
+ """
+ Regression for #497
+
+ https://github.com/tomchristie/django-rest-framework/issues/497
+ """
+ class PhotoSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Photo
+ fields = ("description", "callable")
+
+ callable = serializers.SerializerMethodField('_callable')
+
+ def _callable(self, instance):
+ if not 'context_item' in self.context:
+ raise RuntimeError("context isn't getting passed into 2nd level nested serializer")
+ return "success"
+
+ class AlbumSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Album
+ fields = ("photo_set", "callable")
+
+ photo_set = PhotoSerializer(source="photo_set")
+ callable = serializers.SerializerMethodField("_callable")
+
+ def _callable(self, instance):
+ if not 'context_item' in self.context:
+ raise RuntimeError("context isn't getting passed into 1st level nested serializer")
+ return "success"
+
+ class AlbumCollection(object):
+ albums = None
+
+ class AlbumCollectionSerializer(serializers.Serializer):
+ albums = AlbumSerializer(source="albums")
+
+ album1 = Album.objects.create(title="album 1")
+ album2 = Album.objects.create(title="album 2")
+ Photo.objects.create(description="Bigfoot", album=album1)
+ Photo.objects.create(description="Unicorn", album=album1)
+ Photo.objects.create(description="Yeti", album=album2)
+ Photo.objects.create(description="Sasquatch", album=album2)
+ album_collection = AlbumCollection()
+ album_collection.albums = [album1, album2]
+
+ # This will raise RuntimeError if context doesn't get passed correctly to the nested Serializers
+ AlbumCollectionSerializer(album_collection, context={'context_item': 'album context'}).data
+
+
+class DeserializeListTestCase(TestCase):
+
+ def setUp(self):
+ self.data = {
+ 'email': 'nobody@nowhere.com',
+ 'content': 'This is some test content',
+ 'created': datetime.datetime(2013, 3, 7),
+ }
+
+ def test_no_errors(self):
+ data = [self.data.copy() for x in range(0, 3)]
+ serializer = CommentSerializer(data=data, many=True)
+ self.assertTrue(serializer.is_valid())
+ self.assertTrue(isinstance(serializer.object, list))
+ self.assertTrue(
+ all((isinstance(item, Comment) for item in serializer.object))
+ )
+
+ def test_errors_return_as_list(self):
+ invalid_item = self.data.copy()
+ invalid_item['email'] = ''
+ data = [self.data.copy(), invalid_item, self.data.copy()]
+
+ serializer = CommentSerializer(data=data, many=True)
+ self.assertFalse(serializer.is_valid())
+ expected = [{}, {'email': ['This field is required.']}, {}]
+ self.assertEqual(serializer.errors, expected)
+
+
+# Test for issue 747
+
+class LazyStringModel(object):
+ def __init__(self, lazystring):
+ self.lazystring = lazystring
+
+
+class LazyStringSerializer(serializers.Serializer):
+ lazystring = serializers.Field()
+
+ def restore_object(self, attrs, instance=None):
+ if instance is not None:
+ instance.lazystring = attrs.get('lazystring', instance.lazystring)
+ return instance
+ return LazyStringModel(**attrs)
+
+
+class LazyStringsTestCase(TestCase):
+ def setUp(self):
+ self.model = LazyStringModel(lazystring=_('lazystring'))
+
+ def test_lazy_strings_are_translated(self):
+ serializer = LazyStringSerializer(self.model)
+ self.assertEqual(type(serializer.data['lazystring']),
+ type('lazystring'))
+
+
+# Test for issue #467
+
+class FieldLabelTest(TestCase):
+ def setUp(self):
+ self.serializer_class = BasicModelSerializer
+
+ def test_label_from_model(self):
+ """
+ Validates that label and help_text are correctly copied from the model class.
+ """
+ serializer = self.serializer_class()
+ text_field = serializer.fields['text']
+
+ self.assertEqual('Text comes here', text_field.label)
+ self.assertEqual('Text description.', text_field.help_text)
+
+ def test_field_ctor(self):
+ """
+ This is check that ctor supports both label and help_text.
+ """
+ self.assertEqual('Label', fields.Field(label='Label', help_text='Help').label)
+ self.assertEqual('Help', fields.CharField(label='Label', help_text='Help').help_text)
+ self.assertEqual('Label', relations.HyperlinkedRelatedField(view_name='fake', label='Label', help_text='Help', many=True).label)
+
+
+# Test for issue #961
+
+class ManyFieldHelpTextTest(TestCase):
+ def test_help_text_no_hold_down_control_msg(self):
+ """
+ Validate that help_text doesn't contain the 'Hold down "Control" ...'
+ message that Django appends to choice fields.
+ """
+ rel_field = fields.Field(help_text=ManyToManyModel._meta.get_field('rel').help_text)
+ self.assertEqual('Some help text.', rel_field.help_text)
+
+
+class AttributeMappingOnAutogeneratedFieldsTests(TestCase):
+
+ def setUp(self):
+ class AMOAFModel(RESTFrameworkModel):
+ char_field = models.CharField(max_length=1024, blank=True)
+ comma_separated_integer_field = models.CommaSeparatedIntegerField(max_length=1024, blank=True)
+ decimal_field = models.DecimalField(max_digits=64, decimal_places=32, blank=True)
+ email_field = models.EmailField(max_length=1024, blank=True)
+ file_field = models.FileField(max_length=1024, blank=True)
+ image_field = models.ImageField(max_length=1024, blank=True)
+ slug_field = models.SlugField(max_length=1024, blank=True)
+ url_field = models.URLField(max_length=1024, blank=True)
+
+ class AMOAFSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = AMOAFModel
+
+ self.serializer_class = AMOAFSerializer
+ self.fields_attributes = {
+ 'char_field': [
+ ('max_length', 1024),
+ ],
+ 'comma_separated_integer_field': [
+ ('max_length', 1024),
+ ],
+ 'decimal_field': [
+ ('max_digits', 64),
+ ('decimal_places', 32),
+ ],
+ 'email_field': [
+ ('max_length', 1024),
+ ],
+ 'file_field': [
+ ('max_length', 1024),
+ ],
+ 'image_field': [
+ ('max_length', 1024),
+ ],
+ 'slug_field': [
+ ('max_length', 1024),
+ ],
+ 'url_field': [
+ ('max_length', 1024),
+ ],
+ }
+
+ def field_test(self, field):
+ serializer = self.serializer_class(data={})
+ self.assertEqual(serializer.is_valid(), True)
+
+ for attribute in self.fields_attributes[field]:
+ self.assertEqual(
+ getattr(serializer.fields[field], attribute[0]),
+ attribute[1]
+ )
+
+ def test_char_field(self):
+ self.field_test('char_field')
+
+ def test_comma_separated_integer_field(self):
+ self.field_test('comma_separated_integer_field')
+
+ def test_decimal_field(self):
+ self.field_test('decimal_field')
+
+ def test_email_field(self):
+ self.field_test('email_field')
+
+ def test_file_field(self):
+ self.field_test('file_field')
+
+ def test_image_field(self):
+ self.field_test('image_field')
+
+ def test_slug_field(self):
+ self.field_test('slug_field')
+
+ def test_url_field(self):
+ self.field_test('url_field')
+
+
+class DefaultValuesOnAutogeneratedFieldsTests(TestCase):
+
+ def setUp(self):
+ class DVOAFModel(RESTFrameworkModel):
+ positive_integer_field = models.PositiveIntegerField(blank=True)
+ positive_small_integer_field = models.PositiveSmallIntegerField(blank=True)
+ email_field = models.EmailField(blank=True)
+ file_field = models.FileField(blank=True)
+ image_field = models.ImageField(blank=True)
+ slug_field = models.SlugField(blank=True)
+ url_field = models.URLField(blank=True)
+
+ class DVOAFSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = DVOAFModel
+
+ self.serializer_class = DVOAFSerializer
+ self.fields_attributes = {
+ 'positive_integer_field': [
+ ('min_value', 0),
+ ],
+ 'positive_small_integer_field': [
+ ('min_value', 0),
+ ],
+ 'email_field': [
+ ('max_length', 75),
+ ],
+ 'file_field': [
+ ('max_length', 100),
+ ],
+ 'image_field': [
+ ('max_length', 100),
+ ],
+ 'slug_field': [
+ ('max_length', 50),
+ ],
+ 'url_field': [
+ ('max_length', 200),
+ ],
+ }
+
+ def field_test(self, field):
+ serializer = self.serializer_class(data={})
+ self.assertEqual(serializer.is_valid(), True)
+
+ for attribute in self.fields_attributes[field]:
+ self.assertEqual(
+ getattr(serializer.fields[field], attribute[0]),
+ attribute[1]
+ )
+
+ def test_positive_integer_field(self):
+ self.field_test('positive_integer_field')
+
+ def test_positive_small_integer_field(self):
+ self.field_test('positive_small_integer_field')
+
+ def test_email_field(self):
+ self.field_test('email_field')
+
+ def test_file_field(self):
+ self.field_test('file_field')
+
+ def test_image_field(self):
+ self.field_test('image_field')
+
+ def test_slug_field(self):
+ self.field_test('slug_field')
+
+ def test_url_field(self):
+ self.field_test('url_field')
+
+
+class MetadataSerializer(serializers.Serializer):
+ field1 = serializers.CharField(3, required=True)
+ field2 = serializers.CharField(10, required=False)
+
+
+class MetadataSerializerTestCase(TestCase):
+ def setUp(self):
+ self.serializer = MetadataSerializer()
+
+ def test_serializer_metadata(self):
+ metadata = self.serializer.metadata()
+ expected = {
+ 'field1': {
+ 'required': True,
+ 'max_length': 3,
+ 'type': 'string',
+ 'read_only': False
+ },
+ 'field2': {
+ 'required': False,
+ 'max_length': 10,
+ 'type': 'string',
+ 'read_only': False
+ }
+ }
+ self.assertEqual(expected, metadata)
+
+
+### Regression test for #840
+
+class SimpleModel(models.Model):
+ text = models.CharField(max_length=100)
+
+
+class SimpleModelSerializer(serializers.ModelSerializer):
+ text = serializers.CharField()
+ other = serializers.CharField()
+
+ class Meta:
+ model = SimpleModel
+
+ def validate_other(self, attrs, source):
+ del attrs['other']
+ return attrs
+
+
+class FieldValidationRemovingAttr(TestCase):
+ def test_removing_non_model_field_in_validation(self):
+ """
+ Removing an attr during field valiation should ensure that it is not
+ passed through when restoring the object.
+
+ This allows additional non-model fields to be supported.
+
+ Regression test for #840.
+ """
+ serializer = SimpleModelSerializer(data={'text': 'foo', 'other': 'bar'})
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEqual(serializer.object.text, 'foo')
+
+
+### Regression test for #878
+
+class SimpleTargetModel(models.Model):
+ text = models.CharField(max_length=100)
+
+
+class SimplePKSourceModelSerializer(serializers.Serializer):
+ targets = serializers.PrimaryKeyRelatedField(queryset=SimpleTargetModel.objects.all(), many=True)
+ text = serializers.CharField()
+
+
+class SimpleSlugSourceModelSerializer(serializers.Serializer):
+ targets = serializers.SlugRelatedField(queryset=SimpleTargetModel.objects.all(), many=True, slug_field='pk')
+ text = serializers.CharField()
+
+
+class SerializerSupportsManyRelationships(TestCase):
+ def setUp(self):
+ SimpleTargetModel.objects.create(text='foo')
+ SimpleTargetModel.objects.create(text='bar')
+
+ def test_serializer_supports_pk_many_relationships(self):
+ """
+ Regression test for #878.
+
+ Note that pk behavior has a different code path to usual cases,
+ for performance reasons.
+ """
+ serializer = SimplePKSourceModelSerializer(data={'text': 'foo', 'targets': [1, 2]})
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, {'text': 'foo', 'targets': [1, 2]})
+
+ def test_serializer_supports_slug_many_relationships(self):
+ """
+ Regression test for #878.
+ """
+ serializer = SimpleSlugSourceModelSerializer(data={'text': 'foo', 'targets': [1, 2]})
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, {'text': 'foo', 'targets': [1, 2]})
+
+
+class TransformMethodsSerializer(serializers.Serializer):
+ a = serializers.CharField()
+ b_renamed = serializers.CharField(source='b')
+
+ def transform_a(self, obj, value):
+ return value.lower()
+
+ def transform_b_renamed(self, obj, value):
+ if value is not None:
+ return 'and ' + value
+
+
+class TestSerializerTransformMethods(TestCase):
+ def setUp(self):
+ self.s = TransformMethodsSerializer()
+
+ def test_transform_methods(self):
+ self.assertEqual(
+ self.s.to_native({'a': 'GREEN EGGS', 'b': 'HAM'}),
+ {
+ 'a': 'green eggs',
+ 'b_renamed': 'and HAM',
+ }
+ )
+
+ def test_missing_fields(self):
+ self.assertEqual(
+ self.s.to_native({'a': 'GREEN EGGS'}),
+ {
+ 'a': 'green eggs',
+ 'b_renamed': None,
+ }
+ )
+
+
+class DefaultTrueBooleanModel(models.Model):
+ cat = models.BooleanField(default=True)
+ dog = models.BooleanField(default=False)
+
+
+class SerializerDefaultTrueBoolean(TestCase):
+
+ def setUp(self):
+ super(SerializerDefaultTrueBoolean, self).setUp()
+
+ class DefaultTrueBooleanSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = DefaultTrueBooleanModel
+ fields = ('cat', 'dog')
+
+ self.default_true_boolean_serializer = DefaultTrueBooleanSerializer
+
+ def test_enabled_as_false(self):
+ serializer = self.default_true_boolean_serializer(data={'cat': False,
+ 'dog': False})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.data['cat'], False)
+ self.assertEqual(serializer.data['dog'], False)
+
+ def test_enabled_as_true(self):
+ serializer = self.default_true_boolean_serializer(data={'cat': True,
+ 'dog': True})
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.data['cat'], True)
+ self.assertEqual(serializer.data['dog'], True)
+
+ def test_enabled_partial(self):
+ serializer = self.default_true_boolean_serializer(data={'cat': False},
+ partial=True)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.data['cat'], False)
+ self.assertEqual(serializer.data['dog'], False)
+
+
+class BoolenFieldTypeTest(TestCase):
+ '''
+ Ensure the various Boolean based model fields are rendered as the proper
+ field type
+
+ '''
+
+ def setUp(self):
+ '''
+ Setup an ActionItemSerializer for BooleanTesting
+ '''
+ data = {
+ 'title': 'b' * 201,
+ }
+ self.serializer = ActionItemSerializer(data=data)
+
+ def test_booleanfield_type(self):
+ '''
+ Test that BooleanField is infered from models.BooleanField
+ '''
+ bfield = self.serializer.get_fields()['done']
+ self.assertEqual(type(bfield), fields.BooleanField)
+
+ def test_nullbooleanfield_type(self):
+ '''
+ Test that BooleanField is infered from models.NullBooleanField
+
+ https://groups.google.com/forum/#!topic/django-rest-framework/D9mXEftpuQ8
+ '''
+ bfield = self.serializer.get_fields()['started']
+ self.assertEqual(type(bfield), fields.BooleanField)
diff --git a/tests/test_serializer_bulk_update.py b/tests/test_serializer_bulk_update.py
new file mode 100644
index 00000000..8b0ded1a
--- /dev/null
+++ b/tests/test_serializer_bulk_update.py
@@ -0,0 +1,278 @@
+"""
+Tests to cover bulk create and update using serializers.
+"""
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework import serializers
+
+
+class BulkCreateSerializerTests(TestCase):
+ """
+ Creating multiple instances using serializers.
+ """
+
+ def setUp(self):
+ class BookSerializer(serializers.Serializer):
+ id = serializers.IntegerField()
+ title = serializers.CharField(max_length=100)
+ author = serializers.CharField(max_length=100)
+
+ self.BookSerializer = BookSerializer
+
+ def test_bulk_create_success(self):
+ """
+ Correct bulk update serialization should return the input data.
+ """
+
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 1,
+ 'title': 'If this is a man',
+ 'author': 'Primo Levi'
+ }, {
+ 'id': 2,
+ 'title': 'The wind-up bird chronicle',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+
+ serializer = self.BookSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, data)
+
+ def test_bulk_create_errors(self):
+ """
+ Correct bulk update serialization should return the input data.
+ """
+
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 1,
+ 'title': 'If this is a man',
+ 'author': 'Primo Levi'
+ }, {
+ 'id': 'foo',
+ 'title': 'The wind-up bird chronicle',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+ expected_errors = [
+ {},
+ {},
+ {'id': ['Enter a whole number.']}
+ ]
+
+ serializer = self.BookSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, expected_errors)
+
+ def test_invalid_list_datatype(self):
+ """
+ Data containing list of incorrect data type should return errors.
+ """
+ data = ['foo', 'bar', 'baz']
+ serializer = self.BookSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+
+ expected_errors = [
+ {'non_field_errors': ['Invalid data']},
+ {'non_field_errors': ['Invalid data']},
+ {'non_field_errors': ['Invalid data']}
+ ]
+
+ self.assertEqual(serializer.errors, expected_errors)
+
+ def test_invalid_single_datatype(self):
+ """
+ Data containing a single incorrect data type should return errors.
+ """
+ data = 123
+ serializer = self.BookSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+
+ expected_errors = {'non_field_errors': ['Expected a list of items.']}
+
+ self.assertEqual(serializer.errors, expected_errors)
+
+ def test_invalid_single_object(self):
+ """
+ Data containing only a single object, instead of a list of objects
+ should return errors.
+ """
+ data = {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }
+ serializer = self.BookSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+
+ expected_errors = {'non_field_errors': ['Expected a list of items.']}
+
+ self.assertEqual(serializer.errors, expected_errors)
+
+
+class BulkUpdateSerializerTests(TestCase):
+ """
+ Updating multiple instances using serializers.
+ """
+
+ def setUp(self):
+ class Book(object):
+ """
+ A data type that can be persisted to a mock storage backend
+ with `.save()` and `.delete()`.
+ """
+ object_map = {}
+
+ def __init__(self, id, title, author):
+ self.id = id
+ self.title = title
+ self.author = author
+
+ def save(self):
+ Book.object_map[self.id] = self
+
+ def delete(self):
+ del Book.object_map[self.id]
+
+ class BookSerializer(serializers.Serializer):
+ id = serializers.IntegerField()
+ title = serializers.CharField(max_length=100)
+ author = serializers.CharField(max_length=100)
+
+ def restore_object(self, attrs, instance=None):
+ if instance:
+ instance.id = attrs['id']
+ instance.title = attrs['title']
+ instance.author = attrs['author']
+ return instance
+ return Book(**attrs)
+
+ self.Book = Book
+ self.BookSerializer = BookSerializer
+
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 1,
+ 'title': 'If this is a man',
+ 'author': 'Primo Levi'
+ }, {
+ 'id': 2,
+ 'title': 'The wind-up bird chronicle',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+
+ for item in data:
+ book = Book(item['id'], item['title'], item['author'])
+ book.save()
+
+ def books(self):
+ """
+ Return all the objects in the mock storage backend.
+ """
+ return self.Book.object_map.values()
+
+ def test_bulk_update_success(self):
+ """
+ Correct bulk update serialization should return the input data.
+ """
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 2,
+ 'title': 'Kafka on the shore',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+ serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.data, data)
+ serializer.save()
+ new_data = self.BookSerializer(self.books(), many=True).data
+
+ self.assertEqual(data, new_data)
+
+ def test_bulk_update_and_create(self):
+ """
+ Bulk update serialization may also include created items.
+ """
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 3,
+ 'title': 'Kafka on the shore',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+ serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.data, data)
+ serializer.save()
+ new_data = self.BookSerializer(self.books(), many=True).data
+ self.assertEqual(data, new_data)
+
+ def test_bulk_update_invalid_create(self):
+ """
+ Bulk update serialization without allow_add_remove may not create items.
+ """
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 3,
+ 'title': 'Kafka on the shore',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+ expected_errors = [
+ {},
+ {'non_field_errors': ['Cannot create a new item, only existing items may be updated.']}
+ ]
+ serializer = self.BookSerializer(self.books(), data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, expected_errors)
+
+ def test_bulk_update_error(self):
+ """
+ Incorrect bulk update serialization should return error data.
+ """
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 'foo',
+ 'title': 'Kafka on the shore',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+ expected_errors = [
+ {},
+ {'id': ['Enter a whole number.']}
+ ]
+ serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, expected_errors)
diff --git a/tests/test_serializer_empty.py b/tests/test_serializer_empty.py
new file mode 100644
index 00000000..30cff361
--- /dev/null
+++ b/tests/test_serializer_empty.py
@@ -0,0 +1,15 @@
+from django.test import TestCase
+from rest_framework import serializers
+
+
+class EmptySerializerTestCase(TestCase):
+ def test_empty_serializer(self):
+ class FooBarSerializer(serializers.Serializer):
+ foo = serializers.IntegerField()
+ bar = serializers.SerializerMethodField('get_bar')
+
+ def get_bar(self, obj):
+ return 'bar'
+
+ serializer = FooBarSerializer()
+ self.assertEquals(serializer.data, {'foo': 0})
diff --git a/tests/test_serializer_import.py b/tests/test_serializer_import.py
new file mode 100644
index 00000000..3b8ff4b3
--- /dev/null
+++ b/tests/test_serializer_import.py
@@ -0,0 +1,19 @@
+from django.test import TestCase
+
+from rest_framework import serializers
+from tests.accounts.serializers import AccountSerializer
+
+
+class ImportingModelSerializerTests(TestCase):
+ """
+ In some situations like, GH #1225, it is possible, especially in
+ testing, to import a serializer who's related models have not yet
+ been resolved by Django. `AccountSerializer` is an example of such
+ a serializer (imported at the top of this file).
+ """
+ def test_import_model_serializer(self):
+ """
+ The serializer at the top of this file should have been
+ imported successfully, and we should be able to instantiate it.
+ """
+ self.assertIsInstance(AccountSerializer(), serializers.ModelSerializer)
diff --git a/tests/test_serializer_nested.py b/tests/test_serializer_nested.py
new file mode 100644
index 00000000..6d69ffbd
--- /dev/null
+++ b/tests/test_serializer_nested.py
@@ -0,0 +1,347 @@
+"""
+Tests to cover nested serializers.
+
+Doesn't cover model serializers.
+"""
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework import serializers
+from . import models
+
+
+class WritableNestedSerializerBasicTests(TestCase):
+ """
+ Tests for deserializing nested entities.
+ Basic tests that use serializers that simply restore to dicts.
+ """
+
+ def setUp(self):
+ class TrackSerializer(serializers.Serializer):
+ order = serializers.IntegerField()
+ title = serializers.CharField(max_length=100)
+ duration = serializers.IntegerField()
+
+ class AlbumSerializer(serializers.Serializer):
+ album_name = serializers.CharField(max_length=100)
+ artist = serializers.CharField(max_length=100)
+ tracks = TrackSerializer(many=True)
+
+ self.AlbumSerializer = AlbumSerializer
+
+ def test_nested_validation_success(self):
+ """
+ Correct nested serialization should return the input data.
+ """
+
+ data = {
+ 'album_name': 'Discovery',
+ 'artist': 'Daft Punk',
+ 'tracks': [
+ {'order': 1, 'title': 'One More Time', 'duration': 235},
+ {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
+ {'order': 3, 'title': 'Digital Love', 'duration': 239}
+ ]
+ }
+
+ serializer = self.AlbumSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, data)
+
+ def test_nested_validation_error(self):
+ """
+ Incorrect nested serialization should return appropriate error data.
+ """
+
+ data = {
+ 'album_name': 'Discovery',
+ 'artist': 'Daft Punk',
+ 'tracks': [
+ {'order': 1, 'title': 'One More Time', 'duration': 235},
+ {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
+ {'order': 3, 'title': 'Digital Love', 'duration': 'foobar'}
+ ]
+ }
+ expected_errors = {
+ 'tracks': [
+ {},
+ {},
+ {'duration': ['Enter a whole number.']}
+ ]
+ }
+
+ serializer = self.AlbumSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, expected_errors)
+
+ def test_many_nested_validation_error(self):
+ """
+ Incorrect nested serialization should return appropriate error data
+ when multiple entities are being deserialized.
+ """
+
+ data = [
+ {
+ 'album_name': 'Russian Red',
+ 'artist': 'I Love Your Glasses',
+ 'tracks': [
+ {'order': 1, 'title': 'Cigarettes', 'duration': 121},
+ {'order': 2, 'title': 'No Past Land', 'duration': 198},
+ {'order': 3, 'title': 'They Don\'t Believe', 'duration': 191}
+ ]
+ },
+ {
+ 'album_name': 'Discovery',
+ 'artist': 'Daft Punk',
+ 'tracks': [
+ {'order': 1, 'title': 'One More Time', 'duration': 235},
+ {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
+ {'order': 3, 'title': 'Digital Love', 'duration': 'foobar'}
+ ]
+ }
+ ]
+ expected_errors = [
+ {},
+ {
+ 'tracks': [
+ {},
+ {},
+ {'duration': ['Enter a whole number.']}
+ ]
+ }
+ ]
+
+ serializer = self.AlbumSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, expected_errors)
+
+
+class WritableNestedSerializerObjectTests(TestCase):
+ """
+ Tests for deserializing nested entities.
+ These tests use serializers that restore to concrete objects.
+ """
+
+ def setUp(self):
+ # Couple of concrete objects that we're going to deserialize into
+ class Track(object):
+ def __init__(self, order, title, duration):
+ self.order, self.title, self.duration = order, title, duration
+
+ def __eq__(self, other):
+ return (
+ self.order == other.order and
+ self.title == other.title and
+ self.duration == other.duration
+ )
+
+ class Album(object):
+ def __init__(self, album_name, artist, tracks):
+ self.album_name, self.artist, self.tracks = album_name, artist, tracks
+
+ def __eq__(self, other):
+ return (
+ self.album_name == other.album_name and
+ self.artist == other.artist and
+ self.tracks == other.tracks
+ )
+
+ # And their corresponding serializers
+ class TrackSerializer(serializers.Serializer):
+ order = serializers.IntegerField()
+ title = serializers.CharField(max_length=100)
+ duration = serializers.IntegerField()
+
+ def restore_object(self, attrs, instance=None):
+ return Track(attrs['order'], attrs['title'], attrs['duration'])
+
+ class AlbumSerializer(serializers.Serializer):
+ album_name = serializers.CharField(max_length=100)
+ artist = serializers.CharField(max_length=100)
+ tracks = TrackSerializer(many=True)
+
+ def restore_object(self, attrs, instance=None):
+ return Album(attrs['album_name'], attrs['artist'], attrs['tracks'])
+
+ self.Album, self.Track = Album, Track
+ self.AlbumSerializer = AlbumSerializer
+
+ def test_nested_validation_success(self):
+ """
+ Correct nested serialization should return a restored object
+ that corresponds to the input data.
+ """
+
+ data = {
+ 'album_name': 'Discovery',
+ 'artist': 'Daft Punk',
+ 'tracks': [
+ {'order': 1, 'title': 'One More Time', 'duration': 235},
+ {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
+ {'order': 3, 'title': 'Digital Love', 'duration': 239}
+ ]
+ }
+ expected_object = self.Album(
+ album_name='Discovery',
+ artist='Daft Punk',
+ tracks=[
+ self.Track(order=1, title='One More Time', duration=235),
+ self.Track(order=2, title='Aerodynamic', duration=184),
+ self.Track(order=3, title='Digital Love', duration=239),
+ ]
+ )
+
+ serializer = self.AlbumSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, expected_object)
+
+ def test_many_nested_validation_success(self):
+ """
+ Correct nested serialization should return multiple restored objects
+ that corresponds to the input data when multiple objects are
+ being deserialized.
+ """
+
+ data = [
+ {
+ 'album_name': 'Russian Red',
+ 'artist': 'I Love Your Glasses',
+ 'tracks': [
+ {'order': 1, 'title': 'Cigarettes', 'duration': 121},
+ {'order': 2, 'title': 'No Past Land', 'duration': 198},
+ {'order': 3, 'title': 'They Don\'t Believe', 'duration': 191}
+ ]
+ },
+ {
+ 'album_name': 'Discovery',
+ 'artist': 'Daft Punk',
+ 'tracks': [
+ {'order': 1, 'title': 'One More Time', 'duration': 235},
+ {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
+ {'order': 3, 'title': 'Digital Love', 'duration': 239}
+ ]
+ }
+ ]
+ expected_object = [
+ self.Album(
+ album_name='Russian Red',
+ artist='I Love Your Glasses',
+ tracks=[
+ self.Track(order=1, title='Cigarettes', duration=121),
+ self.Track(order=2, title='No Past Land', duration=198),
+ self.Track(order=3, title='They Don\'t Believe', duration=191),
+ ]
+ ),
+ self.Album(
+ album_name='Discovery',
+ artist='Daft Punk',
+ tracks=[
+ self.Track(order=1, title='One More Time', duration=235),
+ self.Track(order=2, title='Aerodynamic', duration=184),
+ self.Track(order=3, title='Digital Love', duration=239),
+ ]
+ )
+ ]
+
+ serializer = self.AlbumSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, expected_object)
+
+
+class ForeignKeyNestedSerializerUpdateTests(TestCase):
+ def setUp(self):
+ class Artist(object):
+ def __init__(self, name):
+ self.name = name
+
+ def __eq__(self, other):
+ return self.name == other.name
+
+ class Album(object):
+ def __init__(self, name, artist):
+ self.name, self.artist = name, artist
+
+ def __eq__(self, other):
+ return self.name == other.name and self.artist == other.artist
+
+ class ArtistSerializer(serializers.Serializer):
+ name = serializers.CharField()
+
+ def restore_object(self, attrs, instance=None):
+ if instance:
+ instance.name = attrs['name']
+ else:
+ instance = Artist(attrs['name'])
+ return instance
+
+ class AlbumSerializer(serializers.Serializer):
+ name = serializers.CharField()
+ by = ArtistSerializer(source='artist')
+
+ def restore_object(self, attrs, instance=None):
+ if instance:
+ instance.name = attrs['name']
+ instance.artist = attrs['artist']
+ else:
+ instance = Album(attrs['name'], attrs['artist'])
+ return instance
+
+ self.Artist = Artist
+ self.Album = Album
+ self.AlbumSerializer = AlbumSerializer
+
+ def test_create_via_foreign_key_with_source(self):
+ """
+ Check that we can both *create* and *update* into objects across
+ ForeignKeys that have a `source` specified.
+ Regression test for #1170
+ """
+ data = {
+ 'name': 'Discovery',
+ 'by': {'name': 'Daft Punk'},
+ }
+
+ expected = self.Album(artist=self.Artist('Daft Punk'), name='Discovery')
+
+ # create
+ serializer = self.AlbumSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, expected)
+
+ # update
+ original = self.Album(artist=self.Artist('The Bats'), name='Free All the Monsters')
+ serializer = self.AlbumSerializer(instance=original, data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, expected)
+
+
+class NestedModelSerializerUpdateTests(TestCase):
+ def test_second_nested_level(self):
+ john = models.Person.objects.create(name="john")
+
+ post = john.blogpost_set.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I hate this blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ class BlogPostCommentSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = models.BlogPostComment
+
+ class BlogPostSerializer(serializers.ModelSerializer):
+ comments = BlogPostCommentSerializer(many=True, source='blogpostcomment_set')
+ class Meta:
+ model = models.BlogPost
+ fields = ('id', 'title', 'comments')
+
+ class PersonSerializer(serializers.ModelSerializer):
+ posts = BlogPostSerializer(many=True, source='blogpost_set')
+ class Meta:
+ model = models.Person
+ fields = ('id', 'name', 'age', 'posts')
+
+ serialize = PersonSerializer(instance=john)
+ deserialize = PersonSerializer(data=serialize.data, instance=john)
+ self.assertTrue(deserialize.is_valid())
+
+ result = deserialize.object
+ result.save()
+ self.assertEqual(result.id, john.id)
diff --git a/tests/test_serializers.py b/tests/test_serializers.py
new file mode 100644
index 00000000..67547783
--- /dev/null
+++ b/tests/test_serializers.py
@@ -0,0 +1,28 @@
+from django.db import models
+from django.test import TestCase
+
+from rest_framework.serializers import _resolve_model
+from tests.models import BasicModel
+
+
+class ResolveModelTests(TestCase):
+ """
+ `_resolve_model` should return a Django model class given the
+ provided argument is a Django model class itself, or a properly
+ formatted string representation of one.
+ """
+ def test_resolve_django_model(self):
+ resolved_model = _resolve_model(BasicModel)
+ self.assertEqual(resolved_model, BasicModel)
+
+ def test_resolve_string_representation(self):
+ resolved_model = _resolve_model('tests.BasicModel')
+ self.assertEqual(resolved_model, BasicModel)
+
+ def test_resolve_non_django_model(self):
+ with self.assertRaises(ValueError):
+ _resolve_model(TestCase)
+
+ def test_resolve_improper_string_representation(self):
+ with self.assertRaises(ValueError):
+ _resolve_model('BasicModel')
diff --git a/tests/test_settings.py b/tests/test_settings.py
new file mode 100644
index 00000000..e29fc34a
--- /dev/null
+++ b/tests/test_settings.py
@@ -0,0 +1,22 @@
+"""Tests for the settings module"""
+from __future__ import unicode_literals
+from django.test import TestCase
+
+from rest_framework.settings import APISettings, DEFAULTS, IMPORT_STRINGS
+
+
+class TestSettings(TestCase):
+ """Tests relating to the api settings"""
+
+ def test_non_import_errors(self):
+ """Make sure other errors aren't suppressed."""
+ settings = APISettings({'DEFAULT_MODEL_SERIALIZER_CLASS': 'tests.extras.bad_import.ModelSerializer'}, DEFAULTS, IMPORT_STRINGS)
+ with self.assertRaises(ValueError):
+ settings.DEFAULT_MODEL_SERIALIZER_CLASS
+
+ def test_import_error_message_maintained(self):
+ """Make sure real import errors are captured and raised sensibly."""
+ settings = APISettings({'DEFAULT_MODEL_SERIALIZER_CLASS': 'tests.extras.not_here.ModelSerializer'}, DEFAULTS, IMPORT_STRINGS)
+ with self.assertRaises(ImportError) as cm:
+ settings.DEFAULT_MODEL_SERIALIZER_CLASS
+ self.assertTrue('ImportError' in str(cm.exception))
diff --git a/tests/test_status.py b/tests/test_status.py
new file mode 100644
index 00000000..7b1bdae3
--- /dev/null
+++ b/tests/test_status.py
@@ -0,0 +1,33 @@
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework.status import (
+ is_informational, is_success, is_redirect, is_client_error, is_server_error
+)
+
+
+class TestStatus(TestCase):
+ def test_status_categories(self):
+ self.assertFalse(is_informational(99))
+ self.assertTrue(is_informational(100))
+ self.assertTrue(is_informational(199))
+ self.assertFalse(is_informational(200))
+
+ self.assertFalse(is_success(199))
+ self.assertTrue(is_success(200))
+ self.assertTrue(is_success(299))
+ self.assertFalse(is_success(300))
+
+ self.assertFalse(is_redirect(299))
+ self.assertTrue(is_redirect(300))
+ self.assertTrue(is_redirect(399))
+ self.assertFalse(is_redirect(400))
+
+ self.assertFalse(is_client_error(399))
+ self.assertTrue(is_client_error(400))
+ self.assertTrue(is_client_error(499))
+ self.assertFalse(is_client_error(500))
+
+ self.assertFalse(is_server_error(499))
+ self.assertTrue(is_server_error(500))
+ self.assertTrue(is_server_error(599))
+ self.assertFalse(is_server_error(600))
\ No newline at end of file
diff --git a/tests/test_templatetags.py b/tests/test_templatetags.py
new file mode 100644
index 00000000..d4da0c23
--- /dev/null
+++ b/tests/test_templatetags.py
@@ -0,0 +1,51 @@
+# encoding: utf-8
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework.test import APIRequestFactory
+from rest_framework.templatetags.rest_framework import add_query_param, urlize_quoted_links
+
+factory = APIRequestFactory()
+
+
+class TemplateTagTests(TestCase):
+
+ def test_add_query_param_with_non_latin_charactor(self):
+ # Ensure we don't double-escape non-latin characters
+ # that are present in the querystring.
+ # See #1314.
+ request = factory.get("/", {'q': '查询'})
+ json_url = add_query_param(request, "format", "json")
+ self.assertIn("q=%E6%9F%A5%E8%AF%A2", json_url)
+ self.assertIn("format=json", json_url)
+
+
+class Issue1386Tests(TestCase):
+ """
+ Covers #1386
+ """
+
+ def test_issue_1386(self):
+ """
+ Test function urlize_quoted_links with different args
+ """
+ correct_urls = [
+ "asdf.com",
+ "asdf.net",
+ "www.as_df.org",
+ "as.d8f.ghj8.gov",
+ ]
+ for i in correct_urls:
+ res = urlize_quoted_links(i)
+ self.assertNotEqual(res, i)
+ self.assertIn(i, res)
+
+ incorrect_urls = [
+ "mailto://asdf@fdf.com",
+ "asdf.netnet",
+ ]
+ for i in incorrect_urls:
+ res = urlize_quoted_links(i)
+ self.assertEqual(i, res)
+
+ # example from issue #1386, this shouldn't raise an exception
+ _ = urlize_quoted_links("asdf:[/p]zxcv.com")
diff --git a/tests/test_testing.py b/tests/test_testing.py
new file mode 100644
index 00000000..8c6086a2
--- /dev/null
+++ b/tests/test_testing.py
@@ -0,0 +1,154 @@
+# -- coding: utf-8 --
+
+from __future__ import unicode_literals
+from io import BytesIO
+
+from django.contrib.auth.models import User
+from django.test import TestCase
+from rest_framework.compat import patterns, url
+from rest_framework.decorators import api_view
+from rest_framework.response import Response
+from rest_framework.test import APIClient, APIRequestFactory, force_authenticate
+
+
+@api_view(['GET', 'POST'])
+def view(request):
+ return Response({
+ 'auth': request.META.get('HTTP_AUTHORIZATION', b''),
+ 'user': request.user.username
+ })
+
+
+@api_view(['GET', 'POST'])
+def session_view(request):
+ active_session = request.session.get('active_session', False)
+ request.session['active_session'] = True
+ return Response({
+ 'active_session': active_session
+ })
+
+
+urlpatterns = patterns('',
+ url(r'^view/$', view),
+ url(r'^session-view/$', session_view),
+)
+
+
+class TestAPITestClient(TestCase):
+ urls = 'tests.test_testing'
+
+ def setUp(self):
+ self.client = APIClient()
+
+ def test_credentials(self):
+ """
+ Setting `.credentials()` adds the required headers to each request.
+ """
+ self.client.credentials(HTTP_AUTHORIZATION='example')
+ for _ in range(0, 3):
+ response = self.client.get('/view/')
+ self.assertEqual(response.data['auth'], 'example')
+
+ def test_force_authenticate(self):
+ """
+ Setting `.force_authenticate()` forcibly authenticates each request.
+ """
+ user = User.objects.create_user('example', 'example@example.com')
+ self.client.force_authenticate(user)
+ response = self.client.get('/view/')
+ self.assertEqual(response.data['user'], 'example')
+
+ def test_force_authenticate_with_sessions(self):
+ """
+ Setting `.force_authenticate()` forcibly authenticates each request.
+ """
+ user = User.objects.create_user('example', 'example@example.com')
+ self.client.force_authenticate(user)
+
+ # First request does not yet have an active session
+ response = self.client.get('/session-view/')
+ self.assertEqual(response.data['active_session'], False)
+
+ # Subsequant requests have an active session
+ response = self.client.get('/session-view/')
+ self.assertEqual(response.data['active_session'], True)
+
+ # Force authenticating as `None` should also logout the user session.
+ self.client.force_authenticate(None)
+ response = self.client.get('/session-view/')
+ self.assertEqual(response.data['active_session'], False)
+
+ def test_csrf_exempt_by_default(self):
+ """
+ By default, the test client is CSRF exempt.
+ """
+ User.objects.create_user('example', 'example@example.com', 'password')
+ self.client.login(username='example', password='password')
+ response = self.client.post('/view/')
+ self.assertEqual(response.status_code, 200)
+
+ def test_explicitly_enforce_csrf_checks(self):
+ """
+ The test client can enforce CSRF checks.
+ """
+ client = APIClient(enforce_csrf_checks=True)
+ User.objects.create_user('example', 'example@example.com', 'password')
+ client.login(username='example', password='password')
+ response = client.post('/view/')
+ expected = {'detail': 'CSRF Failed: CSRF cookie not set.'}
+ self.assertEqual(response.status_code, 403)
+ self.assertEqual(response.data, expected)
+
+
+class TestAPIRequestFactory(TestCase):
+ def test_csrf_exempt_by_default(self):
+ """
+ By default, the test client is CSRF exempt.
+ """
+ user = User.objects.create_user('example', 'example@example.com', 'password')
+ factory = APIRequestFactory()
+ request = factory.post('/view/')
+ request.user = user
+ response = view(request)
+ self.assertEqual(response.status_code, 200)
+
+ def test_explicitly_enforce_csrf_checks(self):
+ """
+ The test client can enforce CSRF checks.
+ """
+ user = User.objects.create_user('example', 'example@example.com', 'password')
+ factory = APIRequestFactory(enforce_csrf_checks=True)
+ request = factory.post('/view/')
+ request.user = user
+ response = view(request)
+ expected = {'detail': 'CSRF Failed: CSRF cookie not set.'}
+ self.assertEqual(response.status_code, 403)
+ self.assertEqual(response.data, expected)
+
+ def test_invalid_format(self):
+ """
+ Attempting to use a format that is not configured will raise an
+ assertion error.
+ """
+ factory = APIRequestFactory()
+ self.assertRaises(AssertionError, factory.post,
+ path='/view/', data={'example': 1}, format='xml'
+ )
+
+ def test_force_authenticate(self):
+ """
+ Setting `force_authenticate()` forcibly authenticates the request.
+ """
+ user = User.objects.create_user('example', 'example@example.com')
+ factory = APIRequestFactory()
+ request = factory.get('/view')
+ force_authenticate(request, user=user)
+ response = view(request)
+ self.assertEqual(response.data['user'], 'example')
+
+ def test_upload_file(self):
+ # This is a 1x1 black png
+ simple_png = BytesIO(b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x06\x00\x00\x00\x1f\x15\xc4\x89\x00\x00\x00\rIDATx\x9cc````\x00\x00\x00\x05\x00\x01\xa5\xf6E@\x00\x00\x00\x00IEND\xaeB`\x82')
+ simple_png.name = 'test.png'
+ factory = APIRequestFactory()
+ factory.post('/', data={'image': simple_png})
diff --git a/tests/test_throttling.py b/tests/test_throttling.py
new file mode 100644
index 00000000..41bff692
--- /dev/null
+++ b/tests/test_throttling.py
@@ -0,0 +1,277 @@
+"""
+Tests for the throttling implementations in the permissions module.
+"""
+from __future__ import unicode_literals
+from django.test import TestCase
+from django.contrib.auth.models import User
+from django.core.cache import cache
+from rest_framework.test import APIRequestFactory
+from rest_framework.views import APIView
+from rest_framework.throttling import BaseThrottle, UserRateThrottle, ScopedRateThrottle
+from rest_framework.response import Response
+
+
+class User3SecRateThrottle(UserRateThrottle):
+ rate = '3/sec'
+ scope = 'seconds'
+
+
+class User3MinRateThrottle(UserRateThrottle):
+ rate = '3/min'
+ scope = 'minutes'
+
+
+class NonTimeThrottle(BaseThrottle):
+ def allow_request(self, request, view):
+ if not hasattr(self.__class__, 'called'):
+ self.__class__.called = True
+ return True
+ return False
+
+
+class MockView(APIView):
+ throttle_classes = (User3SecRateThrottle,)
+
+ def get(self, request):
+ return Response('foo')
+
+
+class MockView_MinuteThrottling(APIView):
+ throttle_classes = (User3MinRateThrottle,)
+
+ def get(self, request):
+ return Response('foo')
+
+
+class MockView_NonTimeThrottling(APIView):
+ throttle_classes = (NonTimeThrottle,)
+
+ def get(self, request):
+ return Response('foo')
+
+
+class ThrottlingTests(TestCase):
+ def setUp(self):
+ """
+ Reset the cache so that no throttles will be active
+ """
+ cache.clear()
+ self.factory = APIRequestFactory()
+
+ def test_requests_are_throttled(self):
+ """
+ Ensure request rate is limited
+ """
+ request = self.factory.get('/')
+ for dummy in range(4):
+ response = MockView.as_view()(request)
+ self.assertEqual(429, response.status_code)
+
+ def set_throttle_timer(self, view, value):
+ """
+ Explicitly set the timer, overriding time.time()
+ """
+ view.throttle_classes[0].timer = lambda self: value
+
+ def test_request_throttling_expires(self):
+ """
+ Ensure request rate is limited for a limited duration only
+ """
+ self.set_throttle_timer(MockView, 0)
+
+ request = self.factory.get('/')
+ for dummy in range(4):
+ response = MockView.as_view()(request)
+ self.assertEqual(429, response.status_code)
+
+ # Advance the timer by one second
+ self.set_throttle_timer(MockView, 1)
+
+ response = MockView.as_view()(request)
+ self.assertEqual(200, response.status_code)
+
+ def ensure_is_throttled(self, view, expect):
+ request = self.factory.get('/')
+ request.user = User.objects.create(username='a')
+ for dummy in range(3):
+ view.as_view()(request)
+ request.user = User.objects.create(username='b')
+ response = view.as_view()(request)
+ self.assertEqual(expect, response.status_code)
+
+ def test_request_throttling_is_per_user(self):
+ """
+ Ensure request rate is only limited per user, not globally for
+ PerUserThrottles
+ """
+ self.ensure_is_throttled(MockView, 200)
+
+ def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
+ """
+ Ensure the response returns an X-Throttle field with status and next attributes
+ set properly.
+ """
+ request = self.factory.get('/')
+ for timer, expect in expected_headers:
+ self.set_throttle_timer(view, timer)
+ response = view.as_view()(request)
+ if expect is not None:
+ self.assertEqual(response['X-Throttle-Wait-Seconds'], expect)
+ else:
+ self.assertFalse('X-Throttle-Wait-Seconds' in response)
+
+ def test_seconds_fields(self):
+ """
+ Ensure for second based throttles.
+ """
+ self.ensure_response_header_contains_proper_throttle_field(MockView,
+ ((0, None),
+ (0, None),
+ (0, None),
+ (0, '1')
+ ))
+
+ def test_minutes_fields(self):
+ """
+ Ensure for minute based throttles.
+ """
+ self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
+ ((0, None),
+ (0, None),
+ (0, None),
+ (0, '60')
+ ))
+
+ def test_next_rate_remains_constant_if_followed(self):
+ """
+ If a client follows the recommended next request rate,
+ the throttling rate should stay constant.
+ """
+ self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
+ ((0, None),
+ (20, None),
+ (40, None),
+ (60, None),
+ (80, None)
+ ))
+
+ def test_non_time_throttle(self):
+ """
+ Ensure for second based throttles.
+ """
+ request = self.factory.get('/')
+
+ self.assertFalse(hasattr(MockView_NonTimeThrottling.throttle_classes[0], 'called'))
+
+ response = MockView_NonTimeThrottling.as_view()(request)
+ self.assertFalse('X-Throttle-Wait-Seconds' in response)
+
+ self.assertTrue(MockView_NonTimeThrottling.throttle_classes[0].called)
+
+ response = MockView_NonTimeThrottling.as_view()(request)
+ self.assertFalse('X-Throttle-Wait-Seconds' in response)
+
+
+class ScopedRateThrottleTests(TestCase):
+ """
+ Tests for ScopedRateThrottle.
+ """
+
+ def setUp(self):
+ class XYScopedRateThrottle(ScopedRateThrottle):
+ TIMER_SECONDS = 0
+ THROTTLE_RATES = {'x': '3/min', 'y': '1/min'}
+ timer = lambda self: self.TIMER_SECONDS
+
+ class XView(APIView):
+ throttle_classes = (XYScopedRateThrottle,)
+ throttle_scope = 'x'
+
+ def get(self, request):
+ return Response('x')
+
+ class YView(APIView):
+ throttle_classes = (XYScopedRateThrottle,)
+ throttle_scope = 'y'
+
+ def get(self, request):
+ return Response('y')
+
+ class UnscopedView(APIView):
+ throttle_classes = (XYScopedRateThrottle,)
+
+ def get(self, request):
+ return Response('y')
+
+ self.throttle_class = XYScopedRateThrottle
+ self.factory = APIRequestFactory()
+ self.x_view = XView.as_view()
+ self.y_view = YView.as_view()
+ self.unscoped_view = UnscopedView.as_view()
+
+ def increment_timer(self, seconds=1):
+ self.throttle_class.TIMER_SECONDS += seconds
+
+ def test_scoped_rate_throttle(self):
+ request = self.factory.get('/')
+
+ # Should be able to hit x view 3 times per minute.
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(429, response.status_code)
+
+ # Should be able to hit y view 1 time per minute.
+ self.increment_timer()
+ response = self.y_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.y_view(request)
+ self.assertEqual(429, response.status_code)
+
+ # Ensure throttles properly reset by advancing the rest of the minute
+ self.increment_timer(55)
+
+ # Should still be able to hit x view 3 times per minute.
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(429, response.status_code)
+
+ # Should still be able to hit y view 1 time per minute.
+ self.increment_timer()
+ response = self.y_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.y_view(request)
+ self.assertEqual(429, response.status_code)
+
+ def test_unscoped_view_not_throttled(self):
+ request = self.factory.get('/')
+
+ for idx in range(10):
+ self.increment_timer()
+ response = self.unscoped_view(request)
+ self.assertEqual(200, response.status_code)
diff --git a/tests/test_urlpatterns.py b/tests/test_urlpatterns.py
new file mode 100644
index 00000000..8132ec4c
--- /dev/null
+++ b/tests/test_urlpatterns.py
@@ -0,0 +1,76 @@
+from __future__ import unicode_literals
+from collections import namedtuple
+from django.core import urlresolvers
+from django.test import TestCase
+from rest_framework.test import APIRequestFactory
+from rest_framework.compat import patterns, url, include
+from rest_framework.urlpatterns import format_suffix_patterns
+
+
+# A container class for test paths for the test case
+URLTestPath = namedtuple('URLTestPath', ['path', 'args', 'kwargs'])
+
+
+def dummy_view(request, *args, **kwargs):
+ pass
+
+
+class FormatSuffixTests(TestCase):
+ """
+ Tests `format_suffix_patterns` against different URLPatterns to ensure the URLs still resolve properly, including any captured parameters.
+ """
+ def _resolve_urlpatterns(self, urlpatterns, test_paths):
+ factory = APIRequestFactory()
+ try:
+ urlpatterns = format_suffix_patterns(urlpatterns)
+ except Exception:
+ self.fail("Failed to apply `format_suffix_patterns` on the supplied urlpatterns")
+ resolver = urlresolvers.RegexURLResolver(r'^/', urlpatterns)
+ for test_path in test_paths:
+ request = factory.get(test_path.path)
+ try:
+ callback, callback_args, callback_kwargs = resolver.resolve(request.path_info)
+ except Exception:
+ self.fail("Failed to resolve URL: %s" % request.path_info)
+ self.assertEqual(callback_args, test_path.args)
+ self.assertEqual(callback_kwargs, test_path.kwargs)
+
+ def test_format_suffix(self):
+ urlpatterns = patterns(
+ '',
+ url(r'^test$', dummy_view),
+ )
+ test_paths = [
+ URLTestPath('/test', (), {}),
+ URLTestPath('/test.api', (), {'format': 'api'}),
+ URLTestPath('/test.asdf', (), {'format': 'asdf'}),
+ ]
+ self._resolve_urlpatterns(urlpatterns, test_paths)
+
+ def test_default_args(self):
+ urlpatterns = patterns(
+ '',
+ url(r'^test$', dummy_view, {'foo': 'bar'}),
+ )
+ test_paths = [
+ URLTestPath('/test', (), {'foo': 'bar', }),
+ URLTestPath('/test.api', (), {'foo': 'bar', 'format': 'api'}),
+ URLTestPath('/test.asdf', (), {'foo': 'bar', 'format': 'asdf'}),
+ ]
+ self._resolve_urlpatterns(urlpatterns, test_paths)
+
+ def test_included_urls(self):
+ nested_patterns = patterns(
+ '',
+ url(r'^path$', dummy_view)
+ )
+ urlpatterns = patterns(
+ '',
+ url(r'^test/', include(nested_patterns), {'foo': 'bar'}),
+ )
+ test_paths = [
+ URLTestPath('/test/path', (), {'foo': 'bar', }),
+ URLTestPath('/test/path.api', (), {'foo': 'bar', 'format': 'api'}),
+ URLTestPath('/test/path.asdf', (), {'foo': 'bar', 'format': 'asdf'}),
+ ]
+ self._resolve_urlpatterns(urlpatterns, test_paths)
diff --git a/tests/test_validation.py b/tests/test_validation.py
new file mode 100644
index 00000000..124c874d
--- /dev/null
+++ b/tests/test_validation.py
@@ -0,0 +1,104 @@
+from __future__ import unicode_literals
+from django.db import models
+from django.test import TestCase
+from rest_framework import generics, serializers, status
+from rest_framework.test import APIRequestFactory
+
+factory = APIRequestFactory()
+
+
+# Regression for #666
+
+class ValidationModel(models.Model):
+ blank_validated_field = models.CharField(max_length=255)
+
+
+class ValidationModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ValidationModel
+ fields = ('blank_validated_field',)
+ read_only_fields = ('blank_validated_field',)
+
+
+class UpdateValidationModel(generics.RetrieveUpdateDestroyAPIView):
+ model = ValidationModel
+ serializer_class = ValidationModelSerializer
+
+
+class TestPreSaveValidationExclusions(TestCase):
+ def test_pre_save_validation_exclusions(self):
+ """
+ Somewhat weird test case to ensure that we don't perform model
+ validation on read only fields.
+ """
+ obj = ValidationModel.objects.create(blank_validated_field='')
+ request = factory.put('/', {}, format='json')
+ view = UpdateValidationModel().as_view()
+ response = view(request, pk=obj.pk).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+
+# Regression for #653
+
+class ShouldValidateModel(models.Model):
+ should_validate_field = models.CharField(max_length=255)
+
+
+class ShouldValidateModelSerializer(serializers.ModelSerializer):
+ renamed = serializers.CharField(source='should_validate_field', required=False)
+
+ def validate_renamed(self, attrs, source):
+ value = attrs[source]
+ if len(value) < 3:
+ raise serializers.ValidationError('Minimum 3 characters.')
+ return attrs
+
+ class Meta:
+ model = ShouldValidateModel
+ fields = ('renamed',)
+
+
+class TestPreSaveValidationExclusionsSerializer(TestCase):
+ def test_renamed_fields_are_model_validated(self):
+ """
+ Ensure fields with 'source' applied do get still get model validation.
+ """
+ # We've set `required=False` on the serializer, but the model
+ # does not have `blank=True`, so this serializer should not validate.
+ serializer = ShouldValidateModelSerializer(data={'renamed': ''})
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertIn('renamed', serializer.errors)
+ self.assertNotIn('should_validate_field', serializer.errors)
+
+
+class TestCustomValidationMethods(TestCase):
+ def test_custom_validation_method_is_executed(self):
+ serializer = ShouldValidateModelSerializer(data={'renamed': 'fo'})
+ self.assertFalse(serializer.is_valid())
+ self.assertIn('renamed', serializer.errors)
+
+ def test_custom_validation_method_passing(self):
+ serializer = ShouldValidateModelSerializer(data={'renamed': 'foo'})
+ self.assertTrue(serializer.is_valid())
+
+
+class ValidationSerializer(serializers.Serializer):
+ foo = serializers.CharField()
+
+ def validate_foo(self, attrs, source):
+ raise serializers.ValidationError("foo invalid")
+
+ def validate(self, attrs):
+ raise serializers.ValidationError("serializer invalid")
+
+
+class TestAvoidValidation(TestCase):
+ """
+ If serializer was initialized with invalid data (None or non dict-like), it
+ should avoid validation layer (validate_ and validate methods)
+ """
+ def test_serializer_errors_has_only_invalid_data_error(self):
+ serializer = ValidationSerializer(data='invalid data')
+ self.assertFalse(serializer.is_valid())
+ self.assertDictEqual(serializer.errors,
+ {'non_field_errors': ['Invalid data']})
diff --git a/tests/test_views.py b/tests/test_views.py
new file mode 100644
index 00000000..65c7e50e
--- /dev/null
+++ b/tests/test_views.py
@@ -0,0 +1,142 @@
+from __future__ import unicode_literals
+
+import copy
+from django.test import TestCase
+from rest_framework import status
+from rest_framework.decorators import api_view
+from rest_framework.response import Response
+from rest_framework.settings import api_settings
+from rest_framework.test import APIRequestFactory
+from rest_framework.views import APIView
+
+factory = APIRequestFactory()
+
+
+class BasicView(APIView):
+ def get(self, request, *args, **kwargs):
+ return Response({'method': 'GET'})
+
+ def post(self, request, *args, **kwargs):
+ return Response({'method': 'POST', 'data': request.DATA})
+
+
+@api_view(['GET', 'POST', 'PUT', 'PATCH'])
+def basic_view(request):
+ if request.method == 'GET':
+ return {'method': 'GET'}
+ elif request.method == 'POST':
+ return {'method': 'POST', 'data': request.DATA}
+ elif request.method == 'PUT':
+ return {'method': 'PUT', 'data': request.DATA}
+ elif request.method == 'PATCH':
+ return {'method': 'PATCH', 'data': request.DATA}
+
+
+class ErrorView(APIView):
+ def get(self, request, *args, **kwargs):
+ raise Exception
+
+
+@api_view(['GET'])
+def error_view(request):
+ raise Exception
+
+
+def sanitise_json_error(error_dict):
+ """
+ Exact contents of JSON error messages depend on the installed version
+ of json.
+ """
+ ret = copy.copy(error_dict)
+ chop = len('JSON parse error - No JSON object could be decoded')
+ ret['detail'] = ret['detail'][:chop]
+ return ret
+
+
+class ClassBasedViewIntegrationTests(TestCase):
+ def setUp(self):
+ self.view = BasicView.as_view()
+
+ def test_400_parse_error(self):
+ request = factory.post('/', 'f00bar', content_type='application/json')
+ response = self.view(request)
+ expected = {
+ 'detail': 'JSON parse error - No JSON object could be decoded'
+ }
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEqual(sanitise_json_error(response.data), expected)
+
+ def test_400_parse_error_tunneled_content(self):
+ content = 'f00bar'
+ content_type = 'application/json'
+ form_data = {
+ api_settings.FORM_CONTENT_OVERRIDE: content,
+ api_settings.FORM_CONTENTTYPE_OVERRIDE: content_type
+ }
+ request = factory.post('/', form_data)
+ response = self.view(request)
+ expected = {
+ 'detail': 'JSON parse error - No JSON object could be decoded'
+ }
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEqual(sanitise_json_error(response.data), expected)
+
+
+class FunctionBasedViewIntegrationTests(TestCase):
+ def setUp(self):
+ self.view = basic_view
+
+ def test_400_parse_error(self):
+ request = factory.post('/', 'f00bar', content_type='application/json')
+ response = self.view(request)
+ expected = {
+ 'detail': 'JSON parse error - No JSON object could be decoded'
+ }
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEqual(sanitise_json_error(response.data), expected)
+
+ def test_400_parse_error_tunneled_content(self):
+ content = 'f00bar'
+ content_type = 'application/json'
+ form_data = {
+ api_settings.FORM_CONTENT_OVERRIDE: content,
+ api_settings.FORM_CONTENTTYPE_OVERRIDE: content_type
+ }
+ request = factory.post('/', form_data)
+ response = self.view(request)
+ expected = {
+ 'detail': 'JSON parse error - No JSON object could be decoded'
+ }
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEqual(sanitise_json_error(response.data), expected)
+
+
+class TestCustomExceptionHandler(TestCase):
+ def setUp(self):
+ self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER
+
+ def exception_handler(exc):
+ return Response('Error!', status=status.HTTP_400_BAD_REQUEST)
+
+ api_settings.EXCEPTION_HANDLER = exception_handler
+
+ def tearDown(self):
+ api_settings.EXCEPTION_HANDLER = self.DEFAULT_HANDLER
+
+ def test_class_based_view_exception_handler(self):
+ view = ErrorView.as_view()
+
+ request = factory.get('/', content_type='application/json')
+ response = view(request)
+ expected = 'Error!'
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEqual(response.data, expected)
+
+ def test_function_based_view_exception_handler(self):
+ view = error_view
+
+ request = factory.get('/', content_type='application/json')
+ response = view(request)
+ expected = 'Error!'
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEqual(response.data, expected)
diff --git a/tests/test_write_only_fields.py b/tests/test_write_only_fields.py
new file mode 100644
index 00000000..aabb18d6
--- /dev/null
+++ b/tests/test_write_only_fields.py
@@ -0,0 +1,42 @@
+from django.db import models
+from django.test import TestCase
+from rest_framework import serializers
+
+
+class ExampleModel(models.Model):
+ email = models.EmailField(max_length=100)
+ password = models.CharField(max_length=100)
+
+
+class WriteOnlyFieldTests(TestCase):
+ def test_write_only_fields(self):
+ class ExampleSerializer(serializers.Serializer):
+ email = serializers.EmailField()
+ password = serializers.CharField(write_only=True)
+
+ data = {
+ 'email': 'foo@example.com',
+ 'password': '123'
+ }
+ serializer = ExampleSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEquals(serializer.object, data)
+ self.assertEquals(serializer.data, {'email': 'foo@example.com'})
+
+ def test_write_only_fields_meta(self):
+ class ExampleSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ExampleModel
+ fields = ('email', 'password')
+ write_only_fields = ('password',)
+
+ data = {
+ 'email': 'foo@example.com',
+ 'password': '123'
+ }
+ serializer = ExampleSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertTrue(isinstance(serializer.object, ExampleModel))
+ self.assertEquals(serializer.object.email, data['email'])
+ self.assertEquals(serializer.object.password, data['password'])
+ self.assertEquals(serializer.data, {'email': 'foo@example.com'})
diff --git a/tests/urls.py b/tests/urls.py
new file mode 100644
index 00000000..62cad339
--- /dev/null
+++ b/tests/urls.py
@@ -0,0 +1,6 @@
+"""
+Blank URLConf just to keep the test suite happy
+"""
+from rest_framework.compat import patterns
+
+urlpatterns = patterns('')
diff --git a/tests/users/__init__.py b/tests/users/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/users/models.py b/tests/users/models.py
new file mode 100644
index 00000000..128bac90
--- /dev/null
+++ b/tests/users/models.py
@@ -0,0 +1,6 @@
+from django.db import models
+
+
+class User(models.Model):
+ account = models.ForeignKey('accounts.Account', blank=True, null=True, related_name='users')
+ active_record = models.ForeignKey('records.Record', blank=True, null=True)
diff --git a/tests/users/serializers.py b/tests/users/serializers.py
new file mode 100644
index 00000000..4893ddb3
--- /dev/null
+++ b/tests/users/serializers.py
@@ -0,0 +1,8 @@
+from rest_framework import serializers
+
+from tests.users.models import User
+
+
+class UserSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = User
diff --git a/tests/views.py b/tests/views.py
new file mode 100644
index 00000000..55935e92
--- /dev/null
+++ b/tests/views.py
@@ -0,0 +1,8 @@
+from rest_framework import generics
+from .models import NullableForeignKeySource
+from .serializers import NullableFKSourceSerializer
+
+
+class NullableFKSourceDetail(generics.RetrieveUpdateDestroyAPIView):
+ model = NullableForeignKeySource
+ model_serializer_class = NullableFKSourceSerializer
diff --git a/tox.ini b/tox.ini
index 77766d20..2fe39f12 100644
--- a/tox.ini
+++ b/tox.ini
@@ -3,19 +3,21 @@ downloadcache = {toxworkdir}/cache/
envlist = py3.3-django1.6,py3.2-django1.6,py2.7-django1.6,py2.6-django1.6,py3.3-django1.5,py3.2-django1.5,py2.7-django1.5,py2.6-django1.5,py2.7-django1.4,py2.6-django1.4,py2.7-django1.3,py2.6-django1.3
[testenv]
-commands = {envpython} rest_framework/runtests/runtests.py
+commands = py.test -q
[testenv:py3.3-django1.6]
basepython = python3.3
deps = Django==1.6
django-filter==0.6a1
defusedxml==0.3
+ pytest-django==2.6
[testenv:py3.2-django1.6]
basepython = python3.2
deps = Django==1.6
django-filter==0.6a1
defusedxml==0.3
+ pytest-django==2.6
[testenv:py2.7-django1.6]
basepython = python2.7
@@ -26,6 +28,7 @@ deps = Django==1.6
oauth2==1.5.211
django-oauth2-provider==0.2.4
django-guardian==1.1.1
+ pytest-django==2.6
[testenv:py2.6-django1.6]
basepython = python2.6
@@ -36,18 +39,21 @@ deps = Django==1.6
oauth2==1.5.211
django-oauth2-provider==0.2.4
django-guardian==1.1.1
+ pytest-django==2.6
[testenv:py3.3-django1.5]
basepython = python3.3
deps = django==1.5.5
django-filter==0.6a1
defusedxml==0.3
+ pytest-django==2.6
[testenv:py3.2-django1.5]
basepython = python3.2
deps = django==1.5.5
django-filter==0.6a1
defusedxml==0.3
+ pytest-django==2.6
[testenv:py2.7-django1.5]
basepython = python2.7
@@ -58,6 +64,7 @@ deps = django==1.5.5
oauth2==1.5.211
django-oauth2-provider==0.2.3
django-guardian==1.1.1
+ pytest-django==2.6
[testenv:py2.6-django1.5]
basepython = python2.6
@@ -68,6 +75,7 @@ deps = django==1.5.5
oauth2==1.5.211
django-oauth2-provider==0.2.3
django-guardian==1.1.1
+ pytest-django==2.6
[testenv:py2.7-django1.4]
basepython = python2.7
@@ -78,6 +86,7 @@ deps = django==1.4.10
oauth2==1.5.211
django-oauth2-provider==0.2.3
django-guardian==1.1.1
+ pytest-django==2.6
[testenv:py2.6-django1.4]
basepython = python2.6
@@ -88,6 +97,7 @@ deps = django==1.4.10
oauth2==1.5.211
django-oauth2-provider==0.2.3
django-guardian==1.1.1
+ pytest-django==2.6
[testenv:py2.7-django1.3]
basepython = python2.7
@@ -98,6 +108,7 @@ deps = django==1.3.5
oauth2==1.5.211
django-oauth2-provider==0.2.3
django-guardian==1.1.1
+ pytest-django==2.6
[testenv:py2.6-django1.3]
basepython = python2.6
@@ -108,3 +119,4 @@ deps = django==1.3.5
oauth2==1.5.211
django-oauth2-provider==0.2.3
django-guardian==1.1.1
+ pytest-django==2.6
--
cgit v1.2.3
From 6c108c459d8cfeda46b8e045ef750c01dd0ffcaa Mon Sep 17 00:00:00 2001
From: Ian Foote
Date: Wed, 16 Apr 2014 12:32:04 +0100
Subject: Allow customising ChoiceField blank display value
---
rest_framework/fields.py | 8 ++++++--
rest_framework/tests/test_fields.py | 9 +++++++++
2 files changed, 15 insertions(+), 2 deletions(-)
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 946a5954..d9521cd4 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -509,12 +509,16 @@ class ChoiceField(WritableField):
'the available choices.'),
}
- def __init__(self, choices=(), *args, **kwargs):
+ def __init__(self, choices=(), blank_display_value=None, *args, **kwargs):
self.empty = kwargs.pop('empty', '')
super(ChoiceField, self).__init__(*args, **kwargs)
self.choices = choices
if not self.required:
- self.choices = BLANK_CHOICE_DASH + self.choices
+ if blank_display_value is None:
+ blank_choice = BLANK_CHOICE_DASH
+ else:
+ blank_choice = [('', blank_display_value)]
+ self.choices = blank_choice + self.choices
def _get_choices(self):
return self._choices
diff --git a/rest_framework/tests/test_fields.py b/rest_framework/tests/test_fields.py
index e127feef..63dff718 100644
--- a/rest_framework/tests/test_fields.py
+++ b/rest_framework/tests/test_fields.py
@@ -706,6 +706,15 @@ class ChoiceFieldTests(TestCase):
f = serializers.ChoiceField(required=False, choices=SAMPLE_CHOICES)
self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + SAMPLE_CHOICES)
+ def test_blank_choice_display(self):
+ blank = 'No Preference'
+ f = serializers.ChoiceField(
+ required=False,
+ choices=SAMPLE_CHOICES,
+ blank_display_value=blank,
+ )
+ self.assertEqual(f.choices, [('', blank)] + SAMPLE_CHOICES)
+
def test_invalid_choice_model(self):
s = ChoiceFieldModelSerializer(data={'choice': 'wrong_value'})
self.assertFalse(s.is_valid())
--
cgit v1.2.3
From f22ed49c648c6dc3e2cd3c1dfbda77c010189e28 Mon Sep 17 00:00:00 2001
From: Xavier Ordoquy
Date: Thu, 17 Apr 2014 11:09:02 +0200
Subject: Upgraded to pytest-django 2.6.1
---
.travis.yml | 2 +-
conftest.py | 4 ----
tox.ini | 27 +++++++++++++++------------
3 files changed, 16 insertions(+), 17 deletions(-)
diff --git a/.travis.yml b/.travis.yml
index 13dc3e28..4f4d0c30 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -16,7 +16,7 @@ env:
install:
- pip install $DJANGO
- pip install defusedxml==0.3 Pillow==2.3.0
- - pip install pytest-django==2.6
+ - pip install pytest-django==2.6.1
- "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install oauth2==1.5.211; fi"
- "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install django-oauth-plus==2.2.4; fi"
- "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install django-oauth2-provider==0.2.4; fi"
diff --git a/conftest.py b/conftest.py
index 7cfc77f2..b1691a88 100644
--- a/conftest.py
+++ b/conftest.py
@@ -79,7 +79,3 @@ def pytest_configure():
settings.INSTALLED_APPS += (
'guardian',
)
-
- # Force Django to load all models
- from django.db.models import get_models
- get_models()
diff --git a/tox.ini b/tox.ini
index 251a40b7..2f6f1612 100644
--- a/tox.ini
+++ b/tox.ini
@@ -11,6 +11,7 @@ deps = https://www.djangoproject.com/download/1.7b1/tarball/
django-filter==0.7
defusedxml==0.3
Pillow==2.3.0
+ pytest-django==2.6.1
[testenv:py3.2-django1.7]
basepython = python3.2
@@ -18,6 +19,7 @@ deps = https://www.djangoproject.com/download/1.7b1/tarball/
django-filter==0.7
defusedxml==0.3
Pillow==2.3.0
+ pytest-django==2.6.1
[testenv:py2.7-django1.7]
basepython = python2.7
@@ -29,6 +31,7 @@ deps = https://www.djangoproject.com/download/1.7b1/tarball/
django-oauth2-provider==0.2.4
django-guardian==1.1.1
Pillow==2.3.0
+ pytest-django==2.6.1
[testenv:py3.3-django1.6]
basepython = python3.3
@@ -36,7 +39,7 @@ deps = Django==1.6
django-filter==0.7
defusedxml==0.3
Pillow==2.3.0
- pytest-django==2.6
+ pytest-django==2.6.1
[testenv:py3.2-django1.6]
basepython = python3.2
@@ -44,7 +47,7 @@ deps = Django==1.6
django-filter==0.7
defusedxml==0.3
Pillow==2.3.0
- pytest-django==2.6
+ pytest-django==2.6.1
[testenv:py2.7-django1.6]
basepython = python2.7
@@ -56,7 +59,7 @@ deps = Django==1.6
django-oauth2-provider==0.2.4
django-guardian==1.1.1
Pillow==2.3.0
- pytest-django==2.6
+ pytest-django==2.6.1
[testenv:py2.6-django1.6]
basepython = python2.6
@@ -68,7 +71,7 @@ deps = Django==1.6
django-oauth2-provider==0.2.4
django-guardian==1.1.1
Pillow==2.3.0
- pytest-django==2.6
+ pytest-django==2.6.1
[testenv:py3.3-django1.5]
basepython = python3.3
@@ -76,7 +79,7 @@ deps = django==1.5.5
django-filter==0.7
defusedxml==0.3
Pillow==2.3.0
- pytest-django==2.6
+ pytest-django==2.6.1
[testenv:py3.2-django1.5]
basepython = python3.2
@@ -84,7 +87,7 @@ deps = django==1.5.5
django-filter==0.7
defusedxml==0.3
Pillow==2.3.0
- pytest-django==2.6
+ pytest-django==2.6.1
[testenv:py2.7-django1.5]
basepython = python2.7
@@ -96,7 +99,7 @@ deps = django==1.5.5
django-oauth2-provider==0.2.3
django-guardian==1.1.1
Pillow==2.3.0
- pytest-django==2.6
+ pytest-django==2.6.1
[testenv:py2.6-django1.5]
basepython = python2.6
@@ -108,7 +111,7 @@ deps = django==1.5.5
django-oauth2-provider==0.2.3
django-guardian==1.1.1
Pillow==2.3.0
- pytest-django==2.6
+ pytest-django==2.6.1
[testenv:py2.7-django1.4]
basepython = python2.7
@@ -120,7 +123,7 @@ deps = django==1.4.10
django-oauth2-provider==0.2.3
django-guardian==1.1.1
Pillow==2.3.0
- pytest-django==2.6
+ pytest-django==2.6.1
[testenv:py2.6-django1.4]
basepython = python2.6
@@ -132,7 +135,7 @@ deps = django==1.4.10
django-oauth2-provider==0.2.3
django-guardian==1.1.1
Pillow==2.3.0
- pytest-django==2.6
+ pytest-django==2.6.1
[testenv:py2.7-django1.3]
basepython = python2.7
@@ -144,7 +147,7 @@ deps = django==1.3.5
django-oauth2-provider==0.2.3
django-guardian==1.1.1
Pillow==2.3.0
- pytest-django==2.6
+ pytest-django==2.6.1
[testenv:py2.6-django1.3]
basepython = python2.6
@@ -156,4 +159,4 @@ deps = django==1.3.5
django-oauth2-provider==0.2.3
django-guardian==1.1.1
Pillow==2.3.0
- pytest-django==2.6
+ pytest-django==2.6.1
--
cgit v1.2.3
From c5f68fba0638a15fa3c802f1bafc664e890611dc Mon Sep 17 00:00:00 2001
From: Xavier Ordoquy
Date: Thu, 17 Apr 2014 14:30:33 +0200
Subject: Fixed the issue with django-filters / django 1.7 / pytest
---
conftest.py | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/conftest.py b/conftest.py
index b1691a88..7a49845f 100644
--- a/conftest.py
+++ b/conftest.py
@@ -79,3 +79,9 @@ def pytest_configure():
settings.INSTALLED_APPS += (
'guardian',
)
+
+ try:
+ import django
+ django.setup()
+ except AttributeError:
+ pass
--
cgit v1.2.3
From cd93cd195ef83a443e8fe7d745b2947d2636f4ad Mon Sep 17 00:00:00 2001
From: Xavier Ordoquy
Date: Wed, 30 Apr 2014 22:32:29 +0200
Subject: Use url functions from Django itself.
---
rest_framework/tests/test_authentication.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/rest_framework/tests/test_authentication.py b/rest_framework/tests/test_authentication.py
index 09203057..af292bf1 100644
--- a/rest_framework/tests/test_authentication.py
+++ b/rest_framework/tests/test_authentication.py
@@ -20,7 +20,7 @@ from rest_framework.authentication import (
OAuth2Authentication
)
from rest_framework.authtoken.models import Token
-from rest_framework.compat import patterns, url, include, six
+from rest_framework.compat import six
from rest_framework.compat import oauth2_provider, oauth2_provider_scope
from rest_framework.compat import oauth, oauth_provider
from rest_framework.test import APIRequestFactory, APIClient
--
cgit v1.2.3
From 7475fceacc5bc94fde6212937685ef69ae79c751 Mon Sep 17 00:00:00 2001
From: Xavier Ordoquy
Date: Thu, 1 May 2014 00:54:20 +0200
Subject: Added missing field for the tests.
---
rest_framework/tests/test_serializer.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/rest_framework/tests/test_serializer.py b/rest_framework/tests/test_serializer.py
index f0bb112d..31bd1082 100644
--- a/rest_framework/tests/test_serializer.py
+++ b/rest_framework/tests/test_serializer.py
@@ -30,6 +30,7 @@ if PIL is not None:
image_field = models.ImageField(upload_to='test', max_length=1024, blank=True)
slug_field = models.SlugField(max_length=1024, blank=True)
url_field = models.URLField(max_length=1024, blank=True)
+ nullable_char_field = models.CharField(max_length=1024, blank=True, null=True)
class DVOAFModel(RESTFrameworkModel):
positive_integer_field = models.PositiveIntegerField(blank=True)
--
cgit v1.2.3
From 38362bb43a19c287319ccfe0538ce5524f09c633 Mon Sep 17 00:00:00 2001
From: Xavier Ordoquy
Date: Thu, 1 May 2014 01:24:48 +0200
Subject: Fixed new default for many
---
rest_framework/tests/test_genericrelations.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/rest_framework/tests/test_genericrelations.py b/rest_framework/tests/test_genericrelations.py
index 46a2d863..3a8f3c7f 100644
--- a/rest_framework/tests/test_genericrelations.py
+++ b/rest_framework/tests/test_genericrelations.py
@@ -84,7 +84,7 @@ class TestGenericRelations(TestCase):
exclude = ('content_type', 'object_id')
class BookmarkSerializer(serializers.ModelSerializer):
- tags = TagSerializer()
+ tags = TagSerializer(many=True)
class Meta:
model = Bookmark
--
cgit v1.2.3
From c9e6f31166ebccc5c3bf2f27e12a6d6c87f5cf22 Mon Sep 17 00:00:00 2001
From: Xavier Ordoquy
Date: Thu, 1 May 2014 01:27:51 +0200
Subject: Fixed new default for many
---
rest_framework/tests/test_serializer.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/rest_framework/tests/test_serializer.py b/rest_framework/tests/test_serializer.py
index 31bd1082..44ef8a95 100644
--- a/rest_framework/tests/test_serializer.py
+++ b/rest_framework/tests/test_serializer.py
@@ -661,7 +661,7 @@ class ModelValidationTests(TestCase):
second_serializer = AlbumsSerializer(data={'title': 'a'})
self.assertFalse(second_serializer.is_valid())
self.assertEqual(second_serializer.errors, {'title': ['Album with this Title already exists.'],})
- third_serializer = AlbumsSerializer(data=[{'title': 'b', 'ref': '1'}, {'title': 'c'}])
+ third_serializer = AlbumsSerializer(data=[{'title': 'b', 'ref': '1'}, {'title': 'c'}], many=True)
self.assertFalse(third_serializer.is_valid())
self.assertEqual(third_serializer.errors, [{'ref': ['Album with this Ref already exists.']}, {}])
--
cgit v1.2.3
From eb89ed02f247d903db1cdd488d69b316323d9f60 Mon Sep 17 00:00:00 2001
From: Xavier Ordoquy
Date: Thu, 1 May 2014 08:36:18 +0200
Subject: Added missing staticfiles app
---
conftest.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/conftest.py b/conftest.py
index 7a49845f..fa5184dd 100644
--- a/conftest.py
+++ b/conftest.py
@@ -27,6 +27,7 @@ def pytest_configure():
'django.contrib.sessions',
'django.contrib.sites',
'django.contrib.messages',
+ 'django.contrib.staticfiles',
'rest_framework',
'rest_framework.authtoken',
--
cgit v1.2.3
From e5441d845e34f1e1bb2b7464d31aa3df7b02d0fe Mon Sep 17 00:00:00 2001
From: Xavier Ordoquy
Date: Thu, 1 May 2014 08:41:37 +0200
Subject: Use urls functions from django instead of compat.
---
tests/urls.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/tests/urls.py b/tests/urls.py
index 62cad339..41f527df 100644
--- a/tests/urls.py
+++ b/tests/urls.py
@@ -1,6 +1,6 @@
"""
Blank URLConf just to keep the test suite happy
"""
-from rest_framework.compat import patterns
+from django.conf.urls import patterns
urlpatterns = patterns('')
--
cgit v1.2.3
From 15c2c58b43a00ec29af99e0478b70eea57560fce Mon Sep 17 00:00:00 2001
From: Xavier Ordoquy
Date: Thu, 1 May 2014 08:43:49 +0200
Subject: Updated the release-notes.
---
docs/topics/release-notes.md | 1 +
1 file changed, 1 insertion(+)
diff --git a/docs/topics/release-notes.md b/docs/topics/release-notes.md
index d6256b38..fd5c7029 100644
--- a/docs/topics/release-notes.md
+++ b/docs/topics/release-notes.md
@@ -40,6 +40,7 @@ You can determine your currently installed version using `pip freeze`:
### 2.4.0
+* Use py.test
* `@detail_route` and `@list_route` decorators replace `@action` and `@link`.
* `six` no longer bundled. For Django <= 1.4.1, install `six` package.
* Support customizable view name and description functions, using the `VIEW_NAME_FUNCTION` and `VIEW_DESCRIPTION_FUNCTION` settings.
--
cgit v1.2.3
From 4e33ff05d9aabee0a90bfb0ef8ce58a5d274b9a2 Mon Sep 17 00:00:00 2001
From: Lucian Mocanu
Date: Sun, 4 May 2014 00:12:08 +0200
Subject: Automatically set the field name as value for the HTML `id` attribute
on the rendered widget.
---
rest_framework/fields.py | 7 ++++++-
1 file changed, 6 insertions(+), 1 deletion(-)
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 8cdc5551..e6733849 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -154,7 +154,12 @@ class Field(object):
def widget_html(self):
if not self.widget:
return ''
- return self.widget.render(self._name, self._value)
+
+ attrs = {}
+ if 'id' not in self.widget.attrs:
+ attrs['id'] = self._name
+
+ return self.widget.render(self._name, self._value, attrs=attrs)
def label_tag(self):
return '' % (self._name, self.label)
--
cgit v1.2.3
From 708c7b3a816c3c2df7847695044ef852dc89e72c Mon Sep 17 00:00:00 2001
From: Lucian Mocanu
Date: Tue, 6 May 2014 14:17:51 +0200
Subject: Added test case to check if the proper attributes are set on html
widgets.
---
rest_framework/tests/test_fields.py | 11 +++++++++++
1 file changed, 11 insertions(+)
diff --git a/rest_framework/tests/test_fields.py b/rest_framework/tests/test_fields.py
index e127feef..03f79cf4 100644
--- a/rest_framework/tests/test_fields.py
+++ b/rest_framework/tests/test_fields.py
@@ -4,6 +4,7 @@ General serializer field tests.
from __future__ import unicode_literals
import datetime
+import re
from decimal import Decimal
from uuid import uuid4
from django.core import validators
@@ -103,6 +104,16 @@ class BasicFieldTests(TestCase):
keys = list(field.to_native(ret).keys())
self.assertEqual(keys, ['c', 'b', 'a', 'z'])
+ def test_widget_html_attributes(self):
+ """
+ Make sure widget_html() renders the correct attributes
+ """
+ r = re.compile('(\S+)=["\']?((?:.(?!["\']?\s+(?:\S+)=|[>"\']))+.)["\']?')
+ form = TimeFieldModelSerializer().data
+ attributes = r.findall(form.fields['clock'].widget_html())
+ self.assertIn(('name', 'clock'), attributes)
+ self.assertIn(('id', 'clock'), attributes)
+
class DateFieldTest(TestCase):
"""
--
cgit v1.2.3
From 8ecb778cd23d5d561f2e9f4a3561bb1647257a89 Mon Sep 17 00:00:00 2001
From: Corey Farwell
Date: Sun, 11 May 2014 20:29:01 -0700
Subject: Enable testing on Python 3.4
---
.travis.yml | 6 +++++-
1 file changed, 5 insertions(+), 1 deletion(-)
diff --git a/.travis.yml b/.travis.yml
index bd6d2539..0c9b4455 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -5,6 +5,7 @@ python:
- "2.7"
- "3.2"
- "3.3"
+ - "3.4"
env:
- DJANGO="https://www.djangoproject.com/download/1.7b2/tarball/"
@@ -41,4 +42,7 @@ matrix:
env: DJANGO="django==1.4.11"
- python: "3.3"
env: DJANGO="django==1.3.7"
-
+ - python: "3.4"
+ env: DJANGO="django==1.4.11"
+ - python: "3.4"
+ env: DJANGO="django==1.3.7"
--
cgit v1.2.3
From 768f537dcbb5d4f7429a74556559047bfd6f3078 Mon Sep 17 00:00:00 2001
From: Giorgos Logiotatidis
Date: Thu, 15 May 2014 15:34:31 +0300
Subject: Typo fix.
---
docs/api-guide/serializers.md | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/docs/api-guide/serializers.md b/docs/api-guide/serializers.md
index 7ee060af..0044f070 100644
--- a/docs/api-guide/serializers.md
+++ b/docs/api-guide/serializers.md
@@ -464,7 +464,7 @@ For more specific requirements such as specifying a different lookup for each fi
model = Account
fields = ('url', 'account_name', 'users', 'created')
-## Overiding the URL field behavior
+## Overriding the URL field behavior
The name of the URL field defaults to 'url'. You can override this globally, by using the `URL_FIELD_NAME` setting.
@@ -478,7 +478,7 @@ You can also override this on a per-serializer basis by using the `url_field_nam
**Note**: The generic view implementations normally generate a `Location` header in response to successful `POST` requests. Serializers using `url_field_name` option will not have this header automatically included by the view. If you need to do so you will ned to also override the view's `get_success_headers()` method.
-You can also overide the URL field's view name and lookup field without overriding the field explicitly, by using the `view_name` and `lookup_field` options, like so:
+You can also override the URL field's view name and lookup field without overriding the field explicitly, by using the `view_name` and `lookup_field` options, like so:
class AccountSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
--
cgit v1.2.3
From e5556079fc2559916d62b766dc9776b03dc4256b Mon Sep 17 00:00:00 2001
From: Xavier Ordoquy
Date: Fri, 16 May 2014 00:50:16 +0200
Subject: Updated tox with Python 2.4
---
tox.ini | 28 +++++++++++++++++++++++++++-
1 file changed, 27 insertions(+), 1 deletion(-)
diff --git a/tox.ini b/tox.ini
index e2121005..35a108e5 100644
--- a/tox.ini
+++ b/tox.ini
@@ -1,10 +1,22 @@
[tox]
downloadcache = {toxworkdir}/cache/
-envlist = py3.3-django1.7,py3.2-django1.7,py2.7-django1.7,py3.3-django1.6,py3.2-django1.6,py2.7-django1.6,py2.6-django1.6,py3.3-django1.5,py3.2-django1.5,py2.7-django1.5,py2.6-django1.5,py2.7-django1.4,py2.6-django1.4,py2.7-django1.3,py2.6-django1.3
+envlist =
+ py3.4-django1.7,py3.3-django1.7,py3.2-django1.7,py2.7-django1.7,
+ py3.4-django1.6,py3.3-django1.6,py3.2-django1.6,py2.7-django1.6,py2.6-django1.6,
+ py3.4-django1.5,py3.3-django1.5,py3.2-django1.5,py2.7-django1.5,py2.6-django1.5,
+ py2.7-django1.4,py2.6-django1.4,
+ py2.7-django1.3,py2.6-django1.3
[testenv]
commands = {envpython} rest_framework/runtests/runtests.py
+[testenv:py3.4-django1.7]
+basepython = python3.4
+deps = https://www.djangoproject.com/download/1.7b2/tarball/
+ django-filter==0.7
+ defusedxml==0.3
+ Pillow==2.3.0
+
[testenv:py3.3-django1.7]
basepython = python3.3
deps = https://www.djangoproject.com/download/1.7b2/tarball/
@@ -30,6 +42,13 @@ deps = https://www.djangoproject.com/download/1.7b2/tarball/
django-guardian==1.1.1
Pillow==2.3.0
+[testenv:py3.4-django1.6]
+basepython = python3.3
+deps = Django==1.6.3
+ django-filter==0.7
+ defusedxml==0.3
+ Pillow==2.3.0
+
[testenv:py3.3-django1.6]
basepython = python3.3
deps = Django==1.6.3
@@ -66,6 +85,13 @@ deps = Django==1.6.3
django-guardian==1.1.1
Pillow==2.3.0
+[testenv:py3.4-django1.5]
+basepython = python3.3
+deps = django==1.5.6
+ django-filter==0.7
+ defusedxml==0.3
+ Pillow==2.3.0
+
[testenv:py3.3-django1.5]
basepython = python3.3
deps = django==1.5.6
--
cgit v1.2.3
From b370fb40b6bc0fd3f597fb8c2db59f0ca57a7ccd Mon Sep 17 00:00:00 2001
From: Xavier Ordoquy
Date: Fri, 16 May 2014 01:06:34 +0200
Subject: Typo in the Python version.
---
tox.ini | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/tox.ini b/tox.ini
index 35a108e5..279f79cc 100644
--- a/tox.ini
+++ b/tox.ini
@@ -43,7 +43,7 @@ deps = https://www.djangoproject.com/download/1.7b2/tarball/
Pillow==2.3.0
[testenv:py3.4-django1.6]
-basepython = python3.3
+basepython = python3.4
deps = Django==1.6.3
django-filter==0.7
defusedxml==0.3
@@ -86,7 +86,7 @@ deps = Django==1.6.3
Pillow==2.3.0
[testenv:py3.4-django1.5]
-basepython = python3.3
+basepython = python3.4
deps = django==1.5.6
django-filter==0.7
defusedxml==0.3
--
cgit v1.2.3
From a704d5a206238c65765c1f02eb053e461675dda2 Mon Sep 17 00:00:00 2001
From: Xavier Ordoquy
Date: Fri, 16 May 2014 01:20:40 +0200
Subject: Fixed tests for python 3.4
---
rest_framework/tests/test_views.py | 16 +++++++++++-----
1 file changed, 11 insertions(+), 5 deletions(-)
diff --git a/rest_framework/tests/test_views.py b/rest_framework/tests/test_views.py
index 65c7e50e..77b113ee 100644
--- a/rest_framework/tests/test_views.py
+++ b/rest_framework/tests/test_views.py
@@ -1,5 +1,6 @@
from __future__ import unicode_literals
+import sys
import copy
from django.test import TestCase
from rest_framework import status
@@ -11,6 +12,11 @@ from rest_framework.views import APIView
factory = APIRequestFactory()
+if sys.version_info[:2] >= (3, 4):
+ JSON_ERROR = 'JSON parse error - Expecting value:'
+else:
+ JSON_ERROR = 'JSON parse error - No JSON object could be decoded'
+
class BasicView(APIView):
def get(self, request, *args, **kwargs):
@@ -48,7 +54,7 @@ def sanitise_json_error(error_dict):
of json.
"""
ret = copy.copy(error_dict)
- chop = len('JSON parse error - No JSON object could be decoded')
+ chop = len(JSON_ERROR)
ret['detail'] = ret['detail'][:chop]
return ret
@@ -61,7 +67,7 @@ class ClassBasedViewIntegrationTests(TestCase):
request = factory.post('/', 'f00bar', content_type='application/json')
response = self.view(request)
expected = {
- 'detail': 'JSON parse error - No JSON object could be decoded'
+ 'detail': JSON_ERROR
}
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(sanitise_json_error(response.data), expected)
@@ -76,7 +82,7 @@ class ClassBasedViewIntegrationTests(TestCase):
request = factory.post('/', form_data)
response = self.view(request)
expected = {
- 'detail': 'JSON parse error - No JSON object could be decoded'
+ 'detail': JSON_ERROR
}
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(sanitise_json_error(response.data), expected)
@@ -90,7 +96,7 @@ class FunctionBasedViewIntegrationTests(TestCase):
request = factory.post('/', 'f00bar', content_type='application/json')
response = self.view(request)
expected = {
- 'detail': 'JSON parse error - No JSON object could be decoded'
+ 'detail': JSON_ERROR
}
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(sanitise_json_error(response.data), expected)
@@ -105,7 +111,7 @@ class FunctionBasedViewIntegrationTests(TestCase):
request = factory.post('/', form_data)
response = self.view(request)
expected = {
- 'detail': 'JSON parse error - No JSON object could be decoded'
+ 'detail': JSON_ERROR
}
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(sanitise_json_error(response.data), expected)
--
cgit v1.2.3
From 5c12b0768166376783d62632e562f0c1301ee847 Mon Sep 17 00:00:00 2001
From: Xavier Ordoquy
Date: Fri, 16 May 2014 19:40:02 +0200
Subject: Added missing import.
---
rest_framework/serializers.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 2a0d5263..6dd09f68 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -21,6 +21,7 @@ from django.core.paginator import Page
from django.db import models
from django.forms import widgets
from django.utils.datastructures import SortedDict
+from django.core.exceptions import ObjectDoesNotExist
from rest_framework.compat import get_concrete_model, six
from rest_framework.settings import api_settings
--
cgit v1.2.3
From a2e1024f8b0447a712d1f486172d38cfe56535fe Mon Sep 17 00:00:00 2001
From: Xavier Ordoquy
Date: Sun, 18 May 2014 09:27:23 +0200
Subject: Updated Django versions.
---
.travis.yml | 18 +++++++++---------
1 file changed, 9 insertions(+), 9 deletions(-)
diff --git a/.travis.yml b/.travis.yml
index 0c9b4455..638d1499 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -8,10 +8,10 @@ python:
- "3.4"
env:
- - DJANGO="https://www.djangoproject.com/download/1.7b2/tarball/"
- - DJANGO="django==1.6.3"
- - DJANGO="django==1.5.6"
- - DJANGO="django==1.4.11"
+ - DJANGO="https://www.djangoproject.com/download/1.7b4/tarball/"
+ - DJANGO="django==1.6.5"
+ - DJANGO="django==1.5.8"
+ - DJANGO="django==1.4.13"
- DJANGO="django==1.3.7"
install:
@@ -24,7 +24,7 @@ install:
- "if [[ ${DJANGO::11} == 'django==1.3' ]]; then pip install django-filter==0.5.4; fi"
- "if [[ ${DJANGO::11} != 'django==1.3' ]]; then pip install django-filter==0.7; fi"
- "if [[ ${TRAVIS_PYTHON_VERSION::1} == '3' ]]; then pip install -e git+https://github.com/linovia/django-guardian.git@feature/django_1_7#egg=django-guardian-1.2.0; fi"
- - "if [[ ${DJANGO} == 'https://www.djangoproject.com/download/1.7b2/tarball/' ]]; then pip install -e git+https://github.com/linovia/django-guardian.git@feature/django_1_7#egg=django-guardian-1.2.0; fi"
+ - "if [[ ${DJANGO} == 'https://www.djangoproject.com/download/1.7b4/tarball/' ]]; then pip install -e git+https://github.com/linovia/django-guardian.git@feature/django_1_7#egg=django-guardian-1.2.0; fi"
- export PYTHONPATH=.
script:
@@ -33,16 +33,16 @@ script:
matrix:
exclude:
- python: "2.6"
- env: DJANGO="https://www.djangoproject.com/download/1.7b2/tarball/"
+ env: DJANGO="https://www.djangoproject.com/download/1.7b4/tarball/"
- python: "3.2"
- env: DJANGO="django==1.4.11"
+ env: DJANGO="django==1.4.13"
- python: "3.2"
env: DJANGO="django==1.3.7"
- python: "3.3"
- env: DJANGO="django==1.4.11"
+ env: DJANGO="django==1.4.13"
- python: "3.3"
env: DJANGO="django==1.3.7"
- python: "3.4"
- env: DJANGO="django==1.4.11"
+ env: DJANGO="django==1.4.13"
- python: "3.4"
env: DJANGO="django==1.3.7"
--
cgit v1.2.3
From af1ee3e63175d2b1fd30ab18091bed1019ac5de6 Mon Sep 17 00:00:00 2001
From: Xavier Ordoquy
Date: Sun, 18 May 2014 09:38:46 +0200
Subject: Fixed a small change in the 1.7 beta url.
---
.travis.yml | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/.travis.yml b/.travis.yml
index 638d1499..b2da9e81 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -8,7 +8,7 @@ python:
- "3.4"
env:
- - DJANGO="https://www.djangoproject.com/download/1.7b4/tarball/"
+ - DJANGO="https://www.djangoproject.com/download/1.7.b4/tarball/"
- DJANGO="django==1.6.5"
- DJANGO="django==1.5.8"
- DJANGO="django==1.4.13"
@@ -24,7 +24,7 @@ install:
- "if [[ ${DJANGO::11} == 'django==1.3' ]]; then pip install django-filter==0.5.4; fi"
- "if [[ ${DJANGO::11} != 'django==1.3' ]]; then pip install django-filter==0.7; fi"
- "if [[ ${TRAVIS_PYTHON_VERSION::1} == '3' ]]; then pip install -e git+https://github.com/linovia/django-guardian.git@feature/django_1_7#egg=django-guardian-1.2.0; fi"
- - "if [[ ${DJANGO} == 'https://www.djangoproject.com/download/1.7b4/tarball/' ]]; then pip install -e git+https://github.com/linovia/django-guardian.git@feature/django_1_7#egg=django-guardian-1.2.0; fi"
+ - "if [[ ${DJANGO} == 'https://www.djangoproject.com/download/1.7.b4/tarball/' ]]; then pip install -e git+https://github.com/linovia/django-guardian.git@feature/django_1_7#egg=django-guardian-1.2.0; fi"
- export PYTHONPATH=.
script:
@@ -33,7 +33,7 @@ script:
matrix:
exclude:
- python: "2.6"
- env: DJANGO="https://www.djangoproject.com/download/1.7b4/tarball/"
+ env: DJANGO="https://www.djangoproject.com/download/1.7.b4/tarball/"
- python: "3.2"
env: DJANGO="django==1.4.13"
- python: "3.2"
--
cgit v1.2.3
From a1a3ad763996b9ab5535bc5d442c2d6fab10b7cc Mon Sep 17 00:00:00 2001
From: allenhu
Date: Sat, 17 May 2014 06:05:33 +0800
Subject: fix pep8
---
rest_framework/serializers.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 6dd09f68..87d20cfc 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -33,8 +33,8 @@ from rest_framework.settings import api_settings
# This helps keep the separation between model fields, form fields, and
# serializer fields more explicit.
-from rest_framework.relations import *
-from rest_framework.fields import *
+from rest_framework.relations import * # NOQA
+from rest_framework.fields import * # NOQA
def _resolve_model(obj):
@@ -345,7 +345,7 @@ class BaseSerializer(WritableField):
for field_name, field in self.fields.items():
if field.read_only and obj is None:
- continue
+ continue
field.initialize(parent=self, field_name=field_name)
key = self.get_field_key(field_name)
value = field.field_to_native(obj, field_name)
--
cgit v1.2.3
From 1e7b5fd2c04e587e30cf29e15ca3074b8d33b92e Mon Sep 17 00:00:00 2001
From: Ian Foote
Date: Tue, 20 May 2014 14:55:00 +0100
Subject: Document ChoiceField blank_display_value parameter
---
docs/api-guide/fields.md | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/docs/api-guide/fields.md b/docs/api-guide/fields.md
index 67fa65d2..58dbf977 100644
--- a/docs/api-guide/fields.md
+++ b/docs/api-guide/fields.md
@@ -184,7 +184,9 @@ Corresponds to `django.db.models.fields.SlugField`.
## ChoiceField
-A field that can accept a value out of a limited set of choices.
+A field that can accept a value out of a limited set of choices. Optionally takes a `blank_display_value` parameter that customizes the display value of an empty choice.
+
+**Signature:** `ChoiceField(choices=(), blank_display_value=None)`
## EmailField
--
cgit v1.2.3
From 04c820b8e5e4ae153eacd1cbf19b39286c374e87 Mon Sep 17 00:00:00 2001
From: John Spray
Date: Thu, 22 May 2014 15:24:35 +0100
Subject: fields: allow help_text on SerializerMethodField
...by passing through any extra *args and **kwargs
to the parent constructor.
Previously one couldn't assign help_text to a
SerializerMethodField during construction.
---
rest_framework/fields.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 2da89550..4ac5285e 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -1027,9 +1027,9 @@ class SerializerMethodField(Field):
A field that gets its value by calling a method on the serializer it's attached to.
"""
- def __init__(self, method_name):
+ def __init__(self, method_name, *args, **kwargs):
self.method_name = method_name
- super(SerializerMethodField, self).__init__()
+ super(SerializerMethodField, self).__init__(*args, **kwargs)
def field_to_native(self, obj, field_name):
value = getattr(self.parent, self.method_name)(obj)
--
cgit v1.2.3
From 807f7a6bb9e36321f3487b5ac31ef5fdc8f4b3fb Mon Sep 17 00:00:00 2001
From: Piper Merriam
Date: Thu, 22 May 2014 13:51:20 -0600
Subject: Fix _resolve_model to work with unicode strings
---
rest_framework/serializers.py | 14 +++++++-------
rest_framework/tests/test_serializers.py | 5 +++++
2 files changed, 12 insertions(+), 7 deletions(-)
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 87d20cfc..c2b414d7 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -49,7 +49,7 @@ def _resolve_model(obj):
String representations should have the format:
'appname.ModelName'
"""
- if type(obj) == str and len(obj.split('.')) == 2:
+ if isinstance(obj, six.string_types) and len(obj.split('.')) == 2:
app_name, model_name = obj.split('.')
return models.get_model(app_name, model_name)
elif inspect.isclass(obj) and issubclass(obj, models.Model):
@@ -759,9 +759,9 @@ class ModelSerializer(Serializer):
field.read_only = True
ret[accessor_name] = field
-
+
# Ensure that 'read_only_fields' is an iterable
- assert isinstance(self.opts.read_only_fields, (list, tuple)), '`read_only_fields` must be a list or tuple'
+ assert isinstance(self.opts.read_only_fields, (list, tuple)), '`read_only_fields` must be a list or tuple'
# Add the `read_only` flag to any fields that have been specified
# in the `read_only_fields` option
@@ -776,10 +776,10 @@ class ModelSerializer(Serializer):
"on serializer '%s'." %
(field_name, self.__class__.__name__))
ret[field_name].read_only = True
-
+
# Ensure that 'write_only_fields' is an iterable
- assert isinstance(self.opts.write_only_fields, (list, tuple)), '`write_only_fields` must be a list or tuple'
-
+ assert isinstance(self.opts.write_only_fields, (list, tuple)), '`write_only_fields` must be a list or tuple'
+
for field_name in self.opts.write_only_fields:
assert field_name not in self.base_fields.keys(), (
"field '%s' on serializer '%s' specified in "
@@ -790,7 +790,7 @@ class ModelSerializer(Serializer):
"Non-existant field '%s' specified in `write_only_fields` "
"on serializer '%s'." %
(field_name, self.__class__.__name__))
- ret[field_name].write_only = True
+ ret[field_name].write_only = True
return ret
diff --git a/rest_framework/tests/test_serializers.py b/rest_framework/tests/test_serializers.py
index 082a400c..120510ac 100644
--- a/rest_framework/tests/test_serializers.py
+++ b/rest_framework/tests/test_serializers.py
@@ -3,6 +3,7 @@ from django.test import TestCase
from rest_framework.serializers import _resolve_model
from rest_framework.tests.models import BasicModel
+from rest_framework.compat import six
class ResolveModelTests(TestCase):
@@ -19,6 +20,10 @@ class ResolveModelTests(TestCase):
resolved_model = _resolve_model('tests.BasicModel')
self.assertEqual(resolved_model, BasicModel)
+ def test_resolve_unicode_representation(self):
+ resolved_model = _resolve_model(six.text_type('tests.BasicModel'))
+ self.assertEqual(resolved_model, BasicModel)
+
def test_resolve_non_django_model(self):
with self.assertRaises(ValueError):
_resolve_model(TestCase)
--
cgit v1.2.3
From eab5933070d5df9078a6b88e85ee933cbfa28955 Mon Sep 17 00:00:00 2001
From: khamaileon
Date: Mon, 26 May 2014 18:43:50 +0200
Subject: Add the allow_add_remove parameter to the get_serializer method
---
docs/api-guide/generic-views.md | 2 +-
rest_framework/generics.py | 8 +++++---
2 files changed, 6 insertions(+), 4 deletions(-)
diff --git a/docs/api-guide/generic-views.md b/docs/api-guide/generic-views.md
index 7d06f246..bb748981 100755
--- a/docs/api-guide/generic-views.md
+++ b/docs/api-guide/generic-views.md
@@ -187,7 +187,7 @@ Remember that the `pre_save()` method is not called by `GenericAPIView` itself,
You won't typically need to override the following methods, although you might need to call into them if you're writing custom views using `GenericAPIView`.
* `get_serializer_context(self)` - Returns a dictionary containing any extra context that should be supplied to the serializer. Defaults to including `'request'`, `'view'` and `'format'` keys.
-* `get_serializer(self, instance=None, data=None, files=None, many=False, partial=False)` - Returns a serializer instance.
+* `get_serializer(self, instance=None, data=None, files=None, many=False, partial=False, allow_add_remove=False)` - Returns a serializer instance.
* `get_pagination_serializer(self, page)` - Returns a serializer instance to use with paginated data.
* `paginate_queryset(self, queryset)` - Paginate a queryset if required, either returning a page object, or `None` if pagination is not configured for this view.
* `filter_queryset(self, queryset)` - Given a queryset, filter it with whichever filter backends are in use, returning a new queryset.
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index 7bac510f..7fc9db36 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -90,8 +90,8 @@ class GenericAPIView(views.APIView):
'view': self
}
- def get_serializer(self, instance=None, data=None,
- files=None, many=False, partial=False):
+ def get_serializer(self, instance=None, data=None, files=None, many=False,
+ partial=False, allow_add_remove=False):
"""
Return the serializer instance that should be used for validating and
deserializing input, and for serializing output.
@@ -99,7 +99,9 @@ class GenericAPIView(views.APIView):
serializer_class = self.get_serializer_class()
context = self.get_serializer_context()
return serializer_class(instance, data=data, files=files,
- many=many, partial=partial, context=context)
+ many=many, partial=partial,
+ allow_add_remove=allow_add_remove,
+ context=context)
def get_pagination_serializer(self, page):
"""
--
cgit v1.2.3
From a7ff51118f8c8d696219ea7723b283a0ee680457 Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Thu, 29 May 2014 14:33:16 +0100
Subject: Note on configuring TokenAuthentication
---
docs/api-guide/authentication.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/api-guide/authentication.md b/docs/api-guide/authentication.md
index 88a7a011..1cb37d67 100755
--- a/docs/api-guide/authentication.md
+++ b/docs/api-guide/authentication.md
@@ -119,7 +119,7 @@ Unauthenticated responses that are denied permission will result in an `HTTP 401
This authentication scheme uses a simple token-based HTTP Authentication scheme. Token authentication is appropriate for client-server setups, such as native desktop and mobile clients.
-To use the `TokenAuthentication` scheme, include `rest_framework.authtoken` in your `INSTALLED_APPS` setting:
+To use the `TokenAuthentication` scheme you'll need to [configure the authentication classes](#setting-the-authentication-scheme) to include `TokenAuthentication`, and additionally include `rest_framework.authtoken` in your `INSTALLED_APPS` setting:
INSTALLED_APPS = (
...
--
cgit v1.2.3
From 6cb6bfae1b83c8682fa3c3d208c732c8ea49606e Mon Sep 17 00:00:00 2001
From: Danilo Bargen
Date: Fri, 30 May 2014 17:53:26 +0200
Subject: Always use specified content type in APIRequestFactory
If `content_type` is specified in the `APIRequestFactory`, always
include it in the request, even if data is empty.
---
rest_framework/test.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/rest_framework/test.py b/rest_framework/test.py
index df5a5b3b..284bcee0 100644
--- a/rest_framework/test.py
+++ b/rest_framework/test.py
@@ -36,7 +36,7 @@ class APIRequestFactory(DjangoRequestFactory):
"""
if not data:
- return ('', None)
+ return ('', content_type)
assert format is None or content_type is None, (
'You may not set both `format` and `content_type`.'
--
cgit v1.2.3
From 31f63e1e5502d45f414df400679c238346137b10 Mon Sep 17 00:00:00 2001
From: Rodolfo Carvalho
Date: Mon, 2 Jun 2014 11:06:03 +0200
Subject: Fix typo in docs
---
docs/api-guide/viewsets.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/api-guide/viewsets.md b/docs/api-guide/viewsets.md
index 23b16575..b3085f75 100644
--- a/docs/api-guide/viewsets.md
+++ b/docs/api-guide/viewsets.md
@@ -137,7 +137,7 @@ The `@action` and `@link` decorators can additionally take extra arguments that
def set_password(self, request, pk=None):
...
-The `@action` decorator will route `POST` requests by default, but may also accept other HTTP methods, by using the `method` argument. For example:
+The `@action` decorator will route `POST` requests by default, but may also accept other HTTP methods, by using the `methods` argument. For example:
@action(methods=['POST', 'DELETE'])
def unset_password(self, request, pk=None):
--
cgit v1.2.3
From 08c4594145a7219a14fafc87db0b9d61483d70d0 Mon Sep 17 00:00:00 2001
From: khamaileon
Date: Thu, 5 Jun 2014 12:49:02 +0200
Subject: Replace ChoiceField type_label
---
rest_framework/fields.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 4ac5285e..86e8fd9d 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -506,7 +506,7 @@ class SlugField(CharField):
class ChoiceField(WritableField):
type_name = 'ChoiceField'
- type_label = 'multiple choice'
+ type_label = 'choice'
form_field_class = forms.ChoiceField
widget = widgets.Select
default_error_messages = {
--
cgit v1.2.3
From e8ec81f5e985f9cc9f524f77ec23013be918b990 Mon Sep 17 00:00:00 2001
From: Xavier Ordoquy
Date: Sun, 8 Jun 2014 09:03:21 +0200
Subject: Fixed #1624 (thanks @abraithwaite)
---
rest_framework/compat.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index d155f554..fdf12448 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -51,6 +51,7 @@ except ImportError:
# guardian is optional
try:
import guardian
+ import guardian.shortcuts # Fixes #1624
except ImportError:
guardian = None
--
cgit v1.2.3
From be84f71bc906c926c9955a4cf47630b24461067d Mon Sep 17 00:00:00 2001
From: Greg Barker
Date: Tue, 10 Jun 2014 15:20:45 -0700
Subject: Fix #1614 - Corrected reference to serializers.CharField
---
docs/api-guide/serializers.md | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/docs/api-guide/serializers.md b/docs/api-guide/serializers.md
index 0044f070..cedf1ff7 100644
--- a/docs/api-guide/serializers.md
+++ b/docs/api-guide/serializers.md
@@ -73,8 +73,8 @@ Sometimes when serializing objects, you may not want to represent everything exa
If you need to customize the serialized value of a particular field, you can do this by creating a `transform_` method. For example if you needed to render some markdown from a text field:
- description = serializers.TextField()
- description_html = serializers.TextField(source='description', read_only=True)
+ description = serializers.CharField()
+ description_html = serializers.CharField(source='description', read_only=True)
def transform_description_html(self, obj, value):
from django.contrib.markup.templatetags.markup import markdown
--
cgit v1.2.3
From 1386767013d044d337b8e08dd2f9b0197197cccf Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Thu, 12 Jun 2014 11:47:26 +0100
Subject: Version 2.3.14
---
docs/api-guide/content-negotiation.md | 2 --
docs/topics/release-notes.md | 40 +++++++++++++++------------
rest_framework/__init__.py | 2 +-
rest_framework/templatetags/rest_framework.py | 4 +--
4 files changed, 25 insertions(+), 23 deletions(-)
diff --git a/docs/api-guide/content-negotiation.md b/docs/api-guide/content-negotiation.md
index 94dd59ca..58b2a2ce 100644
--- a/docs/api-guide/content-negotiation.md
+++ b/docs/api-guide/content-negotiation.md
@@ -1,5 +1,3 @@
-
-
# Content negotiation
> HTTP has provisions for several mechanisms for "content negotiation" - the process of selecting the best representation for a given response when there are multiple representations available.
diff --git a/docs/topics/release-notes.md b/docs/topics/release-notes.md
index 335497ee..ea4c912c 100644
--- a/docs/topics/release-notes.md
+++ b/docs/topics/release-notes.md
@@ -40,24 +40,28 @@ You can determine your currently installed version using `pip freeze`:
## 2.3.x series
-### 2.3.x
-
-**Date**: April 2014
-
-* Fix nested serializers linked through a backward foreign key relation
-* Fix bad links for the `BrowsableAPIRenderer` with `YAMLRenderer`
-* Add `UnicodeYAMLRenderer` that extends `YAMLRenderer` with unicode
-* Fix `parse_header` argument convertion
-* Fix mediatype detection under Python3
-* Web browseable API now offers blank option on dropdown when the field is not required
-* `APIException` representation improved for logging purposes
-* Allow source="*" within nested serializers
-* Better support for custom oauth2 provider backends
-* Fix field validation if it's optional and has no value
-* Add `SEARCH_PARAM` and `ORDERING_PARAM`
-* Fix `APIRequestFactory` to support arguments within the url string for GET
-* Allow three transport modes for access tokens when accessing a protected resource
-* Fix `Request`'s `QueryDict` encoding
+### 2.3.14
+
+**Date**: 12th June 2014
+
+* **Security fix**: Escape request path when it is include as part of the login and logout links in the browsable API.
+* `help_text` and `verbose_name` automatically set for related fields on `ModelSerializer`.
+* Fix nested serializers linked through a backward foreign key relation.
+* Fix bad links for the `BrowsableAPIRenderer` with `YAMLRenderer`.
+* Add `UnicodeYAMLRenderer` that extends `YAMLRenderer` with unicode.
+* Fix `parse_header` argument convertion.
+* Fix mediatype detection under Python 3.
+* Web browseable API now offers blank option on dropdown when the field is not required.
+* `APIException` representation improved for logging purposes.
+* Allow source="*" within nested serializers.
+* Better support for custom oauth2 provider backends.
+* Fix field validation if it's optional and has no value.
+* Add `SEARCH_PARAM` and `ORDERING_PARAM`.
+* Fix `APIRequestFactory` to support arguments within the url string for GET.
+* Allow three transport modes for access tokens when accessing a protected resource.
+* Fix `QueryDict` encoding on request objects.
+* Ensure throttle keys do not contain spaces, as those are invalid if using `memcached`.
+* Support `blank_display_value` on `ChoiceField`.
### 2.3.13
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py
index 2d76b55d..01036cef 100644
--- a/rest_framework/__init__.py
+++ b/rest_framework/__init__.py
@@ -8,7 +8,7 @@ ______ _____ _____ _____ __ _
"""
__title__ = 'Django REST framework'
-__version__ = '2.3.13'
+__version__ = '2.3.14'
__author__ = 'Tom Christie'
__license__ = 'BSD 2-Clause'
__copyright__ = 'Copyright 2011-2014 Tom Christie'
diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py
index dff176d6..a155d8d2 100644
--- a/rest_framework/templatetags/rest_framework.py
+++ b/rest_framework/templatetags/rest_framework.py
@@ -122,7 +122,7 @@ def optional_login(request):
except NoReverseMatch:
return ''
- snippet = "Log in" % (login_url, request.path)
+ snippet = "Log in" % (login_url, escape(request.path))
return snippet
@@ -136,7 +136,7 @@ def optional_logout(request):
except NoReverseMatch:
return ''
- snippet = "Log out" % (logout_url, request.path)
+ snippet = "Log out" % (logout_url, escape(request.path))
return snippet
--
cgit v1.2.3
From 82659873c9d3e3058b7e7ea63e4c4b190c7fc19c Mon Sep 17 00:00:00 2001
From: Tom Christie
Date: Thu, 12 Jun 2014 11:48:58 +0100
Subject: Fix accidental docs change
---
docs/api-guide/content-negotiation.md | 2 ++
1 file changed, 2 insertions(+)
diff --git a/docs/api-guide/content-negotiation.md b/docs/api-guide/content-negotiation.md
index 58b2a2ce..94dd59ca 100644
--- a/docs/api-guide/content-negotiation.md
+++ b/docs/api-guide/content-negotiation.md
@@ -1,3 +1,5 @@
+
+
# Content negotiation
> HTTP has provisions for several mechanisms for "content negotiation" - the process of selecting the best representation for a given response when there are multiple representations available.
--
cgit v1.2.3