aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
authorTom Christie2012-11-09 13:49:52 +0000
committerTom Christie2012-11-09 13:49:52 +0000
commit8953a60196cb55ec75902882314da5a42636349c (patch)
tree43bf6ea1f69955aeecd83fb9f866d92ea9a5f3df /rest_framework
parentb78872b7dbb55f1aa2d21f15fbb952f0c7156326 (diff)
parent9aaeeacdfebc244850e82469e4af45af252cca4d (diff)
downloaddjango-rest-framework-8953a60196cb55ec75902882314da5a42636349c.tar.bz2
Merge with master
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/__init__.py2
-rw-r--r--rest_framework/compat.py38
-rw-r--r--rest_framework/decorators.py14
-rw-r--r--rest_framework/exceptions.py8
-rw-r--r--rest_framework/fields.py409
-rw-r--r--rest_framework/filters.py59
-rw-r--r--rest_framework/generics.py85
-rw-r--r--rest_framework/mixins.py35
-rw-r--r--rest_framework/negotiation.py36
-rw-r--r--rest_framework/pagination.py17
-rw-r--r--rest_framework/parsers.py50
-rw-r--r--rest_framework/permissions.py13
-rw-r--r--rest_framework/renderers.py188
-rw-r--r--rest_framework/request.py27
-rw-r--r--rest_framework/resources.py96
-rw-r--r--rest_framework/response.py14
-rw-r--r--rest_framework/reverse.py8
-rwxr-xr-xrest_framework/runtests/runcoverage.py4
-rwxr-xr-xrest_framework/runtests/runtests.py2
-rw-r--r--rest_framework/runtests/settings.py14
-rw-r--r--rest_framework/serializers.py169
-rw-r--r--rest_framework/settings.py55
-rw-r--r--rest_framework/static/rest_framework/css/default.css21
-rw-r--r--rest_framework/status.py2
-rw-r--r--rest_framework/templates/rest_framework/base.html14
-rw-r--r--rest_framework/templates/rest_framework/login.html56
-rw-r--r--rest_framework/templatetags/rest_framework.py25
-rw-r--r--rest_framework/tests/__init__.py13
-rw-r--r--rest_framework/tests/filterset.py168
-rw-r--r--rest_framework/tests/genericrelations.py2
-rw-r--r--rest_framework/tests/generics.py51
-rw-r--r--rest_framework/tests/htmlrenderer.py71
-rw-r--r--rest_framework/tests/hyperlinkedserializers.py89
-rw-r--r--rest_framework/tests/models.py60
-rw-r--r--rest_framework/tests/negotiation.py10
-rw-r--r--rest_framework/tests/pagination.py81
-rw-r--r--rest_framework/tests/pk_relations.py205
-rw-r--r--rest_framework/tests/renderers.py89
-rw-r--r--rest_framework/tests/request.py2
-rw-r--r--rest_framework/tests/response.py6
-rw-r--r--rest_framework/tests/serializer.py295
-rw-r--r--rest_framework/tests/tests.py13
-rw-r--r--rest_framework/tests/validators.py10
-rw-r--r--rest_framework/throttling.py6
-rw-r--r--rest_framework/urlpatterns.py13
-rw-r--r--rest_framework/utils/__init__.py1
-rw-r--r--rest_framework/utils/breadcrumbs.py14
-rw-r--r--rest_framework/views.py49
48 files changed, 2120 insertions, 589 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py
index 557f5943..fd176603 100644
--- a/rest_framework/__init__.py
+++ b/rest_framework/__init__.py
@@ -1,3 +1,3 @@
-__version__ = '2.0.0'
+__version__ = '2.1.2'
VERSION = __version__ # synonym
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index 7664c400..02e50604 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -1,8 +1,17 @@
"""
-The :mod:`compat` module provides support for backwards compatibility with older versions of django/python.
+The `compat` module provides support for backwards compatibility with older
+versions of django/python, and compatbility wrappers around optional packages.
"""
+# flake8: noqa
import django
+# django-filter is optional
+try:
+ import django_filters
+except:
+ django_filters = None
+
+
# cStringIO only if it's available, otherwise StringIO
try:
import cStringIO as StringIO
@@ -346,33 +355,6 @@ except ImportError:
yaml = None
-import unittest
-try:
- import unittest.skip
-except ImportError: # python < 2.7
- from unittest import TestCase
- import functools
-
- def skip(reason):
- # Pasted from py27/lib/unittest/case.py
- """
- Unconditionally skip a test.
- """
- def decorator(test_item):
- if not (isinstance(test_item, type) and issubclass(test_item, TestCase)):
- @functools.wraps(test_item)
- def skip_wrapper(*args, **kwargs):
- pass
- test_item = skip_wrapper
-
- test_item.__unittest_skip__ = True
- test_item.__unittest_skip_why__ = reason
- return test_item
- return decorator
-
- unittest.skip = skip
-
-
# xml.etree.parse only throws ParseError for python >= 2.7
try:
from xml.etree import ParseError as ETParseError
diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py
index 948973ae..a231f191 100644
--- a/rest_framework/decorators.py
+++ b/rest_framework/decorators.py
@@ -10,8 +10,18 @@ def api_view(http_method_names):
def decorator(func):
- class WrappedAPIView(APIView):
- pass
+ WrappedAPIView = type(
+ 'WrappedAPIView',
+ (APIView,),
+ {'__doc__': func.__doc__}
+ )
+
+ # Note, the above allows us to set the docstring.
+ # It is the equivelent of:
+ #
+ # class WrappedAPIView(APIView):
+ # pass
+ # WrappedAPIView.__doc__ = func.doc <--- Not possible to do this
allowed_methods = set(http_method_names) | set(('options',))
WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods]
diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py
index 6ae0c95c..d635351c 100644
--- a/rest_framework/exceptions.py
+++ b/rest_framework/exceptions.py
@@ -47,14 +47,6 @@ class PermissionDenied(APIException):
self.detail = detail or self.default_detail
-class InvalidFormat(APIException):
- status_code = status.HTTP_404_NOT_FOUND
- default_detail = "Format suffix '.%s' not found."
-
- def __init__(self, format, detail=None):
- self.detail = (detail or self.default_detail) % format
-
-
class MethodNotAllowed(APIException):
status_code = status.HTTP_405_METHOD_NOT_ALLOWED
default_detail = "Method '%s' not allowed."
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index bb9a523d..a4e29a30 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -5,13 +5,16 @@ import warnings
from django.core import validators
from django.core.exceptions import ObjectDoesNotExist, ValidationError
-from django.core.urlresolvers import resolve
+from django.core.urlresolvers import resolve, get_script_prefix
from django.conf import settings
+from django.forms import widgets
+from django.forms.models import ModelChoiceIterator
from django.utils.encoding import is_protected_type, smart_unicode
from django.utils.translation import ugettext_lazy as _
from rest_framework.reverse import reverse
from rest_framework.compat import parse_date, parse_datetime
from rest_framework.compat import timezone
+from urlparse import urlparse
def is_simple_callable(obj):
@@ -37,12 +40,12 @@ class Field(object):
self.source = source
- def initialize(self, parent):
+ def initialize(self, parent, field_name):
"""
Called to set up a field prior to field_to_native or field_from_native.
parent - The parent serializer.
- model_field - The model field this field corrosponds to, if one exists.
+ model_field - The model field this field corresponds to, if one exists.
"""
self.parent = parent
self.root = parent.root or parent
@@ -70,6 +73,8 @@ class Field(object):
value = obj
for component in self.source.split('.'):
value = getattr(value, component)
+ if is_simple_callable(value):
+ value = value()
else:
value = getattr(obj, field_name)
return self.to_native(value)
@@ -85,6 +90,8 @@ class Field(object):
return value
elif hasattr(value, '__iter__') and not isinstance(value, (dict, basestring)):
return [self.to_native(item) for item in value]
+ elif isinstance(value, dict):
+ return dict(map(self.to_native, (k, v)) for k, v in value.items())
return smart_unicode(value)
def attributes(self):
@@ -105,15 +112,20 @@ class WritableField(Field):
'required': _('This field is required.'),
'invalid': _('Invalid value.'),
}
+ widget = widgets.TextInput
+ default = None
+
+ def __init__(self, source=None, read_only=False, required=None,
+ validators=[], error_messages=None, widget=None,
+ default=None, blank=None):
- def __init__(self, source=None, readonly=False, required=None,
- validators=[], error_messages=None):
super(WritableField, self).__init__(source=source)
- self.readonly = readonly
+
+ self.read_only = read_only
if required is None:
- self.required = not(readonly)
+ self.required = not(read_only)
else:
- assert not readonly, "Cannot set required=True and readonly=True"
+ assert not read_only, "Cannot set required=True and read_only=True"
self.required = required
messages = {}
@@ -123,6 +135,14 @@ class WritableField(Field):
self.error_messages = messages
self.validators = self.default_validators + validators
+ self.default = default if default is not None else self.default
+ self.blank = blank
+
+ # Widgets are ony used for HTML forms.
+ widget = widget or self.widget
+ if isinstance(widget, type):
+ widget = widget()
+ self.widget = widget
def validate(self, value):
if value in validators.EMPTY_VALUES and self.required:
@@ -151,15 +171,18 @@ class WritableField(Field):
Given a dictionary and a field name, updates the dictionary `into`,
with the field and it's deserialized value.
"""
- if self.readonly:
+ if self.read_only:
return
try:
native = data[field_name]
except KeyError:
- if self.required:
- raise ValidationError(self.error_messages['required'])
- return
+ if self.default is not None:
+ native = self.default
+ else:
+ if self.required:
+ raise ValidationError(self.error_messages['required'])
+ return
value = self.from_native(native)
if self.source == '*':
@@ -179,7 +202,7 @@ class WritableField(Field):
class ModelField(WritableField):
"""
- A generic field that can be used against an arbirtrary model field.
+ A generic field that can be used against an arbitrary model field.
"""
def __init__(self, *args, **kwargs):
try:
@@ -189,11 +212,11 @@ class ModelField(WritableField):
super(ModelField, self).__init__(*args, **kwargs)
def from_native(self, value):
- try:
- rel = self.model_field.rel
- except:
+ rel = getattr(self.model_field, "rel", None)
+ if rel is not None:
+ return rel.to._meta.get_field(rel.field_name).to_python(value)
+ else:
return self.model_field.to_python(value)
- return rel.to._meta.get_field(rel.field_name).to_python(value)
def field_to_native(self, obj, field_name):
value = self.model_field._get_val_from_obj(obj)
@@ -209,32 +232,119 @@ class ModelField(WritableField):
##### Relational fields #####
+# Not actually Writable, but subclasses may need to be.
class RelatedField(WritableField):
"""
Base class for related model fields.
+
+ If not overridden, this represents a to-one relationship, using the unicode
+ representation of the target.
"""
+ widget = widgets.Select
+ cache_choices = False
+ empty_label = None
+ default_read_only = True # TODO: Remove this
+
def __init__(self, *args, **kwargs):
self.queryset = kwargs.pop('queryset', None)
super(RelatedField, self).__init__(*args, **kwargs)
+ self.read_only = kwargs.pop('read_only', self.default_read_only)
+
+ def initialize(self, parent, field_name):
+ super(RelatedField, self).initialize(parent, field_name)
+ if self.queryset is None and not self.read_only:
+ try:
+ manager = getattr(self.parent.opts.model, self.source or field_name)
+ if hasattr(manager, 'related'): # Forward
+ self.queryset = manager.related.model._default_manager.all()
+ else: # Reverse
+ self.queryset = manager.field.rel.to._default_manager.all()
+ except:
+ raise
+ msg = ('Serializer related fields must include a `queryset`' +
+ ' argument or set `read_only=True')
+ raise Exception(msg)
+
+ ### We need this stuff to make form choices work...
+
+ # def __deepcopy__(self, memo):
+ # result = super(RelatedField, self).__deepcopy__(memo)
+ # result.queryset = result.queryset
+ # return result
+
+ def prepare_value(self, obj):
+ return self.to_native(obj)
+
+ def label_from_instance(self, obj):
+ """
+ Return a readable representation for use with eg. select widgets.
+ """
+ desc = smart_unicode(obj)
+ ident = smart_unicode(self.to_native(obj))
+ if desc == ident:
+ return desc
+ return "%s - %s" % (desc, ident)
+
+ def _get_queryset(self):
+ return self._queryset
+
+ def _set_queryset(self, queryset):
+ self._queryset = queryset
+ self.widget.choices = self.choices
+
+ queryset = property(_get_queryset, _set_queryset)
+
+ def _get_choices(self):
+ # If self._choices is set, then somebody must have manually set
+ # the property self.choices. In this case, just return self._choices.
+ if hasattr(self, '_choices'):
+ return self._choices
+
+ # Otherwise, execute the QuerySet in self.queryset to determine the
+ # choices dynamically. Return a fresh ModelChoiceIterator that has not been
+ # consumed. Note that we're instantiating a new ModelChoiceIterator *each*
+ # time _get_choices() is called (and, thus, each time self.choices is
+ # accessed) so that we can ensure the QuerySet has not been consumed. This
+ # construct might look complicated but it allows for lazy evaluation of
+ # the queryset.
+ return ModelChoiceIterator(self)
+
+ def _set_choices(self, value):
+ # Setting choices also sets the choices on the widget.
+ # choices can be any iterable, but we call list() on it because
+ # it will be consumed more than once.
+ self._choices = self.widget.choices = list(value)
+
+ choices = property(_get_choices, _set_choices)
+
+ ### Regular serializier stuff...
def field_to_native(self, obj, field_name):
value = getattr(obj, self.source or field_name)
return self.to_native(value)
def field_from_native(self, data, field_name, into):
+ if self.read_only:
+ return
+
value = data.get(field_name)
- into[(self.source or field_name) + '_id'] = self.from_native(value)
+ into[(self.source or field_name)] = self.from_native(value)
class ManyRelatedMixin(object):
"""
Mixin to convert a related field to a many related field.
"""
+ widget = widgets.SelectMultiple
+
def field_to_native(self, obj, field_name):
value = getattr(obj, self.source or field_name)
return [self.to_native(item) for item in value.all()]
def field_from_native(self, data, field_name, into):
+ if self.read_only:
+ return
+
try:
# Form data
value = data.getlist(self.source or field_name)
@@ -250,6 +360,9 @@ class ManyRelatedMixin(object):
class ManyRelatedField(ManyRelatedMixin, RelatedField):
"""
Base class for related model managers.
+
+ If not overridden, this represents a to-many relationship, using the unicode
+ representations of the target, and is read-only.
"""
pass
@@ -258,12 +371,38 @@ class ManyRelatedField(ManyRelatedMixin, RelatedField):
class PrimaryKeyRelatedField(RelatedField):
"""
- Serializes a related field or related object to a pk value.
+ Represents a to-one relationship as a pk value.
"""
+ default_read_only = False
+
+ # TODO: Remove these field hacks...
+ def prepare_value(self, obj):
+ return self.to_native(obj.pk)
+ def label_from_instance(self, obj):
+ """
+ Return a readable representation for use with eg. select widgets.
+ """
+ desc = smart_unicode(obj)
+ ident = smart_unicode(self.to_native(obj.pk))
+ if desc == ident:
+ return desc
+ return "%s - %s" % (desc, ident)
+
+ # TODO: Possibly change this to just take `obj`, through prob less performant
def to_native(self, pk):
return pk
+ def from_native(self, data):
+ if self.queryset is None:
+ raise Exception('Writable related fields must include a `queryset` argument')
+
+ try:
+ return self.queryset.get(pk=data)
+ except ObjectDoesNotExist:
+ msg = "Invalid pk '%s' - object does not exist." % smart_unicode(data)
+ raise ValidationError(msg)
+
def field_to_native(self, obj, field_name):
try:
# Prefer obj.serializable_value for performance reasons
@@ -278,8 +417,23 @@ class PrimaryKeyRelatedField(RelatedField):
class ManyPrimaryKeyRelatedField(ManyRelatedField):
"""
- Serializes a to-many related field or related manager to a pk value.
+ Represents a to-many relationship as a pk value.
"""
+ default_read_only = False
+
+ def prepare_value(self, obj):
+ return self.to_native(obj.pk)
+
+ def label_from_instance(self, obj):
+ """
+ Return a readable representation for use with eg. select widgets.
+ """
+ desc = smart_unicode(obj)
+ ident = smart_unicode(self.to_native(obj.pk))
+ if desc == ident:
+ return desc
+ return "%s - %s" % (desc, ident)
+
def to_native(self, pk):
return pk
@@ -294,27 +448,83 @@ class ManyPrimaryKeyRelatedField(ManyRelatedField):
# Forward relationship
return [self.to_native(item.pk) for item in queryset.all()]
+ def from_native(self, data):
+ if self.queryset is None:
+ raise Exception('Writable related fields must include a `queryset` argument')
+
+ try:
+ return self.queryset.get(pk=data)
+ except ObjectDoesNotExist:
+ msg = "Invalid pk '%s' - object does not exist." % smart_unicode(data)
+ raise ValidationError(msg)
+
+### Slug relationships
+
+
+class SlugRelatedField(RelatedField):
+ default_read_only = False
+
+ def __init__(self, *args, **kwargs):
+ self.slug_field = kwargs.pop('slug_field', None)
+ assert self.slug_field, 'slug_field is required'
+ super(SlugRelatedField, self).__init__(*args, **kwargs)
+
+ def to_native(self, obj):
+ return getattr(obj, self.slug_field)
+
+ def from_native(self, data):
+ if self.queryset is None:
+ raise Exception('Writable related fields must include a `queryset` argument')
+
+ try:
+ return self.queryset.get(**{self.slug_field: data})
+ except ObjectDoesNotExist:
+ raise ValidationError('Object with %s=%s does not exist.' %
+ (self.slug_field, unicode(data)))
+
+
+class ManySlugRelatedField(ManyRelatedMixin, SlugRelatedField):
+ pass
+
### Hyperlinked relationships
class HyperlinkedRelatedField(RelatedField):
+ """
+ Represents a to-one relationship, using hyperlinking.
+ """
pk_url_kwarg = 'pk'
- slug_url_kwarg = 'slug'
slug_field = 'slug'
+ slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
+ default_read_only = False
def __init__(self, *args, **kwargs):
try:
self.view_name = kwargs.pop('view_name')
except:
raise ValueError("Hyperlinked field requires 'view_name' kwarg")
+
+ self.slug_field = kwargs.pop('slug_field', self.slug_field)
+ default_slug_kwarg = self.slug_url_kwarg or self.slug_field
+ self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg)
+ self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg)
+
+ self.format = kwargs.pop('format', None)
super(HyperlinkedRelatedField, self).__init__(*args, **kwargs)
+ def get_slug_field(self):
+ """
+ Get the name of a slug field to be used to look up by slug.
+ """
+ return self.slug_field
+
def to_native(self, obj):
view_name = self.view_name
request = self.context.get('request', None)
+ format = self.format or self.context.get('format', None)
kwargs = {self.pk_url_kwarg: obj.pk}
try:
- return reverse(view_name, kwargs=kwargs, request=request)
+ return reverse(view_name, kwargs=kwargs, request=request, format=format)
except:
pass
@@ -325,13 +535,13 @@ class HyperlinkedRelatedField(RelatedField):
kwargs = {self.slug_url_kwarg: slug}
try:
- return reverse(self.view_name, kwargs=kwargs, request=request)
+ return reverse(self.view_name, kwargs=kwargs, request=request, format=format)
except:
pass
kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug}
try:
- return reverse(self.view_name, kwargs=kwargs, request=request)
+ return reverse(self.view_name, kwargs=kwargs, request=request, format=format)
except:
pass
@@ -340,6 +550,16 @@ class HyperlinkedRelatedField(RelatedField):
def from_native(self, value):
# Convert URL -> model instance pk
# TODO: Use values_list
+ if self.queryset is None:
+ raise Exception('Writable related fields must include a `queryset` argument')
+
+ if value.startswith('http:') or value.startswith('https:'):
+ # If needed convert absolute URLs to relative path
+ value = urlparse(value).path
+ prefix = get_script_prefix()
+ if value.startswith(prefix):
+ value = '/' + value[len(prefix):]
+
try:
match = resolve(value)
except:
@@ -353,7 +573,7 @@ class HyperlinkedRelatedField(RelatedField):
# Try explicit primary key.
if pk is not None:
- return pk
+ queryset = self.queryset.filter(pk=pk)
# Next, try looking up by slug.
elif slug is not None:
slug_field = self.get_slug_field()
@@ -366,48 +586,88 @@ class HyperlinkedRelatedField(RelatedField):
obj = queryset.get()
except ObjectDoesNotExist:
raise ValidationError('Invalid hyperlink - object does not exist.')
- return obj.pk
+ return obj
class ManyHyperlinkedRelatedField(ManyRelatedMixin, HyperlinkedRelatedField):
+ """
+ Represents a to-many relationship, using hyperlinking.
+ """
pass
class HyperlinkedIdentityField(Field):
"""
- A field that represents the model's identity using a hyperlink.
+ Represents the instance, or a property on the instance, using hyperlinking.
"""
+ pk_url_kwarg = 'pk'
+ slug_field = 'slug'
+ slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
+
def __init__(self, *args, **kwargs):
- # TODO: Make this mandatory, and have the HyperlinkedModelSerializer
- # set it on-the-fly
+ # TODO: Make view_name mandatory, and have the
+ # HyperlinkedModelSerializer set it on-the-fly
self.view_name = kwargs.pop('view_name', None)
+ self.format = kwargs.pop('format', None)
+
+ self.slug_field = kwargs.pop('slug_field', self.slug_field)
+ default_slug_kwarg = self.slug_url_kwarg or self.slug_field
+ self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg)
+ self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg)
+
super(HyperlinkedIdentityField, self).__init__(*args, **kwargs)
def field_to_native(self, obj, field_name):
request = self.context.get('request', None)
+ format = self.format or self.context.get('format', None)
view_name = self.view_name or self.parent.opts.view_name
- view_kwargs = {'pk': obj.pk}
- return reverse(view_name, kwargs=view_kwargs, request=request)
+ kwargs = {self.pk_url_kwarg: obj.pk}
+ try:
+ return reverse(view_name, kwargs=kwargs, request=request, format=format)
+ except:
+ pass
+
+ slug = getattr(obj, self.slug_field, None)
+
+ if not slug:
+ raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name)
+
+ kwargs = {self.slug_url_kwarg: slug}
+ try:
+ return reverse(self.view_name, kwargs=kwargs, request=request, format=format)
+ except:
+ pass
+
+ kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug}
+ try:
+ return reverse(self.view_name, kwargs=kwargs, request=request, format=format)
+ except:
+ pass
+
+ raise ValidationError('Could not resolve URL for field using view name "%s"', view_name)
##### Typed Fields #####
class BooleanField(WritableField):
type_name = 'BooleanField'
+ widget = widgets.CheckboxInput
default_error_messages = {
'invalid': _(u"'%s' value must be either True or False."),
}
+ empty = False
+
+ # Note: we set default to `False` in order to fill in missing value not
+ # supplied by html form. TODO: Fix so that only html form input gets
+ # this behavior.
+ default = False
def from_native(self, value):
- if value in (True, False):
- # if value is 1 or 0 than it's equal to True or False, but we want
- # to return a true bool for semantic reasons.
- return bool(value)
if value in ('t', 'True', '1'):
return True
if value in ('f', 'False', '0'):
return False
- raise ValidationError(self.error_messages['invalid'] % value)
+ return bool(value)
class CharField(WritableField):
@@ -421,12 +681,68 @@ class CharField(WritableField):
if max_length is not None:
self.validators.append(validators.MaxLengthValidator(max_length))
+ def validate(self, value):
+ """
+ Validates that the value is supplied (if required).
+ """
+ # if empty string and allow blank
+ if self.blank and not value:
+ return
+ else:
+ super(CharField, self).validate(value)
+
def from_native(self, value):
if isinstance(value, basestring) or value is None:
return value
return smart_unicode(value)
+class ChoiceField(WritableField):
+ type_name = 'ChoiceField'
+ widget = widgets.Select
+ default_error_messages = {
+ 'invalid_choice': _('Select a valid choice. %(value)s is not one of the available choices.'),
+ }
+
+ def __init__(self, choices=(), *args, **kwargs):
+ super(ChoiceField, self).__init__(*args, **kwargs)
+ self.choices = choices
+
+ def _get_choices(self):
+ return self._choices
+
+ def _set_choices(self, value):
+ # Setting choices also sets the choices on the widget.
+ # choices can be any iterable, but we call list() on it because
+ # it will be consumed more than once.
+ self._choices = self.widget.choices = list(value)
+
+ choices = property(_get_choices, _set_choices)
+
+ def validate(self, value):
+ """
+ Validates that the input is in self.choices.
+ """
+ super(ChoiceField, self).validate(value)
+ if value and not self.valid_value(value):
+ raise ValidationError(self.error_messages['invalid_choice'] % {'value': value})
+
+ def valid_value(self, value):
+ """
+ Check to see if the provided value is a valid choice.
+ """
+ for k, v in self.choices:
+ if isinstance(v, (list, tuple)):
+ # This is an optgroup, so look inside the group for options
+ for k2, v2 in v:
+ if value == smart_unicode(k2):
+ return True
+ else:
+ if value == smart_unicode(k):
+ return True
+ return False
+
+
class EmailField(CharField):
type_name = 'EmailField'
@@ -436,7 +752,10 @@ class EmailField(CharField):
default_validators = [validators.validate_email]
def from_native(self, value):
- return super(EmailField, self).from_native(value).strip()
+ ret = super(EmailField, self).from_native(value)
+ if ret is None:
+ return None
+ return ret.strip()
def __deepcopy__(self, memo):
result = copy.copy(self)
@@ -458,8 +777,9 @@ class DateField(WritableField):
empty = None
def from_native(self, value):
- if value is None:
- return value
+ if value in validators.EMPTY_VALUES:
+ return None
+
if isinstance(value, datetime.datetime):
if timezone and settings.USE_TZ and timezone.is_aware(value):
# Convert aware datetimes to the default time zone
@@ -497,8 +817,9 @@ class DateTimeField(WritableField):
empty = None
def from_native(self, value):
- if value is None:
- return value
+ if value in validators.EMPTY_VALUES:
+ return None
+
if isinstance(value, datetime.datetime):
return value
if isinstance(value, datetime.date):
@@ -556,6 +877,7 @@ class IntegerField(WritableField):
def from_native(self, value):
if value in validators.EMPTY_VALUES:
return None
+
try:
value = int(str(value))
except (ValueError, TypeError):
@@ -571,8 +893,9 @@ class FloatField(WritableField):
}
def from_native(self, value):
- if value is None:
- return value
+ if value in validators.EMPTY_VALUES:
+ return None
+
try:
return float(value)
except (TypeError, ValueError):
diff --git a/rest_framework/filters.py b/rest_framework/filters.py
new file mode 100644
index 00000000..ccae4825
--- /dev/null
+++ b/rest_framework/filters.py
@@ -0,0 +1,59 @@
+from rest_framework.compat import django_filters
+
+FilterSet = django_filters and django_filters.FilterSet or None
+
+
+class BaseFilterBackend(object):
+ """
+ A base class from which all filter backend classes should inherit.
+ """
+
+ def filter_queryset(self, request, queryset, view):
+ """
+ Return a filtered queryset.
+ """
+ raise NotImplementedError(".filter_queryset() must be overridden.")
+
+
+class DjangoFilterBackend(BaseFilterBackend):
+ """
+ A filter backend that uses django-filter.
+ """
+ default_filter_set = FilterSet
+
+ def __init__(self):
+ assert django_filters, 'Using DjangoFilterBackend, but django-filter is not installed'
+
+ def get_filter_class(self, view):
+ """
+ Return the django-filters `FilterSet` used to filter the queryset.
+ """
+ filter_class = getattr(view, 'filter_class', None)
+ filter_fields = getattr(view, 'filter_fields', None)
+ view_model = getattr(view, 'model', None)
+
+ if filter_class:
+ filter_model = filter_class.Meta.model
+
+ assert issubclass(filter_model, view_model), \
+ 'FilterSet model %s does not match view model %s' % \
+ (filter_model, view_model)
+
+ return filter_class
+
+ if filter_fields:
+ class AutoFilterSet(self.default_filter_set):
+ class Meta:
+ model = view_model
+ fields = filter_fields
+ return AutoFilterSet
+
+ return None
+
+ def filter_queryset(self, request, queryset, view):
+ filter_class = self.get_filter_class(view)
+
+ if filter_class:
+ return filter_class(request.GET, queryset=queryset)
+
+ return queryset
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index 59739d01..ebd06e45 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -1,5 +1,5 @@
"""
-Generic views that provide commmonly needed behaviour.
+Generic views that provide commonly needed behaviour.
"""
from rest_framework import views, mixins
@@ -10,12 +10,12 @@ from django.views.generic.list import MultipleObjectMixin
### Base classes for the generic views ###
-class BaseView(views.APIView):
+class GenericAPIView(views.APIView):
"""
Base class for all other generic views.
"""
serializer_class = None
- model_serializer_class = api_settings.MODEL_SERIALIZER
+ model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS
def get_serializer_context(self):
"""
@@ -43,21 +43,31 @@ class BaseView(views.APIView):
return serializer_class
- def get_serializer(self, data=None, files=None, instance=None):
+ def get_serializer(self, instance=None, data=None, files=None):
# TODO: add support for files
# TODO: add support for seperate serializer/deserializer
serializer_class = self.get_serializer_class()
context = self.get_serializer_context()
- return serializer_class(data, instance=instance, context=context)
+ return serializer_class(instance, data=data, context=context)
-class MultipleObjectBaseView(MultipleObjectMixin, BaseView):
+class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView):
"""
Base class for generic views onto a queryset.
"""
- pagination_serializer_class = api_settings.PAGINATION_SERIALIZER
+ pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS
paginate_by = api_settings.PAGINATE_BY
+ filter_backend = api_settings.FILTER_BACKEND
+
+ def filter_queryset(self, queryset):
+ if not self.filter_backend:
+ return queryset
+ backend = self.filter_backend()
+ return backend.filter_queryset(self.request, queryset, self)
+
+ def get_filtered_queryset(self):
+ return self.filter_queryset(self.get_queryset())
def get_pagination_serializer_class(self):
"""
@@ -75,7 +85,7 @@ class MultipleObjectBaseView(MultipleObjectMixin, BaseView):
return pagination_serializer_class(instance=page, context=context)
-class SingleObjectBaseView(SingleObjectMixin, BaseView):
+class SingleObjectAPIView(SingleObjectMixin, GenericAPIView):
"""
Base class for generic views onto a model instance.
"""
@@ -86,7 +96,7 @@ class SingleObjectBaseView(SingleObjectMixin, BaseView):
"""
Override default to add support for object-level permissions.
"""
- obj = super(SingleObjectBaseView, self).get_object()
+ obj = super(SingleObjectAPIView, self).get_object()
if not self.has_permission(self.request, obj):
self.permission_denied(self.request)
return obj
@@ -95,8 +105,19 @@ class SingleObjectBaseView(SingleObjectMixin, BaseView):
### Concrete view classes that provide method handlers ###
### by composing the mixin classes with a base view. ###
+
+class CreateAPIView(mixins.CreateModelMixin,
+ GenericAPIView):
+
+ """
+ Concrete view for creating a model instance.
+ """
+ def post(self, request, *args, **kwargs):
+ return self.create(request, *args, **kwargs)
+
+
class ListAPIView(mixins.ListModelMixin,
- MultipleObjectBaseView):
+ MultipleObjectAPIView):
"""
Concrete view for listing a queryset.
"""
@@ -104,9 +125,38 @@ class ListAPIView(mixins.ListModelMixin,
return self.list(request, *args, **kwargs)
+class RetrieveAPIView(mixins.RetrieveModelMixin,
+ SingleObjectAPIView):
+ """
+ Concrete view for retrieving a model instance.
+ """
+ def get(self, request, *args, **kwargs):
+ return self.retrieve(request, *args, **kwargs)
+
+
+class DestroyAPIView(mixins.DestroyModelMixin,
+ SingleObjectAPIView):
+
+ """
+ Concrete view for deleting a model instance.
+ """
+ def delete(self, request, *args, **kwargs):
+ return self.destroy(request, *args, **kwargs)
+
+
+class UpdateAPIView(mixins.UpdateModelMixin,
+ SingleObjectAPIView):
+
+ """
+ Concrete view for updating a model instance.
+ """
+ def put(self, request, *args, **kwargs):
+ return self.update(request, *args, **kwargs)
+
+
class ListCreateAPIView(mixins.ListModelMixin,
mixins.CreateModelMixin,
- MultipleObjectBaseView):
+ MultipleObjectAPIView):
"""
Concrete view for listing a queryset or creating a model instance.
"""
@@ -117,18 +167,9 @@ class ListCreateAPIView(mixins.ListModelMixin,
return self.create(request, *args, **kwargs)
-class RetrieveAPIView(mixins.RetrieveModelMixin,
- SingleObjectBaseView):
- """
- Concrete view for retrieving a model instance.
- """
- def get(self, request, *args, **kwargs):
- return self.retrieve(request, *args, **kwargs)
-
-
class RetrieveDestroyAPIView(mixins.RetrieveModelMixin,
mixins.DestroyModelMixin,
- SingleObjectBaseView):
+ SingleObjectAPIView):
"""
Concrete view for retrieving or deleting a model instance.
"""
@@ -142,7 +183,7 @@ class RetrieveDestroyAPIView(mixins.RetrieveModelMixin,
class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,
mixins.UpdateModelMixin,
mixins.DestroyModelMixin,
- SingleObjectBaseView):
+ SingleObjectAPIView):
"""
Concrete view for retrieving, updating or deleting a model instance.
"""
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index 29153e18..c3625a88 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -3,9 +3,6 @@ Basic building blocks for generic class based views.
We don't bind behaviour to http method handlers yet,
which allows mixin classes to be composed in interesting ways.
-
-Eg. Use mixins to build a Resource class, and have a Router class
- perform the binding of http methods to actions for us.
"""
from django.http import Http404
from rest_framework import status
@@ -20,20 +17,24 @@ class CreateModelMixin(object):
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.DATA)
if serializer.is_valid():
+ self.pre_save(serializer.object)
self.object = serializer.save()
return Response(serializer.data, status=status.HTTP_201_CREATED)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
+ def pre_save(self, obj):
+ pass
+
class ListModelMixin(object):
"""
List a queryset.
- Should be mixed in with `MultipleObjectBaseView`.
+ Should be mixed in with `MultipleObjectAPIView`.
"""
empty_error = u"Empty list and '%(class_name)s.allow_empty' is False."
def list(self, request, *args, **kwargs):
- self.object_list = self.get_queryset()
+ self.object_list = self.get_filtered_queryset()
# Default is to allow empty querysets. This can be altered by setting
# `.allow_empty = False`, to raise 404 errors on empty querysets.
@@ -46,10 +47,11 @@ class ListModelMixin(object):
# which may be `None` to disable pagination.
page_size = self.get_paginate_by(self.object_list)
if page_size:
- paginator, page, queryset, is_paginated = self.paginate_queryset(self.object_list, page_size)
+ packed = self.paginate_queryset(self.object_list, page_size)
+ paginator, page, queryset, is_paginated = packed
serializer = self.get_pagination_serializer(page)
else:
- serializer = self.get_serializer(instance=self.object_list)
+ serializer = self.get_serializer(self.object_list)
return Response(serializer.data)
@@ -61,7 +63,7 @@ class RetrieveModelMixin(object):
"""
def retrieve(self, request, *args, **kwargs):
self.object = self.get_object()
- serializer = self.get_serializer(instance=self.object)
+ serializer = self.get_serializer(self.object)
return Response(serializer.data)
@@ -73,26 +75,25 @@ class UpdateModelMixin(object):
def update(self, request, *args, **kwargs):
try:
self.object = self.get_object()
+ success_status = status.HTTP_200_OK
except Http404:
self.object = None
+ success_status = status.HTTP_201_CREATED
- serializer = self.get_serializer(data=request.DATA, instance=self.object)
+ serializer = self.get_serializer(self.object, data=request.DATA)
if serializer.is_valid():
- if self.object is None:
- # If PUT occurs to a non existant object, we need to set any
- # attributes on the object that are implicit in the URL.
- self.update_urlconf_attributes(serializer.object)
+ self.pre_save(serializer.object)
self.object = serializer.save()
- return Response(serializer.data)
+ return Response(serializer.data, status=success_status)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
- def update_urlconf_attributes(self, obj):
+ def pre_save(self, obj):
"""
- When update (re)creates an object, we need to set any attributes that
- are tied to the URLconf.
+ Set any attributes on the object that are implicit in the request.
"""
+ # pk and/or slug attributes are implicit in the URL.
pk = self.kwargs.get(self.pk_url_kwarg, None)
if pk:
setattr(obj, 'pk', pk)
diff --git a/rest_framework/negotiation.py b/rest_framework/negotiation.py
index 8b22f669..dae38477 100644
--- a/rest_framework/negotiation.py
+++ b/rest_framework/negotiation.py
@@ -1,48 +1,38 @@
+from django.http import Http404
from rest_framework import exceptions
from rest_framework.settings import api_settings
from rest_framework.utils.mediatypes import order_by_precedence, media_type_matches
class BaseContentNegotiation(object):
- def negotiate(self, request, renderers, format=None, force=False):
- raise NotImplementedError('.negotiate() must be implemented')
+ def select_parser(self, request, parsers):
+ raise NotImplementedError('.select_parser() must be implemented')
+ def select_renderer(self, request, renderers, format_suffix=None):
+ raise NotImplementedError('.select_renderer() must be implemented')
-class DefaultContentNegotiation(object):
+
+class DefaultContentNegotiation(BaseContentNegotiation):
settings = api_settings
- def select_parser(self, parsers, media_type):
+ def select_parser(self, request, parsers):
"""
Given a list of parsers and a media type, return the appropriate
parser to handle the incoming request.
"""
for parser in parsers:
- if media_type_matches(parser.media_type, media_type):
+ if media_type_matches(parser.media_type, request.content_type):
return parser
return None
- def negotiate(self, request, renderers, format=None, force=False):
+ def select_renderer(self, request, renderers, format_suffix=None):
"""
Given a request and a list of renderers, return a two-tuple of:
(renderer, media type).
-
- If force is set, then suppress exceptions, and forcibly return a
- fallback renderer and media_type.
- """
- try:
- return self.unforced_negotiate(request, renderers, format)
- except (exceptions.InvalidFormat, exceptions.NotAcceptable):
- if force:
- return (renderers[0], renderers[0].media_type)
- raise
-
- def unforced_negotiate(self, request, renderers, format=None):
- """
- As `.negotiate()`, but does not take the optional `force` agument,
- or suppress exceptions.
"""
# Allow URL style format override. eg. "?format=json
- format = format or request.GET.get(self.settings.URL_FORMAT_OVERRIDE)
+ format_query_param = self.settings.URL_FORMAT_OVERRIDE
+ format = format_suffix or request.GET.get(format_query_param)
if format:
renderers = self.filter_renderers(renderers, format)
@@ -77,7 +67,7 @@ class DefaultContentNegotiation(object):
renderers = [renderer for renderer in renderers
if renderer.format == format]
if not renderers:
- raise exceptions.InvalidFormat(format)
+ raise Http404
return renderers
def get_accept_list(self, request):
diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py
index 131718fd..d241ade7 100644
--- a/rest_framework/pagination.py
+++ b/rest_framework/pagination.py
@@ -1,4 +1,5 @@
from rest_framework import serializers
+from rest_framework.templatetags.rest_framework import replace_query_param
# TODO: Support URLconf kwarg-style paging
@@ -7,30 +8,30 @@ class NextPageField(serializers.Field):
"""
Field that returns a link to the next page in paginated results.
"""
+ page_field = 'page'
+
def to_native(self, value):
if not value.has_next():
return None
page = value.next_page_number()
request = self.context.get('request')
- relative_url = '?page=%d' % page
- if request:
- return request.build_absolute_uri(relative_url)
- return relative_url
+ url = request and request.build_absolute_uri() or ''
+ return replace_query_param(url, self.page_field, page)
class PreviousPageField(serializers.Field):
"""
Field that returns a link to the previous page in paginated results.
"""
+ page_field = 'page'
+
def to_native(self, value):
if not value.has_previous():
return None
page = value.previous_page_number()
request = self.context.get('request')
- relative_url = '?page=%d' % page
- if request:
- return request.build_absolute_uri('?page=%d' % page)
- return relative_url
+ url = request and request.build_absolute_uri() or ''
+ return replace_query_param(url, self.page_field, page)
class PaginationSerializerOptions(serializers.SerializerOptions):
diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py
index 048b71e1..4841676c 100644
--- a/rest_framework/parsers.py
+++ b/rest_framework/parsers.py
@@ -1,14 +1,8 @@
"""
-Django supports parsing the content of an HTTP request, but only for form POST requests.
-That behavior is sufficient for dealing with standard HTML forms, but it doesn't map well
-to general HTTP requests.
+Parsers are used to parse the content of incoming HTTP requests.
-We need a method to be able to:
-
-1.) Determine the parsed content on a request for methods other than POST (eg typically also PUT)
-
-2.) Determine the parsed content on a request for media types other than application/x-www-form-urlencoded
- and multipart/form-data. (eg also handle multipart/json)
+They give us a generic way of being able to handle various media types
+on the request, such as form content or json encoded data.
"""
from django.http import QueryDict
@@ -21,7 +15,6 @@ from xml.etree import ElementTree as ET
from xml.parsers.expat import ExpatError
import datetime
import decimal
-from io import BytesIO
class DataAndFiles(object):
@@ -33,29 +26,18 @@ class DataAndFiles(object):
class BaseParser(object):
"""
All parsers should extend `BaseParser`, specifying a `media_type`
- attribute, and overriding the `.parse_stream()` method.
+ attribute, and overriding the `.parse()` method.
"""
media_type = None
- def parse(self, string_or_stream, parser_context=None):
+ def parse(self, stream, media_type=None, parser_context=None):
"""
- The main entry point to parsers. This is a light wrapper around
- `parse_stream`, that instead handles both string and stream objects.
- """
- if isinstance(string_or_stream, basestring):
- stream = BytesIO(string_or_stream)
- else:
- stream = string_or_stream
- return self.parse_stream(stream, parser_context)
-
- def parse_stream(self, stream, parser_context=None):
- """
- Given a stream to read from, return the deserialized output.
- Should return parsed data, or a DataAndFiles object consisting of the
+ Given a stream to read from, return the parsed representation.
+ Should return parsed data, or a `DataAndFiles` object consisting of the
parsed data and files.
"""
- raise NotImplementedError(".parse_stream() must be overridden.")
+ raise NotImplementedError(".parse() must be overridden.")
class JSONParser(BaseParser):
@@ -65,7 +47,7 @@ class JSONParser(BaseParser):
media_type = 'application/json'
- def parse_stream(self, stream, parser_context=None):
+ def parse(self, stream, media_type=None, parser_context=None):
"""
Returns a 2-tuple of `(data, files)`.
@@ -85,7 +67,7 @@ class YAMLParser(BaseParser):
media_type = 'application/yaml'
- def parse_stream(self, stream, parser_context=None):
+ def parse(self, stream, media_type=None, parser_context=None):
"""
Returns a 2-tuple of `(data, files)`.
@@ -105,7 +87,7 @@ class FormParser(BaseParser):
media_type = 'application/x-www-form-urlencoded'
- def parse_stream(self, stream, parser_context=None):
+ def parse(self, stream, media_type=None, parser_context=None):
"""
Returns a 2-tuple of `(data, files)`.
@@ -123,7 +105,7 @@ class MultiPartParser(BaseParser):
media_type = 'multipart/form-data'
- def parse_stream(self, stream, parser_context=None):
+ def parse(self, stream, media_type=None, parser_context=None):
"""
Returns a DataAndFiles object.
@@ -131,8 +113,10 @@ class MultiPartParser(BaseParser):
`.files` will be a `QueryDict` containing all the form files.
"""
parser_context = parser_context or {}
- meta = parser_context['meta']
- upload_handlers = parser_context['upload_handlers']
+ request = parser_context['request']
+ meta = request.META
+ upload_handlers = request.upload_handlers
+
try:
parser = DjangoMultiPartParser(meta, stream, upload_handlers)
data, files = parser.parse()
@@ -148,7 +132,7 @@ class XMLParser(BaseParser):
media_type = 'application/xml'
- def parse_stream(self, stream, parser_context=None):
+ def parse(self, stream, media_type=None, parser_context=None):
try:
tree = ET.parse(stream)
except (ExpatError, ETParseError, ValueError), exc:
diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py
index 6f848cee..655b78a3 100644
--- a/rest_framework/permissions.py
+++ b/rest_framework/permissions.py
@@ -18,6 +18,17 @@ class BasePermission(object):
raise NotImplementedError(".has_permission() must be overridden.")
+class AllowAny(BasePermission):
+ """
+ Allow any access.
+ This isn't strictly required, since you could use an empty
+ permission_classes list, but it's useful because it makes the intention
+ more explicit.
+ """
+ def has_permission(self, request, view, obj=None):
+ return True
+
+
class IsAuthenticated(BasePermission):
"""
Allows access only to authenticated users.
@@ -85,7 +96,7 @@ class DjangoModelPermissions(BasePermission):
"""
kwargs = {
'app_label': model_cls._meta.app_label,
- 'model_name': model_cls._meta.module_name
+ 'model_name': model_cls._meta.module_name
}
return [perm % kwargs for perm in self.perms_map[method]]
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index 94d253c9..22fd6e74 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -1,14 +1,16 @@
"""
-Renderers are used to serialize a View's output into specific media types.
+Renderers are used to serialize a response into specific media types.
-Django REST framework also provides HTML and PlainText renderers that help self-document the API,
-by serializing the output along with documentation regarding the View, output status and headers,
-and providing forms and links depending on the allowed methods, renderers and parsers on the View.
+They give us a generic way of being able to handle various media types
+on the response, such as JSON encoded data or HTML output.
+
+REST framework also provides an HTML renderer the renders the browseable API.
"""
+import copy
import string
from django import forms
from django.http.multipartparser import parse_header
-from django.template import RequestContext, loader
+from django.template import RequestContext, loader, Template
from django.utils import simplejson as json
from rest_framework.compat import yaml
from rest_framework.exceptions import ConfigurationError
@@ -23,8 +25,8 @@ from rest_framework import serializers, parsers
class BaseRenderer(object):
"""
- All renderers must extend this class, set the :attr:`media_type` attribute,
- and override the :meth:`render` method.
+ All renderers should extend this class, setting the `media_type`
+ and `format` attributes, and override the `.render()` method.
"""
media_type = None
@@ -98,7 +100,7 @@ class JSONPRenderer(JSONRenderer):
callback = self.get_callback(renderer_context)
json = super(JSONPRenderer, self).render(data, accepted_media_type,
renderer_context)
- return "%s(%s);" % (callback, json)
+ return u"%s(%s);" % (callback, json)
class XMLRenderer(BaseRenderer):
@@ -137,18 +139,33 @@ class YAMLRenderer(BaseRenderer):
return yaml.dump(data, stream=None, Dumper=self.encoder)
-class HTMLRenderer(BaseRenderer):
+class TemplateHTMLRenderer(BaseRenderer):
"""
- A Base class provided for convenience.
+ An HTML renderer for use with templates.
+
+ The data supplied to the Response object should be a dictionary that will
+ be used as context for the template.
+
+ The template name is determined by (in order of preference):
+
+ 1. An explicit `.template_name` attribute set on the response.
+ 2. An explicit `.template_name` attribute set on this class.
+ 3. The return result of calling `view.get_template_names()`.
- Render the object simply by using the given template.
- To create a template renderer, subclass this class, and set
- the :attr:`media_type` and :attr:`template` attributes.
+ For example:
+ data = {'users': User.objects.all()}
+ return Response(data, template_name='users.html')
+
+ For pre-rendered HTML, see StaticHTMLRenderer.
"""
media_type = 'text/html'
format = 'html'
template_name = None
+ exception_template_names = [
+ '%(status_code)s.html',
+ 'api_exception.html'
+ ]
def render(self, data, accepted_media_type=None, renderer_context=None):
"""
@@ -165,15 +182,21 @@ class HTMLRenderer(BaseRenderer):
request = renderer_context['request']
response = renderer_context['response']
- template_names = self.get_template_names(response, view)
- template = self.resolve_template(template_names)
- context = self.resolve_context(data, request)
+ if response.exception:
+ template = self.get_exception_template(response)
+ else:
+ template_names = self.get_template_names(response, view)
+ template = self.resolve_template(template_names)
+
+ context = self.resolve_context(data, request, response)
return template.render(context)
def resolve_template(self, template_names):
return loader.select_template(template_names)
- def resolve_context(self, data, request):
+ def resolve_context(self, data, request, response):
+ if response.exception:
+ data['status_code'] = response.status_code
return RequestContext(request, data)
def get_template_names(self, response, view):
@@ -185,6 +208,48 @@ class HTMLRenderer(BaseRenderer):
return view.get_template_names()
raise ConfigurationError('Returned a template response with no template_name')
+ def get_exception_template(self, response):
+ template_names = [name % {'status_code': response.status_code}
+ for name in self.exception_template_names]
+
+ try:
+ # Try to find an appropriate error template
+ return self.resolve_template(template_names)
+ except:
+ # Fall back to using eg '404 Not Found'
+ return Template('%d %s' % (response.status_code,
+ response.status_text.title()))
+
+
+# Note, subclass TemplateHTMLRenderer simply for the exception behavior
+class StaticHTMLRenderer(TemplateHTMLRenderer):
+ """
+ An HTML renderer class that simply returns pre-rendered HTML.
+
+ The data supplied to the Response object should be a string representing
+ the pre-rendered HTML content.
+
+ For example:
+ data = '<html><body>example</body></html>'
+ return Response(data)
+
+ For template rendered HTML, see TemplateHTMLRenderer.
+ """
+ media_type = 'text/html'
+ format = 'html'
+
+ def render(self, data, accepted_media_type=None, renderer_context=None):
+ renderer_context = renderer_context or {}
+ response = renderer_context['response']
+
+ if response and response.exception:
+ request = renderer_context['request']
+ template = self.get_exception_template(response)
+ context = self.resolve_context(data, request, response)
+ return template.render(context)
+
+ return data
+
class BrowsableAPIRenderer(BaseRenderer):
"""
@@ -222,11 +287,9 @@ class BrowsableAPIRenderer(BaseRenderer):
return content
- def get_form(self, view, method, request):
+ def show_form_for_method(self, view, method, request, obj):
"""
- Get a form, possibly bound to either the input or output data.
- In the absence on of the Resource having an associated form then
- provide a form that can be used to submit arbitrary content.
+ Returns True if a form should be shown for this method.
"""
if not method in view.allowed_methods:
return # Not a valid method
@@ -236,24 +299,13 @@ class BrowsableAPIRenderer(BaseRenderer):
request = clone_request(request, method)
try:
- if not view.has_permission(request):
+ if not view.has_permission(request, obj):
return # Don't have permission
except:
return # Don't have permission and exception explicitly raise
+ return True
- if method == 'DELETE' or method == 'OPTIONS':
- return True # Don't actually need to return a form
-
- if (not getattr(view, 'get_serializer', None) or
- not parsers.FormParser in getattr(view, 'parser_classes')):
- media_types = [parser.media_type for parser in view.parser_classes]
- return self.get_generic_content_form(media_types)
-
- #####
- # TODO: This is a little bit of a hack. Actually we'd like to remove
- # this and just render serializer fields to html directly.
-
- # We need to map our Fields to Django's Fields.
+ def serializer_to_form_fields(self, serializer):
field_mapping = {
serializers.FloatField: forms.FloatField,
serializers.IntegerField: forms.IntegerField,
@@ -261,34 +313,72 @@ class BrowsableAPIRenderer(BaseRenderer):
serializers.DateField: forms.DateField,
serializers.EmailField: forms.EmailField,
serializers.CharField: forms.CharField,
+ serializers.ChoiceField: forms.ChoiceField,
serializers.BooleanField: forms.BooleanField,
- serializers.PrimaryKeyRelatedField: forms.ModelChoiceField,
- serializers.ManyPrimaryKeyRelatedField: forms.ModelMultipleChoiceField
+ serializers.PrimaryKeyRelatedField: forms.ChoiceField,
+ serializers.ManyPrimaryKeyRelatedField: forms.MultipleChoiceField,
+ serializers.SlugRelatedField: forms.ChoiceField,
+ serializers.ManySlugRelatedField: forms.MultipleChoiceField,
+ serializers.HyperlinkedRelatedField: forms.ChoiceField,
+ serializers.ManyHyperlinkedRelatedField: forms.MultipleChoiceField
}
- # Creating an on the fly form see: http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python
fields = {}
- obj, data = None, None
- if getattr(view, 'object', None):
- obj = view.object
-
- serializer = view.get_serializer(instance=obj)
for k, v in serializer.get_fields(True).items():
- if getattr(v, 'readonly', True):
+ if getattr(v, 'read_only', True):
continue
kwargs = {}
- if getattr(v, 'queryset', None):
- kwargs['queryset'] = getattr(v, 'queryset', None)
+ kwargs['required'] = v.required
+
+ #if getattr(v, 'queryset', None):
+ # kwargs['queryset'] = v.queryset
+
+ if getattr(v, 'choices', None) is not None:
+ kwargs['choices'] = v.choices
+
+ if getattr(v, 'widget', None):
+ widget = copy.deepcopy(v.widget)
+ kwargs['widget'] = widget
+
+ if getattr(v, 'default', None) is not None:
+ kwargs['initial'] = v.default
+
+ kwargs['label'] = k
try:
fields[k] = field_mapping[v.__class__](**kwargs)
except KeyError:
- fields[k] = forms.CharField()
+ if getattr(v, 'choices', None) is not None:
+ fields[k] = forms.ChoiceField(**kwargs)
+ else:
+ fields[k] = forms.CharField(**kwargs)
+ return fields
+
+ def get_form(self, view, method, request):
+ """
+ Get a form, possibly bound to either the input or output data.
+ In the absence on of the Resource having an associated form then
+ provide a form that can be used to submit arbitrary content.
+ """
+ obj = getattr(view, 'object', None)
+ if not self.show_form_for_method(view, method, request, obj):
+ return
+
+ if method == 'DELETE' or method == 'OPTIONS':
+ return True # Don't actually need to return a form
+
+ if not getattr(view, 'get_serializer', None) or not parsers.FormParser in view.parser_classes:
+ media_types = [parser.media_type for parser in view.parser_classes]
+ return self.get_generic_content_form(media_types)
+
+ serializer = view.get_serializer(instance=obj)
+ fields = self.serializer_to_form_fields(serializer)
+ # Creating an on the fly form see:
+ # http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python
OnTheFlyForm = type("OnTheFlyForm", (forms.Form,), fields)
- if obj and not view.request.method == 'DELETE': # Don't fill in the form when the object is deleted
- data = serializer.data
+ data = (obj is not None) and serializer.data or None
form_instance = OnTheFlyForm(data)
return form_instance
diff --git a/rest_framework/request.py b/rest_framework/request.py
index 6f9cf09a..a1827ba4 100644
--- a/rest_framework/request.py
+++ b/rest_framework/request.py
@@ -21,8 +21,8 @@ def is_form_media_type(media_type):
Return True if the media type is a valid form media type.
"""
base_media_type, params = parse_header(media_type)
- return base_media_type == 'application/x-www-form-urlencoded' or \
- base_media_type == 'multipart/form-data'
+ return (base_media_type == 'application/x-www-form-urlencoded' or
+ base_media_type == 'multipart/form-data')
class Empty(object):
@@ -88,16 +88,11 @@ class Request(object):
self._stream = Empty
if self.parser_context is None:
- self.parser_context = self._default_parser_context(request)
+ self.parser_context = {}
+ self.parser_context['request'] = self
def _default_negotiator(self):
- return api_settings.DEFAULT_CONTENT_NEGOTIATION()
-
- def _default_parser_context(self, request):
- return {
- 'upload_handlers': request.upload_handlers,
- 'meta': request.META,
- }
+ return api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS()
@property
def method(self):
@@ -265,15 +260,19 @@ class Request(object):
May raise an `UnsupportedMediaType`, or `ParseError` exception.
"""
- if self.stream is None or self.content_type is None:
+ stream = self.stream
+ media_type = self.content_type
+
+ if stream is None or media_type is None:
return (None, None)
- parser = self.negotiator.select_parser(self.parsers, self.content_type)
+ parser = self.negotiator.select_parser(self, self.parsers)
if not parser:
- raise exceptions.UnsupportedMediaType(self.content_type)
+ raise exceptions.UnsupportedMediaType(media_type)
+
+ parsed = parser.parse(stream, media_type, self.parser_context)
- parsed = parser.parse(self.stream, self.parser_context)
# Parser classes may return the raw data, or a
# DataAndFiles object. Unpack the result as required.
try:
diff --git a/rest_framework/resources.py b/rest_framework/resources.py
deleted file mode 100644
index dd8a5471..00000000
--- a/rest_framework/resources.py
+++ /dev/null
@@ -1,96 +0,0 @@
-##### RESOURCES AND ROUTERS ARE NOT YET IMPLEMENTED - PLACEHOLDER ONLY #####
-
-from functools import update_wrapper
-import inspect
-from django.utils.decorators import classonlymethod
-from rest_framework import views, generics
-
-
-def wrapped(source, dest):
- """
- Copy public, non-method attributes from source to dest, and return dest.
- """
- for attr in [attr for attr in dir(source)
- if not attr.startswith('_') and not inspect.ismethod(attr)]:
- setattr(dest, attr, getattr(source, attr))
- return dest
-
-
-##### RESOURCES AND ROUTERS ARE NOT YET IMPLEMENTED - PLACEHOLDER ONLY #####
-
-class ResourceMixin(object):
- """
- Clone Django's `View.as_view()` behaviour *except* using REST framework's
- 'method -> action' binding for resources.
- """
-
- @classonlymethod
- def as_view(cls, actions, **initkwargs):
- """
- Main entry point for a request-response process.
- """
- # sanitize keyword arguments
- for key in initkwargs:
- if key in cls.http_method_names:
- raise TypeError("You tried to pass in the %s method name as a "
- "keyword argument to %s(). Don't do that."
- % (key, cls.__name__))
- if not hasattr(cls, key):
- raise TypeError("%s() received an invalid keyword %r" % (
- cls.__name__, key))
-
- def view(request, *args, **kwargs):
- self = cls(**initkwargs)
-
- # Bind methods to actions
- for method, action in actions.items():
- handler = getattr(self, action)
- setattr(self, method, handler)
-
- # As you were, solider.
- if hasattr(self, 'get') and not hasattr(self, 'head'):
- self.head = self.get
- return self.dispatch(request, *args, **kwargs)
-
- # take name and docstring from class
- update_wrapper(view, cls, updated=())
-
- # and possible attributes set by decorators
- # like csrf_exempt from dispatch
- update_wrapper(view, cls.dispatch, assigned=())
- return view
-
-
-##### RESOURCES AND ROUTERS ARE NOT YET IMPLEMENTED - PLACEHOLDER ONLY #####
-
-class Resource(ResourceMixin, views.APIView):
- pass
-
-
-##### RESOURCES AND ROUTERS ARE NOT YET IMPLEMENTED - PLACEHOLDER ONLY #####
-
-class ModelResource(ResourceMixin, views.APIView):
- # TODO: Actually delegation won't work
- root_class = generics.ListCreateAPIView
- detail_class = generics.RetrieveUpdateDestroyAPIView
-
- def root_view(self):
- return wrapped(self, self.root_class())
-
- def detail_view(self):
- return wrapped(self, self.detail_class())
-
- def list(self, request, *args, **kwargs):
- return self.root_view().list(request, args, kwargs)
-
- def create(self, request, *args, **kwargs):
- return self.root_view().create(request, args, kwargs)
-
- def retrieve(self, request, *args, **kwargs):
- return self.detail_view().retrieve(request, args, kwargs)
-
- def update(self, request, *args, **kwargs):
- return self.detail_view().update(request, args, kwargs)
-
- def destroy(self, request, *args, **kwargs):
- return self.detail_view().destroy(request, args, kwargs)
diff --git a/rest_framework/response.py b/rest_framework/response.py
index 7a459c8f..0de01204 100644
--- a/rest_framework/response.py
+++ b/rest_framework/response.py
@@ -9,7 +9,8 @@ class Response(SimpleTemplateResponse):
"""
def __init__(self, data=None, status=200,
- template_name=None, headers=None):
+ template_name=None, headers=None,
+ exception=False):
"""
Alters the init arguments slightly.
For example, drop 'template_name', and instead use 'data'.
@@ -21,6 +22,7 @@ class Response(SimpleTemplateResponse):
self.data = data
self.headers = headers and headers[:] or []
self.template_name = template_name
+ self.exception = exception
@property
def rendered_content(self):
@@ -45,3 +47,13 @@ class Response(SimpleTemplateResponse):
# TODO: Deprecate and use a template tag instead
# TODO: Status code text for RFC 6585 status codes
return STATUS_CODE_TEXT.get(self.status_code, '')
+
+ def __getstate__(self):
+ """
+ Remove attributes from the response that shouldn't be cached
+ """
+ state = super(Response, self).__getstate__()
+ for key in ('accepted_renderer', 'renderer_context', 'data'):
+ if key in state:
+ del state[key]
+ return state
diff --git a/rest_framework/reverse.py b/rest_framework/reverse.py
index ba663f98..c9db02f0 100644
--- a/rest_framework/reverse.py
+++ b/rest_framework/reverse.py
@@ -5,13 +5,15 @@ from django.core.urlresolvers import reverse as django_reverse
from django.utils.functional import lazy
-def reverse(viewname, *args, **kwargs):
+def reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra):
"""
Same as `django.core.urlresolvers.reverse`, but optionally takes a request
and returns a fully qualified URL, using the request to get the base URL.
"""
- request = kwargs.pop('request', None)
- url = django_reverse(viewname, *args, **kwargs)
+ if format is not None:
+ kwargs = kwargs or {}
+ kwargs['format'] = format
+ url = django_reverse(viewname, args=args, kwargs=kwargs, **extra)
if request:
return request.build_absolute_uri(url)
return url
diff --git a/rest_framework/runtests/runcoverage.py b/rest_framework/runtests/runcoverage.py
index ea2e3d45..0ce379eb 100755
--- a/rest_framework/runtests/runcoverage.py
+++ b/rest_framework/runtests/runcoverage.py
@@ -32,10 +32,10 @@ def main():
'Function-based test runners are deprecated. Test runners should be classes with a run_tests() method.',
DeprecationWarning
)
- failures = TestRunner(['rest_framework'])
+ failures = TestRunner(['tests'])
else:
test_runner = TestRunner()
- failures = test_runner.run_tests(['rest_framework'])
+ failures = test_runner.run_tests(['tests'])
cov.stop()
# Discover the list of all modules that we should test coverage for
diff --git a/rest_framework/runtests/runtests.py b/rest_framework/runtests/runtests.py
index b2438c9b..1bd0a5fc 100755
--- a/rest_framework/runtests/runtests.py
+++ b/rest_framework/runtests/runtests.py
@@ -32,7 +32,7 @@ def main():
else:
print usage()
sys.exit(1)
- failures = test_runner.run_tests(['rest_framework' + test_case])
+ failures = test_runner.run_tests(['tests' + test_case])
sys.exit(failures)
diff --git a/rest_framework/runtests/settings.py b/rest_framework/runtests/settings.py
index 67de82c8..dd5d9dc3 100644
--- a/rest_framework/runtests/settings.py
+++ b/rest_framework/runtests/settings.py
@@ -21,6 +21,12 @@ DATABASES = {
}
}
+CACHES = {
+ 'default': {
+ 'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
+ }
+}
+
# Local time zone for this installation. Choices can be found here:
# http://en.wikipedia.org/wiki/List_of_tz_zones_by_name
# although not all choices may be available on all operating systems.
@@ -91,6 +97,7 @@ INSTALLED_APPS = (
# 'django.contrib.admindocs',
'rest_framework',
'rest_framework.authtoken',
+ 'rest_framework.tests'
)
STATIC_URL = '/static/'
@@ -100,13 +107,6 @@ import django
if django.VERSION < (1, 3):
INSTALLED_APPS += ('staticfiles',)
-# OAuth support is optional, so we only test oauth if it's installed.
-try:
- import oauth_provider
-except ImportError:
- pass
-else:
- INSTALLED_APPS += ('oauth_provider',)
# If we're running on the Jenkins server we want to archive the coverage reports as XML.
import os
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 13f8cde2..95145d58 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -3,8 +3,18 @@ import datetime
import types
from decimal import Decimal
from django.db import models
+from django.forms import widgets
from django.utils.datastructures import SortedDict
from rest_framework.compat import get_concrete_model
+
+# Note: We do the following so that users of the framework can use this style:
+#
+# example_field = serializers.CharField(...)
+#
+# This helps keep the seperation between model fields, form fields, and
+# serializer fields more explicit.
+
+
from rest_framework.fields import *
@@ -22,10 +32,6 @@ class SortedDictWithMetadata(SortedDict, DictWithMetadata):
pass
-class RecursionOccured(BaseException):
- pass
-
-
def _is_protected_type(obj):
"""
True if the object is a native datatype that does not need to
@@ -33,10 +39,10 @@ def _is_protected_type(obj):
"""
return isinstance(obj, (
types.NoneType,
- int, long,
- datetime.datetime, datetime.date, datetime.time,
- float, Decimal,
- basestring)
+ int, long,
+ datetime.datetime, datetime.date, datetime.time,
+ float, Decimal,
+ basestring)
)
@@ -73,7 +79,7 @@ class SerializerOptions(object):
Meta class options for Serializer
"""
def __init__(self, meta):
- self.nested = getattr(meta, 'nested', False)
+ self.depth = getattr(meta, 'depth', 0)
self.fields = getattr(meta, 'fields', ())
self.exclude = getattr(meta, 'exclude', ())
@@ -85,14 +91,13 @@ class BaseSerializer(Field):
_options_class = SerializerOptions
_dict_class = SortedDictWithMetadata # Set to unsorted dict for backwards compatability with unsorted implementations.
- def __init__(self, data=None, instance=None, context=None, **kwargs):
+ def __init__(self, instance=None, data=None, context=None, **kwargs):
super(BaseSerializer, self).__init__(**kwargs)
- self.fields = copy.deepcopy(self.base_fields)
self.opts = self._options_class(self.Meta)
+ self.fields = copy.deepcopy(self.base_fields)
self.parent = None
self.root = None
- self.stack = []
self.context = context or {}
self.init_data = data
@@ -104,13 +109,13 @@ class BaseSerializer(Field):
#####
# Methods to determine which fields to use when (de)serializing objects.
- def default_fields(self, serialize, obj=None, data=None, nested=False):
+ def default_fields(self, nested=False):
"""
Return the complete set of default fields for the object, as a dict.
"""
return {}
- def get_fields(self, serialize, obj=None, data=None, nested=False):
+ def get_fields(self, nested=False):
"""
Returns the complete set of fields for the object as a dict.
@@ -123,10 +128,10 @@ class BaseSerializer(Field):
for key, field in self.fields.items():
ret[key] = field
# Set up the field
- field.initialize(parent=self)
+ field.initialize(parent=self, field_name=key)
# Add in the default fields
- fields = self.default_fields(serialize, obj, data, nested)
+ fields = self.default_fields(nested)
for key, val in fields.items():
if key not in ret:
ret[key] = val
@@ -148,17 +153,14 @@ class BaseSerializer(Field):
#####
# Field methods - used when the serializer class is itself used as a field.
- def initialize(self, parent):
+ def initialize(self, parent, field_name):
"""
Same behaviour as usual Field, except that we need to keep track
- of state so that we can deal with handling maximum depth and recursion.
+ of state so that we can deal with handling maximum depth.
"""
- super(BaseSerializer, self).initialize(parent)
- self.stack = parent.stack[:]
- if parent.opts.nested and not isinstance(parent.opts.nested, bool):
- self.opts.nested = parent.opts.nested - 1
- else:
- self.opts.nested = parent.opts.nested
+ super(BaseSerializer, self).initialize(parent, field_name)
+ if parent.opts.depth:
+ self.opts.depth = parent.opts.depth - 1
#####
# Methods to convert or revert from objects <--> primative representations.
@@ -174,21 +176,13 @@ class BaseSerializer(Field):
Core of serialization.
Convert an object into a dictionary of serialized field values.
"""
- if obj in self.stack and not self.source == '*':
- raise RecursionOccured()
- self.stack.append(obj)
-
ret = self._dict_class()
ret.fields = {}
- fields = self.get_fields(serialize=True, obj=obj, nested=self.opts.nested)
+ fields = self.get_fields(nested=bool(self.opts.depth))
for field_name, field in fields.items():
key = self.get_field_key(field_name)
- try:
- value = field.field_to_native(obj, field_name)
- except RecursionOccured:
- field = self.get_fields(serialize=True, obj=obj, nested=False)[field_name]
- value = field.field_to_native(obj, field_name)
+ value = field.field_to_native(obj, field_name)
ret[key] = value
ret.fields[key] = field
return ret
@@ -198,7 +192,7 @@ class BaseSerializer(Field):
Core of deserialization, together with `restore_object`.
Converts a dictionary of data into a dictionary of deserialized fields.
"""
- fields = self.get_fields(serialize=False, data=data, nested=self.opts.nested)
+ fields = self.get_fields(nested=bool(self.opts.depth))
reverted_data = {}
for field_name, field in fields.items():
try:
@@ -208,6 +202,35 @@ class BaseSerializer(Field):
return reverted_data
+ def perform_validation(self, attrs):
+ """
+ Run `validate_<fieldname>()` and `validate()` methods on the serializer
+ """
+ # TODO: refactor this so we're not determining the fields again
+ fields = self.get_fields(nested=bool(self.opts.depth))
+
+ for field_name, field in fields.items():
+ try:
+ validate_method = getattr(self, 'validate_%s' % field_name, None)
+ if validate_method:
+ source = field.source or field_name
+ attrs = validate_method(attrs, source)
+ except ValidationError as err:
+ self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages)
+
+ try:
+ attrs = self.validate(attrs)
+ except ValidationError as err:
+ self._errors['non_field_errors'] = err.messages
+
+ return attrs
+
+ def validate(self, attrs):
+ """
+ Stub method, to be overridden in Serializer subclasses
+ """
+ return attrs
+
def restore_object(self, attrs, instance=None):
"""
Deserialize a dictionary of attributes into an object instance.
@@ -223,11 +246,8 @@ class BaseSerializer(Field):
"""
Serialize objects -> primatives.
"""
- if isinstance(obj, dict):
- return dict([(key, self.to_native(val))
- for (key, val) in obj.items()])
- elif hasattr(obj, '__iter__'):
- return [self.to_native(item) for item in obj]
+ if hasattr(obj, '__iter__'):
+ return [self.convert_object(item) for item in obj]
return self.convert_object(obj)
def from_native(self, data):
@@ -241,17 +261,31 @@ class BaseSerializer(Field):
self._errors = {}
if data is not None:
attrs = self.restore_fields(data)
+ attrs = self.perform_validation(attrs)
else:
- self._errors['non_field_errors'] = 'No input provided'
+ self._errors['non_field_errors'] = ['No input provided']
if not self._errors:
return self.restore_object(attrs, instance=getattr(self, 'object', None))
+ def field_to_native(self, obj, field_name):
+ """
+ Override default so that we can apply ModelSerializer as a nested
+ field to relationships.
+ """
+ obj = getattr(obj, self.source or field_name)
+
+ # If the object has an "all" method, assume it's a relationship
+ if is_simple_callable(getattr(obj, 'all', None)):
+ return [self.to_native(item) for item in obj.all()]
+
+ return self.to_native(obj)
+
@property
def errors(self):
"""
Run deserialization and return error data,
- setting self.object if no errors occured.
+ setting self.object if no errors occurred.
"""
if self._errors is None:
obj = self.from_native(self.init_data)
@@ -295,17 +329,7 @@ class ModelSerializer(Serializer):
"""
_options_class = ModelSerializerOptions
- def field_to_native(self, obj, field_name):
- """
- Override default so that we can apply ModelSerializer as a nested
- field to relationships.
- """
- obj = getattr(obj, self.source or field_name)
- if obj.__class__.__name__ in ('RelatedManager', 'ManyRelatedManager'):
- return [self.to_native(item) for item in obj.all()]
- return self.to_native(obj)
-
- def default_fields(self, serialize, obj=None, data=None, nested=False):
+ def default_fields(self, nested=False):
"""
Return all the fields that should be serialized for the model.
"""
@@ -342,7 +366,7 @@ class ModelSerializer(Serializer):
field = self.get_field(model_field)
if field:
- field.initialize(parent=self)
+ field.initialize(parent=self, field_name=model_field.name)
ret[model_field.name] = field
return ret
@@ -374,6 +398,25 @@ class ModelSerializer(Serializer):
"""
Creates a default instance of a basic non-relational field.
"""
+ kwargs = {}
+
+ kwargs['blank'] = model_field.blank
+
+ if model_field.null:
+ kwargs['required'] = False
+
+ if model_field.has_default():
+ kwargs['required'] = False
+ kwargs['default'] = model_field.get_default()
+
+ if model_field.__class__ == models.TextField:
+ kwargs['widget'] = widgets.Textarea
+
+ # TODO: TypedChoiceField?
+ if model_field.flatchoices: # This ModelField contains choices
+ kwargs['choices'] = model_field.flatchoices
+ return ChoiceField(**kwargs)
+
field_mapping = {
models.FloatField: FloatField,
models.IntegerField: IntegerField,
@@ -389,14 +432,9 @@ class ModelSerializer(Serializer):
models.BooleanField: BooleanField,
}
try:
- ret = field_mapping[model_field.__class__]()
+ return field_mapping[model_field.__class__](**kwargs)
except KeyError:
- ret = ModelField(model_field=model_field)
-
- if model_field.default:
- ret.required = False
-
- return ret
+ return ModelField(model_field=model_field, **kwargs)
def restore_object(self, attrs, instance=None):
"""
@@ -409,6 +447,13 @@ class ModelSerializer(Serializer):
setattr(instance, key, val)
return instance
+ # Reverse relations
+ for (obj, model) in self.opts.model._meta.get_all_related_m2m_objects_with_model():
+ field_name = obj.field.related_query_name()
+ if field_name in attrs:
+ self.m2m_data[field_name] = attrs.pop(field_name)
+
+ # Forward relations
for field in self.opts.model._meta.many_to_many:
if field.name in attrs:
self.m2m_data[field.name] = attrs.pop(field.name)
@@ -420,7 +465,7 @@ class ModelSerializer(Serializer):
"""
self.object.save()
- if self.m2m_data and save_m2m:
+ if getattr(self, 'm2m_data', None) and save_m2m:
for accessor_name, object_list in self.m2m_data.items():
setattr(self.object, accessor_name, object_list)
self.m2m_data = {}
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index 8bbb2f75..906a7cf6 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -3,11 +3,11 @@ Settings for REST framework are all namespaced in the REST_FRAMEWORK setting.
For example your project's `settings.py` file might look like this:
REST_FRAMEWORK = {
- 'DEFAULT_RENDERERS': (
+ 'DEFAULT_RENDERER_CLASSES': (
'rest_framework.renderers.JSONRenderer',
'rest_framework.renderers.YAMLRenderer',
)
- 'DEFAULT_PARSERS': (
+ 'DEFAULT_PARSER_CLASSES': (
'rest_framework.parsers.JSONParser',
'rest_framework.parsers.YAMLParser',
)
@@ -24,31 +24,38 @@ from django.utils import importlib
USER_SETTINGS = getattr(settings, 'REST_FRAMEWORK', None)
DEFAULTS = {
- 'DEFAULT_RENDERERS': (
+ 'DEFAULT_RENDERER_CLASSES': (
'rest_framework.renderers.JSONRenderer',
'rest_framework.renderers.BrowsableAPIRenderer',
),
- 'DEFAULT_PARSERS': (
+ 'DEFAULT_PARSER_CLASSES': (
'rest_framework.parsers.JSONParser',
'rest_framework.parsers.FormParser',
'rest_framework.parsers.MultiPartParser'
),
- 'DEFAULT_AUTHENTICATION': (
+ 'DEFAULT_AUTHENTICATION_CLASSES': (
'rest_framework.authentication.SessionAuthentication',
'rest_framework.authentication.BasicAuthentication'
),
- 'DEFAULT_PERMISSIONS': (),
- 'DEFAULT_THROTTLES': (),
- 'DEFAULT_CONTENT_NEGOTIATION':
+ 'DEFAULT_PERMISSION_CLASSES': (
+ 'rest_framework.permissions.AllowAny',
+ ),
+ 'DEFAULT_THROTTLE_CLASSES': (
+ ),
+
+ 'DEFAULT_CONTENT_NEGOTIATION_CLASS':
'rest_framework.negotiation.DefaultContentNegotiation',
+ 'DEFAULT_MODEL_SERIALIZER_CLASS':
+ 'rest_framework.serializers.ModelSerializer',
+ 'DEFAULT_PAGINATION_SERIALIZER_CLASS':
+ 'rest_framework.pagination.PaginationSerializer',
+
'DEFAULT_THROTTLE_RATES': {
'user': None,
'anon': None,
},
-
- 'MODEL_SERIALIZER': 'rest_framework.serializers.ModelSerializer',
- 'PAGINATION_SERIALIZER': 'rest_framework.pagination.PaginationSerializer',
'PAGINATE_BY': None,
+ 'FILTER_BACKEND': None,
'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
'UNAUTHENTICATED_TOKEN': None,
@@ -65,14 +72,15 @@ DEFAULTS = {
# List of settings that may be in string import notation.
IMPORT_STRINGS = (
- 'DEFAULT_RENDERERS',
- 'DEFAULT_PARSERS',
- 'DEFAULT_AUTHENTICATION',
- 'DEFAULT_PERMISSIONS',
- 'DEFAULT_THROTTLES',
- 'DEFAULT_CONTENT_NEGOTIATION',
- 'MODEL_SERIALIZER',
- 'PAGINATION_SERIALIZER',
+ 'DEFAULT_RENDERER_CLASSES',
+ 'DEFAULT_PARSER_CLASSES',
+ 'DEFAULT_AUTHENTICATION_CLASSES',
+ 'DEFAULT_PERMISSION_CLASSES',
+ 'DEFAULT_THROTTLE_CLASSES',
+ 'DEFAULT_CONTENT_NEGOTIATION_CLASS',
+ 'DEFAULT_MODEL_SERIALIZER_CLASS',
+ 'DEFAULT_PAGINATION_SERIALIZER_CLASS',
+ 'FILTER_BACKEND',
'UNAUTHENTICATED_USER',
'UNAUTHENTICATED_TOKEN',
)
@@ -111,7 +119,7 @@ class APISettings(object):
For example:
from rest_framework.settings import api_settings
- print api_settings.DEFAULT_RENDERERS
+ print api_settings.DEFAULT_RENDERER_CLASSES
Any setting with string import paths will be automatically resolved
and return the class, rather than the string literal.
@@ -136,8 +144,15 @@ class APISettings(object):
if val and attr in self.import_strings:
val = perform_import(val, attr)
+ self.validate_setting(attr, val)
+
# Cache the result
setattr(self, attr, val)
return val
+ def validate_setting(self, attr, val):
+ if attr == 'FILTER_BACKEND' and val is not None:
+ # Make sure we can initilize the class
+ val()
+
api_settings = APISettings(USER_SETTINGS, DEFAULTS, IMPORT_STRINGS)
diff --git a/rest_framework/static/rest_framework/css/default.css b/rest_framework/static/rest_framework/css/default.css
index 739b9300..b2e41b99 100644
--- a/rest_framework/static/rest_framework/css/default.css
+++ b/rest_framework/static/rest_framework/css/default.css
@@ -32,6 +32,17 @@ h2, h3 {
margin-right: 1em;
}
+ul.breadcrumb {
+ margin: 58px 0 0 0;
+}
+
+form select, form input, form textarea {
+ width: 90%;
+}
+
+form select[multiple] {
+ height: 150px;
+}
/* To allow tooltips to work on disabled elements */
.disabled-tooltip-shield {
position: absolute;
@@ -55,6 +66,7 @@ pre {
.page-header {
border-bottom: none;
padding-bottom: 0px;
+ margin-bottom: 20px;
}
@@ -65,7 +77,7 @@ html{
background: none;
}
-body, .navbar .navbar-inner .container-fluid{
+body, .navbar .navbar-inner .container-fluid {
max-width: 1150px;
margin: 0 auto;
}
@@ -76,13 +88,14 @@ body{
}
#content{
- margin: 40px 0 0 0;
+ margin: 0;
}
/* custom navigation styles */
.wrapper .navbar{
- width:100%;
+ width: 100%;
position: absolute;
- left:0;
+ left: 0;
+ top: 0;
}
.navbar .navbar-inner{
diff --git a/rest_framework/status.py b/rest_framework/status.py
index f3a5e481..a1eb48da 100644
--- a/rest_framework/status.py
+++ b/rest_framework/status.py
@@ -49,4 +49,4 @@ HTTP_502_BAD_GATEWAY = 502
HTTP_503_SERVICE_UNAVAILABLE = 503
HTTP_504_GATEWAY_TIMEOUT = 504
HTTP_505_HTTP_VERSION_NOT_SUPPORTED = 505
-HTTP_511_NETWORD_AUTHENTICATION_REQUIRED = 511
+HTTP_511_NETWORK_AUTHENTICATION_REQUIRED = 511
diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html
index 5ac6ef67..fb0e19f0 100644
--- a/rest_framework/templates/rest_framework/base.html
+++ b/rest_framework/templates/rest_framework/base.html
@@ -109,7 +109,7 @@
<div class="content-main">
<div class="page-header"><h1>{{ name }}</h1></div>
- <p class="resource-description">{{ description }}</p>
+ {{ description }}
<div class="request-info">
<pre class="prettyprint"><b>{{ request.method }}</b> {{ request.get_full_path }}</pre>
@@ -131,12 +131,12 @@
{% csrf_token %}
{{ post_form.non_field_errors }}
{% for field in post_form %}
- <div class="control-group {% if field.errors %}error{% endif %}">
+ <div class="control-group"> <!--{% if field.errors %}error{% endif %}-->
{{ field.label_tag|add_class:"control-label" }}
<div class="controls">
- {{ field|add_class:"input-xlarge" }}
+ {{ field }}
<span class="help-inline">{{ field.help_text }}</span>
- {{ field.errors|add_class:"help-block" }}
+ <!--{{ field.errors|add_class:"help-block" }}-->
</div>
</div>
{% endfor %}
@@ -156,12 +156,12 @@
{% csrf_token %}
{{ put_form.non_field_errors }}
{% for field in put_form %}
- <div class="control-group {% if field.errors %}error{% endif %}">
+ <div class="control-group"> <!--{% if field.errors %}error{% endif %}-->
{{ field.label_tag|add_class:"control-label" }}
<div class="controls">
- {{ field|add_class:"input-xlarge" }}
+ {{ field }}
<span class='help-inline'>{{ field.help_text }}</span>
- {{ field.errors|add_class:"help-block" }}
+ <!--{{ field.errors|add_class:"help-block" }}-->
</div>
</div>
{% endfor %}
diff --git a/rest_framework/templates/rest_framework/login.html b/rest_framework/templates/rest_framework/login.html
index 65af512e..c1271399 100644
--- a/rest_framework/templates/rest_framework/login.html
+++ b/rest_framework/templates/rest_framework/login.html
@@ -3,42 +3,50 @@
<html>
<head>
- <link rel="stylesheet" type="text/css" href='{% get_static_prefix %}rest_framework/css/style.css'/>
+ <link rel="stylesheet" type="text/css" href="{% get_static_prefix %}rest_framework/css/bootstrap.min.css"/>
+ <link rel="stylesheet" type="text/css" href="{% get_static_prefix %}rest_framework/css/bootstrap-tweaks.css"/>
+ <link rel="stylesheet" type="text/css" href='{% get_static_prefix %}rest_framework/css/default.css'/>
</head>
- <body class="login">
+ <body class="container">
- <div id="container">
-
- <div id="header">
- <div id="branding">
- <h1 id="site-name">Django REST framework</h1>
+<div class="container-fluid" style="margin-top: 30px">
+ <div class="row-fluid">
+
+ <div class="well" style="width: 320px; margin-left: auto; margin-right: auto">
+ <div class="row-fluid">
+ <div>
+ <h3 style="margin: 0 0 20px;">Django REST framework</h3>
</div>
- </div>
+ </div><!-- /row fluid -->
- <div id="content" class="colM">
- <div id="content-main">
- <form method="post" action="{% url 'rest_framework:login' %}" id="login-form">
+ <div class="row-fluid">
+ <div>
+ <form action="{% url 'rest_framework:login' %}" class=" form-inline" method="post">
{% csrf_token %}
- <div class="form-row">
- <label for="id_username">Username:</label> {{ form.username }}
+ <div id="div_id_username" class="clearfix control-group">
+ <div class="controls" style="height: 30px">
+ <Label class="span4" style="margin-top: 3px">Username:</label>
+ <input style="height: 25px" type="text" name="username" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_username">
+ </div>
</div>
- <div class="form-row">
- <label for="id_password">Password:</label> {{ form.password }}
- <input type="hidden" name="next" value="{{ next }}" />
+ <div id="div_id_password" class="clearfix control-group">
+ <div class="controls" style="height: 30px">
+ <Label class="span4" style="margin-top: 3px">Password:</label>
+ <input style="height: 25px" type="password" name="password" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_password">
+ </div>
</div>
- <div class="form-row">
- <label>&nbsp;</label><input type="submit" value="Log in">
+ <input type="hidden" name="next" value="{{ next }}" />
+ <div class="form-actions-no-box">
+ <input type="submit" name="submit" value="Log in" class="btn btn-primary" id="submit-id-submit">
</div>
</form>
- <script type="text/javascript">
- document.getElementById('id_username').focus()
- </script>
</div>
- <br class="clear">
- </div>
+ </div><!-- /row fluid -->
+ </div><!--/span-->
- <div id="footer"></div>
+ </div><!-- /.row-fluid -->
+ </div>
</div>
</body>
diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py
index c9b6eb10..4e0181ee 100644
--- a/rest_framework/templatetags/rest_framework.py
+++ b/rest_framework/templatetags/rest_framework.py
@@ -11,6 +11,18 @@ import string
register = template.Library()
+def replace_query_param(url, key, val):
+ """
+ Given a URL and a key/val pair, set or replace an item in the query
+ parameters of the URL, and return the new URL.
+ """
+ (scheme, netloc, path, query, fragment) = urlsplit(url)
+ query_dict = QueryDict(query).copy()
+ query_dict[key] = val
+ query = query_dict.urlencode()
+ return urlunsplit((scheme, netloc, path, query, fragment))
+
+
# Regex for adding classes to html snippets
class_re = re.compile(r'(?<=class=["\'])(.*)(?=["\'])')
@@ -31,19 +43,6 @@ hard_coded_bullets_re = re.compile(r'((?:<p>(?:%s).*?[a-zA-Z].*?</p>\s*)+)' % '|
trailing_empty_content_re = re.compile(r'(?:<p>(?:&nbsp;|\s|<br \/>)*?</p>\s*)+\Z')
-# Helper function for 'add_query_param'
-def replace_query_param(url, key, val):
- """
- Given a URL and a key/val pair, set or replace an item in the query
- parameters of the URL, and return the new URL.
- """
- (scheme, netloc, path, query, fragment) = urlsplit(url)
- query_dict = QueryDict(query).copy()
- query_dict[key] = val
- query = query_dict.urlencode()
- return urlunsplit((scheme, netloc, path, query, fragment))
-
-
# And the template tags themselves...
@register.simple_tag
diff --git a/rest_framework/tests/__init__.py b/rest_framework/tests/__init__.py
index adeaf6da..e69de29b 100644
--- a/rest_framework/tests/__init__.py
+++ b/rest_framework/tests/__init__.py
@@ -1,13 +0,0 @@
-"""
-Force import of all modules in this package in order to get the standard test
-runner to pick up the tests. Yowzers.
-"""
-import os
-
-modules = [filename.rsplit('.', 1)[0]
- for filename in os.listdir(os.path.dirname(__file__))
- if filename.endswith('.py') and not filename.startswith('_')]
-__test__ = dict()
-
-for module in modules:
- exec("from rest_framework.tests.%s import *" % module)
diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py
new file mode 100644
index 00000000..af2e6c2e
--- /dev/null
+++ b/rest_framework/tests/filterset.py
@@ -0,0 +1,168 @@
+import datetime
+from decimal import Decimal
+from django.test import TestCase
+from django.test.client import RequestFactory
+from django.utils import unittest
+from rest_framework import generics, status, filters
+from rest_framework.compat import django_filters
+from rest_framework.tests.models import FilterableItem, BasicModel
+
+factory = RequestFactory()
+
+
+if django_filters:
+ # Basic filter on a list view.
+ class FilterFieldsRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_fields = ['decimal', 'date']
+ filter_backend = filters.DjangoFilterBackend
+
+ # These class are used to test a filter class.
+ class SeveralFieldsFilter(django_filters.FilterSet):
+ text = django_filters.CharFilter(lookup_type='icontains')
+ decimal = django_filters.NumberFilter(lookup_type='lt')
+ date = django_filters.DateFilter(lookup_type='gt')
+
+ class Meta:
+ model = FilterableItem
+ fields = ['text', 'decimal', 'date']
+
+ class FilterClassRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_class = SeveralFieldsFilter
+ filter_backend = filters.DjangoFilterBackend
+
+ # These classes are used to test a misconfigured filter class.
+ class MisconfiguredFilter(django_filters.FilterSet):
+ text = django_filters.CharFilter(lookup_type='icontains')
+
+ class Meta:
+ model = BasicModel
+ fields = ['text']
+
+ class IncorrectlyConfiguredRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ filter_class = MisconfiguredFilter
+ filter_backend = filters.DjangoFilterBackend
+
+
+class IntegrationTestFiltering(TestCase):
+ """
+ Integration tests for filtered list views.
+ """
+
+ def setUp(self):
+ """
+ Create 10 FilterableItem instances.
+ """
+ base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
+ for i in range(10):
+ text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
+ decimal = base_data[1] + i
+ date = base_data[2] - datetime.timedelta(days=i * 2)
+ FilterableItem(text=text, decimal=decimal, date=date).save()
+
+ self.objects = FilterableItem.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
+ for obj in self.objects.all()
+ ]
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_get_filtered_fields_root_view(self):
+ """
+ GET requests to paginated ListCreateAPIView should return paginated results.
+ """
+ view = FilterFieldsRootView.as_view()
+
+ # Basic test with no filter.
+ request = factory.get('/')
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data, self.data)
+
+ # Tests that the decimal filter works.
+ search_decimal = Decimal('2.25')
+ request = factory.get('/?decimal=%s' % search_decimal)
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['decimal'] == search_decimal]
+ self.assertEquals(response.data, expected_data)
+
+ # Tests that the date filter works.
+ search_date = datetime.date(2012, 9, 22)
+ request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22'
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['date'] == search_date]
+ self.assertEquals(response.data, expected_data)
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_get_filtered_class_root_view(self):
+ """
+ GET requests to filtered ListCreateAPIView that have a filter_class set
+ should return filtered results.
+ """
+ view = FilterClassRootView.as_view()
+
+ # Basic test with no filter.
+ request = factory.get('/')
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data, self.data)
+
+ # Tests that the decimal filter set with 'lt' in the filter class works.
+ search_decimal = Decimal('4.25')
+ request = factory.get('/?decimal=%s' % search_decimal)
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['decimal'] < search_decimal]
+ self.assertEquals(response.data, expected_data)
+
+ # Tests that the date filter set with 'gt' in the filter class works.
+ search_date = datetime.date(2012, 10, 2)
+ request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02'
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['date'] > search_date]
+ self.assertEquals(response.data, expected_data)
+
+ # Tests that the text filter set with 'icontains' in the filter class works.
+ search_text = 'ff'
+ request = factory.get('/?text=%s' % search_text)
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if search_text in f['text'].lower()]
+ self.assertEquals(response.data, expected_data)
+
+ # Tests that multiple filters works.
+ search_decimal = Decimal('5.25')
+ search_date = datetime.date(2012, 10, 2)
+ request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date))
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if f['date'] > search_date and
+ f['decimal'] < search_decimal]
+ self.assertEquals(response.data, expected_data)
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_incorrectly_configured_filter(self):
+ """
+ An error should be displayed when the filter class is misconfigured.
+ """
+ view = IncorrectlyConfiguredRootView.as_view()
+
+ request = factory.get('/')
+ self.assertRaises(AssertionError, view, request)
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_unknown_filter(self):
+ """
+ GET requests with filters that aren't configured should return 200.
+ """
+ view = FilterFieldsRootView.as_view()
+
+ search_integer = 10
+ request = factory.get('/?integer=%s' % search_integer)
+ response = view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
diff --git a/rest_framework/tests/genericrelations.py b/rest_framework/tests/genericrelations.py
index 1d7e33bc..bc7378e1 100644
--- a/rest_framework/tests/genericrelations.py
+++ b/rest_framework/tests/genericrelations.py
@@ -25,7 +25,7 @@ class TestGenericRelations(TestCase):
model = Bookmark
exclude = ('id',)
- serializer = BookmarkSerializer(instance=self.bookmark)
+ serializer = BookmarkSerializer(self.bookmark)
expected = {
'tags': [u'django', u'python'],
'url': u'https://www.djangoproject.com/'
diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py
index f4263478..a8279ef2 100644
--- a/rest_framework/tests/generics.py
+++ b/rest_framework/tests/generics.py
@@ -2,7 +2,7 @@ from django.test import TestCase
from django.test.client import RequestFactory
from django.utils import simplejson as json
from rest_framework import generics, serializers, status
-from rest_framework.tests.models import BasicModel, Comment
+from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel
factory = RequestFactory()
@@ -22,6 +22,22 @@ class InstanceView(generics.RetrieveUpdateDestroyAPIView):
model = BasicModel
+class SlugSerializer(serializers.ModelSerializer):
+ slug = serializers.Field() # read only
+
+ class Meta:
+ model = SlugBasedModel
+ exclude = ('id',)
+
+
+class SlugBasedInstanceView(InstanceView):
+ """
+ A model with a slug-field.
+ """
+ model = SlugBasedModel
+ serializer_class = SlugSerializer
+
+
class TestRootView(TestCase):
def setUp(self):
"""
@@ -129,6 +145,7 @@ class TestInstanceView(TestCase):
for obj in self.objects.all()
]
self.view = InstanceView.as_view()
+ self.slug_based_view = SlugBasedInstanceView.as_view()
def test_get_instance_view(self):
"""
@@ -198,7 +215,7 @@ class TestInstanceView(TestCase):
def test_put_cannot_set_id(self):
"""
- POST requests to create a new object should not be able to set the id.
+ PUT requests to create a new object should not be able to set the id.
"""
content = {'id': 999, 'text': 'foobar'}
request = factory.put('/1', json.dumps(content),
@@ -219,11 +236,39 @@ class TestInstanceView(TestCase):
request = factory.put('/1', json.dumps(content),
content_type='application/json')
response = self.view(request, pk=1).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.status_code, status.HTTP_201_CREATED)
self.assertEquals(response.data, {'id': 1, 'text': 'foobar'})
updated = self.objects.get(id=1)
self.assertEquals(updated.text, 'foobar')
+ def test_put_as_create_on_id_based_url(self):
+ """
+ PUT requests to RetrieveUpdateDestroyAPIView should create an object
+ at the requested url if it doesn't exist.
+ """
+ content = {'text': 'foobar'}
+ # pk fields can not be created on demand, only the database can set th pk for a new object
+ request = factory.put('/5', json.dumps(content),
+ content_type='application/json')
+ response = self.view(request, pk=5).render()
+ self.assertEquals(response.status_code, status.HTTP_201_CREATED)
+ new_obj = self.objects.get(pk=5)
+ self.assertEquals(new_obj.text, 'foobar')
+
+ def test_put_as_create_on_slug_based_url(self):
+ """
+ PUT requests to RetrieveUpdateDestroyAPIView should create an object
+ at the requested url if possible, else return HTTP_403_FORBIDDEN error-response.
+ """
+ content = {'text': 'foobar'}
+ request = factory.put('/test_slug', json.dumps(content),
+ content_type='application/json')
+ response = self.slug_based_view(request, slug='test_slug').render()
+ self.assertEquals(response.status_code, status.HTTP_201_CREATED)
+ self.assertEquals(response.data, {'slug': 'test_slug', 'text': 'foobar'})
+ new_obj = SlugBasedModel.objects.get(slug='test_slug')
+ self.assertEquals(new_obj.text, 'foobar')
+
# Regression test for #285
diff --git a/rest_framework/tests/htmlrenderer.py b/rest_framework/tests/htmlrenderer.py
index da2f83c3..4caed59e 100644
--- a/rest_framework/tests/htmlrenderer.py
+++ b/rest_framework/tests/htmlrenderer.py
@@ -1,14 +1,16 @@
+from django.core.exceptions import PermissionDenied
from django.conf.urls.defaults import patterns, url
+from django.http import Http404
from django.test import TestCase
from django.template import TemplateDoesNotExist, Template
import django.template.loader
from rest_framework.decorators import api_view, renderer_classes
-from rest_framework.renderers import HTMLRenderer
+from rest_framework.renderers import TemplateHTMLRenderer
from rest_framework.response import Response
@api_view(('GET',))
-@renderer_classes((HTMLRenderer,))
+@renderer_classes((TemplateHTMLRenderer,))
def example(request):
"""
A view that can returns an HTML representation.
@@ -17,12 +19,26 @@ def example(request):
return Response(data, template_name='example.html')
+@api_view(('GET',))
+@renderer_classes((TemplateHTMLRenderer,))
+def permission_denied(request):
+ raise PermissionDenied()
+
+
+@api_view(('GET',))
+@renderer_classes((TemplateHTMLRenderer,))
+def not_found(request):
+ raise Http404()
+
+
urlpatterns = patterns('',
url(r'^$', example),
+ url(r'^permission_denied$', permission_denied),
+ url(r'^not_found$', not_found),
)
-class HTMLRendererTests(TestCase):
+class TemplateHTMLRendererTests(TestCase):
urls = 'rest_framework.tests.htmlrenderer'
def setUp(self):
@@ -48,3 +64,52 @@ class HTMLRendererTests(TestCase):
response = self.client.get('/')
self.assertContains(response, "example: foobar")
self.assertEquals(response['Content-Type'], 'text/html')
+
+ def test_not_found_html_view(self):
+ response = self.client.get('/not_found')
+ self.assertEquals(response.status_code, 404)
+ self.assertEquals(response.content, "404 Not Found")
+ self.assertEquals(response['Content-Type'], 'text/html')
+
+ def test_permission_denied_html_view(self):
+ response = self.client.get('/permission_denied')
+ self.assertEquals(response.status_code, 403)
+ self.assertEquals(response.content, "403 Forbidden")
+ self.assertEquals(response['Content-Type'], 'text/html')
+
+
+class TemplateHTMLRendererExceptionTests(TestCase):
+ urls = 'rest_framework.tests.htmlrenderer'
+
+ def setUp(self):
+ """
+ Monkeypatch get_template
+ """
+ self.get_template = django.template.loader.get_template
+
+ def get_template(template_name):
+ if template_name == '404.html':
+ return Template("404: {{ detail }}")
+ if template_name == '403.html':
+ return Template("403: {{ detail }}")
+ raise TemplateDoesNotExist(template_name)
+
+ django.template.loader.get_template = get_template
+
+ def tearDown(self):
+ """
+ Revert monkeypatching
+ """
+ django.template.loader.get_template = self.get_template
+
+ def test_not_found_html_view_with_template(self):
+ response = self.client.get('/not_found')
+ self.assertEquals(response.status_code, 404)
+ self.assertEquals(response.content, "404: Not found")
+ self.assertEquals(response['Content-Type'], 'text/html')
+
+ def test_permission_denied_html_view_with_template(self):
+ response = self.client.get('/permission_denied')
+ self.assertEquals(response.status_code, 403)
+ self.assertEquals(response.content, "403: Permission denied")
+ self.assertEquals(response['Content-Type'], 'text/html')
diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py
index 5532a8ee..f71e2e28 100644
--- a/rest_framework/tests/hyperlinkedserializers.py
+++ b/rest_framework/tests/hyperlinkedserializers.py
@@ -2,11 +2,28 @@ from django.conf.urls.defaults import patterns, url
from django.test import TestCase
from django.test.client import RequestFactory
from rest_framework import generics, status, serializers
-from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel
+from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, Album, Photo
factory = RequestFactory()
+class BlogPostCommentSerializer(serializers.ModelSerializer):
+ text = serializers.CharField()
+ blog_post_url = serializers.HyperlinkedRelatedField(source='blog_post', view_name='blogpost-detail')
+
+ class Meta:
+ model = BlogPostComment
+ fields = ('text', 'blog_post_url')
+
+
+class PhotoSerializer(serializers.Serializer):
+ description = serializers.CharField()
+ album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', queryset=Album.objects.all(), slug_field='title', slug_url_kwarg='title')
+
+ def restore_object(self, attrs, instance=None):
+ return Photo(**attrs)
+
+
class BasicList(generics.ListCreateAPIView):
model = BasicModel
model_serializer_class = serializers.HyperlinkedModelSerializer
@@ -32,12 +49,34 @@ class ManyToManyDetail(generics.RetrieveAPIView):
model_serializer_class = serializers.HyperlinkedModelSerializer
+class BlogPostCommentListCreate(generics.ListCreateAPIView):
+ model = BlogPostComment
+ serializer_class = BlogPostCommentSerializer
+
+
+class BlogPostDetail(generics.RetrieveAPIView):
+ model = BlogPost
+
+
+class PhotoListCreate(generics.ListCreateAPIView):
+ model = Photo
+ model_serializer_class = PhotoSerializer
+
+
+class AlbumDetail(generics.RetrieveAPIView):
+ model = Album
+
+
urlpatterns = patterns('',
url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'),
url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'),
url(r'^anchor/(?P<pk>\d+)/$', AnchorDetail.as_view(), name='anchor-detail'),
url(r'^manytomany/$', ManyToManyList.as_view(), name='manytomanymodel-list'),
url(r'^manytomany/(?P<pk>\d+)/$', ManyToManyDetail.as_view(), name='manytomanymodel-detail'),
+ url(r'^posts/(?P<pk>\d+)/$', BlogPostDetail.as_view(), name='blogpost-detail'),
+ url(r'^comments/$', BlogPostCommentListCreate.as_view(), name='blogpostcomment-list'),
+ url(r'^albums/(?P<title>\w[\w-]*)/$', AlbumDetail.as_view(), name='album-detail'),
+ url(r'^photos/$', PhotoListCreate.as_view(), name='photo-list')
)
@@ -124,3 +163,51 @@ class TestManyToManyHyperlinkedView(TestCase):
response = self.detail_view(request, pk=1).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data[0])
+
+
+class TestCreateWithForeignKeys(TestCase):
+ urls = 'rest_framework.tests.hyperlinkedserializers'
+
+ def setUp(self):
+ """
+ Create a blog post
+ """
+ self.post = BlogPost.objects.create(title="Test post")
+ self.create_view = BlogPostCommentListCreate.as_view()
+
+ def test_create_comment(self):
+
+ data = {
+ 'text': 'A test comment',
+ 'blog_post_url': 'http://testserver/posts/1/'
+ }
+
+ request = factory.post('/comments/', data=data)
+ response = self.create_view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ self.assertEqual(self.post.blogpostcomment_set.count(), 1)
+ self.assertEqual(self.post.blogpostcomment_set.all()[0].text, 'A test comment')
+
+
+class TestCreateWithForeignKeysAndCustomSlug(TestCase):
+ urls = 'rest_framework.tests.hyperlinkedserializers'
+
+ def setUp(self):
+ """
+ Create an Album
+ """
+ self.post = Album.objects.create(title='test-album')
+ self.list_create_view = PhotoListCreate.as_view()
+
+ def test_create_photo(self):
+
+ data = {
+ 'description': 'A test photo',
+ 'album_url': 'http://testserver/albums/test-album/'
+ }
+
+ request = factory.post('/photos/', data=data)
+ response = self.list_create_view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ self.assertEqual(self.post.photo_set.count(), 1)
+ self.assertEqual(self.post.photo_set.all()[0].description, 'A test photo')
diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py
index 6a758f0c..a2aba5be 100644
--- a/rest_framework/tests/models.py
+++ b/rest_framework/tests/models.py
@@ -40,7 +40,7 @@ class RESTFrameworkModel(models.Model):
Base for test models that sets app_label, so they play nicely.
"""
class Meta:
- app_label = 'rest_framework'
+ app_label = 'tests'
abstract = True
@@ -52,6 +52,11 @@ class BasicModel(RESTFrameworkModel):
text = models.CharField(max_length=100)
+class SlugBasedModel(RESTFrameworkModel):
+ text = models.CharField(max_length=100)
+ slug = models.SlugField(max_length=32)
+
+
class DefaultValueModel(RESTFrameworkModel):
text = models.CharField(default='foobar', max_length=100)
@@ -63,6 +68,11 @@ class CallableDefaultValueModel(RESTFrameworkModel):
class ManyToManyModel(RESTFrameworkModel):
rel = models.ManyToManyField(Anchor)
+
+class ReadOnlyManyToManyModel(RESTFrameworkModel):
+ text = models.CharField(max_length=100, default='anchor')
+ rel = models.ManyToManyField(Anchor)
+
# Models to test generic relations
@@ -85,9 +95,57 @@ class Bookmark(RESTFrameworkModel):
tags = GenericRelation(TaggedItem)
+# Model to test filtering.
+class FilterableItem(RESTFrameworkModel):
+ text = models.CharField(max_length=100)
+ decimal = models.DecimalField(max_digits=4, decimal_places=2)
+ date = models.DateField()
+
+
# Model for regression test for #285
class Comment(RESTFrameworkModel):
email = models.EmailField()
content = models.CharField(max_length=200)
created = models.DateTimeField(auto_now_add=True)
+
+
+class ActionItem(RESTFrameworkModel):
+ title = models.CharField(max_length=200)
+ done = models.BooleanField(default=False)
+
+
+# Models for reverse relations
+class BlogPost(RESTFrameworkModel):
+ title = models.CharField(max_length=100)
+
+
+class BlogPostComment(RESTFrameworkModel):
+ text = models.TextField()
+ blog_post = models.ForeignKey(BlogPost)
+
+
+class Album(RESTFrameworkModel):
+ title = models.CharField(max_length=100, unique=True)
+
+
+class Photo(RESTFrameworkModel):
+ description = models.TextField()
+ album = models.ForeignKey(Album)
+
+
+class Person(RESTFrameworkModel):
+ name = models.CharField(max_length=10)
+ age = models.IntegerField(null=True, blank=True)
+
+ @property
+ def info(self):
+ return {
+ 'name': self.name,
+ 'age': self.age,
+ }
+
+
+# Model for issue #324
+class BlankFieldModel(RESTFrameworkModel):
+ title = models.CharField(max_length=100, blank=True)
diff --git a/rest_framework/tests/negotiation.py b/rest_framework/tests/negotiation.py
index d8265b43..e06354ea 100644
--- a/rest_framework/tests/negotiation.py
+++ b/rest_framework/tests/negotiation.py
@@ -18,20 +18,20 @@ class TestAcceptedMediaType(TestCase):
self.renderers = [MockJSONRenderer(), MockHTMLRenderer()]
self.negotiator = DefaultContentNegotiation()
- def negotiate(self, request):
- return self.negotiator.negotiate(request, self.renderers)
+ def select_renderer(self, request):
+ return self.negotiator.select_renderer(request, self.renderers)
def test_client_without_accept_use_renderer(self):
request = factory.get('/')
- accepted_renderer, accepted_media_type = self.negotiate(request)
+ accepted_renderer, accepted_media_type = self.select_renderer(request)
self.assertEquals(accepted_media_type, 'application/json')
def test_client_underspecifies_accept_use_renderer(self):
request = factory.get('/', HTTP_ACCEPT='*/*')
- accepted_renderer, accepted_media_type = self.negotiate(request)
+ accepted_renderer, accepted_media_type = self.select_renderer(request)
self.assertEquals(accepted_media_type, 'application/json')
def test_client_overspecifies_accept_use_client(self):
request = factory.get('/', HTTP_ACCEPT='application/json; indent=8')
- accepted_renderer, accepted_media_type = self.negotiate(request)
+ accepted_renderer, accepted_media_type = self.select_renderer(request)
self.assertEquals(accepted_media_type, 'application/json; indent=8')
diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py
index a939c9ef..713a7255 100644
--- a/rest_framework/tests/pagination.py
+++ b/rest_framework/tests/pagination.py
@@ -1,8 +1,12 @@
+import datetime
+from decimal import Decimal
from django.core.paginator import Paginator
from django.test import TestCase
from django.test.client import RequestFactory
-from rest_framework import generics, status, pagination
-from rest_framework.tests.models import BasicModel
+from django.utils import unittest
+from rest_framework import generics, status, pagination, filters
+from rest_framework.compat import django_filters
+from rest_framework.tests.models import BasicModel, FilterableItem
factory = RequestFactory()
@@ -15,6 +19,21 @@ class RootView(generics.ListCreateAPIView):
paginate_by = 10
+if django_filters:
+ class DecimalFilter(django_filters.FilterSet):
+ decimal = django_filters.NumberFilter(lookup_type='lt')
+
+ class Meta:
+ model = FilterableItem
+ fields = ['text', 'decimal', 'date']
+
+ class FilterFieldsRootView(generics.ListCreateAPIView):
+ model = FilterableItem
+ paginate_by = 10
+ filter_class = DecimalFilter
+ filter_backend = filters.DjangoFilterBackend
+
+
class IntegrationTestPagination(TestCase):
"""
Integration tests for paginated list views.
@@ -22,7 +41,7 @@ class IntegrationTestPagination(TestCase):
def setUp(self):
"""
- Create 26 BasicModel intances.
+ Create 26 BasicModel instances.
"""
for char in 'abcdefghijklmnopqrstuvwxyz':
BasicModel(text=char * 3).save()
@@ -62,6 +81,58 @@ class IntegrationTestPagination(TestCase):
self.assertNotEquals(response.data['previous'], None)
+class IntegrationTestPaginationAndFiltering(TestCase):
+
+ def setUp(self):
+ """
+ Create 50 FilterableItem instances.
+ """
+ base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
+ for i in range(26):
+ text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
+ decimal = base_data[1] + i
+ date = base_data[2] - datetime.timedelta(days=i * 2)
+ FilterableItem(text=text, decimal=decimal, date=date).save()
+
+ self.objects = FilterableItem.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
+ for obj in self.objects.all()
+ ]
+ self.view = FilterFieldsRootView.as_view()
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_get_paginated_filtered_root_view(self):
+ """
+ GET requests to paginated filtered ListCreateAPIView should return
+ paginated results. The next and previous links should preserve the
+ filtered parameters.
+ """
+ request = factory.get('/?decimal=15.20')
+ response = self.view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data['count'], 15)
+ self.assertEquals(response.data['results'], self.data[:10])
+ self.assertNotEquals(response.data['next'], None)
+ self.assertEquals(response.data['previous'], None)
+
+ request = factory.get(response.data['next'])
+ response = self.view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data['count'], 15)
+ self.assertEquals(response.data['results'], self.data[10:15])
+ self.assertEquals(response.data['next'], None)
+ self.assertNotEquals(response.data['previous'], None)
+
+ request = factory.get(response.data['previous'])
+ response = self.view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data['count'], 15)
+ self.assertEquals(response.data['results'], self.data[:10])
+ self.assertNotEquals(response.data['next'], None)
+ self.assertEquals(response.data['previous'], None)
+
+
class UnitTestPagination(TestCase):
"""
Unit tests for pagination of primative objects.
@@ -74,13 +145,13 @@ class UnitTestPagination(TestCase):
self.last_page = paginator.page(3)
def test_native_pagination(self):
- serializer = pagination.PaginationSerializer(instance=self.first_page)
+ serializer = pagination.PaginationSerializer(self.first_page)
self.assertEquals(serializer.data['count'], 26)
self.assertEquals(serializer.data['next'], '?page=2')
self.assertEquals(serializer.data['previous'], None)
self.assertEquals(serializer.data['results'], self.objects[:10])
- serializer = pagination.PaginationSerializer(instance=self.last_page)
+ serializer = pagination.PaginationSerializer(self.last_page)
self.assertEquals(serializer.data['count'], 26)
self.assertEquals(serializer.data['next'], None)
self.assertEquals(serializer.data['previous'], '?page=2')
diff --git a/rest_framework/tests/pk_relations.py b/rest_framework/tests/pk_relations.py
new file mode 100644
index 00000000..44ae4040
--- /dev/null
+++ b/rest_framework/tests/pk_relations.py
@@ -0,0 +1,205 @@
+from django.db import models
+from django.test import TestCase
+from rest_framework import serializers
+
+
+# ManyToMany
+
+class ManyToManyTarget(models.Model):
+ name = models.CharField(max_length=100)
+
+
+class ManyToManySource(models.Model):
+ name = models.CharField(max_length=100)
+ targets = models.ManyToManyField(ManyToManyTarget, related_name='sources')
+
+
+class ManyToManyTargetSerializer(serializers.ModelSerializer):
+ sources = serializers.ManyPrimaryKeyRelatedField()
+
+ class Meta:
+ model = ManyToManyTarget
+
+
+class ManyToManySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ManyToManySource
+
+
+# ForeignKey
+
+class ForeignKeyTarget(models.Model):
+ name = models.CharField(max_length=100)
+
+
+class ForeignKeySource(models.Model):
+ name = models.CharField(max_length=100)
+ target = models.ForeignKey(ForeignKeyTarget, related_name='sources')
+
+
+class ForeignKeyTargetSerializer(serializers.ModelSerializer):
+ sources = serializers.ManyPrimaryKeyRelatedField(read_only=True)
+
+ class Meta:
+ model = ForeignKeyTarget
+
+
+class ForeignKeySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ForeignKeySource
+
+
+# TODO: Add test that .data cannot be accessed prior to .is_valid
+
+class PrimaryKeyManyToManyTests(TestCase):
+ def setUp(self):
+ for idx in range(1, 4):
+ target = ManyToManyTarget(name='target-%d' % idx)
+ target.save()
+ source = ManyToManySource(name='source-%d' % idx)
+ source.save()
+ for target in ManyToManyTarget.objects.all():
+ source.targets.add(target)
+
+ def test_many_to_many_retrieve(self):
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'targets': [1]},
+ {'id': 2, 'name': u'source-2', 'targets': [1, 2]},
+ {'id': 3, 'name': u'source-3', 'targets': [1, 2, 3]}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_many_to_many_retrieve(self):
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]},
+ {'id': 2, 'name': u'target-2', 'sources': [2, 3]},
+ {'id': 3, 'name': u'target-3', 'sources': [3]}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_many_to_many_update(self):
+ data = {'id': 1, 'name': u'source-1', 'targets': [1, 2, 3]}
+ instance = ManyToManySource.objects.get(pk=1)
+ serializer = ManyToManySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEquals(serializer.data, data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'targets': [1, 2, 3]},
+ {'id': 2, 'name': u'source-2', 'targets': [1, 2]},
+ {'id': 3, 'name': u'source-3', 'targets': [1, 2, 3]}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_many_to_many_update(self):
+ data = {'id': 1, 'name': u'target-1', 'sources': [1]}
+ instance = ManyToManyTarget.objects.get(pk=1)
+ serializer = ManyToManyTargetSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEquals(serializer.data, data)
+ serializer.save()
+
+ # Ensure target 1 is updated, and everything else is as expected
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'sources': [1]},
+ {'id': 2, 'name': u'target-2', 'sources': [2, 3]},
+ {'id': 3, 'name': u'target-3', 'sources': [3]}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_many_to_many_create(self):
+ data = {'id': 4, 'name': u'target-4', 'sources': [1, 3]}
+ serializer = ManyToManyTargetSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEquals(serializer.data, data)
+ self.assertEqual(obj.name, u'target-4')
+
+ # Ensure target 4 is added, and everything else is as expected
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]},
+ {'id': 2, 'name': u'target-2', 'sources': [2, 3]},
+ {'id': 3, 'name': u'target-3', 'sources': [3]},
+ {'id': 4, 'name': u'target-4', 'sources': [1, 3]}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+class PrimaryKeyForeignKeyTests(TestCase):
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ new_target = ForeignKeyTarget(name='target-2')
+ new_target.save()
+ for idx in range(1, 4):
+ source = ForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve(self):
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': 1},
+ {'id': 2, 'name': u'source-2', 'target': 1},
+ {'id': 3, 'name': u'source-3', 'target': 1}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_reverse_foreign_key_retrieve(self):
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]},
+ {'id': 2, 'name': u'target-2', 'sources': []},
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ def test_foreign_key_update(self):
+ data = {'id': 1, 'name': u'source-1', 'target': 2}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEquals(serializer.data, data)
+ serializer.save()
+
+ # # Ensure source 1 is updated, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': 2},
+ {'id': 2, 'name': u'source-2', 'target': 1},
+ {'id': 3, 'name': u'source-3', 'target': 1}
+ ]
+ self.assertEquals(serializer.data, expected)
+
+ # reverse foreign keys MUST be read_only
+ # In the general case they do not provide .remove() or .clear()
+ # and cannot be arbitrarily set.
+
+ # def test_reverse_foreign_key_update(self):
+ # data = {'id': 1, 'name': u'target-1', 'sources': [1]}
+ # instance = ForeignKeyTarget.objects.get(pk=1)
+ # serializer = ForeignKeyTargetSerializer(instance, data=data)
+ # self.assertTrue(serializer.is_valid())
+ # self.assertEquals(serializer.data, data)
+ # serializer.save()
+
+ # # Ensure target 1 is updated, and everything else is as expected
+ # queryset = ForeignKeyTarget.objects.all()
+ # serializer = ForeignKeyTargetSerializer(queryset)
+ # expected = [
+ # {'id': 1, 'name': u'target-1', 'sources': [1]},
+ # {'id': 2, 'name': u'target-2', 'sources': []},
+ # ]
+ # self.assertEquals(serializer.data, expected)
diff --git a/rest_framework/tests/renderers.py b/rest_framework/tests/renderers.py
index 48d8d9bd..9be4b114 100644
--- a/rest_framework/tests/renderers.py
+++ b/rest_framework/tests/renderers.py
@@ -1,6 +1,8 @@
+import pickle
import re
from django.conf.urls.defaults import patterns, url, include
+from django.core.cache import cache
from django.test import TestCase
from django.test.client import RequestFactory
@@ -83,6 +85,7 @@ class HTMLView1(APIView):
urlpatterns = patterns('',
url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
+ url(r'^cache$', MockGETView.as_view()),
url(r'^jsonp/jsonrenderer$', MockGETView.as_view(renderer_classes=[JSONRenderer, JSONPRenderer])),
url(r'^jsonp/nojsonrenderer$', MockGETView.as_view(renderer_classes=[JSONPRenderer])),
url(r'^html$', HTMLView.as_view()),
@@ -416,3 +419,89 @@ class XMLRendererTestCase(TestCase):
self.assertTrue(xml.startswith('<?xml version="1.0" encoding="utf-8"?>\n<root>'))
self.assertTrue(xml.endswith('</root>'))
self.assertTrue(string in xml, '%r not in %r' % (string, xml))
+
+
+# Tests for caching issue, #346
+class CacheRenderTest(TestCase):
+ """
+ Tests specific to caching responses
+ """
+
+ urls = 'rest_framework.tests.renderers'
+
+ cache_key = 'just_a_cache_key'
+
+ @classmethod
+ def _get_pickling_errors(cls, obj, seen=None):
+ """ Return any errors that would be raised if `obj' is pickled
+ Courtesy of koffie @ http://stackoverflow.com/a/7218986/109897
+ """
+ if seen == None:
+ seen = []
+ try:
+ state = obj.__getstate__()
+ except AttributeError:
+ return
+ if state == None:
+ return
+ if isinstance(state,tuple):
+ if not isinstance(state[0],dict):
+ state=state[1]
+ else:
+ state=state[0].update(state[1])
+ result = {}
+ for i in state:
+ try:
+ pickle.dumps(state[i],protocol=2)
+ except pickle.PicklingError:
+ if not state[i] in seen:
+ seen.append(state[i])
+ result[i] = cls._get_pickling_errors(state[i],seen)
+ return result
+
+ def http_resp(self, http_method, url):
+ """
+ Simple wrapper for Client http requests
+ Removes the `client' and `request' attributes from as they are
+ added by django.test.client.Client and not part of caching
+ responses outside of tests.
+ """
+ method = getattr(self.client, http_method)
+ resp = method(url)
+ del resp.client, resp.request
+ return resp
+
+ def test_obj_pickling(self):
+ """
+ Test that responses are properly pickled
+ """
+ resp = self.http_resp('get', '/cache')
+
+ # Make sure that no pickling errors occurred
+ self.assertEqual(self._get_pickling_errors(resp), {})
+
+ # Unfortunately LocMem backend doesn't raise PickleErrors but returns
+ # None instead.
+ cache.set(self.cache_key, resp)
+ self.assertTrue(cache.get(self.cache_key) is not None)
+
+ def test_head_caching(self):
+ """
+ Test caching of HEAD requests
+ """
+ resp = self.http_resp('head', '/cache')
+ cache.set(self.cache_key, resp)
+
+ cached_resp = cache.get(self.cache_key)
+ self.assertIsInstance(cached_resp, Response)
+
+ def test_get_caching(self):
+ """
+ Test caching of GET requests
+ """
+ resp = self.http_resp('get', '/cache')
+ cache.set(self.cache_key, resp)
+
+ cached_resp = cache.get(self.cache_key)
+ self.assertIsInstance(cached_resp, Response)
+ self.assertEqual(cached_resp.content, resp.content)
diff --git a/rest_framework/tests/request.py b/rest_framework/tests/request.py
index f90bebf4..ff48f3fa 100644
--- a/rest_framework/tests/request.py
+++ b/rest_framework/tests/request.py
@@ -27,7 +27,7 @@ factory = RequestFactory()
class PlainTextParser(BaseParser):
media_type = 'text/plain'
- def parse_stream(self, stream, parser_context=None):
+ def parse(self, stream, media_type=None, parser_context=None):
"""
Returns a 2-tuple of `(data, files)`.
diff --git a/rest_framework/tests/response.py b/rest_framework/tests/response.py
index 18b6af39..d7b75450 100644
--- a/rest_framework/tests/response.py
+++ b/rest_framework/tests/response.py
@@ -131,12 +131,6 @@ class RendererIntegrationTests(TestCase):
self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEquals(resp.status_code, DUMMYSTATUS)
- @unittest.skip('can\'t pass because view is a simple Django view and response is an ImmediateResponse')
- def test_unsatisfiable_accept_header_on_request_returns_406_status(self):
- """If the Accept header is unsatisfiable we should return a 406 Not Acceptable response."""
- resp = self.client.get('/', HTTP_ACCEPT='foo/bar')
- self.assertEquals(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE)
-
def test_specified_renderer_serializes_content_on_format_query(self):
"""If a 'format' query is specified, the renderer with the matching
format attribute should serialize the response."""
diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py
index 256987ad..8d1de429 100644
--- a/rest_framework/tests/serializer.py
+++ b/rest_framework/tests/serializer.py
@@ -1,7 +1,14 @@
import datetime
from django.test import TestCase
from rest_framework import serializers
-from rest_framework.tests.models import *
+from rest_framework.tests.models import (ActionItem, Anchor, BasicModel,
+ BlankFieldModel, BlogPost, CallableDefaultValueModel, DefaultValueModel,
+ ManyToManyModel, Person, ReadOnlyManyToManyModel)
+
+
+class SubComment(object):
+ def __init__(self, sub_comment):
+ self.sub_comment = sub_comment
class Comment(object):
@@ -14,11 +21,16 @@ class Comment(object):
return all([getattr(self, attr) == getattr(other, attr)
for attr in ('email', 'content', 'created')])
+ def get_sub_comment(self):
+ sub_comment = SubComment('And Merry Christmas!')
+ return sub_comment
+
class CommentSerializer(serializers.Serializer):
email = serializers.EmailField()
content = serializers.CharField(max_length=1000)
created = serializers.DateTimeField()
+ sub_comment = serializers.Field(source='get_sub_comment.sub_comment')
def restore_object(self, data, instance=None):
if instance is None:
@@ -28,6 +40,19 @@ class CommentSerializer(serializers.Serializer):
return instance
+class ActionItemSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ActionItem
+
+
+class PersonSerializer(serializers.ModelSerializer):
+ info = serializers.Field(source='info')
+
+ class Meta:
+ model = Person
+ fields = ('name', 'age', 'info')
+
+
class BasicTests(TestCase):
def setUp(self):
self.comment = Comment(
@@ -38,36 +63,63 @@ class BasicTests(TestCase):
self.data = {
'email': 'tom@example.com',
'content': 'Happy new year!',
- 'created': datetime.datetime(2012, 1, 1)
+ 'created': datetime.datetime(2012, 1, 1),
+ 'sub_comment': 'This wont change'
+ }
+ self.expected = {
+ 'email': 'tom@example.com',
+ 'content': 'Happy new year!',
+ 'created': datetime.datetime(2012, 1, 1),
+ 'sub_comment': 'And Merry Christmas!'
}
+ self.person_data = {'name': 'dwight', 'age': 35}
+ self.person = Person(**self.person_data)
+ self.person.save()
def test_empty(self):
serializer = CommentSerializer()
expected = {
'email': '',
'content': '',
- 'created': None
+ 'created': None,
+ 'sub_comment': ''
}
self.assertEquals(serializer.data, expected)
def test_retrieve(self):
- serializer = CommentSerializer(instance=self.comment)
- expected = self.data
- self.assertEquals(serializer.data, expected)
+ serializer = CommentSerializer(self.comment)
+ self.assertEquals(serializer.data, self.expected)
def test_create(self):
- serializer = CommentSerializer(self.data)
+ serializer = CommentSerializer(data=self.data)
expected = self.comment
self.assertEquals(serializer.is_valid(), True)
self.assertEquals(serializer.object, expected)
self.assertFalse(serializer.object is expected)
+ self.assertEquals(serializer.data['sub_comment'], 'And Merry Christmas!')
def test_update(self):
- serializer = CommentSerializer(self.data, instance=self.comment)
+ serializer = CommentSerializer(self.comment, data=self.data)
expected = self.comment
self.assertEquals(serializer.is_valid(), True)
self.assertEquals(serializer.object, expected)
self.assertTrue(serializer.object is expected)
+ self.assertEquals(serializer.data['sub_comment'], 'And Merry Christmas!')
+
+ def test_model_fields_as_expected(self):
+ """ Make sure that the fields returned are the same as defined
+ in the Meta data
+ """
+ serializer = PersonSerializer(self.person)
+ self.assertEquals(set(serializer.data.keys()),
+ set(['name', 'age', 'info']))
+
+ def test_field_with_dictionary(self):
+ """ Make sure that dictionaries from fields are left intact
+ """
+ serializer = PersonSerializer(self.person)
+ expected = self.person_data
+ self.assertEquals(serializer.data['info'], expected)
class ValidationTests(TestCase):
@@ -82,14 +134,16 @@ class ValidationTests(TestCase):
'content': 'x' * 1001,
'created': datetime.datetime(2012, 1, 1)
}
+ self.actionitem = ActionItem('Some to do item',
+ )
def test_create(self):
- serializer = CommentSerializer(self.data)
+ serializer = CommentSerializer(data=self.data)
self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, {'content': [u'Ensure this value has at most 1000 characters (it has 1001).']})
def test_update(self):
- serializer = CommentSerializer(self.data, instance=self.comment)
+ serializer = CommentSerializer(self.comment, data=self.data)
self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, {'content': [u'Ensure this value has at most 1000 characters (it has 1001).']})
@@ -98,10 +152,78 @@ class ValidationTests(TestCase):
'content': 'xxx',
'created': datetime.datetime(2012, 1, 1)
}
- serializer = CommentSerializer(data, instance=self.comment)
+ serializer = CommentSerializer(self.comment, data=data)
self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, {'email': [u'This field is required.']})
+ def test_missing_bool_with_default(self):
+ """Make sure that a boolean value with a 'False' value is not
+ mistaken for not having a default."""
+ data = {
+ 'title': 'Some action item',
+ #No 'done' value.
+ }
+ serializer = ActionItemSerializer(self.actionitem, data=data)
+ self.assertEquals(serializer.is_valid(), True)
+ self.assertEquals(serializer.errors, {})
+
+ def test_field_validation(self):
+
+ class CommentSerializerWithFieldValidator(CommentSerializer):
+
+ def validate_content(self, attrs, source):
+ value = attrs[source]
+ if "test" not in value:
+ raise serializers.ValidationError("Test not in value")
+ return attrs
+
+ data = {
+ 'email': 'tom@example.com',
+ 'content': 'A test comment',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+
+ serializer = CommentSerializerWithFieldValidator(data=data)
+ self.assertTrue(serializer.is_valid())
+
+ data['content'] = 'This should not validate'
+
+ serializer = CommentSerializerWithFieldValidator(data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'content': [u'Test not in value']})
+
+ def test_cross_field_validation(self):
+
+ class CommentSerializerWithCrossFieldValidator(CommentSerializer):
+
+ def validate(self, attrs):
+ if attrs["email"] not in attrs["content"]:
+ raise serializers.ValidationError("Email address not in content")
+ return attrs
+
+ data = {
+ 'email': 'tom@example.com',
+ 'content': 'A comment from tom@example.com',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+
+ serializer = CommentSerializerWithCrossFieldValidator(data=data)
+ self.assertTrue(serializer.is_valid())
+
+ data['content'] = 'A comment from foo@bar.com'
+
+ serializer = CommentSerializerWithCrossFieldValidator(data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'non_field_errors': [u'Email address not in content']})
+
+ def test_null_is_true_fields(self):
+ """
+ Omitting a value for null-field should validate.
+ """
+ serializer = PersonSerializer(data={'name': 'marko'})
+ self.assertEquals(serializer.is_valid(), True)
+ self.assertEquals(serializer.errors, {})
+
class MetadataTests(TestCase):
def test_empty(self):
@@ -148,7 +270,7 @@ class ManyToManyTests(TestCase):
Create an instance of a model with a ManyToMany relationship.
"""
data = {'rel': [self.anchor.id]}
- serializer = self.serializer_class(data)
+ serializer = self.serializer_class(data=data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(ManyToManyModel.objects.all()), 2)
@@ -162,7 +284,7 @@ class ManyToManyTests(TestCase):
new_anchor = Anchor()
new_anchor.save()
data = {'rel': [self.anchor.id, new_anchor.id]}
- serializer = self.serializer_class(data, instance=self.instance)
+ serializer = self.serializer_class(self.instance, data=data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(ManyToManyModel.objects.all()), 1)
@@ -175,7 +297,7 @@ class ManyToManyTests(TestCase):
containing no items.
"""
data = {'rel': []}
- serializer = self.serializer_class(data)
+ serializer = self.serializer_class(data=data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(ManyToManyModel.objects.all()), 2)
@@ -190,7 +312,7 @@ class ManyToManyTests(TestCase):
new_anchor = Anchor()
new_anchor.save()
data = {'rel': []}
- serializer = self.serializer_class(data, instance=self.instance)
+ serializer = self.serializer_class(self.instance, data=data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(ManyToManyModel.objects.all()), 1)
@@ -204,7 +326,7 @@ class ManyToManyTests(TestCase):
lists (eg form data).
"""
data = {'rel': ''}
- serializer = self.serializer_class(data)
+ serializer = self.serializer_class(data=data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(ManyToManyModel.objects.all()), 2)
@@ -212,6 +334,61 @@ class ManyToManyTests(TestCase):
self.assertEquals(list(instance.rel.all()), [])
+class ReadOnlyManyToManyTests(TestCase):
+ def setUp(self):
+ class ReadOnlyManyToManySerializer(serializers.ModelSerializer):
+ rel = serializers.ManyRelatedField(read_only=True)
+
+ class Meta:
+ model = ReadOnlyManyToManyModel
+
+ self.serializer_class = ReadOnlyManyToManySerializer
+
+ # An anchor instance to use for the relationship
+ self.anchor = Anchor()
+ self.anchor.save()
+
+ # A model instance with a many to many relationship to the anchor
+ self.instance = ReadOnlyManyToManyModel()
+ self.instance.save()
+ self.instance.rel.add(self.anchor)
+
+ # A serialized representation of the model instance
+ self.data = {'rel': [self.anchor.id], 'id': 1, 'text': 'anchor'}
+
+ def test_update(self):
+ """
+ Attempt to update an instance of a model with a ManyToMany
+ relationship. Not updated due to read_only=True
+ """
+ new_anchor = Anchor()
+ new_anchor.save()
+ data = {'rel': [self.anchor.id, new_anchor.id]}
+ serializer = self.serializer_class(self.instance, data=data)
+ self.assertEquals(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEquals(len(ReadOnlyManyToManyModel.objects.all()), 1)
+ self.assertEquals(instance.pk, 1)
+ # rel is still as original (1 entry)
+ self.assertEquals(list(instance.rel.all()), [self.anchor])
+
+ def test_update_without_relationship(self):
+ """
+ Attempt to update an instance of a model where many to ManyToMany
+ relationship is not supplied. Not updated due to read_only=True
+ """
+ new_anchor = Anchor()
+ new_anchor.save()
+ data = {}
+ serializer = self.serializer_class(self.instance, data=data)
+ self.assertEquals(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEquals(len(ReadOnlyManyToManyModel.objects.all()), 1)
+ self.assertEquals(instance.pk, 1)
+ # rel is still as original (1 entry)
+ self.assertEquals(list(instance.rel.all()), [self.anchor])
+
+
class DefaultValueTests(TestCase):
def setUp(self):
class DefaultValueSerializer(serializers.ModelSerializer):
@@ -223,7 +400,7 @@ class DefaultValueTests(TestCase):
def test_create_using_default(self):
data = {}
- serializer = self.serializer_class(data)
+ serializer = self.serializer_class(data=data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(self.objects.all()), 1)
@@ -232,7 +409,7 @@ class DefaultValueTests(TestCase):
def test_create_overriding_default(self):
data = {'text': 'overridden'}
- serializer = self.serializer_class(data)
+ serializer = self.serializer_class(data=data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(self.objects.all()), 1)
@@ -251,7 +428,7 @@ class CallableDefaultValueTests(TestCase):
def test_create_using_default(self):
data = {}
- serializer = self.serializer_class(data)
+ serializer = self.serializer_class(data=data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(self.objects.all()), 1)
@@ -260,9 +437,87 @@ class CallableDefaultValueTests(TestCase):
def test_create_overriding_default(self):
data = {'text': 'overridden'}
- serializer = self.serializer_class(data)
+ serializer = self.serializer_class(data=data)
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
self.assertEquals(len(self.objects.all()), 1)
self.assertEquals(instance.pk, 1)
self.assertEquals(instance.text, 'overridden')
+
+
+class ManyRelatedTests(TestCase):
+ def setUp(self):
+
+ class BlogPostCommentSerializer(serializers.Serializer):
+ text = serializers.CharField()
+
+ class BlogPostSerializer(serializers.Serializer):
+ title = serializers.CharField()
+ comments = BlogPostCommentSerializer(source='blogpostcomment_set')
+
+ self.serializer_class = BlogPostSerializer
+
+ def test_reverse_relations(self):
+ post = BlogPost.objects.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I hate this blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ serializer = self.serializer_class(instance=post)
+ expected = {
+ 'title': 'Test blog post',
+ 'comments': [
+ {'text': 'I hate this blog post'},
+ {'text': 'I love this blog post'}
+ ]
+ }
+
+ self.assertEqual(serializer.data, expected)
+
+
+# Test for issue #324
+class BlankFieldTests(TestCase):
+ def setUp(self):
+
+ class BlankFieldModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlankFieldModel
+
+ class BlankFieldSerializer(serializers.Serializer):
+ title = serializers.CharField(blank=True)
+
+ class NotBlankFieldModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BasicModel
+
+ class NotBlankFieldSerializer(serializers.Serializer):
+ title = serializers.CharField()
+
+ self.model_serializer_class = BlankFieldModelSerializer
+ self.serializer_class = BlankFieldSerializer
+ self.not_blank_model_serializer_class = NotBlankFieldModelSerializer
+ self.not_blank_serializer_class = NotBlankFieldSerializer
+ self.data = {'title': ''}
+
+ def test_create_blank_field(self):
+ serializer = self.serializer_class(data=self.data)
+ self.assertEquals(serializer.is_valid(), True)
+
+ def test_create_model_blank_field(self):
+ serializer = self.model_serializer_class(data=self.data)
+ self.assertEquals(serializer.is_valid(), True)
+
+ def test_create_not_blank_field(self):
+ """
+ Test to ensure blank data in a field not marked as blank=True
+ is considered invalid in a non-model serializer
+ """
+ serializer = self.not_blank_serializer_class(data=self.data)
+ self.assertEquals(serializer.is_valid(), False)
+
+ def test_create_model_not_blank_field(self):
+ """
+ Test to ensure blank data in a field not marked as blank=True
+ is considered invalid in a model serializer
+ """
+ serializer = self.not_blank_model_serializer_class(data=self.data)
+ self.assertEquals(serializer.is_valid(), False)
diff --git a/rest_framework/tests/tests.py b/rest_framework/tests/tests.py
new file mode 100644
index 00000000..adeaf6da
--- /dev/null
+++ b/rest_framework/tests/tests.py
@@ -0,0 +1,13 @@
+"""
+Force import of all modules in this package in order to get the standard test
+runner to pick up the tests. Yowzers.
+"""
+import os
+
+modules = [filename.rsplit('.', 1)[0]
+ for filename in os.listdir(os.path.dirname(__file__))
+ if filename.endswith('.py') and not filename.startswith('_')]
+__test__ = dict()
+
+for module in modules:
+ exec("from rest_framework.tests.%s import *" % module)
diff --git a/rest_framework/tests/validators.py b/rest_framework/tests/validators.py
index b390c42f..c032985e 100644
--- a/rest_framework/tests/validators.py
+++ b/rest_framework/tests/validators.py
@@ -285,7 +285,7 @@
# uiop = models.CharField(max_length=256, blank=True)
# @property
-# def readonly(self):
+# def read_only(self):
# return 'read only'
# class MockResource(ModelResource):
@@ -298,7 +298,7 @@
# def test_property_fields_are_allowed_on_model_forms(self):
# """Validation on ModelForms may include property fields that exist on the Model to be included in the input."""
-# content = {'qwerty': 'example', 'uiop': 'example', 'readonly': 'read only'}
+# content = {'qwerty': 'example', 'uiop': 'example', 'read_only': 'read only'}
# self.assertEqual(self.validator.validate_request(content, None), content)
# def test_property_fields_are_not_required_on_model_forms(self):
@@ -310,19 +310,19 @@
# """If some (otherwise valid) content includes fields that are not in the form then validation should fail.
# It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up
# broken clients more easily (eg submitting content with a misnamed field)"""
-# content = {'qwerty': 'example', 'uiop': 'example', 'readonly': 'read only', 'extra': 'extra'}
+# content = {'qwerty': 'example', 'uiop': 'example', 'read_only': 'read only', 'extra': 'extra'}
# self.assertRaises(ImmediateResponse, self.validator.validate_request, content, None)
# def test_validate_requires_fields_on_model_forms(self):
# """If some (otherwise valid) content includes fields that are not in the form then validation should fail.
# It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up
# broken clients more easily (eg submitting content with a misnamed field)"""
-# content = {'readonly': 'read only'}
+# content = {'read_only': 'read only'}
# self.assertRaises(ImmediateResponse, self.validator.validate_request, content, None)
# def test_validate_does_not_require_blankable_fields_on_model_forms(self):
# """Test standard ModelForm validation behaviour - fields with blank=True are not required."""
-# content = {'qwerty': 'example', 'readonly': 'read only'}
+# content = {'qwerty': 'example', 'read_only': 'read only'}
# self.validator.validate_request(content, None)
# def test_model_form_validator_uses_model_forms(self):
diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py
index 6e7a0b72..8fe64248 100644
--- a/rest_framework/throttling.py
+++ b/rest_framework/throttling.py
@@ -16,7 +16,7 @@ class BaseThrottle(object):
def wait(self):
"""
- Optionally, return a recommeded number of seconds to wait before
+ Optionally, return a recommended number of seconds to wait before
the next request.
"""
return None
@@ -60,7 +60,7 @@ class SimpleRateThrottle(BaseThrottle):
Determine the string representation of the allowed request rate.
"""
if not getattr(self, 'scope', None):
- msg = ("You must set either `.scope` or `.rate` for '%s' thottle" %
+ msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
self.__class__.__name__)
raise exceptions.ConfigurationError(msg)
@@ -137,7 +137,7 @@ class AnonRateThrottle(SimpleRateThrottle):
"""
Limits the rate of API calls that may be made by a anonymous users.
- The IP address of the request will be used as the unqiue cache key.
+ The IP address of the request will be used as the unique cache key.
"""
scope = 'anon'
diff --git a/rest_framework/urlpatterns.py b/rest_framework/urlpatterns.py
index 386c78a2..316ccd19 100644
--- a/rest_framework/urlpatterns.py
+++ b/rest_framework/urlpatterns.py
@@ -2,26 +2,23 @@ from django.conf.urls.defaults import url
from rest_framework.settings import api_settings
-def format_suffix_patterns(urlpatterns, suffix_required=False,
- suffix_kwarg=None, allowed=None):
+def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None):
"""
Supplement existing urlpatterns with corrosponding patterns that also
include a '.format' suffix. Retains urlpattern ordering.
+ urlpatterns:
+ A list of URL patterns.
+
suffix_required:
If `True`, only suffixed URLs will be generated, and non-suffixed
URLs will not be used. Defaults to `False`.
- suffix_kwarg:
- The name of the kwarg that will be passed to the view.
- Defaults to 'format'.
-
allowed:
An optional tuple/list of allowed suffixes. eg ['json', 'api']
Defaults to `None`, which allows any suffix.
-
"""
- suffix_kwarg = suffix_kwarg or api_settings.FORMAT_SUFFIX_KWARG
+ suffix_kwarg = api_settings.FORMAT_SUFFIX_KWARG
if allowed:
if len(allowed) == 1:
allowed_pattern = allowed[0]
diff --git a/rest_framework/utils/__init__.py b/rest_framework/utils/__init__.py
index a59fff45..84fcb5db 100644
--- a/rest_framework/utils/__init__.py
+++ b/rest_framework/utils/__init__.py
@@ -1,7 +1,6 @@
from django.utils.encoding import smart_unicode
from django.utils.xmlutils import SimplerXMLGenerator
from rest_framework.compat import StringIO
-
import re
import xml.etree.ElementTree as ET
diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py
index 672d32a3..80e39d46 100644
--- a/rest_framework/utils/breadcrumbs.py
+++ b/rest_framework/utils/breadcrumbs.py
@@ -6,7 +6,7 @@ def get_breadcrumbs(url):
from rest_framework.views import APIView
- def breadcrumbs_recursive(url, breadcrumbs_list, prefix):
+ def breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen):
"""Add tuples of (name, url) to the breadcrumbs list, progressively chomping off parts of the url."""
try:
@@ -16,7 +16,11 @@ def get_breadcrumbs(url):
else:
# Check if this is a REST framework view, and if so add it to the breadcrumbs
if isinstance(getattr(view, 'cls_instance', None), APIView):
- breadcrumbs_list.insert(0, (view.cls_instance.get_name(), prefix + url))
+ # Don't list the same view twice in a row.
+ # Probably an optional trailing slash.
+ if not seen or seen[-1] != view:
+ breadcrumbs_list.insert(0, (view.cls_instance.get_name(), prefix + url))
+ seen.append(view)
if url == '':
# All done
@@ -24,11 +28,11 @@ def get_breadcrumbs(url):
elif url.endswith('/'):
# Drop trailing slash off the end and continue to try to resolve more breadcrumbs
- return breadcrumbs_recursive(url.rstrip('/'), breadcrumbs_list, prefix)
+ return breadcrumbs_recursive(url.rstrip('/'), breadcrumbs_list, prefix, seen)
# Drop trailing non-slash off the end and continue to try to resolve more breadcrumbs
- return breadcrumbs_recursive(url[:url.rfind('/') + 1], breadcrumbs_list, prefix)
+ return breadcrumbs_recursive(url[:url.rfind('/') + 1], breadcrumbs_list, prefix, seen)
prefix = get_script_prefix().rstrip('/')
url = url[len(prefix):]
- return breadcrumbs_recursive(url, [], prefix)
+ return breadcrumbs_recursive(url, [], prefix, [])
diff --git a/rest_framework/views.py b/rest_framework/views.py
index 62fc92f9..1afbd697 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -54,12 +54,12 @@ def _camelcase_to_spaces(content):
class APIView(View):
settings = api_settings
- renderer_classes = api_settings.DEFAULT_RENDERERS
- parser_classes = api_settings.DEFAULT_PARSERS
- authentication_classes = api_settings.DEFAULT_AUTHENTICATION
- throttle_classes = api_settings.DEFAULT_THROTTLES
- permission_classes = api_settings.DEFAULT_PERMISSIONS
- content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION
+ renderer_classes = api_settings.DEFAULT_RENDERER_CLASSES
+ parser_classes = api_settings.DEFAULT_PARSER_CLASSES
+ authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES
+ throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES
+ permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES
+ content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS
@classmethod
def as_view(cls, **initkwargs):
@@ -158,12 +158,15 @@ class APIView(View):
def get_parser_context(self, http_request):
"""
- Returns a dict that is passed through to Parser.parse_stream(),
+ Returns a dict that is passed through to Parser.parse(),
as the `parser_context` keyword argument.
"""
+ # Note: Additionally `request` will also be added to the context
+ # by the Request object.
return {
- 'upload_handlers': http_request.upload_handlers,
- 'meta': http_request.META,
+ 'view': self,
+ 'args': getattr(self, 'args', ()),
+ 'kwargs': getattr(self, 'kwargs', {})
}
def get_renderer_context(self):
@@ -171,13 +174,13 @@ class APIView(View):
Returns a dict that is passed through to Renderer.render(),
as the `renderer_context` keyword argument.
"""
- # Note: Additionally 'response' will also be set on the context,
+ # Note: Additionally 'response' will also be added to the context,
# by the Response object.
return {
'view': self,
- 'request': self.request,
- 'args': self.args,
- 'kwargs': self.kwargs
+ 'args': getattr(self, 'args', ()),
+ 'kwargs': getattr(self, 'kwargs', {}),
+ 'request': getattr(self, 'request', None)
}
# API policy instantiation methods
@@ -215,7 +218,7 @@ class APIView(View):
def get_throttles(self):
"""
- Instantiates and returns the list of thottles that this view uses.
+ Instantiates and returns the list of throttles that this view uses.
"""
return [throttle() for throttle in self.throttle_classes]
@@ -235,7 +238,13 @@ class APIView(View):
"""
renderers = self.get_renderers()
conneg = self.get_content_negotiator()
- return conneg.negotiate(request, renderers, self.format_kwarg, force)
+
+ try:
+ return conneg.select_renderer(request, renderers, self.format_kwarg)
+ except:
+ if force:
+ return (renderers[0], renderers[0].media_type)
+ raise
def has_permission(self, request, obj=None):
"""
@@ -311,13 +320,17 @@ class APIView(View):
self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait
if isinstance(exc, exceptions.APIException):
- return Response({'detail': exc.detail}, status=exc.status_code)
+ return Response({'detail': exc.detail},
+ status=exc.status_code,
+ exception=True)
elif isinstance(exc, Http404):
return Response({'detail': 'Not found'},
- status=status.HTTP_404_NOT_FOUND)
+ status=status.HTTP_404_NOT_FOUND,
+ exception=True)
elif isinstance(exc, PermissionDenied):
return Response({'detail': 'Permission denied'},
- status=status.HTTP_403_FORBIDDEN)
+ status=status.HTTP_403_FORBIDDEN,
+ exception=True)
raise
# Note: session based authentication is explicitly CSRF validated,