diff options
Diffstat (limited to 'rest_framework')
30 files changed, 1499 insertions, 411 deletions
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index 1eebb5b9..9caca788 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -1,5 +1,5 @@  """ -Provides a set of pluggable authentication policies. +Provides various authentication policies.  """  from __future__ import unicode_literals  import base64 diff --git a/rest_framework/compat.py b/rest_framework/compat.py index f8e4e7ca..cd39f544 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -88,9 +88,7 @@ else:          raise ImportError("User model is not to be found.") -# First implementation of Django class-based views did not include head method -# in base View class - https://code.djangoproject.com/ticket/15668 -if django.VERSION >= (1, 4): +if django.VERSION >= (1, 5):      from django.views.generic import View  else:      from django.views.generic import View as _View @@ -98,6 +96,8 @@ else:      from django.utils.functional import update_wrapper      class View(_View): +        # 1.3 does not include head method in base View class +        # See: https://code.djangoproject.com/ticket/15668          @classonlymethod          def as_view(cls, **initkwargs):              """ @@ -127,11 +127,15 @@ else:              update_wrapper(view, cls.dispatch, assigned=())              return view -# Taken from @markotibold's attempt at supporting PATCH. -# https://github.com/markotibold/django-rest-framework/tree/patch -http_method_names = set(View.http_method_names) -http_method_names.add('patch') -View.http_method_names = list(http_method_names)  # PATCH method is not implemented by Django +        # _allowed_methods only present from 1.5 onwards +        def _allowed_methods(self): +            return [m.upper() for m in self.http_method_names if hasattr(self, m)] + + +# PATCH method is not implemented by Django +if 'patch' not in View.http_method_names: +    View.http_method_names = View.http_method_names + ['patch'] +  # PUT, DELETE do not require CSRF until 1.4.  They should.  Make it better.  if django.VERSION >= (1, 4): diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index 8250cd3b..81e585e1 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -1,3 +1,11 @@ +""" +The most imporant decorator in this module is `@api_view`, which is used +for writing function-based views with REST framework. + +There are also various decorators for setting the API policies on function +based views, as well as the `@action` and `@link` decorators, which are +used to annotate methods on viewsets that should be included by routers. +"""  from __future__ import unicode_literals  from rest_framework.compat import six  from rest_framework.views import APIView @@ -97,3 +105,25 @@ def permission_classes(permission_classes):          func.permission_classes = permission_classes          return func      return decorator + + +def link(**kwargs): +    """ +    Used to mark a method on a ViewSet that should be routed for GET requests. +    """ +    def decorator(func): +        func.bind_to_method = 'get' +        func.kwargs = kwargs +        return func +    return decorator + + +def action(**kwargs): +    """ +    Used to mark a method on a ViewSet that should be routed for POST requests. +    """ +    def decorator(func): +        func.bind_to_method = 'post' +        func.kwargs = kwargs +        return func +    return decorator diff --git a/rest_framework/fields.py b/rest_framework/fields.py index f3496b53..f934fc39 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1,7 +1,13 @@ +""" +Serializer fields perform validation on incoming data. + +They are very similar to Django's form fields. +"""  from __future__ import unicode_literals  import copy  import datetime +from decimal import Decimal, DecimalException  import inspect  import re  import warnings @@ -194,9 +200,9 @@ class WritableField(Field):          # 'blank' is to be deprecated in favor of 'required'          if blank is not None: -            warnings.warn('The `blank` keyword argument is due to deprecated. ' +            warnings.warn('The `blank` keyword argument is deprecated. '                            'Use the `required` keyword argument instead.', -                          PendingDeprecationWarning, stacklevel=2) +                          DeprecationWarning, stacklevel=2)              required = not(blank)          super(WritableField, self).__init__(source=source) @@ -721,6 +727,75 @@ class FloatField(WritableField):              raise ValidationError(msg) +class DecimalField(WritableField): +    type_name = 'DecimalField' +    form_field_class = forms.DecimalField + +    default_error_messages = { +        'invalid': _('Enter a number.'), +        'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'), +        'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'), +        'max_digits': _('Ensure that there are no more than %s digits in total.'), +        'max_decimal_places': _('Ensure that there are no more than %s decimal places.'), +        'max_whole_digits': _('Ensure that there are no more than %s digits before the decimal point.') +    } + +    def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, *args, **kwargs): +        self.max_value, self.min_value = max_value, min_value +        self.max_digits, self.decimal_places = max_digits, decimal_places +        super(DecimalField, self).__init__(*args, **kwargs) + +        if max_value is not None: +            self.validators.append(validators.MaxValueValidator(max_value)) +        if min_value is not None: +            self.validators.append(validators.MinValueValidator(min_value)) + +    def from_native(self, value): +        """ +        Validates that the input is a decimal number. Returns a Decimal +        instance. Returns None for empty values. Ensures that there are no more +        than max_digits in the number, and no more than decimal_places digits +        after the decimal point. +        """ +        if value in validators.EMPTY_VALUES: +            return None +        value = smart_text(value).strip() +        try: +            value = Decimal(value) +        except DecimalException: +            raise ValidationError(self.error_messages['invalid']) +        return value + +    def validate(self, value): +        super(DecimalField, self).validate(value) +        if value in validators.EMPTY_VALUES: +            return +        # Check for NaN, Inf and -Inf values. We can't compare directly for NaN, +        # since it is never equal to itself. However, NaN is the only value that +        # isn't equal to itself, so we can use this to identify NaN +        if value != value or value == Decimal("Inf") or value == Decimal("-Inf"): +            raise ValidationError(self.error_messages['invalid']) +        sign, digittuple, exponent = value.as_tuple() +        decimals = abs(exponent) +        # digittuple doesn't include any leading zeros. +        digits = len(digittuple) +        if decimals > digits: +            # We have leading zeros up to or past the decimal point.  Count +            # everything past the decimal point as a digit.  We do not count +            # 0 before the decimal point as a digit since that would mean +            # we would not allow max_digits = decimal_places. +            digits = decimals +        whole_digits = digits - decimals + +        if self.max_digits is not None and digits > self.max_digits: +            raise ValidationError(self.error_messages['max_digits'] % self.max_digits) +        if self.decimal_places is not None and decimals > self.decimal_places: +            raise ValidationError(self.error_messages['max_decimal_places'] % self.decimal_places) +        if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places): +            raise ValidationError(self.error_messages['max_whole_digits'] % (self.max_digits - self.decimal_places)) +        return value + +  class FileField(WritableField):      use_files = True      type_name = 'FileField' diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 413fa0d2..571704dc 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -1,5 +1,12 @@ +""" +Provides generic filtering backends that can be used to filter the results +returned by list views. +"""  from __future__ import unicode_literals + +from django.db import models  from rest_framework.compat import django_filters +import operator  FilterSet = django_filters and django_filters.FilterSet or None @@ -58,3 +65,29 @@ class DjangoFilterBackend(BaseFilterBackend):              return filter_class(request.QUERY_PARAMS, queryset=queryset).qs          return queryset + + +class SearchFilter(BaseFilterBackend): +    def construct_search(self, field_name): +        if field_name.startswith('^'): +            return "%s__istartswith" % field_name[1:] +        elif field_name.startswith('='): +            return "%s__iexact" % field_name[1:] +        elif field_name.startswith('@'): +            return "%s__search" % field_name[1:] +        else: +            return "%s__icontains" % field_name + +    def filter_queryset(self, request, queryset, view): +        search_fields = getattr(view, 'search_fields', None) + +        if not search_fields: +            return None + +        orm_lookups = [self.construct_search(str(search_field)) +                       for search_field in self.search_fields] +        for bit in self.query.split(): +            or_queries = [models.Q(**{orm_lookup: bit}) +                          for orm_lookup in orm_lookups] +            queryset = queryset.filter(reduce(operator.or_, or_queries)) +        return queryset diff --git a/rest_framework/generics.py b/rest_framework/generics.py index f9133c73..05ec93d3 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -2,32 +2,59 @@  Generic views that provide commonly needed behaviour.  """  from __future__ import unicode_literals + +from django.core.exceptions import ImproperlyConfigured +from django.core.paginator import Paginator, InvalidPage +from django.http import Http404 +from django.shortcuts import get_object_or_404 +from django.utils.translation import ugettext as _  from rest_framework import views, mixins +from rest_framework.exceptions import ConfigurationError  from rest_framework.settings import api_settings -from django.views.generic.detail import SingleObjectMixin -from django.views.generic.list import MultipleObjectMixin - +import warnings -### Base classes for the generic views ###  class GenericAPIView(views.APIView):      """      Base class for all other generic views.      """ -    model = None +    # You'll need to either set these attributes, +    # or override `get_queryset()`/`get_serializer_class()`. +    queryset = None      serializer_class = None + +    # This shortcut may be used instead of setting either or both +    # of the `queryset`/`serializer_class` attributes, although using +    # the explicit style is generally preferred. +    model = None + +    # If you want to use object lookups other than pk, set this attribute. +    # For more complex lookup requirements override `get_object()`. +    lookup_field = 'pk' + +    # Pagination settings +    paginate_by = api_settings.PAGINATE_BY +    paginate_by_param = api_settings.PAGINATE_BY_PARAM +    pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS +    page_kwarg = 'page' + +    # The filter backend classes to use for queryset filtering +    filter_backends = api_settings.DEFAULT_FILTER_BACKENDS + +    # The following attributes may be subject to change, +    # and should be considered private API.      model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS -    filter_backend = api_settings.FILTER_BACKEND +    paginator_class = Paginator -    def filter_queryset(self, queryset): -        """ -        Given a queryset, filter it with whichever filter backend is in use. -        """ -        if not self.filter_backend: -            return queryset -        backend = self.filter_backend() -        return backend.filter_queryset(self.request, queryset, self) +    ###################################### +    # These are pending deprecation... + +    pk_url_kwarg = 'pk' +    slug_url_kwarg = 'slug' +    slug_field = 'slug' +    allow_empty = True +    filter_backend = api_settings.FILTER_BACKEND      def get_serializer_context(self):          """ @@ -39,24 +66,6 @@ class GenericAPIView(views.APIView):              'view': self          } -    def get_serializer_class(self): -        """ -        Return the class to use for the serializer. - -        Defaults to using `self.serializer_class`, falls back to constructing a -        model serializer class using `self.model_serializer_class`, with -        `self.model` as the model. -        """ -        serializer_class = self.serializer_class - -        if serializer_class is None: -            class DefaultSerializer(self.model_serializer_class): -                class Meta: -                    model = self.model -            serializer_class = DefaultSerializer - -        return serializer_class -      def get_serializer(self, instance=None, data=None,                         files=None, many=False, partial=False):          """ @@ -68,31 +77,7 @@ class GenericAPIView(views.APIView):          return serializer_class(instance, data=data, files=files,                                  many=many, partial=partial, context=context) -    def pre_save(self, obj): -        """ -        Placeholder method for calling before saving an object. -        May be used eg. to set attributes on the object that are implicit -        in either the request, or the url. -        """ -        pass - -    def post_save(self, obj, created=False): -        """ -        Placeholder method for calling after saving an object. -        """ -        pass - - -class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView): -    """ -    Base class for generic views onto a queryset. -    """ - -    paginate_by = api_settings.PAGINATE_BY -    paginate_by_param = api_settings.PAGINATE_BY_PARAM -    pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS - -    def get_pagination_serializer(self, page=None): +    def get_pagination_serializer(self, page):          """          Return a serializer instance to use with paginated data.          """ @@ -104,41 +89,232 @@ class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView):          context = self.get_serializer_context()          return pagination_serializer_class(instance=page, context=context) -    def get_paginate_by(self, queryset): +    def paginate_queryset(self, queryset, page_size=None): +        """ +        Paginate a queryset if required, either returning a page object, +        or `None` if pagination is not configured for this view. +        """ +        deprecated_style = False +        if page_size is not None: +            warnings.warn('The `page_size` parameter to `paginate_queryset()` ' +                          'is due to be deprecated. ' +                          'Note that the return style of this method is also ' +                          'changed, and will simply return a page object ' +                          'when called without a `page_size` argument.', +                          PendingDeprecationWarning, stacklevel=2) +            deprecated_style = True +        else: +            # Determine the required page size. +            # If pagination is not configured, simply return None. +            page_size = self.get_paginate_by() +            if not page_size: +                return None + +        if not self.allow_empty: +            warnings.warn( +                'The `allow_empty` parameter is due to be deprecated. ' +                'To use `allow_empty=False` style behavior, You should override ' +                '`get_queryset()` and explicitly raise a 404 on empty querysets.', +                PendingDeprecationWarning, stacklevel=2 +            ) + +        paginator = self.paginator_class(queryset, page_size, +                                         allow_empty_first_page=self.allow_empty) +        page_kwarg = self.kwargs.get(self.page_kwarg) +        page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg) +        page = page_kwarg or page_query_param or 1 +        try: +            page_number = int(page) +        except ValueError: +            if page == 'last': +                page_number = paginator.num_pages +            else: +                raise Http404(_("Page is not 'last', nor can it be converted to an int.")) +        try: +            page = paginator.page(page_number) +        except InvalidPage as e: +            raise Http404(_('Invalid page (%(page_number)s): %(message)s') % { +                                'page_number': page_number, +                                'message': str(e) +            }) + +        if deprecated_style: +            return (paginator, page, page.object_list, page.has_other_pages()) +        return page + +    def filter_queryset(self, queryset): +        """ +        Given a queryset, filter it with whichever filter backend is in use. + +        You are unlikely to want to override this method, although you may need +        to call it either from a list view, or from a custom `get_object` +        method if you want to apply the configured filtering backend to the +        default queryset. +        """ +        filter_backends = self.filter_backends or [] +        if not filter_backends and self.filter_backend: +            warnings.warn( +                'The `filter_backend` attribute and `FILTER_BACKEND` setting ' +                'are due to be deprecated in favor of a `filter_backends` ' +                'attribute and `DEFAULT_FILTER_BACKENDS` setting, that take ' +                'a *list* of filter backend classes.', +                PendingDeprecationWarning, stacklevel=2 +            ) +            filter_backends = [self.filter_backend] + +        for backend in filter_backends: +            queryset = backend().filter_queryset(self.request, queryset, self) +        return queryset + +    ######################## +    ### The following methods provide default implementations +    ### that you may want to override for more complex cases. + +    def get_paginate_by(self, queryset=None):          """          Return the size of pages to use with pagination. + +        If `PAGINATE_BY_PARAM` is set it will attempt to get the page size +        from a named query parameter in the url, eg. ?page_size=100 + +        Otherwise defaults to using `self.paginate_by`.          """ +        if queryset is not None: +            warnings.warn('The `queryset` parameter to `get_paginate_by()` ' +                          'is due to be deprecated.', +                          PendingDeprecationWarning, stacklevel=2) +          if self.paginate_by_param:              query_params = self.request.QUERY_PARAMS              try:                  return int(query_params[self.paginate_by_param])              except (KeyError, ValueError):                  pass +          return self.paginate_by +    def get_serializer_class(self): +        """ +        Return the class to use for the serializer. +        Defaults to using `self.serializer_class`. + +        You may want to override this if you need to provide different +        serializations depending on the incoming request. -class SingleObjectAPIView(SingleObjectMixin, GenericAPIView): -    """ -    Base class for generic views onto a model instance. -    """ +        (Eg. admins get full serialization, others get basic serilization) +        """ +        serializer_class = self.serializer_class +        if serializer_class is not None: +            return serializer_class -    pk_url_kwarg = 'pk'  # Not provided in Django 1.3 -    slug_url_kwarg = 'slug'  # Not provided in Django 1.3 -    slug_field = 'slug' +        assert self.model is not None, \ +            "'%s' should either include a 'serializer_class' attribute, " \ +            "or use the 'model' attribute as a shortcut for " \ +            "automatically generating a serializer class." \ +            % self.__class__.__name__ + +        class DefaultSerializer(self.model_serializer_class): +            class Meta: +                model = self.model +        return DefaultSerializer + +    def get_queryset(self): +        """ +        Get the list of items for this view. +        This must be an iterable, and may be a queryset. +        Defaults to using `self.queryset`. + +        You may want to override this if you need to provide different +        querysets depending on the incoming request. + +        (Eg. return a list of items that is specific to the user) +        """ +        if self.queryset is not None: +            return self.queryset._clone() + +        if self.model is not None: +            return self.model._default_manager.all() + +        raise ImproperlyConfigured("'%s' must define 'queryset' or 'model'" +                                    % self.__class__.__name__)      def get_object(self, queryset=None):          """ -        Override default to add support for object-level permissions. +        Returns the object the view is displaying. + +        You may want to override this if you need to provide non-standard +        queryset lookups.  Eg if objects are referenced using multiple +        keyword arguments in the url conf.          """ -        queryset = self.filter_queryset(self.get_queryset()) -        obj = super(SingleObjectAPIView, self).get_object(queryset) +        # Determine the base queryset to use. +        if queryset is None: +            queryset = self.filter_queryset(self.get_queryset()) +        else: +            pass  # Deprecation warning + +        # Perform the lookup filtering. +        pk = self.kwargs.get(self.pk_url_kwarg, None) +        slug = self.kwargs.get(self.slug_url_kwarg, None) +        lookup = self.kwargs.get(self.lookup_field, None) + +        if lookup is not None: +            filter_kwargs = {self.lookup_field: lookup} +        elif pk is not None and self.lookup_field == 'pk': +            warnings.warn( +                'The `pk_url_kwarg` attribute is due to be deprecated. ' +                'Use the `lookup_field` attribute instead', +                PendingDeprecationWarning +            ) +            filter_kwargs = {'pk': pk} +        elif slug is not None and self.lookup_field == 'pk': +            warnings.warn( +                'The `slug_url_kwarg` attribute is due to be deprecated. ' +                'Use the `lookup_field` attribute instead', +                PendingDeprecationWarning +            ) +            filter_kwargs = {self.slug_field: slug} +        else: +            raise ConfigurationError( +                'Expected view %s to be called with a URL keyword argument ' +                'named "%s". Fix your URL conf, or set the `.lookup_field` ' +                'attribute on the view correctly.' % +                (self.__class__.__name__, self.lookup_field) +            ) + +        obj = get_object_or_404(queryset, **filter_kwargs) + +        # May raise a permission denied          self.check_object_permissions(self.request, obj) +          return obj +    ######################## +    ### The following are placeholder methods, +    ### and are intended to be overridden. +    ### +    ### The are not called by GenericAPIView directly, +    ### but are used by the mixin methods. + +    def pre_save(self, obj): +        """ +        Placeholder method for calling before saving an object. + +        May be used to set attributes on the object that are implicit +        in either the request, or the url. +        """ +        pass + +    def post_save(self, obj, created=False): +        """ +        Placeholder method for calling after saving an object. +        """ +        pass -### Concrete view classes that provide method handlers ### -### by composing the mixin classes with a base view.   ### +########################################################## +### Concrete view classes that provide method handlers ### +### by composing the mixin classes with the base view. ### +##########################################################  class CreateAPIView(mixins.CreateModelMixin,                      GenericAPIView): @@ -151,7 +327,7 @@ class CreateAPIView(mixins.CreateModelMixin,  class ListAPIView(mixins.ListModelMixin, -                  MultipleObjectAPIView): +                  GenericAPIView):      """      Concrete view for listing a queryset.      """ @@ -160,7 +336,7 @@ class ListAPIView(mixins.ListModelMixin,  class RetrieveAPIView(mixins.RetrieveModelMixin, -                      SingleObjectAPIView): +                      GenericAPIView):      """      Concrete view for retrieving a model instance.      """ @@ -169,7 +345,7 @@ class RetrieveAPIView(mixins.RetrieveModelMixin,  class DestroyAPIView(mixins.DestroyModelMixin, -                     SingleObjectAPIView): +                     GenericAPIView):      """      Concrete view for deleting a model instance. @@ -179,7 +355,7 @@ class DestroyAPIView(mixins.DestroyModelMixin,  class UpdateAPIView(mixins.UpdateModelMixin, -                    SingleObjectAPIView): +                    GenericAPIView):      """      Concrete view for updating a model instance. @@ -188,13 +364,12 @@ class UpdateAPIView(mixins.UpdateModelMixin,          return self.update(request, *args, **kwargs)      def patch(self, request, *args, **kwargs): -        kwargs['partial'] = True -        return self.update(request, *args, **kwargs) +        return self.partial_update(request, *args, **kwargs)  class ListCreateAPIView(mixins.ListModelMixin,                          mixins.CreateModelMixin, -                        MultipleObjectAPIView): +                        GenericAPIView):      """      Concrete view for listing a queryset or creating a model instance.      """ @@ -207,7 +382,7 @@ class ListCreateAPIView(mixins.ListModelMixin,  class RetrieveUpdateAPIView(mixins.RetrieveModelMixin,                              mixins.UpdateModelMixin, -                            SingleObjectAPIView): +                            GenericAPIView):      """      Concrete view for retrieving, updating a model instance.      """ @@ -218,13 +393,12 @@ class RetrieveUpdateAPIView(mixins.RetrieveModelMixin,          return self.update(request, *args, **kwargs)      def patch(self, request, *args, **kwargs): -        kwargs['partial'] = True -        return self.update(request, *args, **kwargs) +        return self.partial_update(request, *args, **kwargs)  class RetrieveDestroyAPIView(mixins.RetrieveModelMixin,                               mixins.DestroyModelMixin, -                             SingleObjectAPIView): +                             GenericAPIView):      """      Concrete view for retrieving or deleting a model instance.      """ @@ -238,7 +412,7 @@ class RetrieveDestroyAPIView(mixins.RetrieveModelMixin,  class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,                                     mixins.UpdateModelMixin,                                     mixins.DestroyModelMixin, -                                   SingleObjectAPIView): +                                   GenericAPIView):      """      Concrete view for retrieving, updating or deleting a model instance.      """ @@ -249,8 +423,31 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,          return self.update(request, *args, **kwargs)      def patch(self, request, *args, **kwargs): -        kwargs['partial'] = True -        return self.update(request, *args, **kwargs) +        return self.partial_update(request, *args, **kwargs)      def delete(self, request, *args, **kwargs):          return self.destroy(request, *args, **kwargs) + + +########################## +### Deprecated classes ### +########################## + +class MultipleObjectAPIView(GenericAPIView): +    def __init__(self, *args, **kwargs): +        warnings.warn( +            'Subclassing `MultipleObjectAPIView` is due to be deprecated. ' +            'You should simply subclass `GenericAPIView` instead.', +            PendingDeprecationWarning, stacklevel=2 +        ) +        super(MultipleObjectAPIView, self).__init__(*args, **kwargs) + + +class SingleObjectAPIView(GenericAPIView): +    def __init__(self, *args, **kwargs): +        warnings.warn( +            'Subclassing `SingleObjectAPIView` is due to be deprecated. ' +            'You should simply subclass `GenericAPIView` instead.', +            PendingDeprecationWarning, stacklevel=2 +        ) +        super(SingleObjectAPIView, self).__init__(*args, **kwargs) diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 3bd7d6df..ae703771 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -12,7 +12,7 @@ from rest_framework.response import Response  from rest_framework.request import clone_request -def _get_validation_exclusions(obj, pk=None, slug_field=None): +def _get_validation_exclusions(obj, pk=None, slug_field=None, lookup_field=None):      """      Given a model instance, and an optional pk and slug field,      return the full list of all other field names on that model. @@ -23,14 +23,19 @@ def _get_validation_exclusions(obj, pk=None, slug_field=None):      include = []      if pk: +        # Pending deprecation          pk_field = obj._meta.pk          while pk_field.rel:              pk_field = pk_field.rel.to._meta.pk          include.append(pk_field.name)      if slug_field: +        # Pending deprecation          include.append(slug_field) +    if lookup_field and lookup_field != 'pk': +        include.append(lookup_field) +      return [field.name for field in obj._meta.fields if field.name not in include] @@ -67,23 +72,18 @@ class ListModelMixin(object):      empty_error = "Empty list and '%(class_name)s.allow_empty' is False."      def list(self, request, *args, **kwargs): -        queryset = self.get_queryset() -        self.object_list = self.filter_queryset(queryset) +        self.object_list = self.filter_queryset(self.get_queryset())          # Default is to allow empty querysets.  This can be altered by setting          # `.allow_empty = False`, to raise 404 errors on empty querysets. -        allow_empty = self.get_allow_empty() -        if not allow_empty and not self.object_list: +        if not self.allow_empty and not self.object_list:              class_name = self.__class__.__name__              error_msg = self.empty_error % {'class_name': class_name}              raise Http404(error_msg) -        # Pagination size is set by the `.paginate_by` attribute, -        # which may be `None` to disable pagination. -        page_size = self.get_paginate_by(self.object_list) -        if page_size: -            packed = self.paginate_queryset(self.object_list, page_size) -            paginator, page, queryset, is_paginated = packed +        # Switch between paginated or standard style responses +        page = self.paginate_queryset(self.object_list) +        if page is not None:              serializer = self.get_pagination_serializer(page)          else:              serializer = self.get_serializer(self.object_list, many=True) @@ -135,14 +135,22 @@ class UpdateModelMixin(object):          return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) +    def partial_update(self, request, *args, **kwargs): +        kwargs['partial'] = True +        return self.update(request, *args, **kwargs) +      def pre_save(self, obj):          """          Set any attributes on the object that are implicit in the request.          """          # pk and/or slug attributes are implicit in the URL. +        lookup = self.kwargs.get(self.lookup_field, None)          pk = self.kwargs.get(self.pk_url_kwarg, None)          slug = self.kwargs.get(self.slug_url_kwarg, None) -        slug_field = slug and self.get_slug_field() or None +        slug_field = slug and self.slug_field or None + +        if lookup: +            setattr(obj, self.lookup_field, lookup)          if pk:              setattr(obj, 'pk', pk) @@ -153,7 +161,7 @@ class UpdateModelMixin(object):          # Ensure we clean the attributes so that we don't eg return integer          # pk using a string representation, as provided by the url conf kwarg.          if hasattr(obj, 'full_clean'): -            exclude = _get_validation_exclusions(obj, pk, slug_field) +            exclude = _get_validation_exclusions(obj, pk, slug_field, self.lookup_field)              obj.full_clean(exclude) diff --git a/rest_framework/negotiation.py b/rest_framework/negotiation.py index 0694d35f..4d205c0e 100644 --- a/rest_framework/negotiation.py +++ b/rest_framework/negotiation.py @@ -1,3 +1,7 @@ +""" +Content negotiation deals with selecting an appropriate renderer given the +incoming request.  Typically this will be based on the request's Accept header. +"""  from __future__ import unicode_literals  from django.http import Http404  from rest_framework import exceptions diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index 03a7a30f..d51ea929 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -1,9 +1,11 @@ +""" +Pagination serializers determine the structure of the output that should +be used for paginated responses. +"""  from __future__ import unicode_literals  from rest_framework import serializers  from rest_framework.templatetags.rest_framework import replace_query_param -# TODO: Support URLconf kwarg-style paging -  class NextPageField(serializers.Field):      """ diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py index ae895f39..751f31a7 100644 --- a/rest_framework/permissions.py +++ b/rest_framework/permissions.py @@ -25,10 +25,12 @@ class BasePermission(object):          """          Return `True` if permission is granted, `False` otherwise.          """ -        if len(inspect.getargspec(self.has_permission)[0]) == 4: -            warnings.warn('The `obj` argument in `has_permission` is due to be deprecated. ' -                      'Use `has_object_permission()` instead for object permissions.', -                       PendingDeprecationWarning, stacklevel=2) +        if len(inspect.getargspec(self.has_permission).args) == 4: +            warnings.warn( +                'The `obj` argument in `has_permission` is deprecated. ' +                'Use `has_object_permission()` instead for object permissions.', +                DeprecationWarning, stacklevel=2 +            )              return self.has_permission(request, view, obj)          return True @@ -87,8 +89,8 @@ class DjangoModelPermissions(BasePermission):      It ensures that the user is authenticated, and has the appropriate      `add`/`change`/`delete` permissions on the model. -    This permission will only be applied against view classes that -    provide a `.model` attribute, such as the generic class-based views. +    This permission can only be applied against view classes that +    provide a `.model` or `.queryset` attribute.      """      # Map methods into required permission codes. @@ -136,6 +138,14 @@ class DjangoModelPermissions(BasePermission):          return False +class DjangoModelPermissionsOrAnonReadOnly(DjangoModelPermissions): +    """ +    Similar to DjangoModelPermissions, except that anonymous users are +    allowed read-only access. +    """ +    authenticated_users_only = False + +  class TokenHasReadWriteScope(BasePermission):      """      The request is authenticated as a user and the token used has the right scope diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 2a10e9af..fc5054b2 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -1,3 +1,9 @@ +""" +Serializer fields that deal with relationships. + +These fields allow you to specify the style that should be used to represent +model relationships, including hyperlinks, primary keys, or slugs. +"""  from __future__ import unicode_literals  from django.core.exceptions import ObjectDoesNotExist, ValidationError  from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch @@ -36,9 +42,9 @@ class RelatedField(WritableField):          # 'null' is to be deprecated in favor of 'required'          if 'null' in kwargs: -            warnings.warn('The `null` keyword argument is due to be deprecated. ' +            warnings.warn('The `null` keyword argument is deprecated. '                            'Use the `required` keyword argument instead.', -                          PendingDeprecationWarning, stacklevel=2) +                          DeprecationWarning, stacklevel=2)              kwargs['required'] = not kwargs.pop('null')          self.queryset = kwargs.pop('queryset', None) @@ -282,10 +288,8 @@ class HyperlinkedRelatedField(RelatedField):      """      Represents a relationship using hyperlinking.      """ -    pk_url_kwarg = 'pk' -    slug_field = 'slug' -    slug_url_kwarg = None  # Defaults to same as `slug_field` unless overridden      read_only = False +    lookup_field = 'pk'      default_error_messages = {          'no_match': _('Invalid hyperlink - No URL match'), @@ -295,69 +299,138 @@ class HyperlinkedRelatedField(RelatedField):          'incorrect_type': _('Incorrect type.  Expected url string, received %s.'),      } +    # These are all pending deprecation +    pk_url_kwarg = 'pk' +    slug_field = 'slug' +    slug_url_kwarg = None  # Defaults to same as `slug_field` unless overridden +      def __init__(self, *args, **kwargs):          try:              self.view_name = kwargs.pop('view_name')          except KeyError:              raise ValueError("Hyperlinked field requires 'view_name' kwarg") +        self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) +        self.format = kwargs.pop('format', None) + +        # These are pending deprecation +        if 'pk_url_kwarg' in kwargs: +            msg = 'pk_url_kwarg is pending deprecation. Use lookup_field instead.' +            warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) +        if 'slug_url_kwarg' in kwargs: +            msg = 'slug_url_kwarg is pending deprecation. Use lookup_field instead.' +            warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) +        if 'slug_field' in kwargs: +            msg = 'slug_field is pending deprecation. Use lookup_field instead.' +            warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + +        self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg)          self.slug_field = kwargs.pop('slug_field', self.slug_field)          default_slug_kwarg = self.slug_url_kwarg or self.slug_field -        self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg)          self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg) -        self.format = kwargs.pop('format', None)          super(HyperlinkedRelatedField, self).__init__(*args, **kwargs) -    def get_slug_field(self): +    def get_url(self, obj, view_name, request, format):          """ -        Get the name of a slug field to be used to look up by slug. -        """ -        return self.slug_field +        Given an object, return the URL that hyperlinks to the object. -    def to_native(self, obj): -        view_name = self.view_name -        request = self.context.get('request', None) -        format = self.format or self.context.get('format', None) - -        if request is None: -            warnings.warn("Using `HyperlinkedRelatedField` without including the " -                          "request in the serializer context is due to be deprecated. " -                          "Add `context={'request': request}` when instantiating the serializer.", -                          PendingDeprecationWarning, stacklevel=4) - -        pk = getattr(obj, 'pk', None) -        if pk is None: -            return -        kwargs = {self.pk_url_kwarg: pk} +        May raise a `NoReverseMatch` if the `view_name` and `lookup_field` +        attributes are not configured to correctly match the URL conf. +        """ +        lookup_field = getattr(obj, self.lookup_field) +        kwargs = {self.lookup_field: lookup_field}          try:              return reverse(view_name, kwargs=kwargs, request=request, format=format)          except NoReverseMatch:              pass +        if self.pk_url_kwarg != 'pk': +            # Only try pk if it has been explicitly set. +            # Otherwise, the default `lookup_field = 'pk'` has us covered. +            pk = obj.pk +            kwargs = {self.pk_url_kwarg: pk} +            try: +                return reverse(view_name, kwargs=kwargs, request=request, format=format) +            except NoReverseMatch: +                pass +          slug = getattr(obj, self.slug_field, None) +        if slug is not None: +            # Only try slug if it corresponds to an attribute on the object. +            kwargs = {self.slug_url_kwarg: slug} +            try: +                ret = reverse(view_name, kwargs=kwargs, request=request, format=format) +                if self.slug_field == 'slug' and self.slug_url_kwarg == 'slug': +                    # If the lookup succeeds using the default slug params, +                    # then `slug_field` is being used implicitly, and we +                    # we need to warn about the pending deprecation. +                    msg = 'Implicit slug field hyperlinked fields are pending deprecation.' \ +                          'You should set `lookup_field=slug` on the HyperlinkedRelatedField.' +                    warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) +                return ret +            except NoReverseMatch: +                pass + +        raise NoReverseMatch() + +    def get_object(self, queryset, view_name, view_args, view_kwargs): +        """ +        Return the object corresponding to a matched URL. -        if not slug: -            raise Exception('Could not resolve URL for field using view name "%s"' % view_name) +        Takes the matched URL conf arguments, and the queryset, and should +        return an object instance, or raise an `ObjectDoesNotExist` exception. +        """ +        lookup = view_kwargs.get(self.lookup_field, None) +        pk = view_kwargs.get(self.pk_url_kwarg, None) +        slug = view_kwargs.get(self.slug_url_kwarg, None) + +        if lookup is not None: +            filter_kwargs = {self.lookup_field: lookup} +        elif pk is not None: +            filter_kwargs = {'pk': pk} +        elif slug is not None: +            filter_kwargs = {self.slug_field: slug} +        else: +            raise ObjectDoesNotExist() -        kwargs = {self.slug_url_kwarg: slug} -        try: -            return reverse(view_name, kwargs=kwargs, request=request, format=format) -        except NoReverseMatch: -            pass +        return queryset.get(**filter_kwargs) -        kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug} +    def to_native(self, obj): +        view_name = self.view_name +        request = self.context.get('request', None) +        format = self.format or self.context.get('format', None) + +        if request is None: +            msg = ( +                "Using `HyperlinkedRelatedField` without including the request " +                "in the serializer context is deprecated. " +                "Add `context={'request': request}` when instantiating " +                "the serializer." +            ) +            warnings.warn(msg, DeprecationWarning, stacklevel=4) + +        # If the object has not yet been saved then we cannot hyperlink to it. +        if getattr(obj, 'pk', None) is None: +            return + +        # Return the hyperlink, or error if incorrectly configured.          try: -            return reverse(view_name, kwargs=kwargs, request=request, format=format) +            return self.get_url(obj, view_name, request, format)          except NoReverseMatch: -            pass - -        raise Exception('Could not resolve URL for field using view name "%s"' % view_name) +            msg = ( +                'Could not resolve URL for hyperlinked relationship using ' +                'view name "%s". You may have failed to include the related ' +                'model in your API, or incorrectly configured the ' +                '`lookup_field` attribute on this field.' +            ) +            raise Exception(msg % view_name)      def from_native(self, value):          # Convert URL -> model instance pk          # TODO: Use values_list -        if self.queryset is None: +        queryset = self.queryset +        if queryset is None:              raise Exception('Writable related fields must include a `queryset` argument')          try: @@ -381,29 +454,11 @@ class HyperlinkedRelatedField(RelatedField):          if match.view_name != self.view_name:              raise ValidationError(self.error_messages['incorrect_match']) -        pk = match.kwargs.get(self.pk_url_kwarg, None) -        slug = match.kwargs.get(self.slug_url_kwarg, None) - -        # Try explicit primary key. -        if pk is not None: -            queryset = self.queryset.filter(pk=pk) -        # Next, try looking up by slug. -        elif slug is not None: -            slug_field = self.get_slug_field() -            queryset = self.queryset.filter(**{slug_field: slug}) -        # If none of those are defined, it's probably a configuation error. -        else: -            raise ValidationError(self.error_messages['configuration_error']) -          try: -            obj = queryset.get() -        except ObjectDoesNotExist: +            return self.get_object(queryset, match.view_name, +                                   match.args, match.kwargs) +        except (ObjectDoesNotExist, TypeError, ValueError):              raise ValidationError(self.error_messages['does_not_exist']) -        except (TypeError, ValueError): -            msg = self.error_messages['incorrect_type'] -            raise ValidationError(msg % type(value).__name__) - -        return obj  class HyperlinkedIdentityField(Field): @@ -437,9 +492,9 @@ class HyperlinkedIdentityField(Field):          if request is None:              warnings.warn("Using `HyperlinkedIdentityField` without including the " -                          "request in the serializer context is due to be deprecated. " +                          "request in the serializer context is deprecated. "                            "Add `context={'request': request}` when instantiating the serializer.", -                          PendingDeprecationWarning, stacklevel=4) +                          DeprecationWarning, stacklevel=4)          # By default use whatever format is given for the current context          # unless the target is a different type to the source. @@ -482,35 +537,35 @@ class HyperlinkedIdentityField(Field):  class ManyRelatedField(RelatedField):      def __init__(self, *args, **kwargs): -        warnings.warn('`ManyRelatedField()` is due to be deprecated. ' +        warnings.warn('`ManyRelatedField()` is deprecated. '                        'Use `RelatedField(many=True)` instead.', -                       PendingDeprecationWarning, stacklevel=2) +                       DeprecationWarning, stacklevel=2)          kwargs['many'] = True          super(ManyRelatedField, self).__init__(*args, **kwargs)  class ManyPrimaryKeyRelatedField(PrimaryKeyRelatedField):      def __init__(self, *args, **kwargs): -        warnings.warn('`ManyPrimaryKeyRelatedField()` is due to be deprecated. ' +        warnings.warn('`ManyPrimaryKeyRelatedField()` is deprecated. '                        'Use `PrimaryKeyRelatedField(many=True)` instead.', -                       PendingDeprecationWarning, stacklevel=2) +                       DeprecationWarning, stacklevel=2)          kwargs['many'] = True          super(ManyPrimaryKeyRelatedField, self).__init__(*args, **kwargs)  class ManySlugRelatedField(SlugRelatedField):      def __init__(self, *args, **kwargs): -        warnings.warn('`ManySlugRelatedField()` is due to be deprecated. ' +        warnings.warn('`ManySlugRelatedField()` is deprecated. '                        'Use `SlugRelatedField(many=True)` instead.', -                       PendingDeprecationWarning, stacklevel=2) +                       DeprecationWarning, stacklevel=2)          kwargs['many'] = True          super(ManySlugRelatedField, self).__init__(*args, **kwargs)  class ManyHyperlinkedRelatedField(HyperlinkedRelatedField):      def __init__(self, *args, **kwargs): -        warnings.warn('`ManyHyperlinkedRelatedField()` is due to be deprecated. ' +        warnings.warn('`ManyHyperlinkedRelatedField()` is deprecated. '                        'Use `HyperlinkedRelatedField(many=True)` instead.', -                       PendingDeprecationWarning, stacklevel=2) +                       DeprecationWarning, stacklevel=2)          kwargs['many'] = True          super(ManyHyperlinkedRelatedField, self).__init__(*args, **kwargs) diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 83bbc5b8..1917a080 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -24,6 +24,7 @@ from rest_framework.settings import api_settings  from rest_framework.request import clone_request  from rest_framework.utils import encoders  from rest_framework.utils.breadcrumbs import get_breadcrumbs +from rest_framework.utils.formatting import get_view_name, get_view_description  from rest_framework import exceptions, parsers, status, VERSION @@ -438,16 +439,13 @@ class BrowsableAPIRenderer(BaseRenderer):          return GenericContentForm()      def get_name(self, view): -        try: -            return view.get_name() -        except AttributeError: -            return smart_text(view.__class__.__name__) +        return get_view_name(view.__class__, getattr(view, 'suffix', None))      def get_description(self, view): -        try: -            return view.get_description(html=True) -        except AttributeError: -            return smart_text(view.__doc__ or '') +        return get_view_description(view.__class__, html=True) + +    def get_breadcrumbs(self, request): +        return get_breadcrumbs(request.path)      def render(self, data, accepted_media_type=None, renderer_context=None):          """ @@ -480,7 +478,7 @@ class BrowsableAPIRenderer(BaseRenderer):          name = self.get_name(view)          description = self.get_description(view) -        breadcrumb_list = get_breadcrumbs(request.path) +        breadcrumb_list = self.get_breadcrumbs(request)          template = loader.get_template(self.template)          context = RequestContext(request, { diff --git a/rest_framework/request.py b/rest_framework/request.py index ffbbab33..a434659c 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -1,11 +1,10 @@  """ -The :mod:`request` module provides a :class:`Request` class used to wrap the standard `request` -object received in all the views. +The Request class is used as a wrapper around the standard request object.  The wrapped request then offers a richer API, in particular :      - content automatically parsed according to `Content-Type` header, -      and available as :meth:`.DATA<Request.DATA>` +      and available as `request.DATA`      - full support of PUT method, including support for file uploads      - form overloading of HTTP method, content type and content  """ diff --git a/rest_framework/response.py b/rest_framework/response.py index 5e1bf46e..26e4ab37 100644 --- a/rest_framework/response.py +++ b/rest_framework/response.py @@ -1,3 +1,9 @@ +""" +The Response class in REST framework is similiar to HTTPResponse, except that +it is initialized with unrendered data, instead of a pre-rendered string. + +The appropriate renderer is called during Django's template response rendering. +"""  from __future__ import unicode_literals  from django.core.handlers.wsgi import STATUS_CODE_TEXT  from django.template.response import SimpleTemplateResponse diff --git a/rest_framework/routers.py b/rest_framework/routers.py new file mode 100644 index 00000000..0707635a --- /dev/null +++ b/rest_framework/routers.py @@ -0,0 +1,246 @@ +""" +Routers provide a convenient and consistent way of automatically +determining the URL conf for your API. + +They are used by simply instantiating a Router class, and then registering +all the required ViewSets with that router. + +For example, you might have a `urls.py` that looks something like this: + +    router = routers.DefaultRouter() +    router.register('users', UserViewSet, 'user') +    router.register('accounts', AccountViewSet, 'account') + +    urlpatterns = router.urls +""" +from __future__ import unicode_literals + +from collections import namedtuple +from django.conf.urls import url, patterns +from rest_framework.decorators import api_view +from rest_framework.response import Response +from rest_framework.reverse import reverse +from rest_framework.urlpatterns import format_suffix_patterns + + +Route = namedtuple('Route', ['url', 'mapping', 'name', 'initkwargs']) + + +def replace_methodname(format_string, methodname): +    """ +    Partially format a format_string, swapping out any +    '{methodname}' or '{methodnamehyphen}' components. +    """ +    methodnamehyphen = methodname.replace('_', '-') +    ret = format_string +    ret = ret.replace('{methodname}', methodname) +    ret = ret.replace('{methodnamehyphen}', methodnamehyphen) +    return ret + + +class BaseRouter(object): +    def __init__(self): +        self.registry = [] + +    def register(self, prefix, viewset, base_name=None): +        if base_name is None: +            base_name = self.get_default_base_name(viewset) +        self.registry.append((prefix, viewset, base_name)) + +    def get_default_base_name(self, viewset): +        """ +        If `base_name` is not specified, attempt to automatically determine +        it from the viewset. +        """ +        raise NotImplemented('get_default_base_name must be overridden') + +    def get_urls(self): +        """ +        Return a list of URL patterns, given the registered viewsets. +        """ +        raise NotImplemented('get_urls must be overridden') + +    @property +    def urls(self): +        if not hasattr(self, '_urls'): +            self._urls = patterns('', *self.get_urls()) +        return self._urls + + +class SimpleRouter(BaseRouter): +    routes = [ +        # List route. +        Route( +            url=r'^{prefix}/$', +            mapping={ +                'get': 'list', +                'post': 'create' +            }, +            name='{basename}-list', +            initkwargs={'suffix': 'List'} +        ), +        # Detail route. +        Route( +            url=r'^{prefix}/{lookup}/$', +            mapping={ +                'get': 'retrieve', +                'put': 'update', +                'patch': 'partial_update', +                'delete': 'destroy' +            }, +            name='{basename}-detail', +            initkwargs={'suffix': 'Instance'} +        ), +        # Dynamically generated routes. +        # Generated using @action or @link decorators on methods of the viewset. +        Route( +            url=r'^{prefix}/{lookup}/{methodname}/$', +            mapping={ +                '{httpmethod}': '{methodname}', +            }, +            name='{basename}-{methodnamehyphen}', +            initkwargs={} +        ), +    ] + +    def get_default_base_name(self, viewset): +        """ +        If `base_name` is not specified, attempt to automatically determine +        it from the viewset. +        """ +        model_cls = getattr(viewset, 'model', None) +        queryset = getattr(viewset, 'queryset', None) +        if model_cls is None and queryset is not None: +            model_cls = queryset.model + +        assert model_cls, '`name` not argument not specified, and could ' \ +            'not automatically determine the name from the viewset, as ' \ +            'it does not have a `.model` or `.queryset` attribute.' + +        return model_cls._meta.object_name.lower() + +    def get_routes(self, viewset): +        """ +        Augment `self.routes` with any dynamically generated routes. + +        Returns a list of the Route namedtuple. +        """ + +        # Determine any `@action` or `@link` decorated methods on the viewset +        dynamic_routes = {} +        for methodname in dir(viewset): +            attr = getattr(viewset, methodname) +            httpmethod = getattr(attr, 'bind_to_method', None) +            if httpmethod: +                dynamic_routes[httpmethod] = methodname + +        ret = [] +        for route in self.routes: +            if route.mapping == {'{httpmethod}': '{methodname}'}: +                # Dynamic routes (@link or @action decorator) +                for httpmethod, methodname in dynamic_routes.items(): +                    initkwargs = route.initkwargs.copy() +                    initkwargs.update(getattr(viewset, methodname).kwargs) +                    ret.append(Route( +                        url=replace_methodname(route.url, methodname), +                        mapping={httpmethod: methodname}, +                        name=replace_methodname(route.name, methodname), +                        initkwargs=initkwargs, +                    )) +            else: +                # Standard route +                ret.append(route) + +        return ret + +    def get_method_map(self, viewset, method_map): +        """ +        Given a viewset, and a mapping of http methods to actions, +        return a new mapping which only includes any mappings that +        are actually implemented by the viewset. +        """ +        bound_methods = {} +        for method, action in method_map.items(): +            if hasattr(viewset, action): +                bound_methods[method] = action +        return bound_methods + +    def get_lookup_regex(self, viewset): +        """ +        Given a viewset, return the portion of URL regex that is used +        to match against a single instance. +        """ +        base_regex = '(?P<{lookup_field}>[^/]+)' +        lookup_field = getattr(viewset, 'lookup_field', 'pk') +        return base_regex.format(lookup_field=lookup_field) + +    def get_urls(self): +        """ +        Use the registered viewsets to generate a list of URL patterns. +        """ +        ret = [] + +        for prefix, viewset, basename in self.registry: +            lookup = self.get_lookup_regex(viewset) +            routes = self.get_routes(viewset) + +            for route in routes: + +                # Only actions which actually exist on the viewset will be bound +                mapping = self.get_method_map(viewset, route.mapping) +                if not mapping: +                    continue + +                # Build the url pattern +                regex = route.url.format(prefix=prefix, lookup=lookup) +                view = viewset.as_view(mapping, **route.initkwargs) +                name = route.name.format(basename=basename) +                ret.append(url(regex, view, name=name)) + +        return ret + + +class DefaultRouter(SimpleRouter): +    """ +    The default router extends the SimpleRouter, but also adds in a default +    API root view, and adds format suffix patterns to the URLs. +    """ +    include_root_view = True +    include_format_suffixes = True + +    def get_api_root_view(self): +        """ +        Return a view to use as the API root. +        """ +        api_root_dict = {} +        list_name = self.routes[0].name +        for prefix, viewset, basename in self.registry: +            api_root_dict[prefix] = list_name.format(basename=basename) + +        @api_view(('GET',)) +        def api_root(request, format=None): +            ret = {} +            for key, url_name in api_root_dict.items(): +                ret[key] = reverse(url_name, request=request, format=format) +            return Response(ret) + +        return api_root + +    def get_urls(self): +        """ +        Generate the list of URL patterns, including a default root view +        for the API, and appending `.json` style format suffixes. +        """ +        urls = [] + +        if self.include_root_view: +            root_url = url(r'^$', self.get_api_root_view(), name='api-root') +            urls.append(root_url) + +        default_urls = super(DefaultRouter, self).get_urls() +        urls.extend(default_urls) + +        if self.include_format_suffixes: +            urls = format_suffix_patterns(urls) + +        return urls diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index add46566..ea5175e2 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -1,3 +1,15 @@ +""" +Serializers and ModelSerializers are similar to Forms and ModelForms. +Unlike forms, they are not constrained to dealing with HTML output, and +form encoded input. + +Serialization in REST framework is a two-phase process: + +1. Serializers marshal between complex types like model instances, and +python primatives. +2. The process of marshalling between python primatives and request and +response content is handled by parsers and renderers. +"""  from __future__ import unicode_literals  import copy  import datetime @@ -412,9 +424,9 @@ class BaseSerializer(WritableField):              else:                  many = hasattr(data, '__iter__') and not isinstance(data, (Page, dict, six.text_type))                  if many: -                    warnings.warn('Implict list/queryset serialization is due to be deprecated. ' +                    warnings.warn('Implict list/queryset serialization is deprecated. '                                    'Use the `many=True` flag when instantiating the serializer.', -                                  PendingDeprecationWarning, stacklevel=3) +                                  DeprecationWarning, stacklevel=3)              if many:                  ret = [] @@ -474,9 +486,9 @@ class BaseSerializer(WritableField):              else:                  many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict))                  if many: -                    warnings.warn('Implict list/queryset serialization is due to be deprecated. ' +                    warnings.warn('Implict list/queryset serialization is deprecated. '                                    'Use the `many=True` flag when instantiating the serializer.', -                                  PendingDeprecationWarning, stacklevel=2) +                                  DeprecationWarning, stacklevel=2)              if many:                  self._data = [self.to_native(item) for item in obj] @@ -536,6 +548,7 @@ class ModelSerializer(Serializer):          models.DateTimeField: DateTimeField,          models.DateField: DateField,          models.TimeField: TimeField, +        models.DecimalField: DecimalField,          models.EmailField: EmailField,          models.CharField: CharField,          models.URLField: URLField, @@ -556,36 +569,85 @@ class ModelSerializer(Serializer):          assert cls is not None, \                  "Serializer class '%s' is missing 'model' Meta option" % self.__class__.__name__          opts = get_concrete_model(cls)._meta -        pk_field = opts.pk +        ret = SortedDict() +        nested = bool(self.opts.depth) -        # If model is a child via multitable inheritance, use parent's pk +        # Deal with adding the primary key field +        pk_field = opts.pk          while pk_field.rel and pk_field.rel.parent_link: +            # If model is a child via multitable inheritance, use parent's pk              pk_field = pk_field.rel.to._meta.pk -        fields = [pk_field] -        fields += [field for field in opts.fields if field.serialize] -        fields += [field for field in opts.many_to_many if field.serialize] +        field = self.get_pk_field(pk_field) +        if field: +            ret[pk_field.name] = field -        ret = SortedDict() -        nested = bool(self.opts.depth) -        is_pk = True  # First field in the list is the pk - -        for model_field in fields: -            if is_pk: -                field = self.get_pk_field(model_field) -                is_pk = False -            elif model_field.rel and nested: -                field = self.get_nested_field(model_field) -            elif model_field.rel: +        # Deal with forward relationships +        forward_rels = [field for field in opts.fields if field.serialize] +        forward_rels += [field for field in opts.many_to_many if field.serialize] + +        for model_field in forward_rels: +            if model_field.rel:                  to_many = isinstance(model_field,                                       models.fields.related.ManyToManyField) -                field = self.get_related_field(model_field, to_many=to_many) +                related_model = model_field.rel.to + +            if model_field.rel and nested: +                if len(inspect.getargspec(self.get_nested_field).args) == 2: +                    warnings.warn( +                        'The `get_nested_field(model_field)` call signature ' +                        'is due to be deprecated. ' +                        'Use `get_nested_field(model_field, related_model, ' +                        'to_many) instead', +                        PendingDeprecationWarning +                    ) +                    field = self.get_nested_field(model_field) +                else: +                    field = self.get_nested_field(model_field, related_model, to_many) +            elif model_field.rel: +                if len(inspect.getargspec(self.get_nested_field).args) == 3: +                    warnings.warn( +                        'The `get_related_field(model_field, to_many)` call ' +                        'signature is due to be deprecated. ' +                        'Use `get_related_field(model_field, related_model, ' +                        'to_many) instead', +                        PendingDeprecationWarning +                    ) +                    field = self.get_related_field(model_field, to_many=to_many) +                else: +                    field = self.get_related_field(model_field, related_model, to_many)              else:                  field = self.get_field(model_field)              if field:                  ret[model_field.name] = field +        # Deal with reverse relationships +        if not self.opts.fields: +            reverse_rels = [] +        else: +            # Reverse relationships are only included if they are explicitly +            # present in the `fields` option on the serializer +            reverse_rels = opts.get_all_related_objects() +            reverse_rels += opts.get_all_related_many_to_many_objects() + +        for relation in reverse_rels: +            accessor_name = relation.get_accessor_name() +            if not self.opts.fields or accessor_name not in self.opts.fields: +                continue +            related_model = relation.model +            to_many = relation.field.rel.multiple + +            if nested: +                field = self.get_nested_field(None, related_model, to_many) +            else: +                field = self.get_related_field(None, related_model, to_many) + +            if field: +                ret[accessor_name] = field + +        # Add the `read_only` flag to any fields that have bee specified +        # in the `read_only_fields` option          for field_name in self.opts.read_only_fields:              assert field_name in ret, \                  "read_only_fields on '%s' included invalid item '%s'" % \ @@ -600,29 +662,36 @@ class ModelSerializer(Serializer):          """          return self.get_field(model_field) -    def get_nested_field(self, model_field): +    def get_nested_field(self, model_field, related_model, to_many):          """          Creates a default instance of a nested relational field. + +        Note that model_field will be `None` for reverse relationships.          """          class NestedModelSerializer(ModelSerializer):              class Meta: -                model = model_field.rel.to +                model = related_model                  depth = self.opts.depth - 1 -        return NestedModelSerializer() +        return NestedModelSerializer(many=to_many) -    def get_related_field(self, model_field, to_many=False): +    def get_related_field(self, model_field, related_model, to_many):          """          Creates a default instance of a flat relational field. + +        Note that model_field will be `None` for reverse relationships.          """          # TODO: filter queryset using:          # .using(db).complex_filter(self.rel.limit_choices_to) +          kwargs = { -            'required': not(model_field.null or model_field.blank), -            'queryset': model_field.rel.to._default_manager, +            'queryset': related_model._default_manager,              'many': to_many          } +        if model_field: +            kwargs['required'] = not(model_field.null or model_field.blank) +          return PrimaryKeyRelatedField(**kwargs)      def get_field(self, model_field): @@ -758,6 +827,7 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions):      def __init__(self, meta):          super(HyperlinkedModelSerializerOptions, self).__init__(meta)          self.view_name = getattr(meta, 'view_name', None) +        self.lookup_field = getattr(meta, 'slug_field', None)  class HyperlinkedModelSerializer(ModelSerializer): @@ -767,6 +837,7 @@ class HyperlinkedModelSerializer(ModelSerializer):      """      _options_class = HyperlinkedModelSerializerOptions      _default_view_name = '%(model_name)s-detail' +    _hyperlink_field_class = HyperlinkedRelatedField      url = HyperlinkedIdentityField() @@ -787,22 +858,28 @@ class HyperlinkedModelSerializer(ModelSerializer):          return self._default_view_name % format_kwargs      def get_pk_field(self, model_field): -        return None +        if self.opts.fields and model_field.name in self.opts.fields: +            return self.get_field(model_field) -    def get_related_field(self, model_field, to_many): +    def get_related_field(self, model_field, related_model, to_many):          """          Creates a default instance of a flat relational field.          """          # TODO: filter queryset using:          # .using(db).complex_filter(self.rel.limit_choices_to) -        rel = model_field.rel.to          kwargs = { -            'required': not(model_field.null or model_field.blank), -            'queryset': rel._default_manager, -            'view_name': self._get_default_view_name(rel), +            'queryset': related_model._default_manager, +            'view_name': self._get_default_view_name(related_model),              'many': to_many          } -        return HyperlinkedRelatedField(**kwargs) + +        if model_field: +            kwargs['required'] = not(model_field.null or model_field.blank) + +        if self.opts.lookup_field: +            kwargs['lookup_field'] = self.opts.lookup_field + +        return self._hyperlink_field_class(**kwargs)      def get_identity(self, data):          """ diff --git a/rest_framework/settings.py b/rest_framework/settings.py index eede0c5a..734d8478 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -29,6 +29,7 @@ from rest_framework.compat import six  USER_SETTINGS = getattr(settings, 'REST_FRAMEWORK', None)  DEFAULTS = { +    # Base API policies      'DEFAULT_RENDERER_CLASSES': (          'rest_framework.renderers.JSONRenderer',          'rest_framework.renderers.BrowsableAPIRenderer', @@ -50,11 +51,15 @@ DEFAULTS = {      'DEFAULT_CONTENT_NEGOTIATION_CLASS':          'rest_framework.negotiation.DefaultContentNegotiation', + +    # Genric view behavior      'DEFAULT_MODEL_SERIALIZER_CLASS':          'rest_framework.serializers.ModelSerializer',      'DEFAULT_PAGINATION_SERIALIZER_CLASS':          'rest_framework.pagination.PaginationSerializer', +    'DEFAULT_FILTER_BACKENDS': (), +    # Throttling      'DEFAULT_THROTTLE_RATES': {          'user': None,          'anon': None, @@ -64,9 +69,6 @@ DEFAULTS = {      'PAGINATE_BY': None,      'PAGINATE_BY_PARAM': None, -    # Filtering -    'FILTER_BACKEND': None, -      # Authentication      'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',      'UNAUTHENTICATED_TOKEN': None, @@ -95,6 +97,9 @@ DEFAULTS = {          ISO_8601,      ),      'TIME_FORMAT': ISO_8601, + +    # Pending deprecation +    'FILTER_BACKEND': None,  } @@ -108,6 +113,7 @@ IMPORT_STRINGS = (      'DEFAULT_CONTENT_NEGOTIATION_CLASS',      'DEFAULT_MODEL_SERIALIZER_CLASS',      'DEFAULT_PAGINATION_SERIALIZER_CLASS', +    'DEFAULT_FILTER_BACKENDS',      'FILTER_BACKEND',      'UNAUTHENTICATED_USER',      'UNAUTHENTICATED_TOKEN', diff --git a/rest_framework/tests/description.py b/rest_framework/tests/description.py index 5b3315bc..52c1a34c 100644 --- a/rest_framework/tests/description.py +++ b/rest_framework/tests/description.py @@ -4,6 +4,7 @@ from __future__ import unicode_literals  from django.test import TestCase  from rest_framework.views import APIView  from rest_framework.compat import apply_markdown +from rest_framework.utils.formatting import get_view_name, get_view_description  # We check that docstrings get nicely un-indented.  DESCRIPTION = """an example docstring @@ -49,22 +50,16 @@ MARKED_DOWN_gte_21 = """<h2 id="an-example-docstring">an example docstring</h2>  class TestViewNamesAndDescriptions(TestCase): -    def test_resource_name_uses_classname_by_default(self): -        """Ensure Resource names are based on the classname by default.""" +    def test_view_name_uses_class_name(self): +        """ +        Ensure view names are based on the class name. +        """          class MockView(APIView):              pass -        self.assertEqual(MockView().get_name(), 'Mock') +        self.assertEqual(get_view_name(MockView), 'Mock') -    def test_resource_name_can_be_set_explicitly(self): -        """Ensure Resource names can be set using the 'get_name' method.""" -        example = 'Some Other Name' -        class MockView(APIView): -            def get_name(self): -                return example -        self.assertEqual(MockView().get_name(), example) - -    def test_resource_description_uses_docstring_by_default(self): -        """Ensure Resource names are based on the docstring by default.""" +    def test_view_description_uses_docstring(self): +        """Ensure view descriptions are based on the docstring."""          class MockView(APIView):              """an example docstring              ==================== @@ -81,44 +76,32 @@ class TestViewNamesAndDescriptions(TestCase):              # hash style header #""" -        self.assertEqual(MockView().get_description(), DESCRIPTION) - -    def test_resource_description_can_be_set_explicitly(self): -        """Ensure Resource descriptions can be set using the 'get_description' method.""" -        example = 'Some other description' - -        class MockView(APIView): -            """docstring""" -            def get_description(self): -                return example -        self.assertEqual(MockView().get_description(), example) +        self.assertEqual(get_view_description(MockView), DESCRIPTION) -    def test_resource_description_supports_unicode(self): +    def test_view_description_supports_unicode(self): +        """ +        Unicode in docstrings should be respected. +        """          class MockView(APIView):              """Проверка"""              pass -        self.assertEqual(MockView().get_description(), "Проверка") - - -    def test_resource_description_does_not_require_docstring(self): -        """Ensure that empty docstrings do not affect the Resource's description if it has been set using the 'get_description' method.""" -        example = 'Some other description' - -        class MockView(APIView): -            def get_description(self): -                return example -        self.assertEqual(MockView().get_description(), example) +        self.assertEqual(get_view_description(MockView), "Проверка") -    def test_resource_description_can_be_empty(self): -        """Ensure that if a resource has no doctring or 'description' class attribute, then it's description is the empty string.""" +    def test_view_description_can_be_empty(self): +        """ +        Ensure that if a view has no docstring, +        then it's description is the empty string. +        """          class MockView(APIView):              pass -        self.assertEqual(MockView().get_description(), '') +        self.assertEqual(get_view_description(MockView), '')      def test_markdown(self): -        """Ensure markdown to HTML works as expected""" +        """ +        Ensure markdown to HTML works as expected. +        """          if apply_markdown:              gte_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_gte_21              lt_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_lt_21 diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py index 19c663d8..3cdfa0f6 100644 --- a/rest_framework/tests/fields.py +++ b/rest_framework/tests/fields.py @@ -3,12 +3,14 @@ General serializer field tests.  """  from __future__ import unicode_literals  import datetime +from decimal import Decimal  from django.db import models  from django.test import TestCase  from django.core import validators  from rest_framework import serializers +from rest_framework.serializers import Serializer  class TimestampedModel(models.Model): @@ -481,3 +483,166 @@ class TimeFieldTest(TestCase):          self.assertEqual('04 - 00 [000000]', result_1)          self.assertEqual('04 - 59 [000000]', result_2)          self.assertEqual('04 - 59 [000200]', result_3) + + +class DecimalFieldTest(TestCase): +    """ +    Tests for the DecimalField from_native() and to_native() behavior +    """ + +    def test_from_native_string(self): +        """ +        Make sure from_native() accepts string values +        """ +        f = serializers.DecimalField() +        result_1 = f.from_native('9000') +        result_2 = f.from_native('1.00000001') + +        self.assertEqual(Decimal('9000'), result_1) +        self.assertEqual(Decimal('1.00000001'), result_2) + +    def test_from_native_invalid_string(self): +        """ +        Make sure from_native() raises ValidationError on passing invalid string +        """ +        f = serializers.DecimalField() + +        try: +            f.from_native('123.45.6') +        except validators.ValidationError as e: +            self.assertEqual(e.messages, ["Enter a number."]) +        else: +            self.fail("ValidationError was not properly raised") + +    def test_from_native_integer(self): +        """ +        Make sure from_native() accepts integer values +        """ +        f = serializers.DecimalField() +        result = f.from_native(9000) + +        self.assertEqual(Decimal('9000'), result) + +    def test_from_native_float(self): +        """ +        Make sure from_native() accepts float values +        """ +        f = serializers.DecimalField() +        result = f.from_native(1.00000001) + +        self.assertEqual(Decimal('1.00000001'), result) + +    def test_from_native_empty(self): +        """ +        Make sure from_native() returns None on empty param. +        """ +        f = serializers.DecimalField() +        result = f.from_native('') + +        self.assertEqual(result, None) + +    def test_from_native_none(self): +        """ +        Make sure from_native() returns None on None param. +        """ +        f = serializers.DecimalField() +        result = f.from_native(None) + +        self.assertEqual(result, None) + +    def test_to_native(self): +        """ +        Make sure to_native() returns Decimal as string. +        """ +        f = serializers.DecimalField() + +        result_1 = f.to_native(Decimal('9000')) +        result_2 = f.to_native(Decimal('1.00000001')) + +        self.assertEqual(Decimal('9000'), result_1) +        self.assertEqual(Decimal('1.00000001'), result_2) + +    def test_to_native_none(self): +        """ +        Make sure from_native() returns None on None param. +        """ +        f = serializers.DecimalField(required=False) +        self.assertEqual(None, f.to_native(None)) + +    def test_valid_serialization(self): +        """ +        Make sure the serializer works correctly +        """ +        class DecimalSerializer(Serializer): +            decimal_field = serializers.DecimalField(max_value=9010, +                                                     min_value=9000, +                                                     max_digits=6, +                                                     decimal_places=2) + +        self.assertTrue(DecimalSerializer(data={'decimal_field': '9001'}).is_valid()) +        self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.2'}).is_valid()) +        self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.23'}).is_valid()) + +        self.assertFalse(DecimalSerializer(data={'decimal_field': '8000'}).is_valid()) +        self.assertFalse(DecimalSerializer(data={'decimal_field': '9900'}).is_valid()) +        self.assertFalse(DecimalSerializer(data={'decimal_field': '9001.234'}).is_valid()) + +    def test_raise_max_value(self): +        """ +        Make sure max_value violations raises ValidationError +        """ +        class DecimalSerializer(Serializer): +            decimal_field = serializers.DecimalField(max_value=100) + +        s = DecimalSerializer(data={'decimal_field': '123'}) + +        self.assertFalse(s.is_valid()) +        self.assertEqual(s.errors,  {'decimal_field': ['Ensure this value is less than or equal to 100.']}) + +    def test_raise_min_value(self): +        """ +        Make sure min_value violations raises ValidationError +        """ +        class DecimalSerializer(Serializer): +            decimal_field = serializers.DecimalField(min_value=100) + +        s = DecimalSerializer(data={'decimal_field': '99'}) + +        self.assertFalse(s.is_valid()) +        self.assertEqual(s.errors,  {'decimal_field': ['Ensure this value is greater than or equal to 100.']}) + +    def test_raise_max_digits(self): +        """ +        Make sure max_digits violations raises ValidationError +        """ +        class DecimalSerializer(Serializer): +            decimal_field = serializers.DecimalField(max_digits=5) + +        s = DecimalSerializer(data={'decimal_field': '123.456'}) + +        self.assertFalse(s.is_valid()) +        self.assertEqual(s.errors,  {'decimal_field': ['Ensure that there are no more than 5 digits in total.']}) + +    def test_raise_max_decimal_places(self): +        """ +        Make sure max_decimal_places violations raises ValidationError +        """ +        class DecimalSerializer(Serializer): +            decimal_field = serializers.DecimalField(decimal_places=3) + +        s = DecimalSerializer(data={'decimal_field': '123.4567'}) + +        self.assertFalse(s.is_valid()) +        self.assertEqual(s.errors,  {'decimal_field': ['Ensure that there are no more than 3 decimal places.']}) + +    def test_raise_max_whole_digits(self): +        """ +        Make sure max_whole_digits violations raises ValidationError +        """ +        class DecimalSerializer(Serializer): +            decimal_field = serializers.DecimalField(max_digits=4, decimal_places=3) + +        s = DecimalSerializer(data={'decimal_field': '12345.6'}) + +        self.assertFalse(s.is_valid()) +        self.assertEqual(s.errors,  {'decimal_field': ['Ensure that there are no more than 4 digits in total.']})
\ No newline at end of file diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py index 1a71558c..1e53a5cd 100644 --- a/rest_framework/tests/filterset.py +++ b/rest_framework/tests/filterset.py @@ -61,7 +61,7 @@ if django_filters:  class CommonFilteringTestCase(TestCase):      def _serialize_object(self, obj):          return {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} -     +      def setUp(self):          """          Create 10 FilterableItem instances. @@ -190,7 +190,7 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase):      Integration tests for filtered detail views.      """      urls = 'rest_framework.tests.filterset' -     +      def _get_url(self, item):          return reverse('detail-view', kwargs=dict(pk=item.pk)) @@ -221,7 +221,7 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase):          response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal))          self.assertEqual(response.status_code, status.HTTP_200_OK)          self.assertEqual(response.data, low_item_data) -         +          # Tests that multiple filters works.          search_decimal = Decimal('5.25')          search_date = datetime.date(2012, 10, 2) diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/relations_hyperlink.py index b5702a48..b1eed9a7 100644 --- a/rest_framework/tests/relations_hyperlink.py +++ b/rest_framework/tests/relations_hyperlink.py @@ -26,42 +26,44 @@ urlpatterns = patterns('',  ) +# ManyToMany  class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer): -    sources = serializers.HyperlinkedRelatedField(many=True, view_name='manytomanysource-detail') -      class Meta:          model = ManyToManyTarget +        fields = ('url', 'name', 'sources')  class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer):      class Meta:          model = ManyToManySource +        fields = ('url', 'name', 'targets') +# ForeignKey  class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer): -    sources = serializers.HyperlinkedRelatedField(many=True, view_name='foreignkeysource-detail') -      class Meta:          model = ForeignKeyTarget +        fields = ('url', 'name', 'sources')  class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):      class Meta:          model = ForeignKeySource +        fields = ('url', 'name', 'target')  # Nullable ForeignKey  class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):      class Meta:          model = NullableForeignKeySource +        fields = ('url', 'name', 'target') -# OneToOne +# Nullable OneToOne  class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer): -    nullable_source = serializers.HyperlinkedRelatedField(view_name='nullableonetoonesource-detail') -      class Meta:          model = OneToOneTarget +        fields = ('url', 'name', 'nullable_source')  # TODO: Add test that .data cannot be accessed prior to .is_valid diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py index a125ba65..f6d006b3 100644 --- a/rest_framework/tests/relations_nested.py +++ b/rest_framework/tests/relations_nested.py @@ -6,38 +6,30 @@ from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, Null  class ForeignKeySourceSerializer(serializers.ModelSerializer):      class Meta: -        depth = 1 -        model = ForeignKeySource - - -class FlatForeignKeySourceSerializer(serializers.ModelSerializer): -    class Meta:          model = ForeignKeySource +        fields = ('id', 'name', 'target') +        depth = 1  class ForeignKeyTargetSerializer(serializers.ModelSerializer): -    sources = FlatForeignKeySourceSerializer(many=True) -      class Meta:          model = ForeignKeyTarget +        fields = ('id', 'name', 'sources') +        depth = 1  class NullableForeignKeySourceSerializer(serializers.ModelSerializer):      class Meta: -        depth = 1          model = NullableForeignKeySource - - -class NullableOneToOneSourceSerializer(serializers.ModelSerializer): -    class Meta: -        model = NullableOneToOneSource +        fields = ('id', 'name', 'target') +        depth = 1  class NullableOneToOneTargetSerializer(serializers.ModelSerializer): -    nullable_source = NullableOneToOneSourceSerializer() -      class Meta:          model = OneToOneTarget +        fields = ('id', 'name', 'nullable_source') +        depth = 1  class ReverseForeignKeyTests(TestCase): diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/relations_pk.py index f08e1808..5ce8b567 100644 --- a/rest_framework/tests/relations_pk.py +++ b/rest_framework/tests/relations_pk.py @@ -5,41 +5,44 @@ from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, Fore  from rest_framework.compat import six +# ManyToMany  class ManyToManyTargetSerializer(serializers.ModelSerializer): -    sources = serializers.PrimaryKeyRelatedField(many=True) -      class Meta:          model = ManyToManyTarget +        fields = ('id', 'name', 'sources')  class ManyToManySourceSerializer(serializers.ModelSerializer):      class Meta:          model = ManyToManySource +        fields = ('id', 'name', 'targets') +# ForeignKey  class ForeignKeyTargetSerializer(serializers.ModelSerializer): -    sources = serializers.PrimaryKeyRelatedField(many=True) -      class Meta:          model = ForeignKeyTarget +        fields = ('id', 'name', 'sources')  class ForeignKeySourceSerializer(serializers.ModelSerializer):      class Meta:          model = ForeignKeySource +        fields = ('id', 'name', 'target') +# Nullable ForeignKey  class NullableForeignKeySourceSerializer(serializers.ModelSerializer):      class Meta:          model = NullableForeignKeySource +        fields = ('id', 'name', 'target') -# OneToOne +# Nullable OneToOne  class NullableOneToOneTargetSerializer(serializers.ModelSerializer): -    nullable_source = serializers.PrimaryKeyRelatedField() -      class Meta:          model = OneToOneTarget +        fields = ('id', 'name', 'nullable_source')  # TODO: Add test that .data cannot be accessed prior to .is_valid diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py index bd874253..84e1ee4e 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/serializer.py @@ -357,7 +357,6 @@ class CustomValidationTests(TestCase):          def validate_email(self, attrs, source):              value = attrs[source] -              return attrs          def validate_content(self, attrs, source): @@ -738,6 +737,43 @@ class ManyRelatedTests(TestCase):          self.assertEqual(serializer.data, expected) +    def test_include_reverse_relations(self): +        post = BlogPost.objects.create(title="Test blog post") +        post.blogpostcomment_set.create(text="I hate this blog post") +        post.blogpostcomment_set.create(text="I love this blog post") + +        class BlogPostSerializer(serializers.ModelSerializer): +            class Meta: +                model = BlogPost +                fields = ('id', 'title', 'blogpostcomment_set') + +        serializer = BlogPostSerializer(instance=post) +        expected = { +            'id': 1, 'title': 'Test blog post', 'blogpostcomment_set': [1, 2] +        } +        self.assertEqual(serializer.data, expected) + +    def test_depth_include_reverse_relations(self): +        post = BlogPost.objects.create(title="Test blog post") +        post.blogpostcomment_set.create(text="I hate this blog post") +        post.blogpostcomment_set.create(text="I love this blog post") + +        class BlogPostSerializer(serializers.ModelSerializer): +            class Meta: +                model = BlogPost +                fields = ('id', 'title', 'blogpostcomment_set') +                depth = 1 + +        serializer = BlogPostSerializer(instance=post) +        expected = { +            'id': 1, 'title': 'Test blog post', +            'blogpostcomment_set': [ +                {'id': 1, 'text': 'I hate this blog post', 'blog_post': 1}, +                {'id': 2, 'text': 'I love this blog post', 'blog_post': 1} +            ] +        } +        self.assertEqual(serializer.data, expected) +      def test_callable_source(self):          post = BlogPost.objects.create(title="Test blog post")          post.blogpostcomment_set.create(text="I love this blog post") @@ -1073,7 +1109,7 @@ class DeserializeListTestCase(TestCase):      def test_no_errors(self):          data = [self.data.copy() for x in range(0, 3)] -        serializer = CommentSerializer(data=data) +        serializer = CommentSerializer(data=data, many=True)          self.assertTrue(serializer.is_valid())          self.assertTrue(isinstance(serializer.object, list))          self.assertTrue( @@ -1085,7 +1121,7 @@ class DeserializeListTestCase(TestCase):          invalid_item['email'] = ''          data = [self.data.copy(), invalid_item, self.data.copy()] -        serializer = CommentSerializer(data=data) +        serializer = CommentSerializer(data=data, many=True)          self.assertFalse(serializer.is_valid())          expected = [{}, {'email': ['This field is required.']}, {}]          self.assertEqual(serializer.errors, expected) diff --git a/rest_framework/tests/serializer_nested.py b/rest_framework/tests/serializer_nested.py index 6a29c652..71d0e24b 100644 --- a/rest_framework/tests/serializer_nested.py +++ b/rest_framework/tests/serializer_nested.py @@ -109,7 +109,7 @@ class WritableNestedSerializerBasicTests(TestCase):              }          ] -        serializer = self.AlbumSerializer(data=data) +        serializer = self.AlbumSerializer(data=data, many=True)          self.assertEqual(serializer.is_valid(), False)          self.assertEqual(serializer.errors, expected_errors) @@ -241,6 +241,6 @@ class WritableNestedSerializerObjectTests(TestCase):              )          ] -        serializer = self.AlbumSerializer(data=data) +        serializer = self.AlbumSerializer(data=data, many=True)          self.assertEqual(serializer.is_valid(), True)          self.assertEqual(serializer.object, expected_object) diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py index 810cad63..93ea9816 100644 --- a/rest_framework/throttling.py +++ b/rest_framework/throttling.py @@ -1,3 +1,6 @@ +""" +Provides various throttling policies. +"""  from __future__ import unicode_literals  from django.core.cache import cache  from rest_framework import exceptions @@ -28,9 +31,8 @@ class SimpleRateThrottle(BaseThrottle):      A simple cache implementation, that only requires `.get_cache_key()`      to be overridden. -    The rate (requests / seconds) is set by a :attr:`throttle` attribute -    on the :class:`.View` class.  The attribute is a string of the form 'number of -    requests/period'. +    The rate (requests / seconds) is set by a `throttle` attribute on the View +    class.  The attribute is a string of the form 'number_of_requests/period'.      Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day') diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py index af21ac79..28801d09 100644 --- a/rest_framework/utils/breadcrumbs.py +++ b/rest_framework/utils/breadcrumbs.py @@ -1,26 +1,36 @@  from __future__ import unicode_literals  from django.core.urlresolvers import resolve, get_script_prefix +from rest_framework.utils.formatting import get_view_name  def get_breadcrumbs(url): -    """Given a url returns a list of breadcrumbs, which are each a tuple of (name, url).""" +    """ +    Given a url returns a list of breadcrumbs, which are each a +    tuple of (name, url). +    """      from rest_framework.views import APIView      def breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen): -        """Add tuples of (name, url) to the breadcrumbs list, progressively chomping off parts of the url.""" +        """ +        Add tuples of (name, url) to the breadcrumbs list, +        progressively chomping off parts of the url. +        """          try:              (view, unused_args, unused_kwargs) = resolve(url)          except Exception:              pass          else: -            # Check if this is a REST framework view, and if so add it to the breadcrumbs -            if isinstance(getattr(view, 'cls_instance', None), APIView): +            # Check if this is a REST framework view, +            # and if so add it to the breadcrumbs +            if issubclass(getattr(view, 'cls', None), APIView):                  # Don't list the same view twice in a row.                  # Probably an optional trailing slash.                  if not seen or seen[-1] != view: -                    breadcrumbs_list.insert(0, (view.cls_instance.get_name(), prefix + url)) +                    suffix = getattr(view, 'suffix', None) +                    name = get_view_name(view.cls, suffix) +                    breadcrumbs_list.insert(0, (name, prefix + url))                      seen.append(view)          if url == '': @@ -28,11 +38,15 @@ def get_breadcrumbs(url):              return breadcrumbs_list          elif url.endswith('/'): -            # Drop trailing slash off the end and continue to try to resolve more breadcrumbs -            return breadcrumbs_recursive(url.rstrip('/'), breadcrumbs_list, prefix, seen) - -        # Drop trailing non-slash off the end and continue to try to resolve more breadcrumbs -        return breadcrumbs_recursive(url[:url.rfind('/') + 1], breadcrumbs_list, prefix, seen) +            # Drop trailing slash off the end and continue to try to +            # resolve more breadcrumbs +            url = url.rstrip('/') +            return breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen) + +        # Drop trailing non-slash off the end and continue to try to +        # resolve more breadcrumbs +        url = url[:url.rfind('/') + 1] +        return breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen)      prefix = get_script_prefix().rstrip('/')      url = url[len(prefix):] diff --git a/rest_framework/utils/formatting.py b/rest_framework/utils/formatting.py new file mode 100644 index 00000000..ebadb3a6 --- /dev/null +++ b/rest_framework/utils/formatting.py @@ -0,0 +1,80 @@ +""" +Utility functions to return a formatted name and description for a given view. +""" +from __future__ import unicode_literals + +from django.utils.html import escape +from django.utils.safestring import mark_safe +from rest_framework.compat import apply_markdown +import re + + +def _remove_trailing_string(content, trailing): +    """ +    Strip trailing component `trailing` from `content` if it exists. +    Used when generating names from view classes. +    """ +    if content.endswith(trailing) and content != trailing: +        return content[:-len(trailing)] +    return content + + +def _remove_leading_indent(content): +    """ +    Remove leading indent from a block of text. +    Used when generating descriptions from docstrings. +    """ +    whitespace_counts = [len(line) - len(line.lstrip(' ')) +                         for line in content.splitlines()[1:] if line.lstrip()] + +    # unindent the content if needed +    if whitespace_counts: +        whitespace_pattern = '^' + (' ' * min(whitespace_counts)) +        content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content) +    content = content.strip('\n') +    return content + + +def _camelcase_to_spaces(content): +    """ +    Translate 'CamelCaseNames' to 'Camel Case Names'. +    Used when generating names from view classes. +    """ +    camelcase_boundry = '(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))' +    content = re.sub(camelcase_boundry, ' \\1', content).strip() +    return ' '.join(content.split('_')).title() + + +def get_view_name(cls, suffix=None): +    """ +    Return a formatted name for an `APIView` class or `@api_view` function. +    """ +    name = cls.__name__ +    name = _remove_trailing_string(name, 'View') +    name = _remove_trailing_string(name, 'ViewSet') +    name = _camelcase_to_spaces(name) +    if suffix: +        name += ' ' + suffix +    return name + + +def get_view_description(cls, html=False): +    """ +    Return a description for an `APIView` class or `@api_view` function. +    """ +    description = cls.__doc__ or '' +    description = _remove_leading_indent(description) +    if html: +        return markup_description(description) +    return description + + +def markup_description(description): +    """ +    Apply HTML markup to the given description. +    """ +    if apply_markdown: +        description = apply_markdown(description) +    else: +        description = escape(description).replace('\n', '<br />') +    return mark_safe(description) diff --git a/rest_framework/views.py b/rest_framework/views.py index 7c97607b..555fa2f4 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -1,54 +1,16 @@  """ -Provides an APIView class that is used as the base of all class-based views. +Provides an APIView class that is the base of all views in REST framework.  """  from __future__ import unicode_literals  from django.core.exceptions import PermissionDenied  from django.http import Http404, HttpResponse -from django.utils.html import escape -from django.utils.safestring import mark_safe  from django.views.decorators.csrf import csrf_exempt  from rest_framework import status, exceptions -from rest_framework.compat import View, apply_markdown +from rest_framework.compat import View  from rest_framework.response import Response  from rest_framework.request import Request  from rest_framework.settings import api_settings -import re - - -def _remove_trailing_string(content, trailing): -    """ -    Strip trailing component `trailing` from `content` if it exists. -    Used when generating names from view classes. -    """ -    if content.endswith(trailing) and content != trailing: -        return content[:-len(trailing)] -    return content - - -def _remove_leading_indent(content): -    """ -    Remove leading indent from a block of text. -    Used when generating descriptions from docstrings. -    """ -    whitespace_counts = [len(line) - len(line.lstrip(' ')) -                         for line in content.splitlines()[1:] if line.lstrip()] - -    # unindent the content if needed -    if whitespace_counts: -        whitespace_pattern = '^' + (' ' * min(whitespace_counts)) -        content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content) -    content = content.strip('\n') -    return content - - -def _camelcase_to_spaces(content): -    """ -    Translate 'CamelCaseNames' to 'Camel Case Names'. -    Used when generating names from view classes. -    """ -    camelcase_boundry = '(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))' -    content = re.sub(camelcase_boundry, ' \\1', content).strip() -    return ' '.join(content.split('_')).title() +from rest_framework.utils.formatting import get_view_name, get_view_description  class APIView(View): @@ -64,22 +26,21 @@ class APIView(View):      @classmethod      def as_view(cls, **initkwargs):          """ -        Override the default :meth:`as_view` to store an instance of the view -        as an attribute on the callable function.  This allows us to discover -        information about the view when we do URL reverse lookups. +        Store the original class on the view function. + +        This allows us to discover information about the view when we do URL +        reverse lookups.  Used for breadcrumb generation.          """ -        # TODO: deprecate?          view = super(APIView, cls).as_view(**initkwargs) -        view.cls_instance = cls(**initkwargs) +        view.cls = cls          return view      @property      def allowed_methods(self):          """ -        Return the list of allowed HTTP methods, uppercased. +        Wrap Django's private `_allowed_methods` interface in a public property.          """ -        return [method.upper() for method in self.http_method_names -                if hasattr(self, method)] +        return self._allowed_methods()      @property      def default_response_headers(self): @@ -90,43 +51,10 @@ class APIView(View):              'Vary': 'Accept'          } -    def get_name(self): -        """ -        Return the resource or view class name for use as this view's name. -        Override to customize. -        """ -        # TODO: deprecate? -        name = self.__class__.__name__ -        name = _remove_trailing_string(name, 'View') -        return _camelcase_to_spaces(name) - -    def get_description(self, html=False): -        """ -        Return the resource or view docstring for use as this view's description. -        Override to customize. -        """ -        # TODO: deprecate? -        description = self.__doc__ or '' -        description = _remove_leading_indent(description) -        if html: -            return self.markup_description(description) -        return description - -    def markup_description(self, description): -        """ -        Apply HTML markup to the description of this view. -        """ -        # TODO: deprecate? -        if apply_markdown: -            description = apply_markdown(description) -        else: -            description = escape(description).replace('\n', '<br />') -        return mark_safe(description) -      def metadata(self, request):          return { -            'name': self.get_name(), -            'description': self.get_description(), +            'name': get_view_name(self.__class__), +            'description': get_view_description(self.__class__),              'renders': [renderer.media_type for renderer in self.renderer_classes],              'parses': [parser.media_type for parser in self.parser_classes],          } @@ -140,7 +68,8 @@ class APIView(View):      def http_method_not_allowed(self, request, *args, **kwargs):          """ -        Called if `request.method` does not correspond to a handler method. +        If `request.method` does not correspond to a handler method, +        determine what kind of exception to raise.          """          raise exceptions.MethodNotAllowed(request.method) diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py new file mode 100644 index 00000000..0eb3e86d --- /dev/null +++ b/rest_framework/viewsets.py @@ -0,0 +1,132 @@ +""" +ViewSets are essentially just a type of class based view, that doesn't provide +any method handlers, such as `get()`, `post()`, etc... but instead has actions, +such as `list()`, `retrieve()`, `create()`, etc... + +Actions are only bound to methods at the point of instantiating the views. + +    user_list = UserViewSet.as_view({'get': 'list'}) +    user_detail = UserViewSet.as_view({'get': 'retrieve'}) + +Typically, rather than instantiate views from viewsets directly, you'll +regsiter the viewset with a router and let the URL conf be determined +automatically. + +    router = DefaultRouter() +    router.register(r'users', UserViewSet, 'user') +    urlpatterns = router.urls +""" +from __future__ import unicode_literals + +from functools import update_wrapper +from django.utils.decorators import classonlymethod +from rest_framework import views, generics, mixins + + +class ViewSetMixin(object): +    """ +    This is the magic. + +    Overrides `.as_view()` so that it takes an `actions` keyword that performs +    the binding of HTTP methods to actions on the Resource. + +    For example, to create a concrete view binding the 'GET' and 'POST' methods +    to the 'list' and 'create' actions... + +    view = MyViewSet.as_view({'get': 'list', 'post': 'create'}) +    """ + +    @classonlymethod +    def as_view(cls, actions=None, **initkwargs): +        """ +        Because of the way class based views create a closure around the +        instantiated view, we need to totally reimplement `.as_view`, +        and slightly modify the view function that is created and returned. +        """ +        # The suffix initkwarg is reserved for identifing the viewset type +        # eg. 'List' or 'Instance'. +        cls.suffix = None + +        # sanitize keyword arguments +        for key in initkwargs: +            if key in cls.http_method_names: +                raise TypeError("You tried to pass in the %s method name as a " +                                "keyword argument to %s(). Don't do that." +                                % (key, cls.__name__)) +            if not hasattr(cls, key): +                raise TypeError("%s() received an invalid keyword %r" % ( +                    cls.__name__, key)) + +        def view(request, *args, **kwargs): +            self = cls(**initkwargs) +            # We also store the mapping of request methods to actions, +            # so that we can later set the action attribute. +            # eg. `self.action = 'list'` on an incoming GET request. +            self.action_map = actions + +            # Bind methods to actions +            # This is the bit that's different to a standard view +            for method, action in actions.items(): +                handler = getattr(self, action) +                setattr(self, method, handler) + +            # Patch this in as it's otherwise only present from 1.5 onwards +            if hasattr(self, 'get') and not hasattr(self, 'head'): +                self.head = self.get + +            # And continue as usual +            return self.dispatch(request, *args, **kwargs) + +        # take name and docstring from class +        update_wrapper(view, cls, updated=()) + +        # and possible attributes set by decorators +        # like csrf_exempt from dispatch +        update_wrapper(view, cls.dispatch, assigned=()) + +        # We need to set these on the view function, so that breadcrumb +        # generation can pick out these bits of information from a +        # resolved URL. +        view.cls = cls +        view.suffix = initkwargs.get('suffix', None) +        return view + +    def initialize_request(self, request, *args, **kargs): +        """ +        Set the `.action` attribute on the view, +        depending on the request method. +        """ +        request = super(ViewSetMixin, self).initialize_request(request, *args, **kargs) +        self.action = self.action_map.get(request.method.lower()) +        return request + + +class ViewSet(ViewSetMixin, views.APIView): +    """ +    The base ViewSet class does not provide any actions by default. +    """ +    pass + + +class ReadOnlyModelViewSet(mixins.RetrieveModelMixin, +                           mixins.ListModelMixin, +                           ViewSetMixin, +                           generics.GenericAPIView): +    """ +    A viewset that provides default `list()` and `retrieve()` actions. +    """ +    pass + + +class ModelViewSet(mixins.CreateModelMixin, +                    mixins.RetrieveModelMixin, +                    mixins.UpdateModelMixin, +                    mixins.DestroyModelMixin, +                    mixins.ListModelMixin, +                    ViewSetMixin, +                    generics.GenericAPIView): +    """ +    A viewset that provides default `create()`, `retrieve()`, `update()`, +    `partial_update()`, `destroy()` and `list()` actions. +    """ +    pass  | 
