aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
authorBen Konrath2012-11-01 14:06:56 +0100
committerBen Konrath2012-11-01 14:06:56 +0100
commit9c82f9717e58f1bb250d5fd4b27619dbcbbd1f21 (patch)
treee976854e6871a8b826e91d8eb16d9a139b90664f /rest_framework
parentc24997df3b943e5d7a3b2e101508e4b79ee82dc4 (diff)
parent204db7bdaa59cd17f762d6cf0e6a8623c2cc9939 (diff)
downloaddjango-rest-framework-9c82f9717e58f1bb250d5fd4b27619dbcbbd1f21.tar.bz2
Merge branch 'master' into restframework2-filter
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/authentication.py62
-rw-r--r--rest_framework/compat.py4
-rw-r--r--rest_framework/decorators.py14
-rw-r--r--rest_framework/exceptions.py8
-rw-r--r--rest_framework/fields.py179
-rw-r--r--rest_framework/generics.py69
-rw-r--r--rest_framework/mixins.py25
-rw-r--r--rest_framework/negotiation.py36
-rw-r--r--rest_framework/parsers.py68
-rw-r--r--rest_framework/permissions.py28
-rw-r--r--rest_framework/renderers.py148
-rw-r--r--rest_framework/request.py38
-rw-r--r--rest_framework/resources.py95
-rw-r--r--rest_framework/reverse.py8
-rwxr-xr-xrest_framework/runtests/runtests.py2
-rw-r--r--rest_framework/runtests/settings.py9
-rw-r--r--rest_framework/serializers.py128
-rw-r--r--rest_framework/settings.py48
-rw-r--r--rest_framework/static/rest_framework/css/default.css14
-rw-r--r--rest_framework/templates/rest_framework/base.html2
-rw-r--r--rest_framework/templates/rest_framework/login.html56
-rw-r--r--rest_framework/tests/__init__.py13
-rw-r--r--rest_framework/tests/generics.py51
-rw-r--r--rest_framework/tests/htmlrenderer.py6
-rw-r--r--rest_framework/tests/hyperlinkedserializers.py44
-rw-r--r--rest_framework/tests/models.py37
-rw-r--r--rest_framework/tests/negotiation.py10
-rw-r--r--rest_framework/tests/request.py15
-rw-r--r--rest_framework/tests/serializer.py240
-rw-r--r--rest_framework/tests/tests.py13
-rw-r--r--rest_framework/tests/validators.py10
-rw-r--r--rest_framework/throttling.py14
-rw-r--r--rest_framework/urlpatterns.py13
-rw-r--r--rest_framework/utils/mediatypes.py26
-rw-r--r--rest_framework/views.py55
35 files changed, 1045 insertions, 543 deletions
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py
index ee5bd2f2..30c78ebc 100644
--- a/rest_framework/authentication.py
+++ b/rest_framework/authentication.py
@@ -1,10 +1,10 @@
"""
-The :mod:`authentication` module provides a set of pluggable authentication classes.
-
-Authentication behavior is provided by mixing the :class:`mixins.RequestMixin` class into a :class:`View` class.
+Provides a set of pluggable authentication policies.
"""
from django.contrib.auth import authenticate
+from django.utils.encoding import smart_unicode, DjangoUnicodeDecodeError
+from rest_framework import exceptions
from rest_framework.compat import CsrfViewMiddleware
from rest_framework.authtoken.models import Token
import base64
@@ -17,25 +17,14 @@ class BaseAuthentication(object):
def authenticate(self, request):
"""
- Authenticate the :obj:`request` and return a :obj:`User` or :const:`None`. [*]_
-
- .. [*] The authentication context *will* typically be a :obj:`User`,
- but it need not be. It can be any user-like object so long as the
- permissions classes (see the :mod:`permissions` module) on the view can
- handle the object and use it to determine if the request has the required
- permissions or not.
-
- This can be an important distinction if you're implementing some token
- based authentication mechanism, where the authentication context
- may be more involved than simply mapping to a :obj:`User`.
+ Authenticate the request and return a two-tuple of (user, token).
"""
- return None
+ raise NotImplementedError(".authenticate() must be overridden.")
class BasicAuthentication(BaseAuthentication):
"""
- Base class for HTTP Basic authentication.
- Subclasses should implement `.authenticate_credentials()`.
+ HTTP Basic authentication against username/password.
"""
def authenticate(self, request):
@@ -43,8 +32,6 @@ class BasicAuthentication(BaseAuthentication):
Returns a `User` if a correct username and password have been supplied
using HTTP Basic authentication. Otherwise returns `None`.
"""
- from django.utils.encoding import smart_unicode, DjangoUnicodeDecodeError
-
if 'HTTP_AUTHORIZATION' in request.META:
auth = request.META['HTTP_AUTHORIZATION'].split()
if len(auth) == 2 and auth[0].lower() == "basic":
@@ -54,7 +41,8 @@ class BasicAuthentication(BaseAuthentication):
return None
try:
- userid, password = smart_unicode(auth_parts[0]), smart_unicode(auth_parts[2])
+ userid = smart_unicode(auth_parts[0])
+ password = smart_unicode(auth_parts[2])
except DjangoUnicodeDecodeError:
return None
@@ -62,15 +50,6 @@ class BasicAuthentication(BaseAuthentication):
def authenticate_credentials(self, userid, password):
"""
- Given the Basic authentication userid and password, authenticate
- and return a user instance.
- """
- raise NotImplementedError('.authenticate_credentials() must be overridden')
-
-
-class UserBasicAuthentication(BasicAuthentication):
- def authenticate_credentials(self, userid, password):
- """
Authenticate the userid and password against username and password.
"""
user = authenticate(username=userid, password=password)
@@ -85,20 +64,31 @@ class SessionAuthentication(BaseAuthentication):
def authenticate(self, request):
"""
- Returns a :obj:`User` if the request session currently has a logged in user.
- Otherwise returns :const:`None`.
+ Returns a `User` if the request session currently has a logged in user.
+ Otherwise returns `None`.
"""
# Get the underlying HttpRequest object
http_request = request._request
user = getattr(http_request, 'user', None)
- if user and user.is_active:
- # Enforce CSRF validation for session based authentication.
- resp = CsrfViewMiddleware().process_view(http_request, None, (), {})
+ # Unauthenticated, CSRF validation not required
+ if not user or not user.is_active:
+ return
+
+ # Enforce CSRF validation for session based authentication.
+ class CSRFCheck(CsrfViewMiddleware):
+ def _reject(self, request, reason):
+ # Return the failure reason instead of an HttpResponse
+ return reason
+
+ reason = CSRFCheck().process_view(http_request, None, (), {})
+ if reason:
+ # CSRF failed, bail with explicit error message
+ raise exceptions.PermissionDenied('CSRF Failed: %s' % reason)
- if resp is None: # csrf passed
- return (user, None)
+ # CSRF passed with authenticated user
+ return (user, None)
class TokenAuthentication(BaseAuthentication):
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index 7664c400..b0367a32 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -1,6 +1,8 @@
"""
-The :mod:`compat` module provides support for backwards compatibility with older versions of django/python.
+The `compat` module provides support for backwards compatibility with older
+versions of django/python, and compatbility wrappers around optional packages.
"""
+# flake8: noqa
import django
# cStringIO only if it's available, otherwise StringIO
diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py
index 948973ae..a231f191 100644
--- a/rest_framework/decorators.py
+++ b/rest_framework/decorators.py
@@ -10,8 +10,18 @@ def api_view(http_method_names):
def decorator(func):
- class WrappedAPIView(APIView):
- pass
+ WrappedAPIView = type(
+ 'WrappedAPIView',
+ (APIView,),
+ {'__doc__': func.__doc__}
+ )
+
+ # Note, the above allows us to set the docstring.
+ # It is the equivelent of:
+ #
+ # class WrappedAPIView(APIView):
+ # pass
+ # WrappedAPIView.__doc__ = func.doc <--- Not possible to do this
allowed_methods = set(http_method_names) | set(('options',))
WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods]
diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py
index 572425b9..89479deb 100644
--- a/rest_framework/exceptions.py
+++ b/rest_framework/exceptions.py
@@ -31,14 +31,6 @@ class PermissionDenied(APIException):
self.detail = detail or self.default_detail
-class InvalidFormat(APIException):
- status_code = status.HTTP_404_NOT_FOUND
- default_detail = "Format suffix '.%s' not found."
-
- def __init__(self, format, detail=None):
- self.detail = (detail or self.default_detail) % format
-
-
class MethodNotAllowed(APIException):
status_code = status.HTTP_405_METHOD_NOT_ALLOWED
default_detail = "Method '%s' not allowed."
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index bb9a523d..73c8f72b 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -5,13 +5,15 @@ import warnings
from django.core import validators
from django.core.exceptions import ObjectDoesNotExist, ValidationError
-from django.core.urlresolvers import resolve
+from django.core.urlresolvers import resolve, get_script_prefix
from django.conf import settings
+from django.forms import widgets
from django.utils.encoding import is_protected_type, smart_unicode
from django.utils.translation import ugettext_lazy as _
from rest_framework.reverse import reverse
from rest_framework.compat import parse_date, parse_datetime
from rest_framework.compat import timezone
+from urlparse import urlparse
def is_simple_callable(obj):
@@ -42,7 +44,7 @@ class Field(object):
Called to set up a field prior to field_to_native or field_from_native.
parent - The parent serializer.
- model_field - The model field this field corrosponds to, if one exists.
+ model_field - The model field this field corresponds to, if one exists.
"""
self.parent = parent
self.root = parent.root or parent
@@ -70,6 +72,8 @@ class Field(object):
value = obj
for component in self.source.split('.'):
value = getattr(value, component)
+ if is_simple_callable(value):
+ value = value()
else:
value = getattr(obj, field_name)
return self.to_native(value)
@@ -105,15 +109,20 @@ class WritableField(Field):
'required': _('This field is required.'),
'invalid': _('Invalid value.'),
}
+ widget = widgets.TextInput
+ default = None
+
+ def __init__(self, source=None, read_only=False, required=None,
+ validators=[], error_messages=None, widget=None,
+ default=None, blank=None):
- def __init__(self, source=None, readonly=False, required=None,
- validators=[], error_messages=None):
super(WritableField, self).__init__(source=source)
- self.readonly = readonly
+
+ self.read_only = read_only
if required is None:
- self.required = not(readonly)
+ self.required = not(read_only)
else:
- assert not readonly, "Cannot set required=True and readonly=True"
+ assert not read_only, "Cannot set required=True and read_only=True"
self.required = required
messages = {}
@@ -123,6 +132,14 @@ class WritableField(Field):
self.error_messages = messages
self.validators = self.default_validators + validators
+ self.default = default or self.default
+ self.blank = blank
+
+ # Widgets are ony used for HTML forms.
+ widget = widget or self.widget
+ if isinstance(widget, type):
+ widget = widget()
+ self.widget = widget
def validate(self, value):
if value in validators.EMPTY_VALUES and self.required:
@@ -151,15 +168,18 @@ class WritableField(Field):
Given a dictionary and a field name, updates the dictionary `into`,
with the field and it's deserialized value.
"""
- if self.readonly:
+ if self.read_only:
return
try:
native = data[field_name]
except KeyError:
- if self.required:
- raise ValidationError(self.error_messages['required'])
- return
+ if self.default is not None:
+ native = self.default
+ else:
+ if self.required:
+ raise ValidationError(self.error_messages['required'])
+ return
value = self.from_native(native)
if self.source == '*':
@@ -179,7 +199,7 @@ class WritableField(Field):
class ModelField(WritableField):
"""
- A generic field that can be used against an arbirtrary model field.
+ A generic field that can be used against an arbitrary model field.
"""
def __init__(self, *args, **kwargs):
try:
@@ -191,9 +211,9 @@ class ModelField(WritableField):
def from_native(self, value):
try:
rel = self.model_field.rel
+ return rel.to._meta.get_field(rel.field_name).to_python(value)
except:
return self.model_field.to_python(value)
- return rel.to._meta.get_field(rel.field_name).to_python(value)
def field_to_native(self, obj, field_name):
value = self.model_field._get_val_from_obj(obj)
@@ -222,8 +242,11 @@ class RelatedField(WritableField):
return self.to_native(value)
def field_from_native(self, data, field_name, into):
+ if self.read_only:
+ return
+
value = data.get(field_name)
- into[(self.source or field_name) + '_id'] = self.from_native(value)
+ into[(self.source or field_name)] = self.from_native(value)
class ManyRelatedMixin(object):
@@ -235,6 +258,9 @@ class ManyRelatedMixin(object):
return [self.to_native(item) for item in value.all()]
def field_from_native(self, data, field_name, into):
+ if self.read_only:
+ return
+
try:
# Form data
value = data.getlist(self.source or field_name)
@@ -264,6 +290,15 @@ class PrimaryKeyRelatedField(RelatedField):
def to_native(self, pk):
return pk
+ def from_native(self, data):
+ if self.queryset is None:
+ raise Exception('Writable related fields must include a `queryset` argument')
+
+ try:
+ return self.queryset.get(pk=data)
+ except ObjectDoesNotExist:
+ raise ValidationError('Invalid hyperlink - object does not exist.')
+
def field_to_native(self, obj, field_name):
try:
# Prefer obj.serializable_value for performance reasons
@@ -307,14 +342,16 @@ class HyperlinkedRelatedField(RelatedField):
self.view_name = kwargs.pop('view_name')
except:
raise ValueError("Hyperlinked field requires 'view_name' kwarg")
+ self.format = kwargs.pop('format', None)
super(HyperlinkedRelatedField, self).__init__(*args, **kwargs)
def to_native(self, obj):
view_name = self.view_name
request = self.context.get('request', None)
+ format = self.format or self.context.get('format', None)
kwargs = {self.pk_url_kwarg: obj.pk}
try:
- return reverse(view_name, kwargs=kwargs, request=request)
+ return reverse(view_name, kwargs=kwargs, request=request, format=format)
except:
pass
@@ -325,13 +362,13 @@ class HyperlinkedRelatedField(RelatedField):
kwargs = {self.slug_url_kwarg: slug}
try:
- return reverse(self.view_name, kwargs=kwargs, request=request)
+ return reverse(self.view_name, kwargs=kwargs, request=request, format=format)
except:
pass
kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug}
try:
- return reverse(self.view_name, kwargs=kwargs, request=request)
+ return reverse(self.view_name, kwargs=kwargs, request=request, format=format)
except:
pass
@@ -340,6 +377,16 @@ class HyperlinkedRelatedField(RelatedField):
def from_native(self, value):
# Convert URL -> model instance pk
# TODO: Use values_list
+ if self.queryset is None:
+ raise Exception('Writable related fields must include a `queryset` argument')
+
+ if value.startswith('http:') or value.startswith('https:'):
+ # If needed convert absolute URLs to relative path
+ value = urlparse(value).path
+ prefix = get_script_prefix()
+ if value.startswith(prefix):
+ value = '/' + value[len(prefix):]
+
try:
match = resolve(value)
except:
@@ -353,7 +400,7 @@ class HyperlinkedRelatedField(RelatedField):
# Try explicit primary key.
if pk is not None:
- return pk
+ queryset = self.queryset.filter(pk=pk)
# Next, try looking up by slug.
elif slug is not None:
slug_field = self.get_slug_field()
@@ -366,7 +413,7 @@ class HyperlinkedRelatedField(RelatedField):
obj = queryset.get()
except ObjectDoesNotExist:
raise ValidationError('Invalid hyperlink - object does not exist.')
- return obj.pk
+ return obj
class ManyHyperlinkedRelatedField(ManyRelatedMixin, HyperlinkedRelatedField):
@@ -381,33 +428,38 @@ class HyperlinkedIdentityField(Field):
# TODO: Make this mandatory, and have the HyperlinkedModelSerializer
# set it on-the-fly
self.view_name = kwargs.pop('view_name', None)
+ self.format = kwargs.pop('format', None)
super(HyperlinkedIdentityField, self).__init__(*args, **kwargs)
def field_to_native(self, obj, field_name):
request = self.context.get('request', None)
+ format = self.format or self.context.get('format', None)
view_name = self.view_name or self.parent.opts.view_name
view_kwargs = {'pk': obj.pk}
- return reverse(view_name, kwargs=view_kwargs, request=request)
+ return reverse(view_name, kwargs=view_kwargs, request=request, format=format)
##### Typed Fields #####
class BooleanField(WritableField):
type_name = 'BooleanField'
+ widget = widgets.CheckboxInput
default_error_messages = {
'invalid': _(u"'%s' value must be either True or False."),
}
+ empty = False
+
+ # Note: we set default to `False` in order to fill in missing value not
+ # supplied by html form. TODO: Fix so that only html form input gets
+ # this behavior.
+ default = False
def from_native(self, value):
- if value in (True, False):
- # if value is 1 or 0 than it's equal to True or False, but we want
- # to return a true bool for semantic reasons.
- return bool(value)
if value in ('t', 'True', '1'):
return True
if value in ('f', 'False', '0'):
return False
- raise ValidationError(self.error_messages['invalid'] % value)
+ return bool(value)
class CharField(WritableField):
@@ -421,12 +473,68 @@ class CharField(WritableField):
if max_length is not None:
self.validators.append(validators.MaxLengthValidator(max_length))
+ def validate(self, value):
+ """
+ Validates that the value is supplied (if required).
+ """
+ # if empty string and allow blank
+ if self.blank and not value:
+ return
+ else:
+ super(CharField, self).validate(value)
+
def from_native(self, value):
if isinstance(value, basestring) or value is None:
return value
return smart_unicode(value)
+class ChoiceField(WritableField):
+ type_name = 'ChoiceField'
+ widget = widgets.Select
+ default_error_messages = {
+ 'invalid_choice': _('Select a valid choice. %(value)s is not one of the available choices.'),
+ }
+
+ def __init__(self, choices=(), *args, **kwargs):
+ super(ChoiceField, self).__init__(*args, **kwargs)
+ self.choices = choices
+
+ def _get_choices(self):
+ return self._choices
+
+ def _set_choices(self, value):
+ # Setting choices also sets the choices on the widget.
+ # choices can be any iterable, but we call list() on it because
+ # it will be consumed more than once.
+ self._choices = self.widget.choices = list(value)
+
+ choices = property(_get_choices, _set_choices)
+
+ def validate(self, value):
+ """
+ Validates that the input is in self.choices.
+ """
+ super(ChoiceField, self).validate(value)
+ if value and not self.valid_value(value):
+ raise ValidationError(self.error_messages['invalid_choice'] % {'value': value})
+
+ def valid_value(self, value):
+ """
+ Check to see if the provided value is a valid choice.
+ """
+ for k, v in self.choices:
+ if isinstance(v, (list, tuple)):
+ # This is an optgroup, so look inside the group for options
+ for k2, v2 in v:
+ if value == smart_unicode(k2):
+ return True
+ else:
+ if value == smart_unicode(k):
+ return True
+ return False
+
+
class EmailField(CharField):
type_name = 'EmailField'
@@ -436,7 +544,10 @@ class EmailField(CharField):
default_validators = [validators.validate_email]
def from_native(self, value):
- return super(EmailField, self).from_native(value).strip()
+ ret = super(EmailField, self).from_native(value)
+ if ret is None:
+ return None
+ return ret.strip()
def __deepcopy__(self, memo):
result = copy.copy(self)
@@ -458,8 +569,9 @@ class DateField(WritableField):
empty = None
def from_native(self, value):
- if value is None:
- return value
+ if value in validators.EMPTY_VALUES:
+ return None
+
if isinstance(value, datetime.datetime):
if timezone and settings.USE_TZ and timezone.is_aware(value):
# Convert aware datetimes to the default time zone
@@ -497,8 +609,9 @@ class DateTimeField(WritableField):
empty = None
def from_native(self, value):
- if value is None:
- return value
+ if value in validators.EMPTY_VALUES:
+ return None
+
if isinstance(value, datetime.datetime):
return value
if isinstance(value, datetime.date):
@@ -556,6 +669,7 @@ class IntegerField(WritableField):
def from_native(self, value):
if value in validators.EMPTY_VALUES:
return None
+
try:
value = int(str(value))
except (ValueError, TypeError):
@@ -571,8 +685,9 @@ class FloatField(WritableField):
}
def from_native(self, value):
- if value is None:
- return value
+ if value in validators.EMPTY_VALUES:
+ return None
+
try:
return float(value)
except (TypeError, ValueError):
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index 3b2bea3b..063382bb 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -10,12 +10,12 @@ import django_filters
### Base classes for the generic views ###
-class BaseView(views.APIView):
+class GenericAPIView(views.APIView):
"""
Base class for all other generic views.
"""
serializer_class = None
- model_serializer_class = api_settings.MODEL_SERIALIZER
+ model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS
def get_serializer_context(self):
"""
@@ -51,12 +51,12 @@ class BaseView(views.APIView):
return serializer_class(data, instance=instance, context=context)
-class MultipleObjectBaseView(MultipleObjectMixin, BaseView):
+class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView):
"""
Base class for generic views onto a queryset.
"""
- pagination_serializer_class = api_settings.PAGINATION_SERIALIZER
+ pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS
paginate_by = api_settings.PAGINATE_BY
filter_class = None
filter_fields = None
@@ -106,7 +106,7 @@ class MultipleObjectBaseView(MultipleObjectMixin, BaseView):
return pagination_serializer_class(instance=page, context=context)
-class SingleObjectBaseView(SingleObjectMixin, BaseView):
+class SingleObjectAPIView(SingleObjectMixin, GenericAPIView):
"""
Base class for generic views onto a model instance.
"""
@@ -117,7 +117,7 @@ class SingleObjectBaseView(SingleObjectMixin, BaseView):
"""
Override default to add support for object-level permissions.
"""
- obj = super(SingleObjectBaseView, self).get_object()
+ obj = super(SingleObjectAPIView, self).get_object()
if not self.has_permission(self.request, obj):
self.permission_denied(self.request)
return obj
@@ -126,8 +126,19 @@ class SingleObjectBaseView(SingleObjectMixin, BaseView):
### Concrete view classes that provide method handlers ###
### by composing the mixin classes with a base view. ###
+
+class CreateAPIView(mixins.CreateModelMixin,
+ GenericAPIView):
+
+ """
+ Concrete view for creating a model instance.
+ """
+ def post(self, request, *args, **kwargs):
+ return self.create(request, *args, **kwargs)
+
+
class ListAPIView(mixins.ListModelMixin,
- MultipleObjectBaseView):
+ MultipleObjectAPIView):
"""
Concrete view for listing a queryset.
"""
@@ -135,9 +146,38 @@ class ListAPIView(mixins.ListModelMixin,
return self.list(request, *args, **kwargs)
+class RetrieveAPIView(mixins.RetrieveModelMixin,
+ SingleObjectAPIView):
+ """
+ Concrete view for retrieving a model instance.
+ """
+ def get(self, request, *args, **kwargs):
+ return self.retrieve(request, *args, **kwargs)
+
+
+class DestroyAPIView(mixins.DestroyModelMixin,
+ SingleObjectAPIView):
+
+ """
+ Concrete view for deleting a model instance.
+ """
+ def delete(self, request, *args, **kwargs):
+ return self.destroy(request, *args, **kwargs)
+
+
+class UpdateAPIView(mixins.UpdateModelMixin,
+ SingleObjectAPIView):
+
+ """
+ Concrete view for updating a model instance.
+ """
+ def put(self, request, *args, **kwargs):
+ return self.update(request, *args, **kwargs)
+
+
class ListCreateAPIView(mixins.ListModelMixin,
mixins.CreateModelMixin,
- MultipleObjectBaseView):
+ MultipleObjectAPIView):
"""
Concrete view for listing a queryset or creating a model instance.
"""
@@ -148,18 +188,9 @@ class ListCreateAPIView(mixins.ListModelMixin,
return self.create(request, *args, **kwargs)
-class RetrieveAPIView(mixins.RetrieveModelMixin,
- SingleObjectBaseView):
- """
- Concrete view for retrieving a model instance.
- """
- def get(self, request, *args, **kwargs):
- return self.retrieve(request, *args, **kwargs)
-
-
class RetrieveDestroyAPIView(mixins.RetrieveModelMixin,
mixins.DestroyModelMixin,
- SingleObjectBaseView):
+ SingleObjectAPIView):
"""
Concrete view for retrieving or deleting a model instance.
"""
@@ -173,7 +204,7 @@ class RetrieveDestroyAPIView(mixins.RetrieveModelMixin,
class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,
mixins.UpdateModelMixin,
mixins.DestroyModelMixin,
- SingleObjectBaseView):
+ SingleObjectAPIView):
"""
Concrete view for retrieving, updating or deleting a model instance.
"""
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index 04626fb0..b0cc043a 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -3,9 +3,6 @@ Basic building blocks for generic class based views.
We don't bind behaviour to http method handlers yet,
which allows mixin classes to be composed in interesting ways.
-
-Eg. Use mixins to build a Resource class, and have a Router class
- perform the binding of http methods to actions for us.
"""
from django.http import Http404
from rest_framework import status
@@ -20,10 +17,14 @@ class CreateModelMixin(object):
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.DATA)
if serializer.is_valid():
+ self.pre_save(serializer.object)
self.object = serializer.save()
return Response(serializer.data, status=status.HTTP_201_CREATED)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
+ def pre_save(self, obj):
+ pass
+
class ListModelMixin(object):
"""
@@ -46,7 +47,8 @@ class ListModelMixin(object):
# which may be `None` to disable pagination.
page_size = self.get_paginate_by(self.object_list)
if page_size:
- paginator, page, queryset, is_paginated = self.paginate_queryset(self.object_list, page_size)
+ packed = self.paginate_queryset(self.object_list, page_size)
+ paginator, page, queryset, is_paginated = packed
serializer = self.get_pagination_serializer(page)
else:
serializer = self.get_serializer(instance=self.object_list)
@@ -73,26 +75,25 @@ class UpdateModelMixin(object):
def update(self, request, *args, **kwargs):
try:
self.object = self.get_object()
+ success_status = status.HTTP_200_OK
except Http404:
self.object = None
+ success_status = status.HTTP_201_CREATED
serializer = self.get_serializer(data=request.DATA, instance=self.object)
if serializer.is_valid():
- if self.object is None:
- # If PUT occurs to a non existant object, we need to set any
- # attributes on the object that are implicit in the URL.
- self.update_urlconf_attributes(serializer.object)
+ self.pre_save(serializer.object)
self.object = serializer.save()
- return Response(serializer.data)
+ return Response(serializer.data, status=success_status)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
- def update_urlconf_attributes(self, obj):
+ def pre_save(self, obj):
"""
- When update (re)creates an object, we need to set any attributes that
- are tied to the URLconf.
+ Set any attributes on the object that are implicit in the request.
"""
+ # pk and/or slug attributes are implicit in the URL.
pk = self.kwargs.get(self.pk_url_kwarg, None)
if pk:
setattr(obj, 'pk', pk)
diff --git a/rest_framework/negotiation.py b/rest_framework/negotiation.py
index 8b22f669..dae38477 100644
--- a/rest_framework/negotiation.py
+++ b/rest_framework/negotiation.py
@@ -1,48 +1,38 @@
+from django.http import Http404
from rest_framework import exceptions
from rest_framework.settings import api_settings
from rest_framework.utils.mediatypes import order_by_precedence, media_type_matches
class BaseContentNegotiation(object):
- def negotiate(self, request, renderers, format=None, force=False):
- raise NotImplementedError('.negotiate() must be implemented')
+ def select_parser(self, request, parsers):
+ raise NotImplementedError('.select_parser() must be implemented')
+ def select_renderer(self, request, renderers, format_suffix=None):
+ raise NotImplementedError('.select_renderer() must be implemented')
-class DefaultContentNegotiation(object):
+
+class DefaultContentNegotiation(BaseContentNegotiation):
settings = api_settings
- def select_parser(self, parsers, media_type):
+ def select_parser(self, request, parsers):
"""
Given a list of parsers and a media type, return the appropriate
parser to handle the incoming request.
"""
for parser in parsers:
- if media_type_matches(parser.media_type, media_type):
+ if media_type_matches(parser.media_type, request.content_type):
return parser
return None
- def negotiate(self, request, renderers, format=None, force=False):
+ def select_renderer(self, request, renderers, format_suffix=None):
"""
Given a request and a list of renderers, return a two-tuple of:
(renderer, media type).
-
- If force is set, then suppress exceptions, and forcibly return a
- fallback renderer and media_type.
- """
- try:
- return self.unforced_negotiate(request, renderers, format)
- except (exceptions.InvalidFormat, exceptions.NotAcceptable):
- if force:
- return (renderers[0], renderers[0].media_type)
- raise
-
- def unforced_negotiate(self, request, renderers, format=None):
- """
- As `.negotiate()`, but does not take the optional `force` agument,
- or suppress exceptions.
"""
# Allow URL style format override. eg. "?format=json
- format = format or request.GET.get(self.settings.URL_FORMAT_OVERRIDE)
+ format_query_param = self.settings.URL_FORMAT_OVERRIDE
+ format = format_suffix or request.GET.get(format_query_param)
if format:
renderers = self.filter_renderers(renderers, format)
@@ -77,7 +67,7 @@ class DefaultContentNegotiation(object):
renderers = [renderer for renderer in renderers
if renderer.format == format]
if not renderers:
- raise exceptions.InvalidFormat(format)
+ raise Http404
return renderers
def get_accept_list(self, request):
diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py
index 5325a64b..4841676c 100644
--- a/rest_framework/parsers.py
+++ b/rest_framework/parsers.py
@@ -1,14 +1,8 @@
"""
-Django supports parsing the content of an HTTP request, but only for form POST requests.
-That behavior is sufficient for dealing with standard HTML forms, but it doesn't map well
-to general HTTP requests.
+Parsers are used to parse the content of incoming HTTP requests.
-We need a method to be able to:
-
-1.) Determine the parsed content on a request for methods other than POST (eg typically also PUT)
-
-2.) Determine the parsed content on a request for media types other than application/x-www-form-urlencoded
- and multipart/form-data. (eg also handle multipart/json)
+They give us a generic way of being able to handle various media types
+on the request, such as form content or json encoded data.
"""
from django.http import QueryDict
@@ -21,7 +15,6 @@ from xml.etree import ElementTree as ET
from xml.parsers.expat import ExpatError
import datetime
import decimal
-from io import BytesIO
class DataAndFiles(object):
@@ -33,29 +26,18 @@ class DataAndFiles(object):
class BaseParser(object):
"""
All parsers should extend `BaseParser`, specifying a `media_type`
- attribute, and overriding the `.parse_stream()` method.
+ attribute, and overriding the `.parse()` method.
"""
media_type = None
- def parse(self, string_or_stream, **opts):
- """
- The main entry point to parsers. This is a light wrapper around
- `parse_stream`, that instead handles both string and stream objects.
+ def parse(self, stream, media_type=None, parser_context=None):
"""
- if isinstance(string_or_stream, basestring):
- stream = BytesIO(string_or_stream)
- else:
- stream = string_or_stream
- return self.parse_stream(stream, **opts)
-
- def parse_stream(self, stream, **opts):
- """
- Given a stream to read from, return the deserialized output.
- Should return parsed data, or a DataAndFiles object consisting of the
+ Given a stream to read from, return the parsed representation.
+ Should return parsed data, or a `DataAndFiles` object consisting of the
parsed data and files.
"""
- raise NotImplementedError(".parse_stream() must be overridden.")
+ raise NotImplementedError(".parse() must be overridden.")
class JSONParser(BaseParser):
@@ -65,7 +47,7 @@ class JSONParser(BaseParser):
media_type = 'application/json'
- def parse_stream(self, stream, **opts):
+ def parse(self, stream, media_type=None, parser_context=None):
"""
Returns a 2-tuple of `(data, files)`.
@@ -85,7 +67,7 @@ class YAMLParser(BaseParser):
media_type = 'application/yaml'
- def parse_stream(self, stream, **opts):
+ def parse(self, stream, media_type=None, parser_context=None):
"""
Returns a 2-tuple of `(data, files)`.
@@ -98,23 +80,6 @@ class YAMLParser(BaseParser):
raise ParseError('YAML parse error - %s' % unicode(exc))
-class PlainTextParser(BaseParser):
- """
- Plain text parser.
- """
-
- media_type = 'text/plain'
-
- def parse_stream(self, stream, **opts):
- """
- Returns a 2-tuple of `(data, files)`.
-
- `data` will simply be a string representing the body of the request.
- `files` will always be `None`.
- """
- return stream.read()
-
-
class FormParser(BaseParser):
"""
Parser for form data.
@@ -122,7 +87,7 @@ class FormParser(BaseParser):
media_type = 'application/x-www-form-urlencoded'
- def parse_stream(self, stream, **opts):
+ def parse(self, stream, media_type=None, parser_context=None):
"""
Returns a 2-tuple of `(data, files)`.
@@ -140,15 +105,18 @@ class MultiPartParser(BaseParser):
media_type = 'multipart/form-data'
- def parse_stream(self, stream, **opts):
+ def parse(self, stream, media_type=None, parser_context=None):
"""
Returns a DataAndFiles object.
`.data` will be a `QueryDict` containing all the form parameters.
`.files` will be a `QueryDict` containing all the form files.
"""
- meta = opts['meta']
- upload_handlers = opts['upload_handlers']
+ parser_context = parser_context or {}
+ request = parser_context['request']
+ meta = request.META
+ upload_handlers = request.upload_handlers
+
try:
parser = DjangoMultiPartParser(meta, stream, upload_handlers)
data, files = parser.parse()
@@ -164,7 +132,7 @@ class XMLParser(BaseParser):
media_type = 'application/xml'
- def parse_stream(self, stream, **opts):
+ def parse(self, stream, media_type=None, parser_context=None):
try:
tree = ET.parse(stream)
except (ExpatError, ETParseError, ValueError), exc:
diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py
index 13ea39ea..655b78a3 100644
--- a/rest_framework/permissions.py
+++ b/rest_framework/permissions.py
@@ -1,8 +1,5 @@
"""
-The :mod:`permissions` module bundles a set of permission classes that are used
-for checking if a request passes a certain set of constraints.
-
-Permission behavior is provided by mixing the :class:`mixins.PermissionsMixin` class into a :class:`View` class.
+Provides a set of pluggable permission policies.
"""
@@ -16,11 +13,22 @@ class BasePermission(object):
def has_permission(self, request, view, obj=None):
"""
- Should simply return, or raise an :exc:`response.ImmediateResponse`.
+ Return `True` if permission is granted, `False` otherwise.
"""
raise NotImplementedError(".has_permission() must be overridden.")
+class AllowAny(BasePermission):
+ """
+ Allow any access.
+ This isn't strictly required, since you could use an empty
+ permission_classes list, but it's useful because it makes the intention
+ more explicit.
+ """
+ def has_permission(self, request, view, obj=None):
+ return True
+
+
class IsAuthenticated(BasePermission):
"""
Allows access only to authenticated users.
@@ -64,7 +72,8 @@ class DjangoModelPermissions(BasePermission):
It ensures that the user is authenticated, and has the appropriate
`add`/`change`/`delete` permissions on the model.
- This permission should only be used on views with a `ModelResource`.
+ This permission will only be applied against view classes that
+ provide a `.model` attribute, such as the generic class-based views.
"""
# Map methods into required permission codes.
@@ -87,12 +96,15 @@ class DjangoModelPermissions(BasePermission):
"""
kwargs = {
'app_label': model_cls._meta.app_label,
- 'model_name': model_cls._meta.module_name
+ 'model_name': model_cls._meta.module_name
}
return [perm % kwargs for perm in self.perms_map[method]]
def has_permission(self, request, view, obj=None):
- model_cls = view.model
+ model_cls = getattr(view, 'model', None)
+ if not model_cls:
+ return True
+
perms = self.get_required_permissions(request.method, model_cls)
if (request.user and
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index e5e4134b..8dff0c77 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -1,12 +1,15 @@
"""
-Renderers are used to serialize a View's output into specific media types.
+Renderers are used to serialize a response into specific media types.
-Django REST framework also provides HTML and PlainText renderers that help self-document the API,
-by serializing the output along with documentation regarding the View, output status and headers,
-and providing forms and links depending on the allowed methods, renderers and parsers on the View.
+They give us a generic way of being able to handle various media types
+on the response, such as JSON encoded data or HTML output.
+
+REST framework also provides an HTML renderer the renders the browseable API.
"""
+import copy
import string
from django import forms
+from django.http.multipartparser import parse_header
from django.template import RequestContext, loader
from django.utils import simplejson as json
from rest_framework.compat import yaml
@@ -16,15 +19,14 @@ from rest_framework.request import clone_request
from rest_framework.utils import dict2xml
from rest_framework.utils import encoders
from rest_framework.utils.breadcrumbs import get_breadcrumbs
-from rest_framework.utils.mediatypes import get_media_type_params
from rest_framework import VERSION
from rest_framework import serializers, parsers
class BaseRenderer(object):
"""
- All renderers must extend this class, set the :attr:`media_type` attribute,
- and override the :meth:`render` method.
+ All renderers should extend this class, setting the `media_type`
+ and `format` attributes, and override the `.render()` method.
"""
media_type = None
@@ -58,7 +60,7 @@ class JSONRenderer(BaseRenderer):
if accepted_media_type:
# If the media type looks like 'application/json; indent=4',
# then pretty print the result.
- params = get_media_type_params(accepted_media_type)
+ base_media_type, params = parse_header(accepted_media_type)
indent = params.get('indent', indent)
try:
indent = max(min(int(indent), 8), 0)
@@ -137,13 +139,24 @@ class YAMLRenderer(BaseRenderer):
return yaml.dump(data, stream=None, Dumper=self.encoder)
-class HTMLRenderer(BaseRenderer):
+class TemplateHTMLRenderer(BaseRenderer):
"""
- A Base class provided for convenience.
+ An HTML renderer for use with templates.
+
+ The data supplied to the Response object should be a dictionary that will
+ be used as context for the template.
+
+ The template name is determined by (in order of preference):
+
+ 1. An explicit `.template_name` attribute set on the response.
+ 2. An explicit `.template_name` attribute set on this class.
+ 3. The return result of calling `view.get_template_names()`.
- Render the object simply by using the given template.
- To create a template renderer, subclass this class, and set
- the :attr:`media_type` and :attr:`template` attributes.
+ For example:
+ data = {'users': User.objects.all()}
+ return Response(data, template_name='users.html')
+
+ For pre-rendered HTML, see StaticHTMLRenderer.
"""
media_type = 'text/html'
@@ -186,6 +199,26 @@ class HTMLRenderer(BaseRenderer):
raise ConfigurationError('Returned a template response with no template_name')
+class StaticHTMLRenderer(BaseRenderer):
+ """
+ An HTML renderer class that simply returns pre-rendered HTML.
+
+ The data supplied to the Response object should be a string representing
+ the pre-rendered HTML content.
+
+ For example:
+ data = '<html><body>example</body></html>'
+ return Response(data)
+
+ For template rendered HTML, see TemplateHTMLRenderer.
+ """
+ media_type = 'text/html'
+ format = 'html'
+
+ def render(self, data, accepted_media_type=None, renderer_context=None):
+ return data
+
+
class BrowsableAPIRenderer(BaseRenderer):
"""
HTML renderer used to self-document the API.
@@ -222,11 +255,9 @@ class BrowsableAPIRenderer(BaseRenderer):
return content
- def get_form(self, view, method, request):
+ def show_form_for_method(self, view, method, request, obj):
"""
- Get a form, possibly bound to either the input or output data.
- In the absence on of the Resource having an associated form then
- provide a form that can be used to submit arbitrary content.
+ Returns True if a form should be shown for this method.
"""
if not method in view.allowed_methods:
return # Not a valid method
@@ -235,22 +266,14 @@ class BrowsableAPIRenderer(BaseRenderer):
return # Cannot use form overloading
request = clone_request(request, method)
- if not view.has_permission(request):
- return # Don't have permission
-
- if method == 'DELETE' or method == 'OPTIONS':
- return True # Don't actually need to return a form
-
- if (not getattr(view, 'get_serializer', None) or
- not parsers.FormParser in getattr(view, 'parser_classes')):
- media_types = [parser.media_type for parser in view.parser_classes]
- return self.get_generic_content_form(media_types)
-
- #####
- # TODO: This is a little bit of a hack. Actually we'd like to remove
- # this and just render serializer fields to html directly.
+ try:
+ if not view.has_permission(request, obj):
+ return # Don't have permission
+ except:
+ return # Don't have permission and exception explicitly raise
+ return True
- # We need to map our Fields to Django's Fields.
+ def serializer_to_form_fields(self, serializer):
field_mapping = {
serializers.FloatField: forms.FloatField,
serializers.IntegerField: forms.IntegerField,
@@ -260,32 +283,69 @@ class BrowsableAPIRenderer(BaseRenderer):
serializers.CharField: forms.CharField,
serializers.BooleanField: forms.BooleanField,
serializers.PrimaryKeyRelatedField: forms.ModelChoiceField,
- serializers.ManyPrimaryKeyRelatedField: forms.ModelMultipleChoiceField
+ serializers.ManyPrimaryKeyRelatedField: forms.ModelMultipleChoiceField,
+ serializers.HyperlinkedRelatedField: forms.ModelChoiceField,
+ serializers.ManyHyperlinkedRelatedField: forms.ModelMultipleChoiceField
}
- # Creating an on the fly form see: http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python
fields = {}
- obj, data = None, None
- if getattr(view, 'object', None):
- obj = view.object
-
- serializer = view.get_serializer(instance=obj)
for k, v in serializer.get_fields(True).items():
- if getattr(v, 'readonly', True):
+ if getattr(v, 'read_only', True):
continue
kwargs = {}
+ kwargs['required'] = v.required
+
if getattr(v, 'queryset', None):
- kwargs['queryset'] = getattr(v, 'queryset', None)
+ kwargs['queryset'] = v.queryset
+
+ if getattr(v, 'widget', None):
+ widget = copy.deepcopy(v.widget)
+ # If choices have friendly readable names,
+ # then add in the identities too
+ if getattr(widget, 'choices', None):
+ choices = widget.choices
+ if any([ident != desc for (ident, desc) in choices]):
+ choices = [(ident, "%s (%s)" % (desc, ident))
+ for (ident, desc) in choices]
+ widget.choices = choices
+ kwargs['widget'] = widget
+
+ if getattr(v, 'default', None) is not None:
+ kwargs['initial'] = v.default
+
+ kwargs['label'] = k
try:
fields[k] = field_mapping[v.__class__](**kwargs)
except KeyError:
- fields[k] = forms.CharField()
+ fields[k] = forms.CharField(**kwargs)
+ return fields
+
+ def get_form(self, view, method, request):
+ """
+ Get a form, possibly bound to either the input or output data.
+ In the absence on of the Resource having an associated form then
+ provide a form that can be used to submit arbitrary content.
+ """
+ obj = getattr(view, 'object', None)
+ if not self.show_form_for_method(view, method, request, obj):
+ return
+
+ if method == 'DELETE' or method == 'OPTIONS':
+ return True # Don't actually need to return a form
+
+ if not getattr(view, 'get_serializer', None) or not parsers.FormParser in view.parser_classes:
+ media_types = [parser.media_type for parser in view.parser_classes]
+ return self.get_generic_content_form(media_types)
+
+ serializer = view.get_serializer(instance=obj)
+ fields = self.serializer_to_form_fields(serializer)
+ # Creating an on the fly form see:
+ # http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python
OnTheFlyForm = type("OnTheFlyForm", (forms.Form,), fields)
- if obj and not view.request.method == 'DELETE': # Don't fill in the form when the object is deleted
- data = serializer.data
+ data = (obj is not None) and serializer.data or None
form_instance = OnTheFlyForm(data)
return form_instance
diff --git a/rest_framework/request.py b/rest_framework/request.py
index 0a57d376..a1827ba4 100644
--- a/rest_framework/request.py
+++ b/rest_framework/request.py
@@ -11,9 +11,18 @@ The wrapped request then offers a richer API, in particular :
"""
from StringIO import StringIO
+from django.http.multipartparser import parse_header
from rest_framework import exceptions
from rest_framework.settings import api_settings
-from rest_framework.utils.mediatypes import is_form_media_type
+
+
+def is_form_media_type(media_type):
+ """
+ Return True if the media type is a valid form media type.
+ """
+ base_media_type, params = parse_header(media_type)
+ return (base_media_type == 'application/x-www-form-urlencoded' or
+ base_media_type == 'multipart/form-data')
class Empty(object):
@@ -35,7 +44,8 @@ def clone_request(request, method):
"""
ret = Request(request._request,
request.parsers,
- request.authenticators)
+ request.authenticators,
+ request.parser_context)
ret._data = request._data
ret._files = request._files
ret._content_type = request._content_type
@@ -65,19 +75,24 @@ class Request(object):
_CONTENTTYPE_PARAM = api_settings.FORM_CONTENTTYPE_OVERRIDE
def __init__(self, request, parsers=None, authenticators=None,
- negotiator=None):
+ negotiator=None, parser_context=None):
self._request = request
self.parsers = parsers or ()
self.authenticators = authenticators or ()
self.negotiator = negotiator or self._default_negotiator()
+ self.parser_context = parser_context
self._data = Empty
self._files = Empty
self._method = Empty
self._content_type = Empty
self._stream = Empty
+ if self.parser_context is None:
+ self.parser_context = {}
+ self.parser_context['request'] = self
+
def _default_negotiator(self):
- return api_settings.DEFAULT_CONTENT_NEGOTIATION()
+ return api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS()
@property
def method(self):
@@ -96,7 +111,7 @@ class Request(object):
"""
Returns the content type header.
- This should be used instead of ``request.META.get('HTTP_CONTENT_TYPE')``,
+ This should be used instead of `request.META.get('HTTP_CONTENT_TYPE')`,
as it allows the content type to be overridden by using a hidden form
field on a form POST request.
"""
@@ -245,16 +260,19 @@ class Request(object):
May raise an `UnsupportedMediaType`, or `ParseError` exception.
"""
- if self.stream is None or self.content_type is None:
+ stream = self.stream
+ media_type = self.content_type
+
+ if stream is None or media_type is None:
return (None, None)
- parser = self.negotiator.select_parser(self.parsers, self.content_type)
+ parser = self.negotiator.select_parser(self, self.parsers)
if not parser:
- raise exceptions.UnsupportedMediaType(self.content_type)
+ raise exceptions.UnsupportedMediaType(media_type)
+
+ parsed = parser.parse(stream, media_type, self.parser_context)
- parsed = parser.parse(self.stream, meta=self.META,
- upload_handlers=self.upload_handlers)
# Parser classes may return the raw data, or a
# DataAndFiles object. Unpack the result as required.
try:
diff --git a/rest_framework/resources.py b/rest_framework/resources.py
deleted file mode 100644
index bb3d581f..00000000
--- a/rest_framework/resources.py
+++ /dev/null
@@ -1,95 +0,0 @@
-##### RESOURCES AND ROUTERS ARE NOT YET IMPLEMENTED - PLACEHOLDER ONLY #####
-
-from functools import update_wrapper
-import inspect
-from django.utils.decorators import classonlymethod
-from rest_framework import views, generics
-
-
-def wrapped(source, dest):
- """
- Copy public, non-method attributes from source to dest, and return dest.
- """
- for attr in [attr for attr in dir(source)
- if not attr.startswith('_') and not inspect.ismethod(attr)]:
- setattr(dest, attr, getattr(source, attr))
- return dest
-
-
-##### RESOURCES AND ROUTERS ARE NOT YET IMPLEMENTED - PLACEHOLDER ONLY #####
-
-class ResourceMixin(object):
- """
- Clone Django's `View.as_view()` behaviour *except* using REST framework's
- 'method -> action' binding for resources.
- """
-
- @classonlymethod
- def as_view(cls, actions, **initkwargs):
- """
- Main entry point for a request-response process.
- """
- # sanitize keyword arguments
- for key in initkwargs:
- if key in cls.http_method_names:
- raise TypeError("You tried to pass in the %s method name as a "
- "keyword argument to %s(). Don't do that."
- % (key, cls.__name__))
- if not hasattr(cls, key):
- raise TypeError("%s() received an invalid keyword %r" % (
- cls.__name__, key))
-
- def view(request, *args, **kwargs):
- self = cls(**initkwargs)
-
- # Bind methods to actions
- for method, action in actions.items():
- handler = getattr(self, action)
- setattr(self, method, handler)
-
- # As you were, solider.
- if hasattr(self, 'get') and not hasattr(self, 'head'):
- self.head = self.get
- return self.dispatch(request, *args, **kwargs)
-
- # take name and docstring from class
- update_wrapper(view, cls, updated=())
-
- # and possible attributes set by decorators
- # like csrf_exempt from dispatch
- update_wrapper(view, cls.dispatch, assigned=())
- return view
-
-
-##### RESOURCES AND ROUTERS ARE NOT YET IMPLEMENTED - PLACEHOLDER ONLY #####
-
-class Resource(ResourceMixin, views.APIView):
- pass
-
-
-##### RESOURCES AND ROUTERS ARE NOT YET IMPLEMENTED - PLACEHOLDER ONLY #####
-
-class ModelResource(ResourceMixin, views.APIView):
- root_class = generics.ListCreateAPIView
- detail_class = generics.RetrieveUpdateDestroyAPIView
-
- def root_view(self):
- return wrapped(self, self.root_class())
-
- def detail_view(self):
- return wrapped(self, self.detail_class())
-
- def list(self, request, *args, **kwargs):
- return self.root_view().list(request, args, kwargs)
-
- def create(self, request, *args, **kwargs):
- return self.root_view().create(request, args, kwargs)
-
- def retrieve(self, request, *args, **kwargs):
- return self.detail_view().retrieve(request, args, kwargs)
-
- def update(self, request, *args, **kwargs):
- return self.detail_view().update(request, args, kwargs)
-
- def destroy(self, request, *args, **kwargs):
- return self.detail_view().destroy(request, args, kwargs)
diff --git a/rest_framework/reverse.py b/rest_framework/reverse.py
index ba663f98..c9db02f0 100644
--- a/rest_framework/reverse.py
+++ b/rest_framework/reverse.py
@@ -5,13 +5,15 @@ from django.core.urlresolvers import reverse as django_reverse
from django.utils.functional import lazy
-def reverse(viewname, *args, **kwargs):
+def reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra):
"""
Same as `django.core.urlresolvers.reverse`, but optionally takes a request
and returns a fully qualified URL, using the request to get the base URL.
"""
- request = kwargs.pop('request', None)
- url = django_reverse(viewname, *args, **kwargs)
+ if format is not None:
+ kwargs = kwargs or {}
+ kwargs['format'] = format
+ url = django_reverse(viewname, args=args, kwargs=kwargs, **extra)
if request:
return request.build_absolute_uri(url)
return url
diff --git a/rest_framework/runtests/runtests.py b/rest_framework/runtests/runtests.py
index b2438c9b..1bd0a5fc 100755
--- a/rest_framework/runtests/runtests.py
+++ b/rest_framework/runtests/runtests.py
@@ -32,7 +32,7 @@ def main():
else:
print usage()
sys.exit(1)
- failures = test_runner.run_tests(['rest_framework' + test_case])
+ failures = test_runner.run_tests(['tests' + test_case])
sys.exit(failures)
diff --git a/rest_framework/runtests/settings.py b/rest_framework/runtests/settings.py
index 67de82c8..951b1e72 100644
--- a/rest_framework/runtests/settings.py
+++ b/rest_framework/runtests/settings.py
@@ -91,6 +91,7 @@ INSTALLED_APPS = (
# 'django.contrib.admindocs',
'rest_framework',
'rest_framework.authtoken',
+ 'rest_framework.tests'
)
STATIC_URL = '/static/'
@@ -100,14 +101,6 @@ import django
if django.VERSION < (1, 3):
INSTALLED_APPS += ('staticfiles',)
-# OAuth support is optional, so we only test oauth if it's installed.
-try:
- import oauth_provider
-except ImportError:
- pass
-else:
- INSTALLED_APPS += ('oauth_provider',)
-
# If we're running on the Jenkins server we want to archive the coverage reports as XML.
import os
if os.environ.get('HUDSON_URL', None):
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 06330017..3d134a74 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -3,6 +3,7 @@ import datetime
import types
from decimal import Decimal
from django.db import models
+from django.forms import widgets
from django.utils.datastructures import SortedDict
from rest_framework.compat import get_concrete_model
from rest_framework.fields import *
@@ -22,10 +23,6 @@ class SortedDictWithMetadata(SortedDict, DictWithMetadata):
pass
-class RecursionOccured(BaseException):
- pass
-
-
def _is_protected_type(obj):
"""
True if the object is a native datatype that does not need to
@@ -33,10 +30,10 @@ def _is_protected_type(obj):
"""
return isinstance(obj, (
types.NoneType,
- int, long,
- datetime.datetime, datetime.date, datetime.time,
- float, Decimal,
- basestring)
+ int, long,
+ datetime.datetime, datetime.date, datetime.time,
+ float, Decimal,
+ basestring)
)
@@ -73,7 +70,7 @@ class SerializerOptions(object):
Meta class options for Serializer
"""
def __init__(self, meta):
- self.nested = getattr(meta, 'nested', False)
+ self.depth = getattr(meta, 'depth', 0)
self.fields = getattr(meta, 'fields', ())
self.exclude = getattr(meta, 'exclude', ())
@@ -92,7 +89,6 @@ class BaseSerializer(Field):
self.parent = None
self.root = None
- self.stack = []
self.context = context or {}
self.init_data = data
@@ -151,14 +147,11 @@ class BaseSerializer(Field):
def initialize(self, parent):
"""
Same behaviour as usual Field, except that we need to keep track
- of state so that we can deal with handling maximum depth and recursion.
+ of state so that we can deal with handling maximum depth.
"""
super(BaseSerializer, self).initialize(parent)
- self.stack = parent.stack[:]
- if parent.opts.nested and not isinstance(parent.opts.nested, bool):
- self.opts.nested = parent.opts.nested - 1
- else:
- self.opts.nested = parent.opts.nested
+ if parent.opts.depth:
+ self.opts.depth = parent.opts.depth - 1
#####
# Methods to convert or revert from objects <--> primative representations.
@@ -174,21 +167,13 @@ class BaseSerializer(Field):
Core of serialization.
Convert an object into a dictionary of serialized field values.
"""
- if obj in self.stack and not self.source == '*':
- raise RecursionOccured()
- self.stack.append(obj)
-
ret = self._dict_class()
ret.fields = {}
- fields = self.get_fields(serialize=True, obj=obj, nested=self.opts.nested)
+ fields = self.get_fields(serialize=True, obj=obj, nested=bool(self.opts.depth))
for field_name, field in fields.items():
key = self.get_field_key(field_name)
- try:
- value = field.field_to_native(obj, field_name)
- except RecursionOccured:
- field = self.get_fields(serialize=True, obj=obj, nested=False)[field_name]
- value = field.field_to_native(obj, field_name)
+ value = field.field_to_native(obj, field_name)
ret[key] = value
ret.fields[key] = field
return ret
@@ -198,7 +183,7 @@ class BaseSerializer(Field):
Core of deserialization, together with `restore_object`.
Converts a dictionary of data into a dictionary of deserialized fields.
"""
- fields = self.get_fields(serialize=False, data=data, nested=self.opts.nested)
+ fields = self.get_fields(serialize=False, data=data, nested=bool(self.opts.depth))
reverted_data = {}
for field_name, field in fields.items():
try:
@@ -208,6 +193,35 @@ class BaseSerializer(Field):
return reverted_data
+ def perform_validation(self, attrs):
+ """
+ Run `validate_<fieldname>()` and `validate()` methods on the serializer
+ """
+ # TODO: refactor this so we're not determining the fields again
+ fields = self.get_fields(serialize=False, data=attrs, nested=bool(self.opts.depth))
+
+ for field_name, field in fields.items():
+ try:
+ validate_method = getattr(self, 'validate_%s' % field_name, None)
+ if validate_method:
+ source = field.source or field_name
+ attrs = validate_method(attrs, source)
+ except ValidationError as err:
+ self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages)
+
+ try:
+ attrs = self.validate(attrs)
+ except ValidationError as err:
+ self._errors['non_field_errors'] = err.messages
+
+ return attrs
+
+ def validate(self, attrs):
+ """
+ Stub method, to be overridden in Serializer subclasses
+ """
+ return attrs
+
def restore_object(self, attrs, instance=None):
"""
Deserialize a dictionary of attributes into an object instance.
@@ -241,17 +255,31 @@ class BaseSerializer(Field):
self._errors = {}
if data is not None:
attrs = self.restore_fields(data)
+ attrs = self.perform_validation(attrs)
else:
- self._errors['non_field_errors'] = 'No input provided'
+ self._errors['non_field_errors'] = ['No input provided']
if not self._errors:
return self.restore_object(attrs, instance=getattr(self, 'object', None))
+ def field_to_native(self, obj, field_name):
+ """
+ Override default so that we can apply ModelSerializer as a nested
+ field to relationships.
+ """
+ obj = getattr(obj, self.source or field_name)
+
+ # If the object has an "all" method, assume it's a relationship
+ if is_simple_callable(getattr(obj, 'all', None)):
+ return [self.to_native(item) for item in obj.all()]
+
+ return self.to_native(obj)
+
@property
def errors(self):
"""
Run deserialization and return error data,
- setting self.object if no errors occured.
+ setting self.object if no errors occurred.
"""
if self._errors is None:
obj = self.from_native(self.init_data)
@@ -295,16 +323,6 @@ class ModelSerializer(Serializer):
"""
_options_class = ModelSerializerOptions
- def field_to_native(self, obj, field_name):
- """
- Override default so that we can apply ModelSerializer as a nested
- field to relationships.
- """
- obj = getattr(obj, self.source or field_name)
- if obj.__class__.__name__ in ('RelatedManager', 'ManyRelatedManager'):
- return [self.to_native(item) for item in obj.all()]
- return self.to_native(obj)
-
def default_fields(self, serialize, obj=None, data=None, nested=False):
"""
Return all the fields that should be serialized for the model.
@@ -374,25 +392,43 @@ class ModelSerializer(Serializer):
"""
Creates a default instance of a basic non-relational field.
"""
+ kwargs = {}
+
+ kwargs['blank'] = model_field.blank
+
+ if model_field.null:
+ kwargs['required'] = False
+
+ if model_field.has_default():
+ kwargs['required'] = False
+ kwargs['default'] = model_field.get_default()
+
+ if model_field.__class__ == models.TextField:
+ kwargs['widget'] = widgets.Textarea
+
+ # TODO: TypedChoiceField?
+ if model_field.flatchoices: # This ModelField contains choices
+ kwargs['choices'] = model_field.flatchoices
+ return ChoiceField(**kwargs)
+
field_mapping = {
models.FloatField: FloatField,
models.IntegerField: IntegerField,
+ models.PositiveIntegerField: IntegerField,
+ models.SmallIntegerField: IntegerField,
+ models.PositiveSmallIntegerField: IntegerField,
models.DateTimeField: DateTimeField,
models.DateField: DateField,
models.EmailField: EmailField,
models.CharField: CharField,
+ models.TextField: CharField,
models.CommaSeparatedIntegerField: CharField,
models.BooleanField: BooleanField,
}
try:
- ret = field_mapping[model_field.__class__]()
+ return field_mapping[model_field.__class__](**kwargs)
except KeyError:
- ret = ModelField(model_field=model_field)
-
- if model_field.default:
- ret.required = False
-
- return ret
+ return ModelField(model_field=model_field, **kwargs)
def restore_object(self, attrs, instance=None):
"""
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index 5ebe7ba5..9c40a214 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -3,11 +3,11 @@ Settings for REST framework are all namespaced in the REST_FRAMEWORK setting.
For example your project's `settings.py` file might look like this:
REST_FRAMEWORK = {
- 'DEFAULT_RENDERERS': (
+ 'DEFAULT_RENDERER_CLASSES': (
'rest_framework.renderers.JSONRenderer',
'rest_framework.renderers.YAMLRenderer',
)
- 'DEFAULT_PARSERS': (
+ 'DEFAULT_PARSER_CLASSES': (
'rest_framework.parsers.JSONParser',
'rest_framework.parsers.YAMLParser',
)
@@ -24,30 +24,36 @@ from django.utils import importlib
USER_SETTINGS = getattr(settings, 'REST_FRAMEWORK', None)
DEFAULTS = {
- 'DEFAULT_RENDERERS': (
+ 'DEFAULT_RENDERER_CLASSES': (
'rest_framework.renderers.JSONRenderer',
'rest_framework.renderers.BrowsableAPIRenderer',
),
- 'DEFAULT_PARSERS': (
+ 'DEFAULT_PARSER_CLASSES': (
'rest_framework.parsers.JSONParser',
'rest_framework.parsers.FormParser',
'rest_framework.parsers.MultiPartParser'
),
- 'DEFAULT_AUTHENTICATION': (
+ 'DEFAULT_AUTHENTICATION_CLASSES': (
'rest_framework.authentication.SessionAuthentication',
- 'rest_framework.authentication.UserBasicAuthentication'
+ 'rest_framework.authentication.BasicAuthentication'
),
- 'DEFAULT_PERMISSIONS': (),
- 'DEFAULT_THROTTLES': (),
- 'DEFAULT_CONTENT_NEGOTIATION':
+ 'DEFAULT_PERMISSION_CLASSES': (
+ 'rest_framework.permissions.AllowAny',
+ ),
+ 'DEFAULT_THROTTLE_CLASSES': (
+ ),
+
+ 'DEFAULT_CONTENT_NEGOTIATION_CLASS':
'rest_framework.negotiation.DefaultContentNegotiation',
+ 'DEFAULT_MODEL_SERIALIZER_CLASS':
+ 'rest_framework.serializers.ModelSerializer',
+ 'DEFAULT_PAGINATION_SERIALIZER_CLASS':
+ 'rest_framework.pagination.PaginationSerializer',
+
'DEFAULT_THROTTLE_RATES': {
'user': None,
'anon': None,
},
-
- 'MODEL_SERIALIZER': 'rest_framework.serializers.ModelSerializer',
- 'PAGINATION_SERIALIZER': 'rest_framework.pagination.PaginationSerializer',
'PAGINATE_BY': None,
'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
@@ -65,14 +71,14 @@ DEFAULTS = {
# List of settings that may be in string import notation.
IMPORT_STRINGS = (
- 'DEFAULT_RENDERERS',
- 'DEFAULT_PARSERS',
- 'DEFAULT_AUTHENTICATION',
- 'DEFAULT_PERMISSIONS',
- 'DEFAULT_THROTTLES',
- 'DEFAULT_CONTENT_NEGOTIATION',
- 'MODEL_SERIALIZER',
- 'PAGINATION_SERIALIZER',
+ 'DEFAULT_RENDERER_CLASSES',
+ 'DEFAULT_PARSER_CLASSES',
+ 'DEFAULT_AUTHENTICATION_CLASSES',
+ 'DEFAULT_PERMISSION_CLASSES',
+ 'DEFAULT_THROTTLE_CLASSES',
+ 'DEFAULT_CONTENT_NEGOTIATION_CLASS',
+ 'DEFAULT_MODEL_SERIALIZER_CLASS',
+ 'DEFAULT_PAGINATION_SERIALIZER_CLASS',
'UNAUTHENTICATED_USER',
'UNAUTHENTICATED_TOKEN',
)
@@ -111,7 +117,7 @@ class APISettings(object):
For example:
from rest_framework.settings import api_settings
- print api_settings.DEFAULT_RENDERERS
+ print api_settings.DEFAULT_RENDERER_CLASSES
Any setting with string import paths will be automatically resolved
and return the class, rather than the string literal.
diff --git a/rest_framework/static/rest_framework/css/default.css b/rest_framework/static/rest_framework/css/default.css
index 739b9300..e29da395 100644
--- a/rest_framework/static/rest_framework/css/default.css
+++ b/rest_framework/static/rest_framework/css/default.css
@@ -32,6 +32,10 @@ h2, h3 {
margin-right: 1em;
}
+ul.breadcrumb {
+ margin: 58px 0 0 0;
+}
+
/* To allow tooltips to work on disabled elements */
.disabled-tooltip-shield {
position: absolute;
@@ -55,6 +59,7 @@ pre {
.page-header {
border-bottom: none;
padding-bottom: 0px;
+ margin-bottom: 20px;
}
@@ -65,7 +70,7 @@ html{
background: none;
}
-body, .navbar .navbar-inner .container-fluid{
+body, .navbar .navbar-inner .container-fluid {
max-width: 1150px;
margin: 0 auto;
}
@@ -76,13 +81,14 @@ body{
}
#content{
- margin: 40px 0 0 0;
+ margin: 0;
}
/* custom navigation styles */
.wrapper .navbar{
- width:100%;
+ width: 100%;
position: absolute;
- left:0;
+ left: 0;
+ top: 0;
}
.navbar .navbar-inner{
diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html
index 5ac6ef67..e0f79481 100644
--- a/rest_framework/templates/rest_framework/base.html
+++ b/rest_framework/templates/rest_framework/base.html
@@ -109,7 +109,7 @@
<div class="content-main">
<div class="page-header"><h1>{{ name }}</h1></div>
- <p class="resource-description">{{ description }}</p>
+ {{ description }}
<div class="request-info">
<pre class="prettyprint"><b>{{ request.method }}</b> {{ request.get_full_path }}</pre>
diff --git a/rest_framework/templates/rest_framework/login.html b/rest_framework/templates/rest_framework/login.html
index 65af512e..c1271399 100644
--- a/rest_framework/templates/rest_framework/login.html
+++ b/rest_framework/templates/rest_framework/login.html
@@ -3,42 +3,50 @@
<html>
<head>
- <link rel="stylesheet" type="text/css" href='{% get_static_prefix %}rest_framework/css/style.css'/>
+ <link rel="stylesheet" type="text/css" href="{% get_static_prefix %}rest_framework/css/bootstrap.min.css"/>
+ <link rel="stylesheet" type="text/css" href="{% get_static_prefix %}rest_framework/css/bootstrap-tweaks.css"/>
+ <link rel="stylesheet" type="text/css" href='{% get_static_prefix %}rest_framework/css/default.css'/>
</head>
- <body class="login">
+ <body class="container">
- <div id="container">
-
- <div id="header">
- <div id="branding">
- <h1 id="site-name">Django REST framework</h1>
+<div class="container-fluid" style="margin-top: 30px">
+ <div class="row-fluid">
+
+ <div class="well" style="width: 320px; margin-left: auto; margin-right: auto">
+ <div class="row-fluid">
+ <div>
+ <h3 style="margin: 0 0 20px;">Django REST framework</h3>
</div>
- </div>
+ </div><!-- /row fluid -->
- <div id="content" class="colM">
- <div id="content-main">
- <form method="post" action="{% url 'rest_framework:login' %}" id="login-form">
+ <div class="row-fluid">
+ <div>
+ <form action="{% url 'rest_framework:login' %}" class=" form-inline" method="post">
{% csrf_token %}
- <div class="form-row">
- <label for="id_username">Username:</label> {{ form.username }}
+ <div id="div_id_username" class="clearfix control-group">
+ <div class="controls" style="height: 30px">
+ <Label class="span4" style="margin-top: 3px">Username:</label>
+ <input style="height: 25px" type="text" name="username" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_username">
+ </div>
</div>
- <div class="form-row">
- <label for="id_password">Password:</label> {{ form.password }}
- <input type="hidden" name="next" value="{{ next }}" />
+ <div id="div_id_password" class="clearfix control-group">
+ <div class="controls" style="height: 30px">
+ <Label class="span4" style="margin-top: 3px">Password:</label>
+ <input style="height: 25px" type="password" name="password" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_password">
+ </div>
</div>
- <div class="form-row">
- <label>&nbsp;</label><input type="submit" value="Log in">
+ <input type="hidden" name="next" value="{{ next }}" />
+ <div class="form-actions-no-box">
+ <input type="submit" name="submit" value="Log in" class="btn btn-primary" id="submit-id-submit">
</div>
</form>
- <script type="text/javascript">
- document.getElementById('id_username').focus()
- </script>
</div>
- <br class="clear">
- </div>
+ </div><!-- /row fluid -->
+ </div><!--/span-->
- <div id="footer"></div>
+ </div><!-- /.row-fluid -->
+ </div>
</div>
</body>
diff --git a/rest_framework/tests/__init__.py b/rest_framework/tests/__init__.py
index adeaf6da..e69de29b 100644
--- a/rest_framework/tests/__init__.py
+++ b/rest_framework/tests/__init__.py
@@ -1,13 +0,0 @@
-"""
-Force import of all modules in this package in order to get the standard test
-runner to pick up the tests. Yowzers.
-"""
-import os
-
-modules = [filename.rsplit('.', 1)[0]
- for filename in os.listdir(os.path.dirname(__file__))
- if filename.endswith('.py') and not filename.startswith('_')]
-__test__ = dict()
-
-for module in modules:
- exec("from rest_framework.tests.%s import *" % module)
diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py
index f4263478..a8279ef2 100644
--- a/rest_framework/tests/generics.py
+++ b/rest_framework/tests/generics.py
@@ -2,7 +2,7 @@ from django.test import TestCase
from django.test.client import RequestFactory
from django.utils import simplejson as json
from rest_framework import generics, serializers, status
-from rest_framework.tests.models import BasicModel, Comment
+from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel
factory = RequestFactory()
@@ -22,6 +22,22 @@ class InstanceView(generics.RetrieveUpdateDestroyAPIView):
model = BasicModel
+class SlugSerializer(serializers.ModelSerializer):
+ slug = serializers.Field() # read only
+
+ class Meta:
+ model = SlugBasedModel
+ exclude = ('id',)
+
+
+class SlugBasedInstanceView(InstanceView):
+ """
+ A model with a slug-field.
+ """
+ model = SlugBasedModel
+ serializer_class = SlugSerializer
+
+
class TestRootView(TestCase):
def setUp(self):
"""
@@ -129,6 +145,7 @@ class TestInstanceView(TestCase):
for obj in self.objects.all()
]
self.view = InstanceView.as_view()
+ self.slug_based_view = SlugBasedInstanceView.as_view()
def test_get_instance_view(self):
"""
@@ -198,7 +215,7 @@ class TestInstanceView(TestCase):
def test_put_cannot_set_id(self):
"""
- POST requests to create a new object should not be able to set the id.
+ PUT requests to create a new object should not be able to set the id.
"""
content = {'id': 999, 'text': 'foobar'}
request = factory.put('/1', json.dumps(content),
@@ -219,11 +236,39 @@ class TestInstanceView(TestCase):
request = factory.put('/1', json.dumps(content),
content_type='application/json')
response = self.view(request, pk=1).render()
- self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.status_code, status.HTTP_201_CREATED)
self.assertEquals(response.data, {'id': 1, 'text': 'foobar'})
updated = self.objects.get(id=1)
self.assertEquals(updated.text, 'foobar')
+ def test_put_as_create_on_id_based_url(self):
+ """
+ PUT requests to RetrieveUpdateDestroyAPIView should create an object
+ at the requested url if it doesn't exist.
+ """
+ content = {'text': 'foobar'}
+ # pk fields can not be created on demand, only the database can set th pk for a new object
+ request = factory.put('/5', json.dumps(content),
+ content_type='application/json')
+ response = self.view(request, pk=5).render()
+ self.assertEquals(response.status_code, status.HTTP_201_CREATED)
+ new_obj = self.objects.get(pk=5)
+ self.assertEquals(new_obj.text, 'foobar')
+
+ def test_put_as_create_on_slug_based_url(self):
+ """
+ PUT requests to RetrieveUpdateDestroyAPIView should create an object
+ at the requested url if possible, else return HTTP_403_FORBIDDEN error-response.
+ """
+ content = {'text': 'foobar'}
+ request = factory.put('/test_slug', json.dumps(content),
+ content_type='application/json')
+ response = self.slug_based_view(request, slug='test_slug').render()
+ self.assertEquals(response.status_code, status.HTTP_201_CREATED)
+ self.assertEquals(response.data, {'slug': 'test_slug', 'text': 'foobar'})
+ new_obj = SlugBasedModel.objects.get(slug='test_slug')
+ self.assertEquals(new_obj.text, 'foobar')
+
# Regression test for #285
diff --git a/rest_framework/tests/htmlrenderer.py b/rest_framework/tests/htmlrenderer.py
index da2f83c3..10d7e31d 100644
--- a/rest_framework/tests/htmlrenderer.py
+++ b/rest_framework/tests/htmlrenderer.py
@@ -3,12 +3,12 @@ from django.test import TestCase
from django.template import TemplateDoesNotExist, Template
import django.template.loader
from rest_framework.decorators import api_view, renderer_classes
-from rest_framework.renderers import HTMLRenderer
+from rest_framework.renderers import TemplateHTMLRenderer
from rest_framework.response import Response
@api_view(('GET',))
-@renderer_classes((HTMLRenderer,))
+@renderer_classes((TemplateHTMLRenderer,))
def example(request):
"""
A view that can returns an HTML representation.
@@ -22,7 +22,7 @@ urlpatterns = patterns('',
)
-class HTMLRendererTests(TestCase):
+class TemplateHTMLRendererTests(TestCase):
urls = 'rest_framework.tests.htmlrenderer'
def setUp(self):
diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py
index 5532a8ee..92c3691e 100644
--- a/rest_framework/tests/hyperlinkedserializers.py
+++ b/rest_framework/tests/hyperlinkedserializers.py
@@ -2,11 +2,19 @@ from django.conf.urls.defaults import patterns, url
from django.test import TestCase
from django.test.client import RequestFactory
from rest_framework import generics, status, serializers
-from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel
+from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment
factory = RequestFactory()
+class BlogPostCommentSerializer(serializers.Serializer):
+ text = serializers.CharField()
+ blog_post_url = serializers.HyperlinkedRelatedField(source='blog_post', view_name='blogpost-detail', queryset=BlogPost.objects.all())
+
+ def restore_object(self, attrs, instance=None):
+ return BlogPostComment(**attrs)
+
+
class BasicList(generics.ListCreateAPIView):
model = BasicModel
model_serializer_class = serializers.HyperlinkedModelSerializer
@@ -32,12 +40,22 @@ class ManyToManyDetail(generics.RetrieveAPIView):
model_serializer_class = serializers.HyperlinkedModelSerializer
+class BlogPostCommentListCreate(generics.ListCreateAPIView):
+ model = BlogPostComment
+ model_serializer_class = BlogPostCommentSerializer
+
+
+class BlogPostDetail(generics.RetrieveAPIView):
+ model = BlogPost
+
urlpatterns = patterns('',
url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'),
url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'),
url(r'^anchor/(?P<pk>\d+)/$', AnchorDetail.as_view(), name='anchor-detail'),
url(r'^manytomany/$', ManyToManyList.as_view(), name='manytomanymodel-list'),
url(r'^manytomany/(?P<pk>\d+)/$', ManyToManyDetail.as_view(), name='manytomanymodel-detail'),
+ url(r'^posts/(?P<pk>\d+)/$', BlogPostDetail.as_view(), name='blogpost-detail'),
+ url(r'^comments/$', BlogPostCommentListCreate.as_view(), name='blogpostcomment-list')
)
@@ -124,3 +142,27 @@ class TestManyToManyHyperlinkedView(TestCase):
response = self.detail_view(request, pk=1).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data[0])
+
+
+class TestCreateWithForeignKeys(TestCase):
+ urls = 'rest_framework.tests.hyperlinkedserializers'
+
+ def setUp(self):
+ """
+ Create a blog post
+ """
+ self.post = BlogPost.objects.create(title="Test post")
+ self.create_view = BlogPostCommentListCreate.as_view()
+
+ def test_create_comment(self):
+
+ data = {
+ 'text': 'A test comment',
+ 'blog_post_url': 'http://testserver/posts/1/'
+ }
+
+ request = factory.post('/comments/', data=data)
+ response = self.create_view(request).render()
+ self.assertEqual(response.status_code, 201)
+ self.assertEqual(self.post.blogpostcomment_set.count(), 1)
+ self.assertEqual(self.post.blogpostcomment_set.all()[0].text, 'A test comment')
diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py
index 780c9dba..9efedbc4 100644
--- a/rest_framework/tests/models.py
+++ b/rest_framework/tests/models.py
@@ -40,7 +40,7 @@ class RESTFrameworkModel(models.Model):
Base for test models that sets app_label, so they play nicely.
"""
class Meta:
- app_label = 'rest_framework'
+ app_label = 'tests'
abstract = True
@@ -52,6 +52,11 @@ class BasicModel(RESTFrameworkModel):
text = models.CharField(max_length=100)
+class SlugBasedModel(RESTFrameworkModel):
+ text = models.CharField(max_length=100)
+ slug = models.SlugField(max_length=32)
+
+
class DefaultValueModel(RESTFrameworkModel):
text = models.CharField(default='foobar', max_length=100)
@@ -63,6 +68,11 @@ class CallableDefaultValueModel(RESTFrameworkModel):
class ManyToManyModel(RESTFrameworkModel):
rel = models.ManyToManyField(Anchor)
+
+class ReadOnlyManyToManyModel(RESTFrameworkModel):
+ text = models.CharField(max_length=100, default='anchor')
+ rel = models.ManyToManyField(Anchor)
+
# Models to test generic relations
@@ -98,3 +108,28 @@ class Comment(RESTFrameworkModel):
email = models.EmailField()
content = models.CharField(max_length=200)
created = models.DateTimeField(auto_now_add=True)
+
+
+class ActionItem(RESTFrameworkModel):
+ title = models.CharField(max_length=200)
+ done = models.BooleanField(default=False)
+
+
+# Models for reverse relations
+class BlogPost(RESTFrameworkModel):
+ title = models.CharField(max_length=100)
+
+
+class BlogPostComment(RESTFrameworkModel):
+ text = models.TextField()
+ blog_post = models.ForeignKey(BlogPost)
+
+
+class Person(RESTFrameworkModel):
+ name = models.CharField(max_length=10)
+ age = models.IntegerField(null=True, blank=True)
+
+
+# Model for issue #324
+class BlankFieldModel(RESTFrameworkModel):
+ title = models.CharField(max_length=100, blank=True)
diff --git a/rest_framework/tests/negotiation.py b/rest_framework/tests/negotiation.py
index d8265b43..e06354ea 100644
--- a/rest_framework/tests/negotiation.py
+++ b/rest_framework/tests/negotiation.py
@@ -18,20 +18,20 @@ class TestAcceptedMediaType(TestCase):
self.renderers = [MockJSONRenderer(), MockHTMLRenderer()]
self.negotiator = DefaultContentNegotiation()
- def negotiate(self, request):
- return self.negotiator.negotiate(request, self.renderers)
+ def select_renderer(self, request):
+ return self.negotiator.select_renderer(request, self.renderers)
def test_client_without_accept_use_renderer(self):
request = factory.get('/')
- accepted_renderer, accepted_media_type = self.negotiate(request)
+ accepted_renderer, accepted_media_type = self.select_renderer(request)
self.assertEquals(accepted_media_type, 'application/json')
def test_client_underspecifies_accept_use_renderer(self):
request = factory.get('/', HTTP_ACCEPT='*/*')
- accepted_renderer, accepted_media_type = self.negotiate(request)
+ accepted_renderer, accepted_media_type = self.select_renderer(request)
self.assertEquals(accepted_media_type, 'application/json')
def test_client_overspecifies_accept_use_client(self):
request = factory.get('/', HTTP_ACCEPT='application/json; indent=8')
- accepted_renderer, accepted_media_type = self.negotiate(request)
+ accepted_renderer, accepted_media_type = self.select_renderer(request)
self.assertEquals(accepted_media_type, 'application/json; indent=8')
diff --git a/rest_framework/tests/request.py b/rest_framework/tests/request.py
index 7b24b036..ff48f3fa 100644
--- a/rest_framework/tests/request.py
+++ b/rest_framework/tests/request.py
@@ -10,9 +10,9 @@ from rest_framework import status
from rest_framework.authentication import SessionAuthentication
from django.test.client import RequestFactory
from rest_framework.parsers import (
+ BaseParser,
FormParser,
MultiPartParser,
- PlainTextParser,
JSONParser
)
from rest_framework.request import Request
@@ -24,6 +24,19 @@ from rest_framework.views import APIView
factory = RequestFactory()
+class PlainTextParser(BaseParser):
+ media_type = 'text/plain'
+
+ def parse(self, stream, media_type=None, parser_context=None):
+ """
+ Returns a 2-tuple of `(data, files)`.
+
+ `data` will simply be a string representing the body of the request.
+ `files` will always be `None`.
+ """
+ return stream.read()
+
+
class TestMethodOverloading(TestCase):
def test_method(self):
"""
diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py
index 256987ad..d4b43862 100644
--- a/rest_framework/tests/serializer.py
+++ b/rest_framework/tests/serializer.py
@@ -4,6 +4,11 @@ from rest_framework import serializers
from rest_framework.tests.models import *
+class SubComment(object):
+ def __init__(self, sub_comment):
+ self.sub_comment = sub_comment
+
+
class Comment(object):
def __init__(self, email, content, created):
self.email = email
@@ -14,11 +19,16 @@ class Comment(object):
return all([getattr(self, attr) == getattr(other, attr)
for attr in ('email', 'content', 'created')])
+ def get_sub_comment(self):
+ sub_comment = SubComment('And Merry Christmas!')
+ return sub_comment
+
class CommentSerializer(serializers.Serializer):
email = serializers.EmailField()
content = serializers.CharField(max_length=1000)
created = serializers.DateTimeField()
+ sub_comment = serializers.Field(source='get_sub_comment.sub_comment')
def restore_object(self, data, instance=None):
if instance is None:
@@ -28,6 +38,16 @@ class CommentSerializer(serializers.Serializer):
return instance
+class ActionItemSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ActionItem
+
+
+class PersonSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Person
+
+
class BasicTests(TestCase):
def setUp(self):
self.comment = Comment(
@@ -38,7 +58,14 @@ class BasicTests(TestCase):
self.data = {
'email': 'tom@example.com',
'content': 'Happy new year!',
- 'created': datetime.datetime(2012, 1, 1)
+ 'created': datetime.datetime(2012, 1, 1),
+ 'sub_comment': 'This wont change'
+ }
+ self.expected = {
+ 'email': 'tom@example.com',
+ 'content': 'Happy new year!',
+ 'created': datetime.datetime(2012, 1, 1),
+ 'sub_comment': 'And Merry Christmas!'
}
def test_empty(self):
@@ -46,14 +73,14 @@ class BasicTests(TestCase):
expected = {
'email': '',
'content': '',
- 'created': None
+ 'created': None,
+ 'sub_comment': ''
}
self.assertEquals(serializer.data, expected)
def test_retrieve(self):
serializer = CommentSerializer(instance=self.comment)
- expected = self.data
- self.assertEquals(serializer.data, expected)
+ self.assertEquals(serializer.data, self.expected)
def test_create(self):
serializer = CommentSerializer(self.data)
@@ -61,6 +88,7 @@ class BasicTests(TestCase):
self.assertEquals(serializer.is_valid(), True)
self.assertEquals(serializer.object, expected)
self.assertFalse(serializer.object is expected)
+ self.assertEquals(serializer.data['sub_comment'], 'And Merry Christmas!')
def test_update(self):
serializer = CommentSerializer(self.data, instance=self.comment)
@@ -68,6 +96,7 @@ class BasicTests(TestCase):
self.assertEquals(serializer.is_valid(), True)
self.assertEquals(serializer.object, expected)
self.assertTrue(serializer.object is expected)
+ self.assertEquals(serializer.data['sub_comment'], 'And Merry Christmas!')
class ValidationTests(TestCase):
@@ -82,6 +111,8 @@ class ValidationTests(TestCase):
'content': 'x' * 1001,
'created': datetime.datetime(2012, 1, 1)
}
+ self.actionitem = ActionItem('Some to do item',
+ )
def test_create(self):
serializer = CommentSerializer(self.data)
@@ -102,6 +133,74 @@ class ValidationTests(TestCase):
self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, {'email': [u'This field is required.']})
+ def test_missing_bool_with_default(self):
+ """Make sure that a boolean value with a 'False' value is not
+ mistaken for not having a default."""
+ data = {
+ 'title': 'Some action item',
+ #No 'done' value.
+ }
+ serializer = ActionItemSerializer(data, instance=self.actionitem)
+ self.assertEquals(serializer.is_valid(), True)
+ self.assertEquals(serializer.errors, {})
+
+ def test_field_validation(self):
+
+ class CommentSerializerWithFieldValidator(CommentSerializer):
+
+ def validate_content(self, attrs, source):
+ value = attrs[source]
+ if "test" not in value:
+ raise serializers.ValidationError("Test not in value")
+ return attrs
+
+ data = {
+ 'email': 'tom@example.com',
+ 'content': 'A test comment',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+
+ serializer = CommentSerializerWithFieldValidator(data)
+ self.assertTrue(serializer.is_valid())
+
+ data['content'] = 'This should not validate'
+
+ serializer = CommentSerializerWithFieldValidator(data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'content': [u'Test not in value']})
+
+ def test_cross_field_validation(self):
+
+ class CommentSerializerWithCrossFieldValidator(CommentSerializer):
+
+ def validate(self, attrs):
+ if attrs["email"] not in attrs["content"]:
+ raise serializers.ValidationError("Email address not in content")
+ return attrs
+
+ data = {
+ 'email': 'tom@example.com',
+ 'content': 'A comment from tom@example.com',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+
+ serializer = CommentSerializerWithCrossFieldValidator(data)
+ self.assertTrue(serializer.is_valid())
+
+ data['content'] = 'A comment from foo@bar.com'
+
+ serializer = CommentSerializerWithCrossFieldValidator(data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'non_field_errors': [u'Email address not in content']})
+
+ def test_null_is_true_fields(self):
+ """
+ Omitting a value for null-field should validate.
+ """
+ serializer = PersonSerializer({'name': 'marko'})
+ self.assertEquals(serializer.is_valid(), True)
+ self.assertEquals(serializer.errors, {})
+
class MetadataTests(TestCase):
def test_empty(self):
@@ -212,6 +311,61 @@ class ManyToManyTests(TestCase):
self.assertEquals(list(instance.rel.all()), [])
+class ReadOnlyManyToManyTests(TestCase):
+ def setUp(self):
+ class ReadOnlyManyToManySerializer(serializers.ModelSerializer):
+ rel = serializers.ManyRelatedField(read_only=True)
+
+ class Meta:
+ model = ReadOnlyManyToManyModel
+
+ self.serializer_class = ReadOnlyManyToManySerializer
+
+ # An anchor instance to use for the relationship
+ self.anchor = Anchor()
+ self.anchor.save()
+
+ # A model instance with a many to many relationship to the anchor
+ self.instance = ReadOnlyManyToManyModel()
+ self.instance.save()
+ self.instance.rel.add(self.anchor)
+
+ # A serialized representation of the model instance
+ self.data = {'rel': [self.anchor.id], 'id': 1, 'text': 'anchor'}
+
+ def test_update(self):
+ """
+ Attempt to update an instance of a model with a ManyToMany
+ relationship. Not updated due to read_only=True
+ """
+ new_anchor = Anchor()
+ new_anchor.save()
+ data = {'rel': [self.anchor.id, new_anchor.id]}
+ serializer = self.serializer_class(data, instance=self.instance)
+ self.assertEquals(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEquals(len(ReadOnlyManyToManyModel.objects.all()), 1)
+ self.assertEquals(instance.pk, 1)
+ # rel is still as original (1 entry)
+ self.assertEquals(list(instance.rel.all()), [self.anchor])
+
+ def test_update_without_relationship(self):
+ """
+ Attempt to update an instance of a model where many to ManyToMany
+ relationship is not supplied. Not updated due to read_only=True
+ """
+ new_anchor = Anchor()
+ new_anchor.save()
+ data = {}
+ serializer = self.serializer_class(data, instance=self.instance)
+ self.assertEquals(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEquals(len(ReadOnlyManyToManyModel.objects.all()), 1)
+ self.assertEquals(instance.pk, 1)
+ # rel is still as original (1 entry)
+ self.assertEquals(list(instance.rel.all()), [self.anchor])
+
+
class DefaultValueTests(TestCase):
def setUp(self):
class DefaultValueSerializer(serializers.ModelSerializer):
@@ -266,3 +420,81 @@ class CallableDefaultValueTests(TestCase):
self.assertEquals(len(self.objects.all()), 1)
self.assertEquals(instance.pk, 1)
self.assertEquals(instance.text, 'overridden')
+
+
+class ManyRelatedTests(TestCase):
+ def setUp(self):
+
+ class BlogPostCommentSerializer(serializers.Serializer):
+ text = serializers.CharField()
+
+ class BlogPostSerializer(serializers.Serializer):
+ title = serializers.CharField()
+ comments = BlogPostCommentSerializer(source='blogpostcomment_set')
+
+ self.serializer_class = BlogPostSerializer
+
+ def test_reverse_relations(self):
+ post = BlogPost.objects.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I hate this blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ serializer = self.serializer_class(instance=post)
+ expected = {
+ 'title': 'Test blog post',
+ 'comments': [
+ {'text': 'I hate this blog post'},
+ {'text': 'I love this blog post'}
+ ]
+ }
+
+ self.assertEqual(serializer.data, expected)
+
+
+# Test for issue #324
+class BlankFieldTests(TestCase):
+ def setUp(self):
+
+ class BlankFieldModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlankFieldModel
+
+ class BlankFieldSerializer(serializers.Serializer):
+ title = serializers.CharField(blank=True)
+
+ class NotBlankFieldModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BasicModel
+
+ class NotBlankFieldSerializer(serializers.Serializer):
+ title = serializers.CharField()
+
+ self.model_serializer_class = BlankFieldModelSerializer
+ self.serializer_class = BlankFieldSerializer
+ self.not_blank_model_serializer_class = NotBlankFieldModelSerializer
+ self.not_blank_serializer_class = NotBlankFieldSerializer
+ self.data = {'title': ''}
+
+ def test_create_blank_field(self):
+ serializer = self.serializer_class(self.data)
+ self.assertEquals(serializer.is_valid(), True)
+
+ def test_create_model_blank_field(self):
+ serializer = self.model_serializer_class(self.data)
+ self.assertEquals(serializer.is_valid(), True)
+
+ def test_create_not_blank_field(self):
+ """
+ Test to ensure blank data in a field not marked as blank=True
+ is considered invalid in a non-model serializer
+ """
+ serializer = self.not_blank_serializer_class(self.data)
+ self.assertEquals(serializer.is_valid(), False)
+
+ def test_create_model_not_blank_field(self):
+ """
+ Test to ensure blank data in a field not marked as blank=True
+ is considered invalid in a model serializer
+ """
+ serializer = self.not_blank_model_serializer_class(self.data)
+ self.assertEquals(serializer.is_valid(), False)
diff --git a/rest_framework/tests/tests.py b/rest_framework/tests/tests.py
new file mode 100644
index 00000000..adeaf6da
--- /dev/null
+++ b/rest_framework/tests/tests.py
@@ -0,0 +1,13 @@
+"""
+Force import of all modules in this package in order to get the standard test
+runner to pick up the tests. Yowzers.
+"""
+import os
+
+modules = [filename.rsplit('.', 1)[0]
+ for filename in os.listdir(os.path.dirname(__file__))
+ if filename.endswith('.py') and not filename.startswith('_')]
+__test__ = dict()
+
+for module in modules:
+ exec("from rest_framework.tests.%s import *" % module)
diff --git a/rest_framework/tests/validators.py b/rest_framework/tests/validators.py
index b390c42f..c032985e 100644
--- a/rest_framework/tests/validators.py
+++ b/rest_framework/tests/validators.py
@@ -285,7 +285,7 @@
# uiop = models.CharField(max_length=256, blank=True)
# @property
-# def readonly(self):
+# def read_only(self):
# return 'read only'
# class MockResource(ModelResource):
@@ -298,7 +298,7 @@
# def test_property_fields_are_allowed_on_model_forms(self):
# """Validation on ModelForms may include property fields that exist on the Model to be included in the input."""
-# content = {'qwerty': 'example', 'uiop': 'example', 'readonly': 'read only'}
+# content = {'qwerty': 'example', 'uiop': 'example', 'read_only': 'read only'}
# self.assertEqual(self.validator.validate_request(content, None), content)
# def test_property_fields_are_not_required_on_model_forms(self):
@@ -310,19 +310,19 @@
# """If some (otherwise valid) content includes fields that are not in the form then validation should fail.
# It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up
# broken clients more easily (eg submitting content with a misnamed field)"""
-# content = {'qwerty': 'example', 'uiop': 'example', 'readonly': 'read only', 'extra': 'extra'}
+# content = {'qwerty': 'example', 'uiop': 'example', 'read_only': 'read only', 'extra': 'extra'}
# self.assertRaises(ImmediateResponse, self.validator.validate_request, content, None)
# def test_validate_requires_fields_on_model_forms(self):
# """If some (otherwise valid) content includes fields that are not in the form then validation should fail.
# It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up
# broken clients more easily (eg submitting content with a misnamed field)"""
-# content = {'readonly': 'read only'}
+# content = {'read_only': 'read only'}
# self.assertRaises(ImmediateResponse, self.validator.validate_request, content, None)
# def test_validate_does_not_require_blankable_fields_on_model_forms(self):
# """Test standard ModelForm validation behaviour - fields with blank=True are not required."""
-# content = {'qwerty': 'example', 'readonly': 'read only'}
+# content = {'qwerty': 'example', 'read_only': 'read only'}
# self.validator.validate_request(content, None)
# def test_model_form_validator_uses_model_forms(self):
diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py
index 566c277d..8fe64248 100644
--- a/rest_framework/throttling.py
+++ b/rest_framework/throttling.py
@@ -16,13 +16,13 @@ class BaseThrottle(object):
def wait(self):
"""
- Optionally, return a recommeded number of seconds to wait before
+ Optionally, return a recommended number of seconds to wait before
the next request.
"""
return None
-class SimpleRateThottle(BaseThrottle):
+class SimpleRateThrottle(BaseThrottle):
"""
A simple cache implementation, that only requires `.get_cache_key()`
to be overridden.
@@ -60,7 +60,7 @@ class SimpleRateThottle(BaseThrottle):
Determine the string representation of the allowed request rate.
"""
if not getattr(self, 'scope', None):
- msg = ("You must set either `.scope` or `.rate` for '%s' thottle" %
+ msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
self.__class__.__name__)
raise exceptions.ConfigurationError(msg)
@@ -133,11 +133,11 @@ class SimpleRateThottle(BaseThrottle):
return remaining_duration / float(available_requests)
-class AnonRateThrottle(SimpleRateThottle):
+class AnonRateThrottle(SimpleRateThrottle):
"""
Limits the rate of API calls that may be made by a anonymous users.
- The IP address of the request will be used as the unqiue cache key.
+ The IP address of the request will be used as the unique cache key.
"""
scope = 'anon'
@@ -153,7 +153,7 @@ class AnonRateThrottle(SimpleRateThottle):
}
-class UserRateThrottle(SimpleRateThottle):
+class UserRateThrottle(SimpleRateThrottle):
"""
Limits the rate of API calls that may be made by a given user.
@@ -175,7 +175,7 @@ class UserRateThrottle(SimpleRateThottle):
}
-class ScopedRateThrottle(SimpleRateThottle):
+class ScopedRateThrottle(SimpleRateThrottle):
"""
Limits the rate of API calls by different amounts for various parts of
the API. Any view that has the `throttle_scope` property set will be
diff --git a/rest_framework/urlpatterns.py b/rest_framework/urlpatterns.py
index 386c78a2..316ccd19 100644
--- a/rest_framework/urlpatterns.py
+++ b/rest_framework/urlpatterns.py
@@ -2,26 +2,23 @@ from django.conf.urls.defaults import url
from rest_framework.settings import api_settings
-def format_suffix_patterns(urlpatterns, suffix_required=False,
- suffix_kwarg=None, allowed=None):
+def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None):
"""
Supplement existing urlpatterns with corrosponding patterns that also
include a '.format' suffix. Retains urlpattern ordering.
+ urlpatterns:
+ A list of URL patterns.
+
suffix_required:
If `True`, only suffixed URLs will be generated, and non-suffixed
URLs will not be used. Defaults to `False`.
- suffix_kwarg:
- The name of the kwarg that will be passed to the view.
- Defaults to 'format'.
-
allowed:
An optional tuple/list of allowed suffixes. eg ['json', 'api']
Defaults to `None`, which allows any suffix.
-
"""
- suffix_kwarg = suffix_kwarg or api_settings.FORMAT_SUFFIX_KWARG
+ suffix_kwarg = api_settings.FORMAT_SUFFIX_KWARG
if allowed:
if len(allowed) == 1:
allowed_pattern = allowed[0]
diff --git a/rest_framework/utils/mediatypes.py b/rest_framework/utils/mediatypes.py
index 5eba7fb2..ee7f3a54 100644
--- a/rest_framework/utils/mediatypes.py
+++ b/rest_framework/utils/mediatypes.py
@@ -25,32 +25,6 @@ def media_type_matches(lhs, rhs):
return lhs.match(rhs)
-def is_form_media_type(media_type):
- """
- Return True if the media type is a valid form media type as defined by the HTML4 spec.
- (NB. HTML5 also adds text/plain to the list of valid form media types, but we don't support this here)
- """
- media_type = _MediaType(media_type)
- return media_type.full_type == 'application/x-www-form-urlencoded' or \
- media_type.full_type == 'multipart/form-data'
-
-
-def add_media_type_param(media_type, key, val):
- """
- Add a key, value parameter to a media type string, and return the new media type string.
- """
- media_type = _MediaType(media_type)
- media_type.params[key] = val
- return str(media_type)
-
-
-def get_media_type_params(media_type):
- """
- Return a dictionary of the parameters on the given media type.
- """
- return _MediaType(media_type).params
-
-
def order_by_precedence(media_type_lst):
"""
Returns a list of sets of media type strings, ordered by precedence.
diff --git a/rest_framework/views.py b/rest_framework/views.py
index b3f36085..71e1fe6c 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -1,8 +1,5 @@
"""
-The :mod:`views` module provides the Views you will most probably
-be subclassing in your implementation.
-
-By setting or modifying class attributes on your view, you change it's predefined behaviour.
+Provides an APIView class that is used as the base of all class-based views.
"""
import re
@@ -57,12 +54,12 @@ def _camelcase_to_spaces(content):
class APIView(View):
settings = api_settings
- renderer_classes = api_settings.DEFAULT_RENDERERS
- parser_classes = api_settings.DEFAULT_PARSERS
- authentication_classes = api_settings.DEFAULT_AUTHENTICATION
- throttle_classes = api_settings.DEFAULT_THROTTLES
- permission_classes = api_settings.DEFAULT_PERMISSIONS
- content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION
+ renderer_classes = api_settings.DEFAULT_RENDERER_CLASSES
+ parser_classes = api_settings.DEFAULT_PARSER_CLASSES
+ authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES
+ throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES
+ permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES
+ content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS
@classmethod
def as_view(cls, **initkwargs):
@@ -159,18 +156,31 @@ class APIView(View):
"""
raise exceptions.Throttled(wait)
+ def get_parser_context(self, http_request):
+ """
+ Returns a dict that is passed through to Parser.parse(),
+ as the `parser_context` keyword argument.
+ """
+ # Note: Additionally `request` will also be added to the context
+ # by the Request object.
+ return {
+ 'view': self,
+ 'args': getattr(self, 'args', ()),
+ 'kwargs': getattr(self, 'kwargs', {})
+ }
+
def get_renderer_context(self):
"""
- Returns a dict that is passed through to the Renderer.render(),
+ Returns a dict that is passed through to Renderer.render(),
as the `renderer_context` keyword argument.
"""
- # Note: Additionally 'response' will also be set on the context,
+ # Note: Additionally 'response' will also be added to the context,
# by the Response object.
return {
'view': self,
- 'request': self.request,
- 'args': self.args,
- 'kwargs': self.kwargs
+ 'args': getattr(self, 'args', ()),
+ 'kwargs': getattr(self, 'kwargs', {}),
+ 'request': getattr(self, 'request', None)
}
# API policy instantiation methods
@@ -208,7 +218,7 @@ class APIView(View):
def get_throttles(self):
"""
- Instantiates and returns the list of thottles that this view uses.
+ Instantiates and returns the list of throttles that this view uses.
"""
return [throttle() for throttle in self.throttle_classes]
@@ -228,7 +238,13 @@ class APIView(View):
"""
renderers = self.get_renderers()
conneg = self.get_content_negotiator()
- return conneg.negotiate(request, renderers, self.format_kwarg, force)
+
+ try:
+ return conneg.select_renderer(request, renderers, self.format_kwarg)
+ except:
+ if force:
+ return (renderers[0], renderers[0].media_type)
+ raise
def has_permission(self, request, obj=None):
"""
@@ -253,10 +269,13 @@ class APIView(View):
"""
Returns the initial request object.
"""
+ parser_context = self.get_parser_context(request)
+
return Request(request,
parsers=self.get_parsers(),
authenticators=self.get_authenticators(),
- negotiator=self.get_content_negotiator())
+ negotiator=self.get_content_negotiator(),
+ parser_context=parser_context)
def initial(self, request, *args, **kwargs):
"""