aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/authentication.py2
-rw-r--r--rest_framework/compat.py20
-rw-r--r--rest_framework/decorators.py30
-rw-r--r--rest_framework/fields.py79
-rw-r--r--rest_framework/filters.py4
-rw-r--r--rest_framework/generics.py348
-rw-r--r--rest_framework/mixins.py34
-rw-r--r--rest_framework/negotiation.py4
-rw-r--r--rest_framework/pagination.py6
-rw-r--r--rest_framework/permissions.py22
-rw-r--r--rest_framework/relations.py195
-rw-r--r--rest_framework/renderers.py16
-rw-r--r--rest_framework/request.py5
-rw-r--r--rest_framework/response.py6
-rw-r--r--rest_framework/routers.py246
-rw-r--r--rest_framework/serializers.py145
-rw-r--r--rest_framework/tests/description.py63
-rw-r--r--rest_framework/tests/fields.py165
-rw-r--r--rest_framework/tests/filterset.py6
-rw-r--r--rest_framework/tests/relations_hyperlink.py16
-rw-r--r--rest_framework/tests/relations_nested.py24
-rw-r--r--rest_framework/tests/relations_pk.py17
-rw-r--r--rest_framework/tests/serializer.py42
-rw-r--r--rest_framework/tests/serializer_nested.py4
-rw-r--r--rest_framework/throttling.py8
-rw-r--r--rest_framework/utils/breadcrumbs.py34
-rw-r--r--rest_framework/utils/formatting.py80
-rw-r--r--rest_framework/views.py99
-rw-r--r--rest_framework/viewsets.py132
29 files changed, 1444 insertions, 408 deletions
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py
index 1eebb5b9..9caca788 100644
--- a/rest_framework/authentication.py
+++ b/rest_framework/authentication.py
@@ -1,5 +1,5 @@
"""
-Provides a set of pluggable authentication policies.
+Provides various authentication policies.
"""
from __future__ import unicode_literals
import base64
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index f8e4e7ca..cd39f544 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -88,9 +88,7 @@ else:
raise ImportError("User model is not to be found.")
-# First implementation of Django class-based views did not include head method
-# in base View class - https://code.djangoproject.com/ticket/15668
-if django.VERSION >= (1, 4):
+if django.VERSION >= (1, 5):
from django.views.generic import View
else:
from django.views.generic import View as _View
@@ -98,6 +96,8 @@ else:
from django.utils.functional import update_wrapper
class View(_View):
+ # 1.3 does not include head method in base View class
+ # See: https://code.djangoproject.com/ticket/15668
@classonlymethod
def as_view(cls, **initkwargs):
"""
@@ -127,11 +127,15 @@ else:
update_wrapper(view, cls.dispatch, assigned=())
return view
-# Taken from @markotibold's attempt at supporting PATCH.
-# https://github.com/markotibold/django-rest-framework/tree/patch
-http_method_names = set(View.http_method_names)
-http_method_names.add('patch')
-View.http_method_names = list(http_method_names) # PATCH method is not implemented by Django
+ # _allowed_methods only present from 1.5 onwards
+ def _allowed_methods(self):
+ return [m.upper() for m in self.http_method_names if hasattr(self, m)]
+
+
+# PATCH method is not implemented by Django
+if 'patch' not in View.http_method_names:
+ View.http_method_names = View.http_method_names + ['patch']
+
# PUT, DELETE do not require CSRF until 1.4. They should. Make it better.
if django.VERSION >= (1, 4):
diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py
index 8250cd3b..81e585e1 100644
--- a/rest_framework/decorators.py
+++ b/rest_framework/decorators.py
@@ -1,3 +1,11 @@
+"""
+The most imporant decorator in this module is `@api_view`, which is used
+for writing function-based views with REST framework.
+
+There are also various decorators for setting the API policies on function
+based views, as well as the `@action` and `@link` decorators, which are
+used to annotate methods on viewsets that should be included by routers.
+"""
from __future__ import unicode_literals
from rest_framework.compat import six
from rest_framework.views import APIView
@@ -97,3 +105,25 @@ def permission_classes(permission_classes):
func.permission_classes = permission_classes
return func
return decorator
+
+
+def link(**kwargs):
+ """
+ Used to mark a method on a ViewSet that should be routed for GET requests.
+ """
+ def decorator(func):
+ func.bind_to_method = 'get'
+ func.kwargs = kwargs
+ return func
+ return decorator
+
+
+def action(**kwargs):
+ """
+ Used to mark a method on a ViewSet that should be routed for POST requests.
+ """
+ def decorator(func):
+ func.bind_to_method = 'post'
+ func.kwargs = kwargs
+ return func
+ return decorator
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index f3496b53..f934fc39 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -1,7 +1,13 @@
+"""
+Serializer fields perform validation on incoming data.
+
+They are very similar to Django's form fields.
+"""
from __future__ import unicode_literals
import copy
import datetime
+from decimal import Decimal, DecimalException
import inspect
import re
import warnings
@@ -194,9 +200,9 @@ class WritableField(Field):
# 'blank' is to be deprecated in favor of 'required'
if blank is not None:
- warnings.warn('The `blank` keyword argument is due to deprecated. '
+ warnings.warn('The `blank` keyword argument is deprecated. '
'Use the `required` keyword argument instead.',
- PendingDeprecationWarning, stacklevel=2)
+ DeprecationWarning, stacklevel=2)
required = not(blank)
super(WritableField, self).__init__(source=source)
@@ -721,6 +727,75 @@ class FloatField(WritableField):
raise ValidationError(msg)
+class DecimalField(WritableField):
+ type_name = 'DecimalField'
+ form_field_class = forms.DecimalField
+
+ default_error_messages = {
+ 'invalid': _('Enter a number.'),
+ 'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'),
+ 'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'),
+ 'max_digits': _('Ensure that there are no more than %s digits in total.'),
+ 'max_decimal_places': _('Ensure that there are no more than %s decimal places.'),
+ 'max_whole_digits': _('Ensure that there are no more than %s digits before the decimal point.')
+ }
+
+ def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, *args, **kwargs):
+ self.max_value, self.min_value = max_value, min_value
+ self.max_digits, self.decimal_places = max_digits, decimal_places
+ super(DecimalField, self).__init__(*args, **kwargs)
+
+ if max_value is not None:
+ self.validators.append(validators.MaxValueValidator(max_value))
+ if min_value is not None:
+ self.validators.append(validators.MinValueValidator(min_value))
+
+ def from_native(self, value):
+ """
+ Validates that the input is a decimal number. Returns a Decimal
+ instance. Returns None for empty values. Ensures that there are no more
+ than max_digits in the number, and no more than decimal_places digits
+ after the decimal point.
+ """
+ if value in validators.EMPTY_VALUES:
+ return None
+ value = smart_text(value).strip()
+ try:
+ value = Decimal(value)
+ except DecimalException:
+ raise ValidationError(self.error_messages['invalid'])
+ return value
+
+ def validate(self, value):
+ super(DecimalField, self).validate(value)
+ if value in validators.EMPTY_VALUES:
+ return
+ # Check for NaN, Inf and -Inf values. We can't compare directly for NaN,
+ # since it is never equal to itself. However, NaN is the only value that
+ # isn't equal to itself, so we can use this to identify NaN
+ if value != value or value == Decimal("Inf") or value == Decimal("-Inf"):
+ raise ValidationError(self.error_messages['invalid'])
+ sign, digittuple, exponent = value.as_tuple()
+ decimals = abs(exponent)
+ # digittuple doesn't include any leading zeros.
+ digits = len(digittuple)
+ if decimals > digits:
+ # We have leading zeros up to or past the decimal point. Count
+ # everything past the decimal point as a digit. We do not count
+ # 0 before the decimal point as a digit since that would mean
+ # we would not allow max_digits = decimal_places.
+ digits = decimals
+ whole_digits = digits - decimals
+
+ if self.max_digits is not None and digits > self.max_digits:
+ raise ValidationError(self.error_messages['max_digits'] % self.max_digits)
+ if self.decimal_places is not None and decimals > self.decimal_places:
+ raise ValidationError(self.error_messages['max_decimal_places'] % self.decimal_places)
+ if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places):
+ raise ValidationError(self.error_messages['max_whole_digits'] % (self.max_digits - self.decimal_places))
+ return value
+
+
class FileField(WritableField):
use_files = True
type_name = 'FileField'
diff --git a/rest_framework/filters.py b/rest_framework/filters.py
index 413fa0d2..5e1cdbac 100644
--- a/rest_framework/filters.py
+++ b/rest_framework/filters.py
@@ -1,3 +1,7 @@
+"""
+Provides generic filtering backends that can be used to filter the results
+returned by list views.
+"""
from __future__ import unicode_literals
from rest_framework.compat import django_filters
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index f9133c73..62129dcc 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -2,32 +2,60 @@
Generic views that provide commonly needed behaviour.
"""
from __future__ import unicode_literals
+
+from django.core.exceptions import ImproperlyConfigured
+from django.core.paginator import Paginator, InvalidPage
+from django.http import Http404
+from django.shortcuts import get_object_or_404
+from django.utils.translation import ugettext as _
from rest_framework import views, mixins
+from rest_framework.exceptions import ConfigurationError
from rest_framework.settings import api_settings
-from django.views.generic.detail import SingleObjectMixin
-from django.views.generic.list import MultipleObjectMixin
-
+import warnings
-### Base classes for the generic views ###
class GenericAPIView(views.APIView):
"""
Base class for all other generic views.
"""
- model = None
+ # You'll need to either set these attributes,
+ # or override `get_queryset()`/`get_serializer_class()`.
+ queryset = None
serializer_class = None
- model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS
+
+ # This shortcut may be used instead of setting either or both
+ # of the `queryset`/`serializer_class` attributes, although using
+ # the explicit style is generally preferred.
+ model = None
+
+ # If you want to use object lookups other than pk, set this attribute.
+ # For more complex lookup requirements override `get_object()`.
+ lookup_field = 'pk'
+
+ # Pagination settings
+ paginate_by = api_settings.PAGINATE_BY
+ paginate_by_param = api_settings.PAGINATE_BY_PARAM
+ pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS
+ page_kwarg = 'page'
+
+ # The filter backend class to use for queryset filtering
filter_backend = api_settings.FILTER_BACKEND
- def filter_queryset(self, queryset):
- """
- Given a queryset, filter it with whichever filter backend is in use.
- """
- if not self.filter_backend:
- return queryset
- backend = self.filter_backend()
- return backend.filter_queryset(self.request, queryset, self)
+ # Determines if the view will return 200 or 404 responses for empty lists.
+ allow_empty = True
+
+ # The following attributes may be subject to change,
+ # and should be considered private API.
+ model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS
+ paginator_class = Paginator
+
+ ######################################
+ # These are pending deprecation...
+
+ pk_url_kwarg = 'pk'
+ slug_url_kwarg = 'slug'
+ slug_field = 'slug'
def get_serializer_context(self):
"""
@@ -39,24 +67,6 @@ class GenericAPIView(views.APIView):
'view': self
}
- def get_serializer_class(self):
- """
- Return the class to use for the serializer.
-
- Defaults to using `self.serializer_class`, falls back to constructing a
- model serializer class using `self.model_serializer_class`, with
- `self.model` as the model.
- """
- serializer_class = self.serializer_class
-
- if serializer_class is None:
- class DefaultSerializer(self.model_serializer_class):
- class Meta:
- model = self.model
- serializer_class = DefaultSerializer
-
- return serializer_class
-
def get_serializer(self, instance=None, data=None,
files=None, many=False, partial=False):
"""
@@ -68,31 +78,7 @@ class GenericAPIView(views.APIView):
return serializer_class(instance, data=data, files=files,
many=many, partial=partial, context=context)
- def pre_save(self, obj):
- """
- Placeholder method for calling before saving an object.
- May be used eg. to set attributes on the object that are implicit
- in either the request, or the url.
- """
- pass
-
- def post_save(self, obj, created=False):
- """
- Placeholder method for calling after saving an object.
- """
- pass
-
-
-class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView):
- """
- Base class for generic views onto a queryset.
- """
-
- paginate_by = api_settings.PAGINATE_BY
- paginate_by_param = api_settings.PAGINATE_BY_PARAM
- pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS
-
- def get_pagination_serializer(self, page=None):
+ def get_pagination_serializer(self, page):
"""
Return a serializer instance to use with paginated data.
"""
@@ -104,41 +90,214 @@ class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView):
context = self.get_serializer_context()
return pagination_serializer_class(instance=page, context=context)
- def get_paginate_by(self, queryset):
+ def paginate_queryset(self, queryset, page_size=None):
+ """
+ Paginate a queryset if required, either returning a page object,
+ or `None` if pagination is not configured for this view.
+ """
+ deprecated_style = False
+ if page_size is not None:
+ warnings.warn('The `page_size` parameter to `paginate_queryset()` '
+ 'is due to be deprecated. '
+ 'Note that the return style of this method is also '
+ 'changed, and will simply return a page object '
+ 'when called without a `page_size` argument.',
+ PendingDeprecationWarning, stacklevel=2)
+ deprecated_style = True
+ else:
+ # Determine the required page size.
+ # If pagination is not configured, simply return None.
+ page_size = self.get_paginate_by()
+ if not page_size:
+ return None
+
+ paginator = self.paginator_class(queryset, page_size,
+ allow_empty_first_page=self.allow_empty)
+ page_kwarg = self.kwargs.get(self.page_kwarg)
+ page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg)
+ page = page_kwarg or page_query_param or 1
+ try:
+ page_number = int(page)
+ except ValueError:
+ if page == 'last':
+ page_number = paginator.num_pages
+ else:
+ raise Http404(_("Page is not 'last', nor can it be converted to an int."))
+ try:
+ page = paginator.page(page_number)
+ except InvalidPage as e:
+ raise Http404(_('Invalid page (%(page_number)s): %(message)s') % {
+ 'page_number': page_number,
+ 'message': str(e)
+ })
+
+ if deprecated_style:
+ return (paginator, page, page.object_list, page.has_other_pages())
+ return page
+
+ def filter_queryset(self, queryset):
+ """
+ Given a queryset, filter it with whichever filter backend is in use.
+
+ You are unlikely to want to override this method, although you may need
+ to call it either from a list view, or from a custom `get_object`
+ method if you want to apply the configured filtering backend to the
+ default queryset.
+ """
+ if not self.filter_backend:
+ return queryset
+ backend = self.filter_backend()
+ return backend.filter_queryset(self.request, queryset, self)
+
+ ########################
+ ### The following methods provide default implementations
+ ### that you may want to override for more complex cases.
+
+ def get_paginate_by(self, queryset=None):
"""
Return the size of pages to use with pagination.
+
+ If `PAGINATE_BY_PARAM` is set it will attempt to get the page size
+ from a named query parameter in the url, eg. ?page_size=100
+
+ Otherwise defaults to using `self.paginate_by`.
"""
+ if queryset is not None:
+ warnings.warn('The `queryset` parameter to `get_paginate_by()` '
+ 'is due to be deprecated.',
+ PendingDeprecationWarning, stacklevel=2)
+
if self.paginate_by_param:
query_params = self.request.QUERY_PARAMS
try:
return int(query_params[self.paginate_by_param])
except (KeyError, ValueError):
pass
+
return self.paginate_by
+ def get_serializer_class(self):
+ """
+ Return the class to use for the serializer.
+ Defaults to using `self.serializer_class`.
-class SingleObjectAPIView(SingleObjectMixin, GenericAPIView):
- """
- Base class for generic views onto a model instance.
- """
+ You may want to override this if you need to provide different
+ serializations depending on the incoming request.
- pk_url_kwarg = 'pk' # Not provided in Django 1.3
- slug_url_kwarg = 'slug' # Not provided in Django 1.3
- slug_field = 'slug'
+ (Eg. admins get full serialization, others get basic serilization)
+ """
+ serializer_class = self.serializer_class
+ if serializer_class is not None:
+ return serializer_class
+
+ assert self.model is not None, \
+ "'%s' should either include a 'serializer_class' attribute, " \
+ "or use the 'model' attribute as a shortcut for " \
+ "automatically generating a serializer class." \
+ % self.__class__.__name__
+
+ class DefaultSerializer(self.model_serializer_class):
+ class Meta:
+ model = self.model
+ return DefaultSerializer
+
+ def get_queryset(self):
+ """
+ Get the list of items for this view.
+ This must be an iterable, and may be a queryset.
+ Defaults to using `self.queryset`.
+
+ You may want to override this if you need to provide different
+ querysets depending on the incoming request.
+
+ (Eg. return a list of items that is specific to the user)
+ """
+ if self.queryset is not None:
+ return self.queryset._clone()
+
+ if self.model is not None:
+ return self.model._default_manager.all()
+
+ raise ImproperlyConfigured("'%s' must define 'queryset' or 'model'"
+ % self.__class__.__name__)
def get_object(self, queryset=None):
"""
- Override default to add support for object-level permissions.
+ Returns the object the view is displaying.
+
+ You may want to override this if you need to provide non-standard
+ queryset lookups. Eg if objects are referenced using multiple
+ keyword arguments in the url conf.
"""
- queryset = self.filter_queryset(self.get_queryset())
- obj = super(SingleObjectAPIView, self).get_object(queryset)
+ # Determine the base queryset to use.
+ if queryset is None:
+ queryset = self.filter_queryset(self.get_queryset())
+ else:
+ pass # Deprecation warning
+
+ # Perform the lookup filtering.
+ pk = self.kwargs.get(self.pk_url_kwarg, None)
+ slug = self.kwargs.get(self.slug_url_kwarg, None)
+ lookup = self.kwargs.get(self.lookup_field, None)
+
+ if lookup is not None:
+ filter_kwargs = {self.lookup_field: lookup}
+ elif pk is not None and self.lookup_field == 'pk':
+ warnings.warn(
+ 'The `pk_url_kwarg` attribute is due to be deprecated. '
+ 'Use the `lookup_field` attribute instead',
+ PendingDeprecationWarning
+ )
+ filter_kwargs = {'pk': pk}
+ elif slug is not None and self.lookup_field == 'pk':
+ warnings.warn(
+ 'The `slug_url_kwarg` attribute is due to be deprecated. '
+ 'Use the `lookup_field` attribute instead',
+ PendingDeprecationWarning
+ )
+ filter_kwargs = {self.slug_field: slug}
+ else:
+ raise ConfigurationError(
+ 'Expected view %s to be called with a URL keyword argument '
+ 'named "%s". Fix your URL conf, or set the `.lookup_field` '
+ 'attribute on the view correctly.' %
+ (self.__class__.__name__, self.lookup_field)
+ )
+
+ obj = get_object_or_404(queryset, **filter_kwargs)
+
+ # May raise a permission denied
self.check_object_permissions(self.request, obj)
+
return obj
+ ########################
+ ### The following are placeholder methods,
+ ### and are intended to be overridden.
+ ###
+ ### The are not called by GenericAPIView directly,
+ ### but are used by the mixin methods.
-### Concrete view classes that provide method handlers ###
-### by composing the mixin classes with a base view. ###
+ def pre_save(self, obj):
+ """
+ Placeholder method for calling before saving an object.
+
+ May be used to set attributes on the object that are implicit
+ in either the request, or the url.
+ """
+ pass
+ def post_save(self, obj, created=False):
+ """
+ Placeholder method for calling after saving an object.
+ """
+ pass
+
+
+##########################################################
+### Concrete view classes that provide method handlers ###
+### by composing the mixin classes with the base view. ###
+##########################################################
class CreateAPIView(mixins.CreateModelMixin,
GenericAPIView):
@@ -151,7 +310,7 @@ class CreateAPIView(mixins.CreateModelMixin,
class ListAPIView(mixins.ListModelMixin,
- MultipleObjectAPIView):
+ GenericAPIView):
"""
Concrete view for listing a queryset.
"""
@@ -160,7 +319,7 @@ class ListAPIView(mixins.ListModelMixin,
class RetrieveAPIView(mixins.RetrieveModelMixin,
- SingleObjectAPIView):
+ GenericAPIView):
"""
Concrete view for retrieving a model instance.
"""
@@ -169,7 +328,7 @@ class RetrieveAPIView(mixins.RetrieveModelMixin,
class DestroyAPIView(mixins.DestroyModelMixin,
- SingleObjectAPIView):
+ GenericAPIView):
"""
Concrete view for deleting a model instance.
@@ -179,7 +338,7 @@ class DestroyAPIView(mixins.DestroyModelMixin,
class UpdateAPIView(mixins.UpdateModelMixin,
- SingleObjectAPIView):
+ GenericAPIView):
"""
Concrete view for updating a model instance.
@@ -188,13 +347,12 @@ class UpdateAPIView(mixins.UpdateModelMixin,
return self.update(request, *args, **kwargs)
def patch(self, request, *args, **kwargs):
- kwargs['partial'] = True
- return self.update(request, *args, **kwargs)
+ return self.partial_update(request, *args, **kwargs)
class ListCreateAPIView(mixins.ListModelMixin,
mixins.CreateModelMixin,
- MultipleObjectAPIView):
+ GenericAPIView):
"""
Concrete view for listing a queryset or creating a model instance.
"""
@@ -207,7 +365,7 @@ class ListCreateAPIView(mixins.ListModelMixin,
class RetrieveUpdateAPIView(mixins.RetrieveModelMixin,
mixins.UpdateModelMixin,
- SingleObjectAPIView):
+ GenericAPIView):
"""
Concrete view for retrieving, updating a model instance.
"""
@@ -218,13 +376,12 @@ class RetrieveUpdateAPIView(mixins.RetrieveModelMixin,
return self.update(request, *args, **kwargs)
def patch(self, request, *args, **kwargs):
- kwargs['partial'] = True
- return self.update(request, *args, **kwargs)
+ return self.partial_update(request, *args, **kwargs)
class RetrieveDestroyAPIView(mixins.RetrieveModelMixin,
mixins.DestroyModelMixin,
- SingleObjectAPIView):
+ GenericAPIView):
"""
Concrete view for retrieving or deleting a model instance.
"""
@@ -238,7 +395,7 @@ class RetrieveDestroyAPIView(mixins.RetrieveModelMixin,
class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,
mixins.UpdateModelMixin,
mixins.DestroyModelMixin,
- SingleObjectAPIView):
+ GenericAPIView):
"""
Concrete view for retrieving, updating or deleting a model instance.
"""
@@ -249,8 +406,31 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,
return self.update(request, *args, **kwargs)
def patch(self, request, *args, **kwargs):
- kwargs['partial'] = True
- return self.update(request, *args, **kwargs)
+ return self.partial_update(request, *args, **kwargs)
def delete(self, request, *args, **kwargs):
return self.destroy(request, *args, **kwargs)
+
+
+##########################
+### Deprecated classes ###
+##########################
+
+class MultipleObjectAPIView(GenericAPIView):
+ def __init__(self, *args, **kwargs):
+ warnings.warn(
+ 'Subclassing `MultipleObjectAPIView` is due to be deprecated. '
+ 'You should simply subclass `GenericAPIView` instead.',
+ PendingDeprecationWarning, stacklevel=2
+ )
+ super(MultipleObjectAPIView, self).__init__(*args, **kwargs)
+
+
+class SingleObjectAPIView(GenericAPIView):
+ def __init__(self, *args, **kwargs):
+ warnings.warn(
+ 'Subclassing `SingleObjectAPIView` is due to be deprecated. '
+ 'You should simply subclass `GenericAPIView` instead.',
+ PendingDeprecationWarning, stacklevel=2
+ )
+ super(SingleObjectAPIView, self).__init__(*args, **kwargs)
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index 3bd7d6df..ae703771 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -12,7 +12,7 @@ from rest_framework.response import Response
from rest_framework.request import clone_request
-def _get_validation_exclusions(obj, pk=None, slug_field=None):
+def _get_validation_exclusions(obj, pk=None, slug_field=None, lookup_field=None):
"""
Given a model instance, and an optional pk and slug field,
return the full list of all other field names on that model.
@@ -23,14 +23,19 @@ def _get_validation_exclusions(obj, pk=None, slug_field=None):
include = []
if pk:
+ # Pending deprecation
pk_field = obj._meta.pk
while pk_field.rel:
pk_field = pk_field.rel.to._meta.pk
include.append(pk_field.name)
if slug_field:
+ # Pending deprecation
include.append(slug_field)
+ if lookup_field and lookup_field != 'pk':
+ include.append(lookup_field)
+
return [field.name for field in obj._meta.fields if field.name not in include]
@@ -67,23 +72,18 @@ class ListModelMixin(object):
empty_error = "Empty list and '%(class_name)s.allow_empty' is False."
def list(self, request, *args, **kwargs):
- queryset = self.get_queryset()
- self.object_list = self.filter_queryset(queryset)
+ self.object_list = self.filter_queryset(self.get_queryset())
# Default is to allow empty querysets. This can be altered by setting
# `.allow_empty = False`, to raise 404 errors on empty querysets.
- allow_empty = self.get_allow_empty()
- if not allow_empty and not self.object_list:
+ if not self.allow_empty and not self.object_list:
class_name = self.__class__.__name__
error_msg = self.empty_error % {'class_name': class_name}
raise Http404(error_msg)
- # Pagination size is set by the `.paginate_by` attribute,
- # which may be `None` to disable pagination.
- page_size = self.get_paginate_by(self.object_list)
- if page_size:
- packed = self.paginate_queryset(self.object_list, page_size)
- paginator, page, queryset, is_paginated = packed
+ # Switch between paginated or standard style responses
+ page = self.paginate_queryset(self.object_list)
+ if page is not None:
serializer = self.get_pagination_serializer(page)
else:
serializer = self.get_serializer(self.object_list, many=True)
@@ -135,14 +135,22 @@ class UpdateModelMixin(object):
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
+ def partial_update(self, request, *args, **kwargs):
+ kwargs['partial'] = True
+ return self.update(request, *args, **kwargs)
+
def pre_save(self, obj):
"""
Set any attributes on the object that are implicit in the request.
"""
# pk and/or slug attributes are implicit in the URL.
+ lookup = self.kwargs.get(self.lookup_field, None)
pk = self.kwargs.get(self.pk_url_kwarg, None)
slug = self.kwargs.get(self.slug_url_kwarg, None)
- slug_field = slug and self.get_slug_field() or None
+ slug_field = slug and self.slug_field or None
+
+ if lookup:
+ setattr(obj, self.lookup_field, lookup)
if pk:
setattr(obj, 'pk', pk)
@@ -153,7 +161,7 @@ class UpdateModelMixin(object):
# Ensure we clean the attributes so that we don't eg return integer
# pk using a string representation, as provided by the url conf kwarg.
if hasattr(obj, 'full_clean'):
- exclude = _get_validation_exclusions(obj, pk, slug_field)
+ exclude = _get_validation_exclusions(obj, pk, slug_field, self.lookup_field)
obj.full_clean(exclude)
diff --git a/rest_framework/negotiation.py b/rest_framework/negotiation.py
index 0694d35f..4d205c0e 100644
--- a/rest_framework/negotiation.py
+++ b/rest_framework/negotiation.py
@@ -1,3 +1,7 @@
+"""
+Content negotiation deals with selecting an appropriate renderer given the
+incoming request. Typically this will be based on the request's Accept header.
+"""
from __future__ import unicode_literals
from django.http import Http404
from rest_framework import exceptions
diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py
index 03a7a30f..d51ea929 100644
--- a/rest_framework/pagination.py
+++ b/rest_framework/pagination.py
@@ -1,9 +1,11 @@
+"""
+Pagination serializers determine the structure of the output that should
+be used for paginated responses.
+"""
from __future__ import unicode_literals
from rest_framework import serializers
from rest_framework.templatetags.rest_framework import replace_query_param
-# TODO: Support URLconf kwarg-style paging
-
class NextPageField(serializers.Field):
"""
diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py
index ae895f39..751f31a7 100644
--- a/rest_framework/permissions.py
+++ b/rest_framework/permissions.py
@@ -25,10 +25,12 @@ class BasePermission(object):
"""
Return `True` if permission is granted, `False` otherwise.
"""
- if len(inspect.getargspec(self.has_permission)[0]) == 4:
- warnings.warn('The `obj` argument in `has_permission` is due to be deprecated. '
- 'Use `has_object_permission()` instead for object permissions.',
- PendingDeprecationWarning, stacklevel=2)
+ if len(inspect.getargspec(self.has_permission).args) == 4:
+ warnings.warn(
+ 'The `obj` argument in `has_permission` is deprecated. '
+ 'Use `has_object_permission()` instead for object permissions.',
+ DeprecationWarning, stacklevel=2
+ )
return self.has_permission(request, view, obj)
return True
@@ -87,8 +89,8 @@ class DjangoModelPermissions(BasePermission):
It ensures that the user is authenticated, and has the appropriate
`add`/`change`/`delete` permissions on the model.
- This permission will only be applied against view classes that
- provide a `.model` attribute, such as the generic class-based views.
+ This permission can only be applied against view classes that
+ provide a `.model` or `.queryset` attribute.
"""
# Map methods into required permission codes.
@@ -136,6 +138,14 @@ class DjangoModelPermissions(BasePermission):
return False
+class DjangoModelPermissionsOrAnonReadOnly(DjangoModelPermissions):
+ """
+ Similar to DjangoModelPermissions, except that anonymous users are
+ allowed read-only access.
+ """
+ authenticated_users_only = False
+
+
class TokenHasReadWriteScope(BasePermission):
"""
The request is authenticated as a user and the token used has the right scope
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index 2a10e9af..fc5054b2 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -1,3 +1,9 @@
+"""
+Serializer fields that deal with relationships.
+
+These fields allow you to specify the style that should be used to represent
+model relationships, including hyperlinks, primary keys, or slugs.
+"""
from __future__ import unicode_literals
from django.core.exceptions import ObjectDoesNotExist, ValidationError
from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch
@@ -36,9 +42,9 @@ class RelatedField(WritableField):
# 'null' is to be deprecated in favor of 'required'
if 'null' in kwargs:
- warnings.warn('The `null` keyword argument is due to be deprecated. '
+ warnings.warn('The `null` keyword argument is deprecated. '
'Use the `required` keyword argument instead.',
- PendingDeprecationWarning, stacklevel=2)
+ DeprecationWarning, stacklevel=2)
kwargs['required'] = not kwargs.pop('null')
self.queryset = kwargs.pop('queryset', None)
@@ -282,10 +288,8 @@ class HyperlinkedRelatedField(RelatedField):
"""
Represents a relationship using hyperlinking.
"""
- pk_url_kwarg = 'pk'
- slug_field = 'slug'
- slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
read_only = False
+ lookup_field = 'pk'
default_error_messages = {
'no_match': _('Invalid hyperlink - No URL match'),
@@ -295,69 +299,138 @@ class HyperlinkedRelatedField(RelatedField):
'incorrect_type': _('Incorrect type. Expected url string, received %s.'),
}
+ # These are all pending deprecation
+ pk_url_kwarg = 'pk'
+ slug_field = 'slug'
+ slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
+
def __init__(self, *args, **kwargs):
try:
self.view_name = kwargs.pop('view_name')
except KeyError:
raise ValueError("Hyperlinked field requires 'view_name' kwarg")
+ self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
+ self.format = kwargs.pop('format', None)
+
+ # These are pending deprecation
+ if 'pk_url_kwarg' in kwargs:
+ msg = 'pk_url_kwarg is pending deprecation. Use lookup_field instead.'
+ warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+ if 'slug_url_kwarg' in kwargs:
+ msg = 'slug_url_kwarg is pending deprecation. Use lookup_field instead.'
+ warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+ if 'slug_field' in kwargs:
+ msg = 'slug_field is pending deprecation. Use lookup_field instead.'
+ warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+
+ self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg)
self.slug_field = kwargs.pop('slug_field', self.slug_field)
default_slug_kwarg = self.slug_url_kwarg or self.slug_field
- self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg)
self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg)
- self.format = kwargs.pop('format', None)
super(HyperlinkedRelatedField, self).__init__(*args, **kwargs)
- def get_slug_field(self):
+ def get_url(self, obj, view_name, request, format):
"""
- Get the name of a slug field to be used to look up by slug.
- """
- return self.slug_field
+ Given an object, return the URL that hyperlinks to the object.
- def to_native(self, obj):
- view_name = self.view_name
- request = self.context.get('request', None)
- format = self.format or self.context.get('format', None)
-
- if request is None:
- warnings.warn("Using `HyperlinkedRelatedField` without including the "
- "request in the serializer context is due to be deprecated. "
- "Add `context={'request': request}` when instantiating the serializer.",
- PendingDeprecationWarning, stacklevel=4)
-
- pk = getattr(obj, 'pk', None)
- if pk is None:
- return
- kwargs = {self.pk_url_kwarg: pk}
+ May raise a `NoReverseMatch` if the `view_name` and `lookup_field`
+ attributes are not configured to correctly match the URL conf.
+ """
+ lookup_field = getattr(obj, self.lookup_field)
+ kwargs = {self.lookup_field: lookup_field}
try:
return reverse(view_name, kwargs=kwargs, request=request, format=format)
except NoReverseMatch:
pass
+ if self.pk_url_kwarg != 'pk':
+ # Only try pk if it has been explicitly set.
+ # Otherwise, the default `lookup_field = 'pk'` has us covered.
+ pk = obj.pk
+ kwargs = {self.pk_url_kwarg: pk}
+ try:
+ return reverse(view_name, kwargs=kwargs, request=request, format=format)
+ except NoReverseMatch:
+ pass
+
slug = getattr(obj, self.slug_field, None)
+ if slug is not None:
+ # Only try slug if it corresponds to an attribute on the object.
+ kwargs = {self.slug_url_kwarg: slug}
+ try:
+ ret = reverse(view_name, kwargs=kwargs, request=request, format=format)
+ if self.slug_field == 'slug' and self.slug_url_kwarg == 'slug':
+ # If the lookup succeeds using the default slug params,
+ # then `slug_field` is being used implicitly, and we
+ # we need to warn about the pending deprecation.
+ msg = 'Implicit slug field hyperlinked fields are pending deprecation.' \
+ 'You should set `lookup_field=slug` on the HyperlinkedRelatedField.'
+ warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+ return ret
+ except NoReverseMatch:
+ pass
+
+ raise NoReverseMatch()
+
+ def get_object(self, queryset, view_name, view_args, view_kwargs):
+ """
+ Return the object corresponding to a matched URL.
- if not slug:
- raise Exception('Could not resolve URL for field using view name "%s"' % view_name)
+ Takes the matched URL conf arguments, and the queryset, and should
+ return an object instance, or raise an `ObjectDoesNotExist` exception.
+ """
+ lookup = view_kwargs.get(self.lookup_field, None)
+ pk = view_kwargs.get(self.pk_url_kwarg, None)
+ slug = view_kwargs.get(self.slug_url_kwarg, None)
+
+ if lookup is not None:
+ filter_kwargs = {self.lookup_field: lookup}
+ elif pk is not None:
+ filter_kwargs = {'pk': pk}
+ elif slug is not None:
+ filter_kwargs = {self.slug_field: slug}
+ else:
+ raise ObjectDoesNotExist()
- kwargs = {self.slug_url_kwarg: slug}
- try:
- return reverse(view_name, kwargs=kwargs, request=request, format=format)
- except NoReverseMatch:
- pass
+ return queryset.get(**filter_kwargs)
- kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug}
+ def to_native(self, obj):
+ view_name = self.view_name
+ request = self.context.get('request', None)
+ format = self.format or self.context.get('format', None)
+
+ if request is None:
+ msg = (
+ "Using `HyperlinkedRelatedField` without including the request "
+ "in the serializer context is deprecated. "
+ "Add `context={'request': request}` when instantiating "
+ "the serializer."
+ )
+ warnings.warn(msg, DeprecationWarning, stacklevel=4)
+
+ # If the object has not yet been saved then we cannot hyperlink to it.
+ if getattr(obj, 'pk', None) is None:
+ return
+
+ # Return the hyperlink, or error if incorrectly configured.
try:
- return reverse(view_name, kwargs=kwargs, request=request, format=format)
+ return self.get_url(obj, view_name, request, format)
except NoReverseMatch:
- pass
-
- raise Exception('Could not resolve URL for field using view name "%s"' % view_name)
+ msg = (
+ 'Could not resolve URL for hyperlinked relationship using '
+ 'view name "%s". You may have failed to include the related '
+ 'model in your API, or incorrectly configured the '
+ '`lookup_field` attribute on this field.'
+ )
+ raise Exception(msg % view_name)
def from_native(self, value):
# Convert URL -> model instance pk
# TODO: Use values_list
- if self.queryset is None:
+ queryset = self.queryset
+ if queryset is None:
raise Exception('Writable related fields must include a `queryset` argument')
try:
@@ -381,29 +454,11 @@ class HyperlinkedRelatedField(RelatedField):
if match.view_name != self.view_name:
raise ValidationError(self.error_messages['incorrect_match'])
- pk = match.kwargs.get(self.pk_url_kwarg, None)
- slug = match.kwargs.get(self.slug_url_kwarg, None)
-
- # Try explicit primary key.
- if pk is not None:
- queryset = self.queryset.filter(pk=pk)
- # Next, try looking up by slug.
- elif slug is not None:
- slug_field = self.get_slug_field()
- queryset = self.queryset.filter(**{slug_field: slug})
- # If none of those are defined, it's probably a configuation error.
- else:
- raise ValidationError(self.error_messages['configuration_error'])
-
try:
- obj = queryset.get()
- except ObjectDoesNotExist:
+ return self.get_object(queryset, match.view_name,
+ match.args, match.kwargs)
+ except (ObjectDoesNotExist, TypeError, ValueError):
raise ValidationError(self.error_messages['does_not_exist'])
- except (TypeError, ValueError):
- msg = self.error_messages['incorrect_type']
- raise ValidationError(msg % type(value).__name__)
-
- return obj
class HyperlinkedIdentityField(Field):
@@ -437,9 +492,9 @@ class HyperlinkedIdentityField(Field):
if request is None:
warnings.warn("Using `HyperlinkedIdentityField` without including the "
- "request in the serializer context is due to be deprecated. "
+ "request in the serializer context is deprecated. "
"Add `context={'request': request}` when instantiating the serializer.",
- PendingDeprecationWarning, stacklevel=4)
+ DeprecationWarning, stacklevel=4)
# By default use whatever format is given for the current context
# unless the target is a different type to the source.
@@ -482,35 +537,35 @@ class HyperlinkedIdentityField(Field):
class ManyRelatedField(RelatedField):
def __init__(self, *args, **kwargs):
- warnings.warn('`ManyRelatedField()` is due to be deprecated. '
+ warnings.warn('`ManyRelatedField()` is deprecated. '
'Use `RelatedField(many=True)` instead.',
- PendingDeprecationWarning, stacklevel=2)
+ DeprecationWarning, stacklevel=2)
kwargs['many'] = True
super(ManyRelatedField, self).__init__(*args, **kwargs)
class ManyPrimaryKeyRelatedField(PrimaryKeyRelatedField):
def __init__(self, *args, **kwargs):
- warnings.warn('`ManyPrimaryKeyRelatedField()` is due to be deprecated. '
+ warnings.warn('`ManyPrimaryKeyRelatedField()` is deprecated. '
'Use `PrimaryKeyRelatedField(many=True)` instead.',
- PendingDeprecationWarning, stacklevel=2)
+ DeprecationWarning, stacklevel=2)
kwargs['many'] = True
super(ManyPrimaryKeyRelatedField, self).__init__(*args, **kwargs)
class ManySlugRelatedField(SlugRelatedField):
def __init__(self, *args, **kwargs):
- warnings.warn('`ManySlugRelatedField()` is due to be deprecated. '
+ warnings.warn('`ManySlugRelatedField()` is deprecated. '
'Use `SlugRelatedField(many=True)` instead.',
- PendingDeprecationWarning, stacklevel=2)
+ DeprecationWarning, stacklevel=2)
kwargs['many'] = True
super(ManySlugRelatedField, self).__init__(*args, **kwargs)
class ManyHyperlinkedRelatedField(HyperlinkedRelatedField):
def __init__(self, *args, **kwargs):
- warnings.warn('`ManyHyperlinkedRelatedField()` is due to be deprecated. '
+ warnings.warn('`ManyHyperlinkedRelatedField()` is deprecated. '
'Use `HyperlinkedRelatedField(many=True)` instead.',
- PendingDeprecationWarning, stacklevel=2)
+ DeprecationWarning, stacklevel=2)
kwargs['many'] = True
super(ManyHyperlinkedRelatedField, self).__init__(*args, **kwargs)
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index 4c15e0db..c457ec73 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -24,6 +24,7 @@ from rest_framework.settings import api_settings
from rest_framework.request import clone_request
from rest_framework.utils import encoders
from rest_framework.utils.breadcrumbs import get_breadcrumbs
+from rest_framework.utils.formatting import get_view_name, get_view_description
from rest_framework import exceptions, parsers, status, VERSION
@@ -438,16 +439,13 @@ class BrowsableAPIRenderer(BaseRenderer):
return GenericContentForm()
def get_name(self, view):
- try:
- return view.get_name()
- except AttributeError:
- return smart_text(view.__class__.__name__)
+ return get_view_name(view.__class__, getattr(view, 'suffix', None))
def get_description(self, view):
- try:
- return view.get_description(html=True)
- except AttributeError:
- return smart_text(view.__doc__ or '')
+ return get_view_description(view.__class__, html=True)
+
+ def get_breadcrumbs(self, request):
+ return get_breadcrumbs(request.path)
def render(self, data, accepted_media_type=None, renderer_context=None):
"""
@@ -480,7 +478,7 @@ class BrowsableAPIRenderer(BaseRenderer):
name = self.get_name(view)
description = self.get_description(view)
- breadcrumb_list = get_breadcrumbs(request.path)
+ breadcrumb_list = self.get_breadcrumbs(request)
template = loader.get_template(self.template)
context = RequestContext(request, {
diff --git a/rest_framework/request.py b/rest_framework/request.py
index ffbbab33..a434659c 100644
--- a/rest_framework/request.py
+++ b/rest_framework/request.py
@@ -1,11 +1,10 @@
"""
-The :mod:`request` module provides a :class:`Request` class used to wrap the standard `request`
-object received in all the views.
+The Request class is used as a wrapper around the standard request object.
The wrapped request then offers a richer API, in particular :
- content automatically parsed according to `Content-Type` header,
- and available as :meth:`.DATA<Request.DATA>`
+ and available as `request.DATA`
- full support of PUT method, including support for file uploads
- form overloading of HTTP method, content type and content
"""
diff --git a/rest_framework/response.py b/rest_framework/response.py
index 5e1bf46e..26e4ab37 100644
--- a/rest_framework/response.py
+++ b/rest_framework/response.py
@@ -1,3 +1,9 @@
+"""
+The Response class in REST framework is similiar to HTTPResponse, except that
+it is initialized with unrendered data, instead of a pre-rendered string.
+
+The appropriate renderer is called during Django's template response rendering.
+"""
from __future__ import unicode_literals
from django.core.handlers.wsgi import STATUS_CODE_TEXT
from django.template.response import SimpleTemplateResponse
diff --git a/rest_framework/routers.py b/rest_framework/routers.py
new file mode 100644
index 00000000..0707635a
--- /dev/null
+++ b/rest_framework/routers.py
@@ -0,0 +1,246 @@
+"""
+Routers provide a convenient and consistent way of automatically
+determining the URL conf for your API.
+
+They are used by simply instantiating a Router class, and then registering
+all the required ViewSets with that router.
+
+For example, you might have a `urls.py` that looks something like this:
+
+ router = routers.DefaultRouter()
+ router.register('users', UserViewSet, 'user')
+ router.register('accounts', AccountViewSet, 'account')
+
+ urlpatterns = router.urls
+"""
+from __future__ import unicode_literals
+
+from collections import namedtuple
+from django.conf.urls import url, patterns
+from rest_framework.decorators import api_view
+from rest_framework.response import Response
+from rest_framework.reverse import reverse
+from rest_framework.urlpatterns import format_suffix_patterns
+
+
+Route = namedtuple('Route', ['url', 'mapping', 'name', 'initkwargs'])
+
+
+def replace_methodname(format_string, methodname):
+ """
+ Partially format a format_string, swapping out any
+ '{methodname}' or '{methodnamehyphen}' components.
+ """
+ methodnamehyphen = methodname.replace('_', '-')
+ ret = format_string
+ ret = ret.replace('{methodname}', methodname)
+ ret = ret.replace('{methodnamehyphen}', methodnamehyphen)
+ return ret
+
+
+class BaseRouter(object):
+ def __init__(self):
+ self.registry = []
+
+ def register(self, prefix, viewset, base_name=None):
+ if base_name is None:
+ base_name = self.get_default_base_name(viewset)
+ self.registry.append((prefix, viewset, base_name))
+
+ def get_default_base_name(self, viewset):
+ """
+ If `base_name` is not specified, attempt to automatically determine
+ it from the viewset.
+ """
+ raise NotImplemented('get_default_base_name must be overridden')
+
+ def get_urls(self):
+ """
+ Return a list of URL patterns, given the registered viewsets.
+ """
+ raise NotImplemented('get_urls must be overridden')
+
+ @property
+ def urls(self):
+ if not hasattr(self, '_urls'):
+ self._urls = patterns('', *self.get_urls())
+ return self._urls
+
+
+class SimpleRouter(BaseRouter):
+ routes = [
+ # List route.
+ Route(
+ url=r'^{prefix}/$',
+ mapping={
+ 'get': 'list',
+ 'post': 'create'
+ },
+ name='{basename}-list',
+ initkwargs={'suffix': 'List'}
+ ),
+ # Detail route.
+ Route(
+ url=r'^{prefix}/{lookup}/$',
+ mapping={
+ 'get': 'retrieve',
+ 'put': 'update',
+ 'patch': 'partial_update',
+ 'delete': 'destroy'
+ },
+ name='{basename}-detail',
+ initkwargs={'suffix': 'Instance'}
+ ),
+ # Dynamically generated routes.
+ # Generated using @action or @link decorators on methods of the viewset.
+ Route(
+ url=r'^{prefix}/{lookup}/{methodname}/$',
+ mapping={
+ '{httpmethod}': '{methodname}',
+ },
+ name='{basename}-{methodnamehyphen}',
+ initkwargs={}
+ ),
+ ]
+
+ def get_default_base_name(self, viewset):
+ """
+ If `base_name` is not specified, attempt to automatically determine
+ it from the viewset.
+ """
+ model_cls = getattr(viewset, 'model', None)
+ queryset = getattr(viewset, 'queryset', None)
+ if model_cls is None and queryset is not None:
+ model_cls = queryset.model
+
+ assert model_cls, '`name` not argument not specified, and could ' \
+ 'not automatically determine the name from the viewset, as ' \
+ 'it does not have a `.model` or `.queryset` attribute.'
+
+ return model_cls._meta.object_name.lower()
+
+ def get_routes(self, viewset):
+ """
+ Augment `self.routes` with any dynamically generated routes.
+
+ Returns a list of the Route namedtuple.
+ """
+
+ # Determine any `@action` or `@link` decorated methods on the viewset
+ dynamic_routes = {}
+ for methodname in dir(viewset):
+ attr = getattr(viewset, methodname)
+ httpmethod = getattr(attr, 'bind_to_method', None)
+ if httpmethod:
+ dynamic_routes[httpmethod] = methodname
+
+ ret = []
+ for route in self.routes:
+ if route.mapping == {'{httpmethod}': '{methodname}'}:
+ # Dynamic routes (@link or @action decorator)
+ for httpmethod, methodname in dynamic_routes.items():
+ initkwargs = route.initkwargs.copy()
+ initkwargs.update(getattr(viewset, methodname).kwargs)
+ ret.append(Route(
+ url=replace_methodname(route.url, methodname),
+ mapping={httpmethod: methodname},
+ name=replace_methodname(route.name, methodname),
+ initkwargs=initkwargs,
+ ))
+ else:
+ # Standard route
+ ret.append(route)
+
+ return ret
+
+ def get_method_map(self, viewset, method_map):
+ """
+ Given a viewset, and a mapping of http methods to actions,
+ return a new mapping which only includes any mappings that
+ are actually implemented by the viewset.
+ """
+ bound_methods = {}
+ for method, action in method_map.items():
+ if hasattr(viewset, action):
+ bound_methods[method] = action
+ return bound_methods
+
+ def get_lookup_regex(self, viewset):
+ """
+ Given a viewset, return the portion of URL regex that is used
+ to match against a single instance.
+ """
+ base_regex = '(?P<{lookup_field}>[^/]+)'
+ lookup_field = getattr(viewset, 'lookup_field', 'pk')
+ return base_regex.format(lookup_field=lookup_field)
+
+ def get_urls(self):
+ """
+ Use the registered viewsets to generate a list of URL patterns.
+ """
+ ret = []
+
+ for prefix, viewset, basename in self.registry:
+ lookup = self.get_lookup_regex(viewset)
+ routes = self.get_routes(viewset)
+
+ for route in routes:
+
+ # Only actions which actually exist on the viewset will be bound
+ mapping = self.get_method_map(viewset, route.mapping)
+ if not mapping:
+ continue
+
+ # Build the url pattern
+ regex = route.url.format(prefix=prefix, lookup=lookup)
+ view = viewset.as_view(mapping, **route.initkwargs)
+ name = route.name.format(basename=basename)
+ ret.append(url(regex, view, name=name))
+
+ return ret
+
+
+class DefaultRouter(SimpleRouter):
+ """
+ The default router extends the SimpleRouter, but also adds in a default
+ API root view, and adds format suffix patterns to the URLs.
+ """
+ include_root_view = True
+ include_format_suffixes = True
+
+ def get_api_root_view(self):
+ """
+ Return a view to use as the API root.
+ """
+ api_root_dict = {}
+ list_name = self.routes[0].name
+ for prefix, viewset, basename in self.registry:
+ api_root_dict[prefix] = list_name.format(basename=basename)
+
+ @api_view(('GET',))
+ def api_root(request, format=None):
+ ret = {}
+ for key, url_name in api_root_dict.items():
+ ret[key] = reverse(url_name, request=request, format=format)
+ return Response(ret)
+
+ return api_root
+
+ def get_urls(self):
+ """
+ Generate the list of URL patterns, including a default root view
+ for the API, and appending `.json` style format suffixes.
+ """
+ urls = []
+
+ if self.include_root_view:
+ root_url = url(r'^$', self.get_api_root_view(), name='api-root')
+ urls.append(root_url)
+
+ default_urls = super(DefaultRouter, self).get_urls()
+ urls.extend(default_urls)
+
+ if self.include_format_suffixes:
+ urls = format_suffix_patterns(urls)
+
+ return urls
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index add46566..ea5175e2 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -1,3 +1,15 @@
+"""
+Serializers and ModelSerializers are similar to Forms and ModelForms.
+Unlike forms, they are not constrained to dealing with HTML output, and
+form encoded input.
+
+Serialization in REST framework is a two-phase process:
+
+1. Serializers marshal between complex types like model instances, and
+python primatives.
+2. The process of marshalling between python primatives and request and
+response content is handled by parsers and renderers.
+"""
from __future__ import unicode_literals
import copy
import datetime
@@ -412,9 +424,9 @@ class BaseSerializer(WritableField):
else:
many = hasattr(data, '__iter__') and not isinstance(data, (Page, dict, six.text_type))
if many:
- warnings.warn('Implict list/queryset serialization is due to be deprecated. '
+ warnings.warn('Implict list/queryset serialization is deprecated. '
'Use the `many=True` flag when instantiating the serializer.',
- PendingDeprecationWarning, stacklevel=3)
+ DeprecationWarning, stacklevel=3)
if many:
ret = []
@@ -474,9 +486,9 @@ class BaseSerializer(WritableField):
else:
many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict))
if many:
- warnings.warn('Implict list/queryset serialization is due to be deprecated. '
+ warnings.warn('Implict list/queryset serialization is deprecated. '
'Use the `many=True` flag when instantiating the serializer.',
- PendingDeprecationWarning, stacklevel=2)
+ DeprecationWarning, stacklevel=2)
if many:
self._data = [self.to_native(item) for item in obj]
@@ -536,6 +548,7 @@ class ModelSerializer(Serializer):
models.DateTimeField: DateTimeField,
models.DateField: DateField,
models.TimeField: TimeField,
+ models.DecimalField: DecimalField,
models.EmailField: EmailField,
models.CharField: CharField,
models.URLField: URLField,
@@ -556,36 +569,85 @@ class ModelSerializer(Serializer):
assert cls is not None, \
"Serializer class '%s' is missing 'model' Meta option" % self.__class__.__name__
opts = get_concrete_model(cls)._meta
- pk_field = opts.pk
+ ret = SortedDict()
+ nested = bool(self.opts.depth)
- # If model is a child via multitable inheritance, use parent's pk
+ # Deal with adding the primary key field
+ pk_field = opts.pk
while pk_field.rel and pk_field.rel.parent_link:
+ # If model is a child via multitable inheritance, use parent's pk
pk_field = pk_field.rel.to._meta.pk
- fields = [pk_field]
- fields += [field for field in opts.fields if field.serialize]
- fields += [field for field in opts.many_to_many if field.serialize]
+ field = self.get_pk_field(pk_field)
+ if field:
+ ret[pk_field.name] = field
- ret = SortedDict()
- nested = bool(self.opts.depth)
- is_pk = True # First field in the list is the pk
-
- for model_field in fields:
- if is_pk:
- field = self.get_pk_field(model_field)
- is_pk = False
- elif model_field.rel and nested:
- field = self.get_nested_field(model_field)
- elif model_field.rel:
+ # Deal with forward relationships
+ forward_rels = [field for field in opts.fields if field.serialize]
+ forward_rels += [field for field in opts.many_to_many if field.serialize]
+
+ for model_field in forward_rels:
+ if model_field.rel:
to_many = isinstance(model_field,
models.fields.related.ManyToManyField)
- field = self.get_related_field(model_field, to_many=to_many)
+ related_model = model_field.rel.to
+
+ if model_field.rel and nested:
+ if len(inspect.getargspec(self.get_nested_field).args) == 2:
+ warnings.warn(
+ 'The `get_nested_field(model_field)` call signature '
+ 'is due to be deprecated. '
+ 'Use `get_nested_field(model_field, related_model, '
+ 'to_many) instead',
+ PendingDeprecationWarning
+ )
+ field = self.get_nested_field(model_field)
+ else:
+ field = self.get_nested_field(model_field, related_model, to_many)
+ elif model_field.rel:
+ if len(inspect.getargspec(self.get_nested_field).args) == 3:
+ warnings.warn(
+ 'The `get_related_field(model_field, to_many)` call '
+ 'signature is due to be deprecated. '
+ 'Use `get_related_field(model_field, related_model, '
+ 'to_many) instead',
+ PendingDeprecationWarning
+ )
+ field = self.get_related_field(model_field, to_many=to_many)
+ else:
+ field = self.get_related_field(model_field, related_model, to_many)
else:
field = self.get_field(model_field)
if field:
ret[model_field.name] = field
+ # Deal with reverse relationships
+ if not self.opts.fields:
+ reverse_rels = []
+ else:
+ # Reverse relationships are only included if they are explicitly
+ # present in the `fields` option on the serializer
+ reverse_rels = opts.get_all_related_objects()
+ reverse_rels += opts.get_all_related_many_to_many_objects()
+
+ for relation in reverse_rels:
+ accessor_name = relation.get_accessor_name()
+ if not self.opts.fields or accessor_name not in self.opts.fields:
+ continue
+ related_model = relation.model
+ to_many = relation.field.rel.multiple
+
+ if nested:
+ field = self.get_nested_field(None, related_model, to_many)
+ else:
+ field = self.get_related_field(None, related_model, to_many)
+
+ if field:
+ ret[accessor_name] = field
+
+ # Add the `read_only` flag to any fields that have bee specified
+ # in the `read_only_fields` option
for field_name in self.opts.read_only_fields:
assert field_name in ret, \
"read_only_fields on '%s' included invalid item '%s'" % \
@@ -600,29 +662,36 @@ class ModelSerializer(Serializer):
"""
return self.get_field(model_field)
- def get_nested_field(self, model_field):
+ def get_nested_field(self, model_field, related_model, to_many):
"""
Creates a default instance of a nested relational field.
+
+ Note that model_field will be `None` for reverse relationships.
"""
class NestedModelSerializer(ModelSerializer):
class Meta:
- model = model_field.rel.to
+ model = related_model
depth = self.opts.depth - 1
- return NestedModelSerializer()
+ return NestedModelSerializer(many=to_many)
- def get_related_field(self, model_field, to_many=False):
+ def get_related_field(self, model_field, related_model, to_many):
"""
Creates a default instance of a flat relational field.
+
+ Note that model_field will be `None` for reverse relationships.
"""
# TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to)
+
kwargs = {
- 'required': not(model_field.null or model_field.blank),
- 'queryset': model_field.rel.to._default_manager,
+ 'queryset': related_model._default_manager,
'many': to_many
}
+ if model_field:
+ kwargs['required'] = not(model_field.null or model_field.blank)
+
return PrimaryKeyRelatedField(**kwargs)
def get_field(self, model_field):
@@ -758,6 +827,7 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
def __init__(self, meta):
super(HyperlinkedModelSerializerOptions, self).__init__(meta)
self.view_name = getattr(meta, 'view_name', None)
+ self.lookup_field = getattr(meta, 'slug_field', None)
class HyperlinkedModelSerializer(ModelSerializer):
@@ -767,6 +837,7 @@ class HyperlinkedModelSerializer(ModelSerializer):
"""
_options_class = HyperlinkedModelSerializerOptions
_default_view_name = '%(model_name)s-detail'
+ _hyperlink_field_class = HyperlinkedRelatedField
url = HyperlinkedIdentityField()
@@ -787,22 +858,28 @@ class HyperlinkedModelSerializer(ModelSerializer):
return self._default_view_name % format_kwargs
def get_pk_field(self, model_field):
- return None
+ if self.opts.fields and model_field.name in self.opts.fields:
+ return self.get_field(model_field)
- def get_related_field(self, model_field, to_many):
+ def get_related_field(self, model_field, related_model, to_many):
"""
Creates a default instance of a flat relational field.
"""
# TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to)
- rel = model_field.rel.to
kwargs = {
- 'required': not(model_field.null or model_field.blank),
- 'queryset': rel._default_manager,
- 'view_name': self._get_default_view_name(rel),
+ 'queryset': related_model._default_manager,
+ 'view_name': self._get_default_view_name(related_model),
'many': to_many
}
- return HyperlinkedRelatedField(**kwargs)
+
+ if model_field:
+ kwargs['required'] = not(model_field.null or model_field.blank)
+
+ if self.opts.lookup_field:
+ kwargs['lookup_field'] = self.opts.lookup_field
+
+ return self._hyperlink_field_class(**kwargs)
def get_identity(self, data):
"""
diff --git a/rest_framework/tests/description.py b/rest_framework/tests/description.py
index 5b3315bc..52c1a34c 100644
--- a/rest_framework/tests/description.py
+++ b/rest_framework/tests/description.py
@@ -4,6 +4,7 @@ from __future__ import unicode_literals
from django.test import TestCase
from rest_framework.views import APIView
from rest_framework.compat import apply_markdown
+from rest_framework.utils.formatting import get_view_name, get_view_description
# We check that docstrings get nicely un-indented.
DESCRIPTION = """an example docstring
@@ -49,22 +50,16 @@ MARKED_DOWN_gte_21 = """<h2 id="an-example-docstring">an example docstring</h2>
class TestViewNamesAndDescriptions(TestCase):
- def test_resource_name_uses_classname_by_default(self):
- """Ensure Resource names are based on the classname by default."""
+ def test_view_name_uses_class_name(self):
+ """
+ Ensure view names are based on the class name.
+ """
class MockView(APIView):
pass
- self.assertEqual(MockView().get_name(), 'Mock')
+ self.assertEqual(get_view_name(MockView), 'Mock')
- def test_resource_name_can_be_set_explicitly(self):
- """Ensure Resource names can be set using the 'get_name' method."""
- example = 'Some Other Name'
- class MockView(APIView):
- def get_name(self):
- return example
- self.assertEqual(MockView().get_name(), example)
-
- def test_resource_description_uses_docstring_by_default(self):
- """Ensure Resource names are based on the docstring by default."""
+ def test_view_description_uses_docstring(self):
+ """Ensure view descriptions are based on the docstring."""
class MockView(APIView):
"""an example docstring
====================
@@ -81,44 +76,32 @@ class TestViewNamesAndDescriptions(TestCase):
# hash style header #"""
- self.assertEqual(MockView().get_description(), DESCRIPTION)
-
- def test_resource_description_can_be_set_explicitly(self):
- """Ensure Resource descriptions can be set using the 'get_description' method."""
- example = 'Some other description'
-
- class MockView(APIView):
- """docstring"""
- def get_description(self):
- return example
- self.assertEqual(MockView().get_description(), example)
+ self.assertEqual(get_view_description(MockView), DESCRIPTION)
- def test_resource_description_supports_unicode(self):
+ def test_view_description_supports_unicode(self):
+ """
+ Unicode in docstrings should be respected.
+ """
class MockView(APIView):
"""Проверка"""
pass
- self.assertEqual(MockView().get_description(), "Проверка")
-
-
- def test_resource_description_does_not_require_docstring(self):
- """Ensure that empty docstrings do not affect the Resource's description if it has been set using the 'get_description' method."""
- example = 'Some other description'
-
- class MockView(APIView):
- def get_description(self):
- return example
- self.assertEqual(MockView().get_description(), example)
+ self.assertEqual(get_view_description(MockView), "Проверка")
- def test_resource_description_can_be_empty(self):
- """Ensure that if a resource has no doctring or 'description' class attribute, then it's description is the empty string."""
+ def test_view_description_can_be_empty(self):
+ """
+ Ensure that if a view has no docstring,
+ then it's description is the empty string.
+ """
class MockView(APIView):
pass
- self.assertEqual(MockView().get_description(), '')
+ self.assertEqual(get_view_description(MockView), '')
def test_markdown(self):
- """Ensure markdown to HTML works as expected"""
+ """
+ Ensure markdown to HTML works as expected.
+ """
if apply_markdown:
gte_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_gte_21
lt_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_lt_21
diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py
index 19c663d8..3cdfa0f6 100644
--- a/rest_framework/tests/fields.py
+++ b/rest_framework/tests/fields.py
@@ -3,12 +3,14 @@ General serializer field tests.
"""
from __future__ import unicode_literals
import datetime
+from decimal import Decimal
from django.db import models
from django.test import TestCase
from django.core import validators
from rest_framework import serializers
+from rest_framework.serializers import Serializer
class TimestampedModel(models.Model):
@@ -481,3 +483,166 @@ class TimeFieldTest(TestCase):
self.assertEqual('04 - 00 [000000]', result_1)
self.assertEqual('04 - 59 [000000]', result_2)
self.assertEqual('04 - 59 [000200]', result_3)
+
+
+class DecimalFieldTest(TestCase):
+ """
+ Tests for the DecimalField from_native() and to_native() behavior
+ """
+
+ def test_from_native_string(self):
+ """
+ Make sure from_native() accepts string values
+ """
+ f = serializers.DecimalField()
+ result_1 = f.from_native('9000')
+ result_2 = f.from_native('1.00000001')
+
+ self.assertEqual(Decimal('9000'), result_1)
+ self.assertEqual(Decimal('1.00000001'), result_2)
+
+ def test_from_native_invalid_string(self):
+ """
+ Make sure from_native() raises ValidationError on passing invalid string
+ """
+ f = serializers.DecimalField()
+
+ try:
+ f.from_native('123.45.6')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Enter a number."])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_integer(self):
+ """
+ Make sure from_native() accepts integer values
+ """
+ f = serializers.DecimalField()
+ result = f.from_native(9000)
+
+ self.assertEqual(Decimal('9000'), result)
+
+ def test_from_native_float(self):
+ """
+ Make sure from_native() accepts float values
+ """
+ f = serializers.DecimalField()
+ result = f.from_native(1.00000001)
+
+ self.assertEqual(Decimal('1.00000001'), result)
+
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns None on empty param.
+ """
+ f = serializers.DecimalField()
+ result = f.from_native('')
+
+ self.assertEqual(result, None)
+
+ def test_from_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DecimalField()
+ result = f.from_native(None)
+
+ self.assertEqual(result, None)
+
+ def test_to_native(self):
+ """
+ Make sure to_native() returns Decimal as string.
+ """
+ f = serializers.DecimalField()
+
+ result_1 = f.to_native(Decimal('9000'))
+ result_2 = f.to_native(Decimal('1.00000001'))
+
+ self.assertEqual(Decimal('9000'), result_1)
+ self.assertEqual(Decimal('1.00000001'), result_2)
+
+ def test_to_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DecimalField(required=False)
+ self.assertEqual(None, f.to_native(None))
+
+ def test_valid_serialization(self):
+ """
+ Make sure the serializer works correctly
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(max_value=9010,
+ min_value=9000,
+ max_digits=6,
+ decimal_places=2)
+
+ self.assertTrue(DecimalSerializer(data={'decimal_field': '9001'}).is_valid())
+ self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.2'}).is_valid())
+ self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.23'}).is_valid())
+
+ self.assertFalse(DecimalSerializer(data={'decimal_field': '8000'}).is_valid())
+ self.assertFalse(DecimalSerializer(data={'decimal_field': '9900'}).is_valid())
+ self.assertFalse(DecimalSerializer(data={'decimal_field': '9001.234'}).is_valid())
+
+ def test_raise_max_value(self):
+ """
+ Make sure max_value violations raises ValidationError
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(max_value=100)
+
+ s = DecimalSerializer(data={'decimal_field': '123'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is less than or equal to 100.']})
+
+ def test_raise_min_value(self):
+ """
+ Make sure min_value violations raises ValidationError
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(min_value=100)
+
+ s = DecimalSerializer(data={'decimal_field': '99'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is greater than or equal to 100.']})
+
+ def test_raise_max_digits(self):
+ """
+ Make sure max_digits violations raises ValidationError
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(max_digits=5)
+
+ s = DecimalSerializer(data={'decimal_field': '123.456'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 5 digits in total.']})
+
+ def test_raise_max_decimal_places(self):
+ """
+ Make sure max_decimal_places violations raises ValidationError
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(decimal_places=3)
+
+ s = DecimalSerializer(data={'decimal_field': '123.4567'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 3 decimal places.']})
+
+ def test_raise_max_whole_digits(self):
+ """
+ Make sure max_whole_digits violations raises ValidationError
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(max_digits=4, decimal_places=3)
+
+ s = DecimalSerializer(data={'decimal_field': '12345.6'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']}) \ No newline at end of file
diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py
index 1a71558c..1e53a5cd 100644
--- a/rest_framework/tests/filterset.py
+++ b/rest_framework/tests/filterset.py
@@ -61,7 +61,7 @@ if django_filters:
class CommonFilteringTestCase(TestCase):
def _serialize_object(self, obj):
return {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
-
+
def setUp(self):
"""
Create 10 FilterableItem instances.
@@ -190,7 +190,7 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase):
Integration tests for filtered detail views.
"""
urls = 'rest_framework.tests.filterset'
-
+
def _get_url(self, item):
return reverse('detail-view', kwargs=dict(pk=item.pk))
@@ -221,7 +221,7 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase):
response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal))
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, low_item_data)
-
+
# Tests that multiple filters works.
search_decimal = Decimal('5.25')
search_date = datetime.date(2012, 10, 2)
diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/relations_hyperlink.py
index b5702a48..b1eed9a7 100644
--- a/rest_framework/tests/relations_hyperlink.py
+++ b/rest_framework/tests/relations_hyperlink.py
@@ -26,42 +26,44 @@ urlpatterns = patterns('',
)
+# ManyToMany
class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer):
- sources = serializers.HyperlinkedRelatedField(many=True, view_name='manytomanysource-detail')
-
class Meta:
model = ManyToManyTarget
+ fields = ('url', 'name', 'sources')
class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = ManyToManySource
+ fields = ('url', 'name', 'targets')
+# ForeignKey
class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer):
- sources = serializers.HyperlinkedRelatedField(many=True, view_name='foreignkeysource-detail')
-
class Meta:
model = ForeignKeyTarget
+ fields = ('url', 'name', 'sources')
class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = ForeignKeySource
+ fields = ('url', 'name', 'target')
# Nullable ForeignKey
class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = NullableForeignKeySource
+ fields = ('url', 'name', 'target')
-# OneToOne
+# Nullable OneToOne
class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer):
- nullable_source = serializers.HyperlinkedRelatedField(view_name='nullableonetoonesource-detail')
-
class Meta:
model = OneToOneTarget
+ fields = ('url', 'name', 'nullable_source')
# TODO: Add test that .data cannot be accessed prior to .is_valid
diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py
index a125ba65..f6d006b3 100644
--- a/rest_framework/tests/relations_nested.py
+++ b/rest_framework/tests/relations_nested.py
@@ -6,38 +6,30 @@ from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, Null
class ForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
- depth = 1
- model = ForeignKeySource
-
-
-class FlatForeignKeySourceSerializer(serializers.ModelSerializer):
- class Meta:
model = ForeignKeySource
+ fields = ('id', 'name', 'target')
+ depth = 1
class ForeignKeyTargetSerializer(serializers.ModelSerializer):
- sources = FlatForeignKeySourceSerializer(many=True)
-
class Meta:
model = ForeignKeyTarget
+ fields = ('id', 'name', 'sources')
+ depth = 1
class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
- depth = 1
model = NullableForeignKeySource
-
-
-class NullableOneToOneSourceSerializer(serializers.ModelSerializer):
- class Meta:
- model = NullableOneToOneSource
+ fields = ('id', 'name', 'target')
+ depth = 1
class NullableOneToOneTargetSerializer(serializers.ModelSerializer):
- nullable_source = NullableOneToOneSourceSerializer()
-
class Meta:
model = OneToOneTarget
+ fields = ('id', 'name', 'nullable_source')
+ depth = 1
class ReverseForeignKeyTests(TestCase):
diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/relations_pk.py
index f08e1808..5ce8b567 100644
--- a/rest_framework/tests/relations_pk.py
+++ b/rest_framework/tests/relations_pk.py
@@ -5,41 +5,44 @@ from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, Fore
from rest_framework.compat import six
+# ManyToMany
class ManyToManyTargetSerializer(serializers.ModelSerializer):
- sources = serializers.PrimaryKeyRelatedField(many=True)
-
class Meta:
model = ManyToManyTarget
+ fields = ('id', 'name', 'sources')
class ManyToManySourceSerializer(serializers.ModelSerializer):
class Meta:
model = ManyToManySource
+ fields = ('id', 'name', 'targets')
+# ForeignKey
class ForeignKeyTargetSerializer(serializers.ModelSerializer):
- sources = serializers.PrimaryKeyRelatedField(many=True)
-
class Meta:
model = ForeignKeyTarget
+ fields = ('id', 'name', 'sources')
class ForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
model = ForeignKeySource
+ fields = ('id', 'name', 'target')
+# Nullable ForeignKey
class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
model = NullableForeignKeySource
+ fields = ('id', 'name', 'target')
-# OneToOne
+# Nullable OneToOne
class NullableOneToOneTargetSerializer(serializers.ModelSerializer):
- nullable_source = serializers.PrimaryKeyRelatedField()
-
class Meta:
model = OneToOneTarget
+ fields = ('id', 'name', 'nullable_source')
# TODO: Add test that .data cannot be accessed prior to .is_valid
diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py
index bd874253..84e1ee4e 100644
--- a/rest_framework/tests/serializer.py
+++ b/rest_framework/tests/serializer.py
@@ -357,7 +357,6 @@ class CustomValidationTests(TestCase):
def validate_email(self, attrs, source):
value = attrs[source]
-
return attrs
def validate_content(self, attrs, source):
@@ -738,6 +737,43 @@ class ManyRelatedTests(TestCase):
self.assertEqual(serializer.data, expected)
+ def test_include_reverse_relations(self):
+ post = BlogPost.objects.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I hate this blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ class BlogPostSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlogPost
+ fields = ('id', 'title', 'blogpostcomment_set')
+
+ serializer = BlogPostSerializer(instance=post)
+ expected = {
+ 'id': 1, 'title': 'Test blog post', 'blogpostcomment_set': [1, 2]
+ }
+ self.assertEqual(serializer.data, expected)
+
+ def test_depth_include_reverse_relations(self):
+ post = BlogPost.objects.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I hate this blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ class BlogPostSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlogPost
+ fields = ('id', 'title', 'blogpostcomment_set')
+ depth = 1
+
+ serializer = BlogPostSerializer(instance=post)
+ expected = {
+ 'id': 1, 'title': 'Test blog post',
+ 'blogpostcomment_set': [
+ {'id': 1, 'text': 'I hate this blog post', 'blog_post': 1},
+ {'id': 2, 'text': 'I love this blog post', 'blog_post': 1}
+ ]
+ }
+ self.assertEqual(serializer.data, expected)
+
def test_callable_source(self):
post = BlogPost.objects.create(title="Test blog post")
post.blogpostcomment_set.create(text="I love this blog post")
@@ -1073,7 +1109,7 @@ class DeserializeListTestCase(TestCase):
def test_no_errors(self):
data = [self.data.copy() for x in range(0, 3)]
- serializer = CommentSerializer(data=data)
+ serializer = CommentSerializer(data=data, many=True)
self.assertTrue(serializer.is_valid())
self.assertTrue(isinstance(serializer.object, list))
self.assertTrue(
@@ -1085,7 +1121,7 @@ class DeserializeListTestCase(TestCase):
invalid_item['email'] = ''
data = [self.data.copy(), invalid_item, self.data.copy()]
- serializer = CommentSerializer(data=data)
+ serializer = CommentSerializer(data=data, many=True)
self.assertFalse(serializer.is_valid())
expected = [{}, {'email': ['This field is required.']}, {}]
self.assertEqual(serializer.errors, expected)
diff --git a/rest_framework/tests/serializer_nested.py b/rest_framework/tests/serializer_nested.py
index 6a29c652..71d0e24b 100644
--- a/rest_framework/tests/serializer_nested.py
+++ b/rest_framework/tests/serializer_nested.py
@@ -109,7 +109,7 @@ class WritableNestedSerializerBasicTests(TestCase):
}
]
- serializer = self.AlbumSerializer(data=data)
+ serializer = self.AlbumSerializer(data=data, many=True)
self.assertEqual(serializer.is_valid(), False)
self.assertEqual(serializer.errors, expected_errors)
@@ -241,6 +241,6 @@ class WritableNestedSerializerObjectTests(TestCase):
)
]
- serializer = self.AlbumSerializer(data=data)
+ serializer = self.AlbumSerializer(data=data, many=True)
self.assertEqual(serializer.is_valid(), True)
self.assertEqual(serializer.object, expected_object)
diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py
index 810cad63..93ea9816 100644
--- a/rest_framework/throttling.py
+++ b/rest_framework/throttling.py
@@ -1,3 +1,6 @@
+"""
+Provides various throttling policies.
+"""
from __future__ import unicode_literals
from django.core.cache import cache
from rest_framework import exceptions
@@ -28,9 +31,8 @@ class SimpleRateThrottle(BaseThrottle):
A simple cache implementation, that only requires `.get_cache_key()`
to be overridden.
- The rate (requests / seconds) is set by a :attr:`throttle` attribute
- on the :class:`.View` class. The attribute is a string of the form 'number of
- requests/period'.
+ The rate (requests / seconds) is set by a `throttle` attribute on the View
+ class. The attribute is a string of the form 'number_of_requests/period'.
Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')
diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py
index af21ac79..28801d09 100644
--- a/rest_framework/utils/breadcrumbs.py
+++ b/rest_framework/utils/breadcrumbs.py
@@ -1,26 +1,36 @@
from __future__ import unicode_literals
from django.core.urlresolvers import resolve, get_script_prefix
+from rest_framework.utils.formatting import get_view_name
def get_breadcrumbs(url):
- """Given a url returns a list of breadcrumbs, which are each a tuple of (name, url)."""
+ """
+ Given a url returns a list of breadcrumbs, which are each a
+ tuple of (name, url).
+ """
from rest_framework.views import APIView
def breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen):
- """Add tuples of (name, url) to the breadcrumbs list, progressively chomping off parts of the url."""
+ """
+ Add tuples of (name, url) to the breadcrumbs list,
+ progressively chomping off parts of the url.
+ """
try:
(view, unused_args, unused_kwargs) = resolve(url)
except Exception:
pass
else:
- # Check if this is a REST framework view, and if so add it to the breadcrumbs
- if isinstance(getattr(view, 'cls_instance', None), APIView):
+ # Check if this is a REST framework view,
+ # and if so add it to the breadcrumbs
+ if issubclass(getattr(view, 'cls', None), APIView):
# Don't list the same view twice in a row.
# Probably an optional trailing slash.
if not seen or seen[-1] != view:
- breadcrumbs_list.insert(0, (view.cls_instance.get_name(), prefix + url))
+ suffix = getattr(view, 'suffix', None)
+ name = get_view_name(view.cls, suffix)
+ breadcrumbs_list.insert(0, (name, prefix + url))
seen.append(view)
if url == '':
@@ -28,11 +38,15 @@ def get_breadcrumbs(url):
return breadcrumbs_list
elif url.endswith('/'):
- # Drop trailing slash off the end and continue to try to resolve more breadcrumbs
- return breadcrumbs_recursive(url.rstrip('/'), breadcrumbs_list, prefix, seen)
-
- # Drop trailing non-slash off the end and continue to try to resolve more breadcrumbs
- return breadcrumbs_recursive(url[:url.rfind('/') + 1], breadcrumbs_list, prefix, seen)
+ # Drop trailing slash off the end and continue to try to
+ # resolve more breadcrumbs
+ url = url.rstrip('/')
+ return breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen)
+
+ # Drop trailing non-slash off the end and continue to try to
+ # resolve more breadcrumbs
+ url = url[:url.rfind('/') + 1]
+ return breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen)
prefix = get_script_prefix().rstrip('/')
url = url[len(prefix):]
diff --git a/rest_framework/utils/formatting.py b/rest_framework/utils/formatting.py
new file mode 100644
index 00000000..ebadb3a6
--- /dev/null
+++ b/rest_framework/utils/formatting.py
@@ -0,0 +1,80 @@
+"""
+Utility functions to return a formatted name and description for a given view.
+"""
+from __future__ import unicode_literals
+
+from django.utils.html import escape
+from django.utils.safestring import mark_safe
+from rest_framework.compat import apply_markdown
+import re
+
+
+def _remove_trailing_string(content, trailing):
+ """
+ Strip trailing component `trailing` from `content` if it exists.
+ Used when generating names from view classes.
+ """
+ if content.endswith(trailing) and content != trailing:
+ return content[:-len(trailing)]
+ return content
+
+
+def _remove_leading_indent(content):
+ """
+ Remove leading indent from a block of text.
+ Used when generating descriptions from docstrings.
+ """
+ whitespace_counts = [len(line) - len(line.lstrip(' '))
+ for line in content.splitlines()[1:] if line.lstrip()]
+
+ # unindent the content if needed
+ if whitespace_counts:
+ whitespace_pattern = '^' + (' ' * min(whitespace_counts))
+ content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content)
+ content = content.strip('\n')
+ return content
+
+
+def _camelcase_to_spaces(content):
+ """
+ Translate 'CamelCaseNames' to 'Camel Case Names'.
+ Used when generating names from view classes.
+ """
+ camelcase_boundry = '(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))'
+ content = re.sub(camelcase_boundry, ' \\1', content).strip()
+ return ' '.join(content.split('_')).title()
+
+
+def get_view_name(cls, suffix=None):
+ """
+ Return a formatted name for an `APIView` class or `@api_view` function.
+ """
+ name = cls.__name__
+ name = _remove_trailing_string(name, 'View')
+ name = _remove_trailing_string(name, 'ViewSet')
+ name = _camelcase_to_spaces(name)
+ if suffix:
+ name += ' ' + suffix
+ return name
+
+
+def get_view_description(cls, html=False):
+ """
+ Return a description for an `APIView` class or `@api_view` function.
+ """
+ description = cls.__doc__ or ''
+ description = _remove_leading_indent(description)
+ if html:
+ return markup_description(description)
+ return description
+
+
+def markup_description(description):
+ """
+ Apply HTML markup to the given description.
+ """
+ if apply_markdown:
+ description = apply_markdown(description)
+ else:
+ description = escape(description).replace('\n', '<br />')
+ return mark_safe(description)
diff --git a/rest_framework/views.py b/rest_framework/views.py
index 7c97607b..555fa2f4 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -1,54 +1,16 @@
"""
-Provides an APIView class that is used as the base of all class-based views.
+Provides an APIView class that is the base of all views in REST framework.
"""
from __future__ import unicode_literals
from django.core.exceptions import PermissionDenied
from django.http import Http404, HttpResponse
-from django.utils.html import escape
-from django.utils.safestring import mark_safe
from django.views.decorators.csrf import csrf_exempt
from rest_framework import status, exceptions
-from rest_framework.compat import View, apply_markdown
+from rest_framework.compat import View
from rest_framework.response import Response
from rest_framework.request import Request
from rest_framework.settings import api_settings
-import re
-
-
-def _remove_trailing_string(content, trailing):
- """
- Strip trailing component `trailing` from `content` if it exists.
- Used when generating names from view classes.
- """
- if content.endswith(trailing) and content != trailing:
- return content[:-len(trailing)]
- return content
-
-
-def _remove_leading_indent(content):
- """
- Remove leading indent from a block of text.
- Used when generating descriptions from docstrings.
- """
- whitespace_counts = [len(line) - len(line.lstrip(' '))
- for line in content.splitlines()[1:] if line.lstrip()]
-
- # unindent the content if needed
- if whitespace_counts:
- whitespace_pattern = '^' + (' ' * min(whitespace_counts))
- content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content)
- content = content.strip('\n')
- return content
-
-
-def _camelcase_to_spaces(content):
- """
- Translate 'CamelCaseNames' to 'Camel Case Names'.
- Used when generating names from view classes.
- """
- camelcase_boundry = '(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))'
- content = re.sub(camelcase_boundry, ' \\1', content).strip()
- return ' '.join(content.split('_')).title()
+from rest_framework.utils.formatting import get_view_name, get_view_description
class APIView(View):
@@ -64,22 +26,21 @@ class APIView(View):
@classmethod
def as_view(cls, **initkwargs):
"""
- Override the default :meth:`as_view` to store an instance of the view
- as an attribute on the callable function. This allows us to discover
- information about the view when we do URL reverse lookups.
+ Store the original class on the view function.
+
+ This allows us to discover information about the view when we do URL
+ reverse lookups. Used for breadcrumb generation.
"""
- # TODO: deprecate?
view = super(APIView, cls).as_view(**initkwargs)
- view.cls_instance = cls(**initkwargs)
+ view.cls = cls
return view
@property
def allowed_methods(self):
"""
- Return the list of allowed HTTP methods, uppercased.
+ Wrap Django's private `_allowed_methods` interface in a public property.
"""
- return [method.upper() for method in self.http_method_names
- if hasattr(self, method)]
+ return self._allowed_methods()
@property
def default_response_headers(self):
@@ -90,43 +51,10 @@ class APIView(View):
'Vary': 'Accept'
}
- def get_name(self):
- """
- Return the resource or view class name for use as this view's name.
- Override to customize.
- """
- # TODO: deprecate?
- name = self.__class__.__name__
- name = _remove_trailing_string(name, 'View')
- return _camelcase_to_spaces(name)
-
- def get_description(self, html=False):
- """
- Return the resource or view docstring for use as this view's description.
- Override to customize.
- """
- # TODO: deprecate?
- description = self.__doc__ or ''
- description = _remove_leading_indent(description)
- if html:
- return self.markup_description(description)
- return description
-
- def markup_description(self, description):
- """
- Apply HTML markup to the description of this view.
- """
- # TODO: deprecate?
- if apply_markdown:
- description = apply_markdown(description)
- else:
- description = escape(description).replace('\n', '<br />')
- return mark_safe(description)
-
def metadata(self, request):
return {
- 'name': self.get_name(),
- 'description': self.get_description(),
+ 'name': get_view_name(self.__class__),
+ 'description': get_view_description(self.__class__),
'renders': [renderer.media_type for renderer in self.renderer_classes],
'parses': [parser.media_type for parser in self.parser_classes],
}
@@ -140,7 +68,8 @@ class APIView(View):
def http_method_not_allowed(self, request, *args, **kwargs):
"""
- Called if `request.method` does not correspond to a handler method.
+ If `request.method` does not correspond to a handler method,
+ determine what kind of exception to raise.
"""
raise exceptions.MethodNotAllowed(request.method)
diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py
new file mode 100644
index 00000000..0eb3e86d
--- /dev/null
+++ b/rest_framework/viewsets.py
@@ -0,0 +1,132 @@
+"""
+ViewSets are essentially just a type of class based view, that doesn't provide
+any method handlers, such as `get()`, `post()`, etc... but instead has actions,
+such as `list()`, `retrieve()`, `create()`, etc...
+
+Actions are only bound to methods at the point of instantiating the views.
+
+ user_list = UserViewSet.as_view({'get': 'list'})
+ user_detail = UserViewSet.as_view({'get': 'retrieve'})
+
+Typically, rather than instantiate views from viewsets directly, you'll
+regsiter the viewset with a router and let the URL conf be determined
+automatically.
+
+ router = DefaultRouter()
+ router.register(r'users', UserViewSet, 'user')
+ urlpatterns = router.urls
+"""
+from __future__ import unicode_literals
+
+from functools import update_wrapper
+from django.utils.decorators import classonlymethod
+from rest_framework import views, generics, mixins
+
+
+class ViewSetMixin(object):
+ """
+ This is the magic.
+
+ Overrides `.as_view()` so that it takes an `actions` keyword that performs
+ the binding of HTTP methods to actions on the Resource.
+
+ For example, to create a concrete view binding the 'GET' and 'POST' methods
+ to the 'list' and 'create' actions...
+
+ view = MyViewSet.as_view({'get': 'list', 'post': 'create'})
+ """
+
+ @classonlymethod
+ def as_view(cls, actions=None, **initkwargs):
+ """
+ Because of the way class based views create a closure around the
+ instantiated view, we need to totally reimplement `.as_view`,
+ and slightly modify the view function that is created and returned.
+ """
+ # The suffix initkwarg is reserved for identifing the viewset type
+ # eg. 'List' or 'Instance'.
+ cls.suffix = None
+
+ # sanitize keyword arguments
+ for key in initkwargs:
+ if key in cls.http_method_names:
+ raise TypeError("You tried to pass in the %s method name as a "
+ "keyword argument to %s(). Don't do that."
+ % (key, cls.__name__))
+ if not hasattr(cls, key):
+ raise TypeError("%s() received an invalid keyword %r" % (
+ cls.__name__, key))
+
+ def view(request, *args, **kwargs):
+ self = cls(**initkwargs)
+ # We also store the mapping of request methods to actions,
+ # so that we can later set the action attribute.
+ # eg. `self.action = 'list'` on an incoming GET request.
+ self.action_map = actions
+
+ # Bind methods to actions
+ # This is the bit that's different to a standard view
+ for method, action in actions.items():
+ handler = getattr(self, action)
+ setattr(self, method, handler)
+
+ # Patch this in as it's otherwise only present from 1.5 onwards
+ if hasattr(self, 'get') and not hasattr(self, 'head'):
+ self.head = self.get
+
+ # And continue as usual
+ return self.dispatch(request, *args, **kwargs)
+
+ # take name and docstring from class
+ update_wrapper(view, cls, updated=())
+
+ # and possible attributes set by decorators
+ # like csrf_exempt from dispatch
+ update_wrapper(view, cls.dispatch, assigned=())
+
+ # We need to set these on the view function, so that breadcrumb
+ # generation can pick out these bits of information from a
+ # resolved URL.
+ view.cls = cls
+ view.suffix = initkwargs.get('suffix', None)
+ return view
+
+ def initialize_request(self, request, *args, **kargs):
+ """
+ Set the `.action` attribute on the view,
+ depending on the request method.
+ """
+ request = super(ViewSetMixin, self).initialize_request(request, *args, **kargs)
+ self.action = self.action_map.get(request.method.lower())
+ return request
+
+
+class ViewSet(ViewSetMixin, views.APIView):
+ """
+ The base ViewSet class does not provide any actions by default.
+ """
+ pass
+
+
+class ReadOnlyModelViewSet(mixins.RetrieveModelMixin,
+ mixins.ListModelMixin,
+ ViewSetMixin,
+ generics.GenericAPIView):
+ """
+ A viewset that provides default `list()` and `retrieve()` actions.
+ """
+ pass
+
+
+class ModelViewSet(mixins.CreateModelMixin,
+ mixins.RetrieveModelMixin,
+ mixins.UpdateModelMixin,
+ mixins.DestroyModelMixin,
+ mixins.ListModelMixin,
+ ViewSetMixin,
+ generics.GenericAPIView):
+ """
+ A viewset that provides default `create()`, `retrieve()`, `update()`,
+ `partial_update()`, `destroy()` and `list()` actions.
+ """
+ pass