diff options
Diffstat (limited to 'rest_framework')
32 files changed, 899 insertions, 162 deletions
| diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index 2bd2991b..f5483b9d 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -1,6 +1,20 @@ -__version__ = '2.3.8' +""" +______ _____ _____ _____    __                                             _     +| ___ \  ___/  ___|_   _|  / _|                                           | |    +| |_/ / |__ \ `--.  | |   | |_ _ __ __ _ _ __ ___   _____      _____  _ __| | __ +|    /|  __| `--. \ | |   |  _| '__/ _` | '_ ` _ \ / _ \ \ /\ / / _ \| '__| |/ / +| |\ \| |___/\__/ / | |   | | | | | (_| | | | | | |  __/\ V  V / (_) | |  |   <  +\_| \_\____/\____/  \_/   |_| |_|  \__,_|_| |_| |_|\___| \_/\_/ \___/|_|  |_|\_| +""" -VERSION = __version__  # synonym +__title__ = 'Django REST framework' +__version__ = '2.3.10' +__author__ = 'Tom Christie' +__license__ = 'BSD 2-Clause' +__copyright__ = 'Copyright 2011-2013 Tom Christie' + +# Version synonym +VERSION = __version__  # Header encoding (see RFC5987)  HTTP_HEADER_ENCODING = 'iso-8859-1' diff --git a/rest_framework/compat.py b/rest_framework/compat.py index efd2581f..b4d37ab8 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -65,6 +65,13 @@ try:  except ImportError:      import urlparse +# UserDict moves in Python 3 +try: +    from UserDict import UserDict +    from UserDict import DictMixin +except ImportError: +    from collections import UserDict +    from collections import MutableMapping as DictMixin  # Try to import PIL in either of the two ways it can end up installed.  try: @@ -76,6 +83,22 @@ except ImportError:          Image = None +def get_model_name(model_cls): +    try: +        return model_cls._meta.model_name +    except AttributeError: +        # < 1.6 used module_name instead of model_name +        return model_cls._meta.module_name + + +def get_concrete_model(model_cls): +    try: +        return model_cls._meta.concrete_model +    except AttributeError: +        # 1.3 does not include concrete model +        return model_cls + +  # Django 1.5 add support for custom auth user model  if django.VERSION >= (1, 5):      AUTH_USER_MODEL = settings.AUTH_USER_MODEL diff --git a/rest_framework/fields.py b/rest_framework/fields.py index f340510d..65edd0d6 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -125,6 +125,7 @@ class Field(object):      use_files = False      form_field_class = forms.CharField      type_label = 'field' +    widget = None      def __init__(self, source=None, label=None, help_text=None):          self.parent = None @@ -136,9 +137,29 @@ class Field(object):          if label is not None:              self.label = smart_text(label) +        else: +            self.label = None          if help_text is not None:              self.help_text = strip_multiple_choice_msg(smart_text(help_text)) +        else: +            self.help_text = None + +        self._errors = [] +        self._value = None +        self._name = None + +    @property +    def errors(self): +        return self._errors + +    def widget_html(self): +        if not self.widget: +            return '' +        return self.widget.render(self._name, self._value) + +    def label_tag(self): +        return '<label for="%s">%s:</label>' % (self._name, self.label)      def initialize(self, parent, field_name):          """ @@ -301,6 +322,7 @@ class WritableField(Field):              return          try: +            data = data or {}              if self.use_files:                  files = files or {}                  try: @@ -470,6 +492,7 @@ class ChoiceField(WritableField):      }      def __init__(self, choices=(), *args, **kwargs): +        self.empty = kwargs.pop('empty', '')          super(ChoiceField, self).__init__(*args, **kwargs)          self.choices = choices          if not self.required: @@ -486,6 +509,11 @@ class ChoiceField(WritableField):      choices = property(_get_choices, _set_choices) +    def metadata(self): +        data = super(ChoiceField, self).metadata() +        data['choices'] = [{'value': v, 'display_name': n} for v, n in self.choices] +        return data +      def validate(self, value):          """          Validates that the input is in self.choices. @@ -510,9 +538,10 @@ class ChoiceField(WritableField):          return False      def from_native(self, value): -        if value in validators.EMPTY_VALUES: -            return None -        return super(ChoiceField, self).from_native(value) +        value = super(ChoiceField, self).from_native(value) +        if value == self.empty or value in validators.EMPTY_VALUES: +            return self.empty +        return value  class EmailField(CharField): @@ -751,6 +780,7 @@ class IntegerField(WritableField):      type_name = 'IntegerField'      type_label = 'integer'      form_field_class = forms.IntegerField +    empty = 0      default_error_messages = {          'invalid': _('Enter a whole number.'), @@ -782,6 +812,7 @@ class FloatField(WritableField):      type_name = 'FloatField'      type_label = 'float'      form_field_class = forms.FloatField +    empty = 0      default_error_messages = {          'invalid': _("'%s' value must be a float."), @@ -802,6 +833,7 @@ class DecimalField(WritableField):      type_name = 'DecimalField'      type_label = 'decimal'      form_field_class = forms.DecimalField +    empty = Decimal('0')      default_error_messages = {          'invalid': _('Enter a number.'), @@ -934,7 +966,7 @@ class ImageField(FileField):              return None          from rest_framework.compat import Image -        assert Image is not None, 'PIL must be installed for ImageField support' +        assert Image is not None, 'Either Pillow or PIL must be installed for ImageField support.'          # We need to get a file object for PIL. We might have a path or we might          # have to read the data into memory. diff --git a/rest_framework/filters.py b/rest_framework/filters.py index b8fe7f77..5c6a187c 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -4,7 +4,7 @@ returned by list views.  """  from __future__ import unicode_literals  from django.db import models -from rest_framework.compat import django_filters, six, guardian +from rest_framework.compat import django_filters, six, guardian, get_model_name  from functools import reduce  import operator @@ -124,6 +124,7 @@ class OrderingFilter(BaseFilterBackend):      def remove_invalid_fields(self, queryset, ordering):          field_names = [field.name for field in queryset.model._meta.fields] +        field_names += queryset.query.aggregates.keys()          return [term for term in ordering if term.lstrip('-') in field_names]      def filter_queryset(self, request, queryset, view): @@ -158,7 +159,7 @@ class DjangoObjectPermissionsFilter(BaseFilterBackend):          model_cls = queryset.model          kwargs = {              'app_label': model_cls._meta.app_label, -            'model_name': model_cls._meta.module_name +            'model_name': get_model_name(model_cls)          }          permission = self.perm_format % kwargs          return guardian.shortcuts.get_objects_for_user(user, permission, queryset) diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 5fb37db7..bd33c01a 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -25,13 +25,13 @@ def strict_positive_int(integer_string, cutoff=None):          ret = min(ret, cutoff)      return ret -def get_object_or_404(queryset, **filter_kwargs): +def get_object_or_404(queryset, *filter_args, **filter_kwargs):      """      Same as Django's standard shortcut, but make sure to raise 404      if the filter_kwargs don't match the required types.      """      try: -        return _get_object_or_404(queryset, **filter_kwargs) +        return _get_object_or_404(queryset, *filter_args, **filter_kwargs)      except (TypeError, ValueError):          raise Http404 @@ -54,6 +54,7 @@ class GenericAPIView(views.APIView):      # If you want to use object lookups other than pk, set this attribute.      # For more complex lookup requirements override `get_object()`.      lookup_field = 'pk' +    lookup_url_kwarg = None      # Pagination settings      paginate_by = api_settings.PAGINATE_BY @@ -147,8 +148,8 @@ class GenericAPIView(views.APIView):          page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg)          page = page_kwarg or page_query_param or 1          try: -            page_number = strict_positive_int(page) -        except ValueError: +            page_number = paginator.validate_number(page) +        except InvalidPage:              if page == 'last':                  page_number = paginator.num_pages              else: @@ -174,6 +175,14 @@ class GenericAPIView(views.APIView):          method if you want to apply the configured filtering backend to the          default queryset.          """ +        for backend in self.get_filter_backends(): +            queryset = backend().filter_queryset(self.request, queryset, self) +        return queryset + +    def get_filter_backends(self): +        """ +        Returns the list of filter backends that this view requires. +        """          filter_backends = self.filter_backends or []          if not filter_backends and self.filter_backend:              warnings.warn( @@ -184,10 +193,8 @@ class GenericAPIView(views.APIView):                  DeprecationWarning, stacklevel=2              )              filter_backends = [self.filter_backend] +        return filter_backends -        for backend in filter_backends: -            queryset = backend().filter_queryset(self.request, queryset, self) -        return queryset      ########################      ### The following methods provide default implementations @@ -278,9 +285,11 @@ class GenericAPIView(views.APIView):              pass  # Deprecation warning          # Perform the lookup filtering. +        # Note that `pk` and `slug` are deprecated styles of lookup filtering. +        lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field +        lookup = self.kwargs.get(lookup_url_kwarg, None)          pk = self.kwargs.get(self.pk_url_kwarg, None)          slug = self.kwargs.get(self.slug_url_kwarg, None) -        lookup = self.kwargs.get(self.lookup_field, None)          if lookup is not None:              filter_kwargs = {self.lookup_field: lookup} @@ -335,6 +344,18 @@ class GenericAPIView(views.APIView):          """          pass +    def pre_delete(self, obj): +        """ +        Placeholder method for calling before deleting an object. +        """ +        pass + +    def post_delete(self, obj): +        """ +        Placeholder method for calling after saving an object. +        """ +        pass +      def metadata(self, request):          """          Return a dictionary of metadata about the view. diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 2c85d157..b62a4cc1 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -6,6 +6,7 @@ which allows mixin classes to be composed in interesting ways.  """  from __future__ import unicode_literals +from django.core.exceptions import ValidationError  from django.http import Http404  from rest_framework import status  from rest_framework.response import Response @@ -127,7 +128,12 @@ class UpdateModelMixin(object):                                           files=request.FILES, partial=partial)          if serializer.is_valid(): -            self.pre_save(serializer.object) +            try: +                self.pre_save(serializer.object) +            except ValidationError as err: +                # full_clean on model instance may be called in pre_save, so we +                # have to handle eventual errors. +                return Response(err.message_dict, status=status.HTTP_400_BAD_REQUEST)              self.object = serializer.save(**save_kwargs)              self.post_save(self.object, created=created)              return Response(serializer.data, status=success_status_code) @@ -158,7 +164,8 @@ class UpdateModelMixin(object):          Set any attributes on the object that are implicit in the request.          """          # pk and/or slug attributes are implicit in the URL. -        lookup = self.kwargs.get(self.lookup_field, None) +        lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field +        lookup = self.kwargs.get(lookup_url_kwarg, None)          pk = self.kwargs.get(self.pk_url_kwarg, None)          slug = self.kwargs.get(self.slug_url_kwarg, None)          slug_field = slug and self.slug_field or None @@ -185,5 +192,7 @@ class DestroyModelMixin(object):      """      def destroy(self, request, *args, **kwargs):          obj = self.get_object() +        self.pre_delete(obj)          obj.delete() +        self.post_delete(obj)          return Response(status=status.HTTP_204_NO_CONTENT) diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index 98fc0341..f1b3e38d 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -83,7 +83,7 @@ class YAMLParser(BaseParser):              data = stream.read().decode(encoding)              return yaml.safe_load(data)          except (ValueError, yaml.parser.ParserError) as exc: -            raise ParseError('YAML parse error - %s' % six.u(exc)) +            raise ParseError('YAML parse error - %s' % six.text_type(exc))  class FormParser(BaseParser): @@ -153,7 +153,7 @@ class XMLParser(BaseParser):          try:              tree = etree.parse(stream, parser=parser, forbid_dtd=True)          except (etree.ParseError, ValueError) as exc: -            raise ParseError('XML parse error - %s' % six.u(exc)) +            raise ParseError('XML parse error - %s' % six.text_type(exc))          data = self._xml_convert(tree.getroot())          return data diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py index 14bec42c..d93dba19 100644 --- a/rest_framework/permissions.py +++ b/rest_framework/permissions.py @@ -3,7 +3,8 @@ Provides a set of pluggable permission policies.  """  from __future__ import unicode_literals  from django.http import Http404 -from rest_framework.compat import oauth2_provider_scope, oauth2_constants +from rest_framework.compat import (get_model_name, oauth2_provider_scope, +                                   oauth2_constants)  SAFE_METHODS = ['GET', 'HEAD', 'OPTIONS'] @@ -106,7 +107,7 @@ class DjangoModelPermissions(BasePermission):          """          kwargs = {              'app_label': model_cls._meta.app_label, -            'model_name': model_cls._meta.module_name +            'model_name': get_model_name(model_cls)          }          return [perm % kwargs for perm in self.perms_map[method]] @@ -167,7 +168,7 @@ class DjangoObjectPermissions(DjangoModelPermissions):      def get_required_object_permissions(self, method, model_cls):          kwargs = {              'app_label': model_cls._meta.app_label, -            'model_name': model_cls._meta.module_name +            'model_name': get_model_name(model_cls)          }          return [perm % kwargs for perm in self.perms_map[method]] diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 2ce51e97..2fdd3337 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -20,6 +20,7 @@ from rest_framework.compat import StringIO  from rest_framework.compat import six  from rest_framework.compat import smart_text  from rest_framework.compat import yaml +from rest_framework.exceptions import ParseError  from rest_framework.settings import api_settings  from rest_framework.request import is_form_media_type, override_method  from rest_framework.utils import encoders @@ -272,7 +273,9 @@ class TemplateHTMLRenderer(BaseRenderer):              return [self.template_name]          elif hasattr(view, 'get_template_names'):              return view.get_template_names() -        raise ImproperlyConfigured('Returned a template response with no template_name') +        elif hasattr(view, 'template_name'): +            return [view.template_name] +        raise ImproperlyConfigured('Returned a template response with no `template_name` attribute set on either the view or response')      def get_exception_template(self, response):          template_names = [name % {'status_code': response.status_code} @@ -334,71 +337,15 @@ class HTMLFormRenderer(BaseRenderer):      template = 'rest_framework/form.html'      charset = 'utf-8' -    def data_to_form_fields(self, data): -        fields = {} -        for key, val in data.fields.items(): -            if getattr(val, 'read_only', True): -                # Don't include read-only fields. -                continue - -            if getattr(val, 'fields', None): -                # Nested data not supported by HTML forms. -                continue - -            kwargs = {} -            kwargs['required'] = val.required - -            #if getattr(v, 'queryset', None): -            #    kwargs['queryset'] = v.queryset - -            if getattr(val, 'choices', None) is not None: -                kwargs['choices'] = val.choices - -            if getattr(val, 'regex', None) is not None: -                kwargs['regex'] = val.regex - -            if getattr(val, 'widget', None): -                widget = copy.deepcopy(val.widget) -                kwargs['widget'] = widget - -            if getattr(val, 'default', None) is not None: -                kwargs['initial'] = val.default - -            if getattr(val, 'label', None) is not None: -                kwargs['label'] = val.label - -            if getattr(val, 'help_text', None) is not None: -                kwargs['help_text'] = val.help_text - -            fields[key] = val.form_field_class(**kwargs) - -        return fields -      def render(self, data, accepted_media_type=None, renderer_context=None):          """          Render serializer data and return an HTML form, as a string.          """ -        # The HTMLFormRenderer currently uses something of a hack to render -        # the content, by translating each of the serializer fields into -        # an html form field, creating a dynamic form using those fields, -        # and then rendering that form. - -        # This isn't strictly neccessary, as we could render the serilizer -        # fields to HTML directly.  The implementation is historical and will -        # likely change at some point. - -        self.renderer_context = renderer_context or {} +        renderer_context = renderer_context or {}          request = renderer_context['request'] -        # Creating an on the fly form see: -        # http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python -        fields = self.data_to_form_fields(data) -        DynamicForm = type(str('DynamicForm'), (forms.Form,), fields) -        data = None if data.empty else data -          template = loader.get_template(self.template) -        context = RequestContext(request, {'form': DynamicForm(data)}) - +        context = RequestContext(request, {'form': data})          return template.render(context) @@ -419,8 +366,13 @@ class BrowsableAPIRenderer(BaseRenderer):          """          renderers = [renderer for renderer in view.renderer_classes                       if not issubclass(renderer, BrowsableAPIRenderer)] +        non_template_renderers = [renderer for renderer in renderers +                                  if not hasattr(renderer, 'get_template_names')] +          if not renderers:              return None +        elif non_template_renderers: +            return non_template_renderers[0]()          return renderers[0]()      def get_content(self, renderer, data, @@ -468,6 +420,17 @@ class BrowsableAPIRenderer(BaseRenderer):          In the absence of the View having an associated form then return None.          """ +        if request.method == method: +            try: +                data = request.DATA +                files = request.FILES +            except ParseError: +                data = None +                files = None         +        else: +            data = None +            files = None +          with override_method(view, request, method) as request:              obj = getattr(view, 'object', None)              if not self.show_form_for_method(view, method, request, obj): @@ -480,9 +443,10 @@ class BrowsableAPIRenderer(BaseRenderer):                  or not any(is_form_media_type(parser.media_type) for parser in view.parser_classes)):                  return -            serializer = view.get_serializer(instance=obj) - +            serializer = view.get_serializer(instance=obj, data=data, files=files) +            serializer.is_valid()              data = serializer.data +              form_renderer = self.form_renderer_class()              return form_renderer.render(data, self.accepted_media_type, self.renderer_context) @@ -574,6 +538,7 @@ class BrowsableAPIRenderer(BaseRenderer):          renderer = self.get_default_renderer(view) +        raw_data_post_form = self.get_raw_data_form(view, 'POST', request)          raw_data_put_form = self.get_raw_data_form(view, 'PUT', request)          raw_data_patch_form = self.get_raw_data_form(view, 'PATCH', request)          raw_data_put_or_patch_form = raw_data_put_form or raw_data_patch_form @@ -592,12 +557,11 @@ class BrowsableAPIRenderer(BaseRenderer):              'put_form': self.get_rendered_html_form(view, 'PUT', request),              'post_form': self.get_rendered_html_form(view, 'POST', request), -            'patch_form': self.get_rendered_html_form(view, 'PATCH', request),              'delete_form': self.get_rendered_html_form(view, 'DELETE', request),              'options_form': self.get_rendered_html_form(view, 'OPTIONS', request),              'raw_data_put_form': raw_data_put_form, -            'raw_data_post_form': self.get_raw_data_form(view, 'POST', request), +            'raw_data_post_form': raw_data_post_form,              'raw_data_patch_form': raw_data_patch_form,              'raw_data_put_or_patch_form': raw_data_put_or_patch_form, diff --git a/rest_framework/request.py b/rest_framework/request.py index 977d4d96..fcea2508 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -334,7 +334,7 @@ class Request(object):              self._CONTENT_PARAM in self._data and              self._CONTENTTYPE_PARAM in self._data):              self._content_type = self._data[self._CONTENTTYPE_PARAM] -            self._stream = BytesIO(self._data[self._CONTENT_PARAM].encode(HTTP_HEADER_ENCODING)) +            self._stream = BytesIO(self._data[self._CONTENT_PARAM].encode(self.parser_context['encoding']))              self._data, self._files = (Empty, Empty)      def _parse(self): @@ -356,7 +356,16 @@ class Request(object):          if not parser:              raise exceptions.UnsupportedMediaType(media_type) -        parsed = parser.parse(stream, media_type, self.parser_context) +        try: +            parsed = parser.parse(stream, media_type, self.parser_context) +        except: +            # If we get an exception during parsing, fill in empty data and +            # re-raise.  Ensures we don't simply repeat the error when +            # attempting to render the browsable renderer response, or when +            # logging the request or similar. +            self._data = QueryDict('', self._request._encoding) +            self._files = MultiValueDict() +            raise          # Parser classes may return the raw data, or a          # DataAndFiles object.  Unpack the result as required. diff --git a/rest_framework/response.py b/rest_framework/response.py index 5877c8a3..1dc6abcf 100644 --- a/rest_framework/response.py +++ b/rest_framework/response.py @@ -61,6 +61,10 @@ class Response(SimpleTemplateResponse):              assert charset, 'renderer returned unicode, and did not specify ' \              'a charset value.'              return bytes(ret.encode(charset)) + +        if not ret: +            del self['Content-Type'] +          return ret      @property diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 9e3881a2..9c27717f 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -6,8 +6,8 @@ form encoded input.  Serialization in REST framework is a two-phase process:  1. Serializers marshal between complex types like model instances, and -python primatives. -2. The process of marshalling between python primatives and request and +python primitives. +2. The process of marshalling between python primitives and request and  response content is handled by parsers and renderers.  """  from __future__ import unicode_literals @@ -31,9 +31,17 @@ from rest_framework.relations import *  from rest_framework.fields import * +def pretty_name(name): +    """Converts 'first_name' to 'First name'""" +    if not name: +        return '' +    return name.replace('_', ' ').capitalize() + +  class RelationsList(list):      _deleted = [] +  class NestedValidationError(ValidationError):      """      The default ValidationError behavior is to stringify each item in the list @@ -48,9 +56,13 @@ class NestedValidationError(ValidationError):      def __init__(self, message):          if isinstance(message, dict): -            self.messages = [message] +            self._messages = [message]          else: -            self.messages = message +            self._messages = message + +    @property +    def messages(self): +        return self._messages  class DictWithMetadata(dict): @@ -254,10 +266,13 @@ class BaseSerializer(WritableField):          for field_name, field in self.fields.items():              if field_name in self._errors:                  continue + +            source = field.source or field_name +            if self.partial and source not in attrs: +                continue              try:                  validate_method = getattr(self, 'validate_%s' % field_name, None)                  if validate_method: -                    source = field.source or field_name                      attrs = validate_method(attrs, source)              except ValidationError as err:                  self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages) @@ -300,14 +315,19 @@ class BaseSerializer(WritableField):          """          ret = self._dict_class()          ret.fields = self._dict_class() -        ret.empty = obj is None          for field_name, field in self.fields.items(): +            if field.read_only and obj is None: +               continue              field.initialize(parent=self, field_name=field_name)              key = self.get_field_key(field_name)              value = field.field_to_native(obj, field_name) +            method = getattr(self, 'transform_%s' % field_name, None) +            if callable(method): +                value = method(obj, value)              ret[key] = value -            ret.fields[key] = field +            ret.fields[key] = self.augment_field(field, field_name, key, value) +          return ret      def from_native(self, data, files): @@ -315,6 +335,7 @@ class BaseSerializer(WritableField):          Deserialize primitives -> objects.          """          self._errors = {} +          if data is not None or files is not None:              attrs = self.restore_fields(data, files)              if attrs is not None: @@ -325,6 +346,15 @@ class BaseSerializer(WritableField):          if not self._errors:              return self.restore_object(attrs, instance=getattr(self, 'object', None)) +    def augment_field(self, field, field_name, key, value): +        # This horrible stuff is to manage serializers rendering to HTML +        field._errors = self._errors.get(key) if self._errors else None +        field._name = field_name +        field._value = self.init_data.get(key) if self._errors and self.init_data else value +        if not field.label: +            field.label = pretty_name(key) +        return field +      def field_to_native(self, obj, field_name):          """          Override default so that the serializer can be used as a nested field @@ -375,8 +405,14 @@ class BaseSerializer(WritableField):                  return          # Set the serializer object if it exists -        obj = getattr(self.parent.object, field_name) if self.parent.object else None -        obj = obj.all() if is_simple_callable(getattr(obj, 'all', None)) else obj +        obj = get_component(self.parent.object, self.source or field_name) if self.parent.object else None + +        # If we have a model manager or similar object then we need +        # to iterate through each instance. +        if (self.many and +            not hasattr(obj, '__iter__') and +            is_simple_callable(getattr(obj, 'all', None))): +            obj = obj.all()          if self.source == '*':              if value: @@ -503,6 +539,9 @@ class BaseSerializer(WritableField):          """          Save the deserialized object and return it.          """ +        # Clear cached _data, which may be invalidated by `save()` +        self._data = None +          if isinstance(self.object, list):              [self.save_object(item, **kwargs) for item in self.object] @@ -751,6 +790,8 @@ class ModelSerializer(Serializer):          # TODO: TypedChoiceField?          if model_field.flatchoices:  # This ModelField contains choices              kwargs['choices'] = model_field.flatchoices +            if model_field.null: +                kwargs['empty'] = None              return ChoiceField(**kwargs)          # put this below the ChoiceField because min_value isn't a valid initializer @@ -822,13 +863,13 @@ class ModelSerializer(Serializer):          # Reverse fk or one-to-one relations          for (obj, model) in meta.get_all_related_objects_with_model(): -            field_name = obj.field.related_query_name() +            field_name = obj.get_accessor_name()              if field_name in attrs:                  related_data[field_name] = attrs.pop(field_name)          # Reverse m2m relations          for (obj, model) in meta.get_all_related_m2m_objects_with_model(): -            field_name = obj.field.related_query_name() +            field_name = obj.get_accessor_name()              if field_name in attrs:                  m2m_data[field_name] = attrs.pop(field_name) @@ -846,7 +887,10 @@ class ModelSerializer(Serializer):          # Update an existing instance...          if instance is not None:              for key, val in attrs.items(): -                setattr(instance, key, val) +                try: +                    setattr(instance, key, val) +                except ValueError: +                    self._errors[key] = self.error_messages['required']          # ...or create a new instance          else: @@ -872,7 +916,7 @@ class ModelSerializer(Serializer):      def save_object(self, obj, **kwargs):          """ -        Save the deserialized object and return it. +        Save the deserialized object.          """          if getattr(obj, '_nested_forward_relations', None):              # Nested relationships need to be saved before we can save the @@ -890,11 +934,16 @@ class ModelSerializer(Serializer):              del(obj._m2m_data)          if getattr(obj, '_related_data', None): +            related_fields = dict([ +                (field.get_accessor_name(), field) +                for field, model +                in obj._meta.get_all_related_objects_with_model() +            ])              for accessor_name, related in obj._related_data.items():                  if isinstance(related, RelationsList):                      # Nested reverse fk relationship                      for related_item in related: -                        fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name +                        fk_field = related_fields[accessor_name].field.name                          setattr(related_item, fk_field, obj)                          self.save_object(related_item) diff --git a/rest_framework/status.py b/rest_framework/status.py index b9f249f9..76435371 100644 --- a/rest_framework/status.py +++ b/rest_framework/status.py @@ -6,6 +6,23 @@ And RFC 6585 - http://tools.ietf.org/html/rfc6585  """  from __future__ import unicode_literals + +def is_informational(code): +    return code >= 100 and code <= 199 + +def is_success(code): +    return code >= 200 and code <= 299 + +def is_redirect(code): +    return code >= 300 and code <= 399 + +def is_client_error(code): +    return code >= 400 and code <= 499 + +def is_server_error(code): +    return code >= 500 and code <= 599 + +  HTTP_100_CONTINUE = 100  HTTP_101_SWITCHING_PROTOCOLS = 101  HTTP_200_OK = 200 diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html index 47377d51..42ede968 100644 --- a/rest_framework/templates/rest_framework/base.html +++ b/rest_framework/templates/rest_framework/base.html @@ -111,7 +111,9 @@          <div class="content-main">              <div class="page-header"><h1>{{ name }}</h1></div> +            {% block description %}              {{ description }} +            {% endblock %}              <div class="request-info" style="clear: both" >                  <pre class="prettyprint"><b>{{ request.method }}</b> {{ request.get_full_path }}</pre>              </div> @@ -152,7 +154,7 @@                              {% with form=raw_data_post_form %}                              <form action="{{ request.get_full_path }}" method="POST" class="form-horizontal">                                  <fieldset> -                                    {% include "rest_framework/form.html" %} +                                    {% include "rest_framework/raw_data_form.html" %}                                      <div class="form-actions">                                          <button class="btn btn-primary" title="Make a POST request on the {{ name }} resource">POST</button>                                      </div> @@ -189,7 +191,7 @@                              {% with form=raw_data_put_or_patch_form %}                              <form action="{{ request.get_full_path }}" method="POST" class="form-horizontal">                                  <fieldset> -                                    {% include "rest_framework/form.html" %} +                                    {% include "rest_framework/raw_data_form.html" %}                                      <div class="form-actions">                                          {% if raw_data_put_form %}                                          <button class="btn btn-primary js-tooltip" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="PUT" title="Make a PUT request on the {{ name }} resource">PUT</button> @@ -220,9 +222,6 @@      </div><!-- ./wrapper -->      {% block footer %} -    <!--<div id="footer"> -        <a class="powered-by" href='http://django-rest-framework.org'>Django REST framework</a> -    </div>-->      {% endblock %}      {% block script %} diff --git a/rest_framework/templates/rest_framework/form.html b/rest_framework/templates/rest_framework/form.html index b27f652e..b1e148df 100644 --- a/rest_framework/templates/rest_framework/form.html +++ b/rest_framework/templates/rest_framework/form.html @@ -1,13 +1,15 @@  {% load rest_framework %}  {% csrf_token %}  {{ form.non_field_errors }} -{% for field in form %} -    <div class="control-group"> <!--{% if field.errors %}error{% endif %}--> +{% for field in form.fields.values %} +    {% if not field.read_only %} +    <div class="control-group {% if field.errors %}error{% endif %}">          {{ field.label_tag|add_class:"control-label" }}          <div class="controls"> -            {{ field }} -            <span class="help-block">{{ field.help_text }}</span> -            <!--{{ field.errors|add_class:"help-block" }}--> +            {{ field.widget_html }} +            {% if field.help_text %}<span class="help-block">{{ field.help_text }}</span>{% endif %} +            {% for error in field.errors %}<span class="help-block">{{ error }}</span>{% endfor %}          </div>      </div> +    {% endif %}  {% endfor %} diff --git a/rest_framework/templates/rest_framework/raw_data_form.html b/rest_framework/templates/rest_framework/raw_data_form.html new file mode 100644 index 00000000..075279f7 --- /dev/null +++ b/rest_framework/templates/rest_framework/raw_data_form.html @@ -0,0 +1,12 @@ +{% load rest_framework %} +{% csrf_token %} +{{ form.non_field_errors }} +{% for field in form %} +    <div class="control-group"> +        {{ field.label_tag|add_class:"control-label" }} +        <div class="controls"> +            {{ field }} +            <span class="help-block">{{ field.help_text }}</span> +        </div> +    </div> +{% endfor %} diff --git a/rest_framework/tests/test_fields.py b/rest_framework/tests/test_fields.py index 34fbab9c..5c96bce9 100644 --- a/rest_framework/tests/test_fields.py +++ b/rest_framework/tests/test_fields.py @@ -42,6 +42,31 @@ class TimeFieldModelSerializer(serializers.ModelSerializer):          model = TimeFieldModel +SAMPLE_CHOICES = [ +    ('red', 'Red'), +    ('green', 'Green'), +    ('blue', 'Blue'), +] + + +class ChoiceFieldModel(models.Model): +    choice = models.CharField(choices=SAMPLE_CHOICES, blank=True, max_length=255) + + +class ChoiceFieldModelSerializer(serializers.ModelSerializer): +    class Meta: +        model = ChoiceFieldModel + + +class ChoiceFieldModelWithNull(models.Model): +    choice = models.CharField(choices=SAMPLE_CHOICES, blank=True, null=True, max_length=255) + + +class ChoiceFieldModelWithNullSerializer(serializers.ModelSerializer): +    class Meta: +        model = ChoiceFieldModelWithNull + +  class BasicFieldTests(TestCase):      def test_auto_now_fields_read_only(self):          """ @@ -667,34 +692,71 @@ class ChoiceFieldTests(TestCase):      """      Tests for the ChoiceField options generator      """ - -    SAMPLE_CHOICES = [ -        ('red', 'Red'), -        ('green', 'Green'), -        ('blue', 'Blue'), -    ] -      def test_choices_required(self):          """          Make sure proper choices are rendered if field is required          """ -        f = serializers.ChoiceField(required=True, choices=self.SAMPLE_CHOICES) -        self.assertEqual(f.choices, self.SAMPLE_CHOICES) +        f = serializers.ChoiceField(required=True, choices=SAMPLE_CHOICES) +        self.assertEqual(f.choices, SAMPLE_CHOICES)      def test_choices_not_required(self):          """          Make sure proper choices (plus blank) are rendered if the field isn't required          """ -        f = serializers.ChoiceField(required=False, choices=self.SAMPLE_CHOICES) -        self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + self.SAMPLE_CHOICES) +        f = serializers.ChoiceField(required=False, choices=SAMPLE_CHOICES) +        self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + SAMPLE_CHOICES) + +    def test_invalid_choice_model(self): +        s = ChoiceFieldModelSerializer(data={'choice': 'wrong_value'}) +        self.assertFalse(s.is_valid()) +        self.assertEqual(s.errors,  {'choice': ['Select a valid choice. wrong_value is not one of the available choices.']}) +        self.assertEqual(s.data['choice'], '') + +    def test_empty_choice_model(self): +        """ +        Test that the 'empty' value is correctly passed and used depending on +        the 'null' property on the model field. +        """ +        s = ChoiceFieldModelSerializer(data={'choice': ''}) +        self.assertTrue(s.is_valid()) +        self.assertEqual(s.data['choice'], '') + +        s = ChoiceFieldModelWithNullSerializer(data={'choice': ''}) +        self.assertTrue(s.is_valid()) +        self.assertEqual(s.data['choice'], None)      def test_from_native_empty(self):          """ -        Make sure from_native() returns None on empty param. +        Make sure from_native() returns an empty string on empty param by default.          """ -        f = serializers.ChoiceField(choices=self.SAMPLE_CHOICES) -        result = f.from_native('') -        self.assertEqual(result, None) +        f = serializers.ChoiceField(choices=SAMPLE_CHOICES) +        self.assertEqual(f.from_native(''), '') +        self.assertEqual(f.from_native(None), '') + +    def test_from_native_empty_override(self): +        """ +        Make sure you can override from_native() behavior regarding empty values. +        """ +        f = serializers.ChoiceField(choices=SAMPLE_CHOICES, empty=None) +        self.assertEqual(f.from_native(''), None) +        self.assertEqual(f.from_native(None), None) + +    def test_metadata_choices(self): +        """ +        Make sure proper choices are included in the field's metadata. +        """ +        choices = [{'value': v, 'display_name': n} for v, n in SAMPLE_CHOICES] +        f = serializers.ChoiceField(choices=SAMPLE_CHOICES) +        self.assertEqual(f.metadata()['choices'], choices) + +    def test_metadata_choices_not_required(self): +        """ +        Make sure proper choices are included in the field's metadata. +        """ +        choices = [{'value': v, 'display_name': n} +                   for v, n in models.fields.BLANK_CHOICE_DASH + SAMPLE_CHOICES] +        f = serializers.ChoiceField(required=False, choices=SAMPLE_CHOICES) +        self.assertEqual(f.metadata()['choices'], choices)  class EmailFieldTests(TestCase): diff --git a/rest_framework/tests/test_files.py b/rest_framework/tests/test_files.py index c13c38b8..78f4cf42 100644 --- a/rest_framework/tests/test_files.py +++ b/rest_framework/tests/test_files.py @@ -80,3 +80,16 @@ class FileSerializerTests(TestCase):          serializer = UploadedFileSerializer(data={'created': now, 'file': 'abc'})          self.assertFalse(serializer.is_valid())          self.assertEqual(serializer.errors, {'file': [errmsg]}) + +    def test_validation_with_no_data(self): +        """ +        Validation should still function when no data dictionary is provided. +        """ +        now = datetime.datetime.now() +        file = BytesIO(six.b('stuff')) +        file.name = 'stuff.txt' +        file.size = len(file.getvalue()) +        uploaded_file = UploadedFile(file=file, created=now) + +        serializer = UploadedFileSerializer(files={'file': file}) +        self.assertFalse(serializer.is_valid()) diff --git a/rest_framework/tests/test_filters.py b/rest_framework/tests/test_filters.py index 9697c5ee..8a03a077 100644 --- a/rest_framework/tests/test_filters.py +++ b/rest_framework/tests/test_filters.py @@ -364,6 +364,12 @@ class OrdringFilterModel(models.Model):      text = models.CharField(max_length=100) +class OrderingFilterRelatedModel(models.Model): +    related_object = models.ForeignKey(OrdringFilterModel, +                                       related_name="relateds") + + +  class OrderingFilterTests(TestCase):      def setUp(self):          # Sequence of title/text is: @@ -473,3 +479,36 @@ class OrderingFilterTests(TestCase):                  {'id': 1, 'title': 'zyx', 'text': 'abc'},              ]          ) + +    def test_ordering_by_aggregate_field(self): +        # create some related models to aggregate order by +        num_objs = [2, 5, 3] +        for obj, num_relateds in zip(OrdringFilterModel.objects.all(), +                                     num_objs): +            for _ in range(num_relateds): +                new_related = OrderingFilterRelatedModel( +                    related_object=obj +                ) +                new_related.save() + +        class OrderingListView(generics.ListAPIView): +            model = OrdringFilterModel +            filter_backends = (filters.OrderingFilter,) +            ordering = 'title' +            queryset = OrdringFilterModel.objects.all().annotate( +                models.Count("relateds")) + +        view = OrderingListView.as_view() +        request = factory.get('?ordering=relateds__count') +        response = view(request) +        self.assertEqual( +            response.data, +            [ +                {'id': 1, 'title': 'zyx', 'text': 'abc'}, +                {'id': 3, 'title': 'xwv', 'text': 'cde'}, +                {'id': 2, 'title': 'yxw', 'text': 'bcd'}, +            ] +        ) + + + diff --git a/rest_framework/tests/test_generics.py b/rest_framework/tests/test_generics.py index 79cd99ac..996bd5b0 100644 --- a/rest_framework/tests/test_generics.py +++ b/rest_framework/tests/test_generics.py @@ -23,6 +23,10 @@ class InstanceView(generics.RetrieveUpdateDestroyAPIView):      """      model = BasicModel +    def get_queryset(self): +        queryset = super(InstanceView, self).get_queryset() +        return queryset.exclude(text='filtered out') +  class SlugSerializer(serializers.ModelSerializer):      slug = serializers.Field()  # read only @@ -160,10 +164,10 @@ class TestInstanceView(TestCase):          """          Create 3 BasicModel intances.          """ -        items = ['foo', 'bar', 'baz'] +        items = ['foo', 'bar', 'baz', 'filtered out']          for item in items:              BasicModel(text=item).save() -        self.objects = BasicModel.objects +        self.objects = BasicModel.objects.exclude(text='filtered out')          self.data = [              {'id': obj.id, 'text': obj.text}              for obj in self.objects.all() @@ -352,6 +356,17 @@ class TestInstanceView(TestCase):          updated = self.objects.get(id=1)          self.assertEqual(updated.text, 'foobar') +    def test_put_to_filtered_out_instance(self): +        """ +        PUT requests to an URL of instance which is filtered out should not be +        able to create new objects. +        """ +        data = {'text': 'foo'} +        filtered_out_pk = BasicModel.objects.filter(text='filtered out')[0].pk +        request = factory.put('/{0}'.format(filtered_out_pk), data, format='json') +        response = self.view(request, pk=filtered_out_pk).render() +        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) +      def test_put_as_create_on_id_based_url(self):          """          PUT requests to RetrieveUpdateDestroyAPIView should create an object @@ -508,6 +523,25 @@ class ExclusiveFilterBackend(object):          return queryset.filter(text='other') +class TwoFieldModel(models.Model): +    field_a = models.CharField(max_length=100) +    field_b = models.CharField(max_length=100) + + +class DynamicSerializerView(generics.ListCreateAPIView): +    model = TwoFieldModel +    renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer) + +    def get_serializer_class(self): +        if self.request.method == 'POST': +            class DynamicSerializer(serializers.ModelSerializer): +                class Meta: +                    model = TwoFieldModel +                    fields = ('field_b',) +            return DynamicSerializer +        return super(DynamicSerializerView, self).get_serializer_class() + +  class TestFilterBackendAppliedToViews(TestCase):      def setUp(self): @@ -564,28 +598,6 @@ class TestFilterBackendAppliedToViews(TestCase):          self.assertEqual(response.status_code, status.HTTP_200_OK)          self.assertEqual(response.data, {'id': 1, 'text': 'foo'}) - -class TwoFieldModel(models.Model): -    field_a = models.CharField(max_length=100) -    field_b = models.CharField(max_length=100) - - -class DynamicSerializerView(generics.ListCreateAPIView): -    model = TwoFieldModel -    renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer) - -    def get_serializer_class(self): -        if self.request.method == 'POST': -            class DynamicSerializer(serializers.ModelSerializer): -                class Meta: -                    model = TwoFieldModel -                    fields = ('field_b',) -            return DynamicSerializer -        return super(DynamicSerializerView, self).get_serializer_class() - - -class TestFilterBackendAppliedToViews(TestCase): -      def test_dynamic_serializer_form_in_browsable_api(self):          """          GET requests to ListCreateAPIView should return filtered list. diff --git a/rest_framework/tests/test_pagination.py b/rest_framework/tests/test_pagination.py index d6bc7895..cadb515f 100644 --- a/rest_framework/tests/test_pagination.py +++ b/rest_framework/tests/test_pagination.py @@ -430,3 +430,88 @@ class TestCustomPaginationSerializer(TestCase):              'objects': ['john', 'paul']          }          self.assertEqual(serializer.data, expected) + + +class NonIntegerPage(object): + +    def __init__(self, paginator, object_list, prev_token, token, next_token): +        self.paginator = paginator +        self.object_list = object_list +        self.prev_token = prev_token +        self.token = token +        self.next_token = next_token + +    def has_next(self): +        return not not self.next_token + +    def next_page_number(self): +        return self.next_token + +    def has_previous(self): +        return not not self.prev_token + +    def previous_page_number(self): +        return self.prev_token + + +class NonIntegerPaginator(object): + +    def __init__(self, object_list, per_page): +        self.object_list = object_list +        self.per_page = per_page + +    def count(self): +        # pretend like we don't know how many pages we have +        return None + +    def page(self, token=None): +        if token: +            try: +                first = self.object_list.index(token) +            except ValueError: +                first = 0 +        else: +            first = 0 +        n = len(self.object_list) +        last = min(first + self.per_page, n) +        prev_token = self.object_list[last - (2 * self.per_page)] if first else None +        next_token = self.object_list[last] if last < n else None +        return NonIntegerPage(self, self.object_list[first:last], prev_token, token, next_token) + + +class TestNonIntegerPagination(TestCase): + + +    def test_custom_pagination_serializer(self): +        objects = ['john', 'paul', 'george', 'ringo'] +        paginator = NonIntegerPaginator(objects, 2) + +        request = APIRequestFactory().get('/foobar') +        serializer = CustomPaginationSerializer( +            instance=paginator.page(), +            context={'request': request} +        ) +        expected = { +            'links': { +                'next': 'http://testserver/foobar?page={0}'.format(objects[2]), +                'prev': None +            }, +            'total_results': None, +            'objects': objects[:2] +        } +        self.assertEqual(serializer.data, expected) + +        request = APIRequestFactory().get('/foobar') +        serializer = CustomPaginationSerializer( +            instance=paginator.page('george'), +            context={'request': request} +        ) +        expected = { +            'links': { +                'next': None, +                'prev': 'http://testserver/foobar?page={0}'.format(objects[0]), +            }, +            'total_results': None, +            'objects': objects[2:] +        } +        self.assertEqual(serializer.data, expected) diff --git a/rest_framework/tests/test_permissions.py b/rest_framework/tests/test_permissions.py index d08124f4..6e3a6303 100644 --- a/rest_framework/tests/test_permissions.py +++ b/rest_framework/tests/test_permissions.py @@ -4,7 +4,7 @@ from django.db import models  from django.test import TestCase  from django.utils import unittest  from rest_framework import generics, status, permissions, authentication, HTTP_HEADER_ENCODING -from rest_framework.compat import guardian +from rest_framework.compat import guardian, get_model_name  from rest_framework.filters import DjangoObjectPermissionsFilter  from rest_framework.test import APIRequestFactory  from rest_framework.tests.models import BasicModel @@ -202,7 +202,7 @@ class ObjectPermissionsIntegrationTests(TestCase):          # give everyone model level permissions, as we are not testing those          everyone = Group.objects.create(name='everyone') -        model_name = BasicPermModel._meta.module_name +        model_name = get_model_name(BasicPermModel)          app_label = BasicPermModel._meta.app_label          f = '{0}_{1}'.format          perms = { diff --git a/rest_framework/tests/test_renderers.py b/rest_framework/tests/test_renderers.py index 9d1dd77e..9cb68233 100644 --- a/rest_framework/tests/test_renderers.py +++ b/rest_framework/tests/test_renderers.py @@ -16,7 +16,9 @@ from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \  from rest_framework.parsers import YAMLParser, XMLParser  from rest_framework.settings import api_settings  from rest_framework.test import APIRequestFactory +from collections import MutableMapping  import datetime +import json  import pickle  import re @@ -65,11 +67,23 @@ class MockView(APIView):  class MockGETView(APIView): -      def get(self, request, **kwargs):          return Response({'foo': ['bar', 'baz']}) + +class MockPOSTView(APIView): +    def post(self, request, **kwargs): +        return Response({'foo': request.DATA}) + + +class EmptyGETView(APIView): +    renderer_classes = (JSONRenderer,) + +    def get(self, request, **kwargs): +        return Response(status=status.HTTP_204_NO_CONTENT) + +  class HTMLView(APIView):      renderer_classes = (BrowsableAPIRenderer, ) @@ -89,8 +103,10 @@ urlpatterns = patterns('',      url(r'^cache$', MockGETView.as_view()),      url(r'^jsonp/jsonrenderer$', MockGETView.as_view(renderer_classes=[JSONRenderer, JSONPRenderer])),      url(r'^jsonp/nojsonrenderer$', MockGETView.as_view(renderer_classes=[JSONPRenderer])), +    url(r'^parseerror$', MockPOSTView.as_view(renderer_classes=[JSONRenderer, BrowsableAPIRenderer])),      url(r'^html$', HTMLView.as_view()),      url(r'^html1$', HTMLView1.as_view()), +    url(r'^empty$', EmptyGETView.as_view()),      url(r'^api', include('rest_framework.urls', namespace='rest_framework'))  ) @@ -220,6 +236,22 @@ class RendererEndToEndTests(TestCase):          self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))          self.assertEqual(resp.status_code, DUMMYSTATUS) +    def test_parse_error_renderers_browsable_api(self): +        """Invalid data should still render the browsable API correctly.""" +        resp = self.client.post('/parseerror', data='foobar', content_type='application/json', HTTP_ACCEPT='text/html') +        self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') +        self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST) + +    def test_204_no_content_responses_have_no_content_type_set(self): +        """ +        Regression test for #1196 + +        https://github.com/tomchristie/django-rest-framework/issues/1196 +        """ +        resp = self.client.get('/empty') +        self.assertEqual(resp.get('Content-Type', None), None) +        self.assertEqual(resp.status_code, status.HTTP_204_NO_CONTENT) +  _flat_repr = '{"foo": ["bar", "baz"]}'  _indented_repr = '{\n  "foo": [\n    "bar",\n    "baz"\n  ]\n}' @@ -245,6 +277,44 @@ class JSONRendererTests(TestCase):          ret = JSONRenderer().render(_('test'))          self.assertEqual(ret, b'"test"') +    def test_render_dict_abc_obj(self): +        class Dict(MutableMapping): +            def __init__(self): +                self._dict = dict() +            def __getitem__(self, key): +                return self._dict.__getitem__(key) +            def __setitem__(self, key, value): +                return self._dict.__setitem__(key, value) +            def __delitem__(self, key): +                return self._dict.__delitem__(key) +            def __iter__(self): +                return self._dict.__iter__() +            def __len__(self): +                return self._dict.__len__() +            def keys(self): +                return self._dict.keys() + +        x = Dict() +        x['key'] = 'string value' +        x[2] = 3 +        ret = JSONRenderer().render(x) +        data = json.loads(ret.decode('utf-8')) +        self.assertEquals(data, {'key': 'string value', '2': 3})     + +    def test_render_obj_with_getitem(self): +        class DictLike(object): +            def __init__(self): +                self._dict = {} +            def set(self, value): +                self._dict = dict(value) +            def __getitem__(self, key): +                return self._dict[key] +             +        x = DictLike() +        x.set({'a': 1, 'b': 'string'}) +        with self.assertRaises(TypeError): +            JSONRenderer().render(x) +              def test_without_content_type_args(self):          """          Test basic JSON rendering. @@ -329,7 +399,7 @@ if yaml:      class YAMLRendererTests(TestCase):          """ -        Tests specific to the JSON Renderer +        Tests specific to the YAML Renderer          """          def test_render(self): @@ -355,6 +425,17 @@ if yaml:              data = parser.parse(StringIO(content))              self.assertEqual(obj, data) +        def test_render_decimal(self): +            """ +            Test YAML decimal rendering. +            """ +            renderer = YAMLRenderer() +            content = renderer.render({'field': Decimal('111.2')}, 'application/yaml') +            self.assertYAMLContains(content, "field: '111.2'") + +        def assertYAMLContains(self, content, string): +            self.assertTrue(string in content, '%r not in %r' % (string, content)) +  class XMLRendererTestCase(TestCase):      """ diff --git a/rest_framework/tests/test_request.py b/rest_framework/tests/test_request.py index d6363425..f07c31a3 100644 --- a/rest_framework/tests/test_request.py +++ b/rest_framework/tests/test_request.py @@ -6,6 +6,7 @@ from django.conf.urls import patterns  from django.contrib.auth.models import User  from django.contrib.auth import authenticate, login, logout  from django.contrib.sessions.middleware import SessionMiddleware +from django.core.handlers.wsgi import WSGIRequest  from django.test import TestCase  from rest_framework import status  from rest_framework.authentication import SessionAuthentication @@ -15,12 +16,13 @@ from rest_framework.parsers import (      MultiPartParser,      JSONParser  ) -from rest_framework.request import Request +from rest_framework.request import Request, Empty  from rest_framework.response import Response  from rest_framework.settings import api_settings  from rest_framework.test import APIRequestFactory, APIClient  from rest_framework.views import APIView  from rest_framework.compat import six +from io import BytesIO  import json @@ -146,6 +148,34 @@ class TestContentParsing(TestCase):          request.parsers = (JSONParser(), )          self.assertEqual(request.DATA, json_data) +    def test_form_POST_unicode(self): +        """ +        JSON POST via default web interface with unicode data +        """ +        # Note: environ and other variables here have simplified content compared to real Request +        CONTENT = b'_content_type=application%2Fjson&_content=%7B%22request%22%3A+4%2C+%22firm%22%3A+1%2C+%22text%22%3A+%22%D0%9F%D1%80%D0%B8%D0%B2%D0%B5%D1%82%21%22%7D' +        environ = { +            'REQUEST_METHOD': 'POST', +            'CONTENT_TYPE': 'application/x-www-form-urlencoded', +            'CONTENT_LENGTH': len(CONTENT), +            'wsgi.input': BytesIO(CONTENT), +        } +        wsgi_request = WSGIRequest(environ=environ) +        wsgi_request._load_post_and_files() +        parsers = (JSONParser(), FormParser(), MultiPartParser()) +        parser_context = { +            'encoding': 'utf-8', +            'kwargs': {}, +            'args': (), +        } +        request = Request(wsgi_request, parsers=parsers, parser_context=parser_context) +        method = request.method +        self.assertEqual(method, 'POST') +        self.assertEqual(request._content_type, 'application/json') +        self.assertEqual(request._stream.getvalue(), b'{"request": 4, "firm": 1, "text": "\xd0\x9f\xd1\x80\xd0\xb8\xd0\xb2\xd0\xb5\xd1\x82!"}') +        self.assertEqual(request._data, Empty) +        self.assertEqual(request._files, Empty) +      # def test_accessing_post_after_data_form(self):      #     """      #     Ensures request.POST can be accessed after request.DATA in diff --git a/rest_framework/tests/test_serializer.py b/rest_framework/tests/test_serializer.py index 739bb70a..e80276e9 100644 --- a/rest_framework/tests/test_serializer.py +++ b/rest_framework/tests/test_serializer.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*-  from __future__ import unicode_literals  from django.db import models  from django.db.models.fields import BLANK_CHOICE_DASH @@ -136,6 +137,7 @@ class BasicTests(TestCase):              'Happy new year!',              datetime.datetime(2012, 1, 1)          ) +        self.actionitem = ActionItem(title='Some to do item',)          self.data = {              'email': 'tom@example.com',              'content': 'Happy new year!', @@ -157,8 +159,7 @@ class BasicTests(TestCase):          expected = {              'email': '',              'content': '', -            'created': None, -            'sub_comment': '' +            'created': None          }          self.assertEqual(serializer.data, expected) @@ -264,6 +265,20 @@ class BasicTests(TestCase):          """          self.assertRaises(AssertionError, PersonSerializerInvalidReadOnly, []) +    def test_serializer_data_is_cleared_on_save(self): +        """ +        Check _data attribute is cleared on `save()` + +        Regression test for #1116 +            — id field is not populated if `data` is accessed prior to `save()` +        """ +        serializer = ActionItemSerializer(self.actionitem) +        self.assertIsNone(serializer.data.get('id',None), 'New instance. `id` should not be set.') +        serializer.save() +        self.assertIsNotNone(serializer.data.get('id',None), 'Model is saved. `id` should be set.') + + +  class DictStyleSerializer(serializers.Serializer):      """ @@ -496,6 +511,33 @@ class CustomValidationTests(TestCase):          self.assertFalse(serializer.is_valid())          self.assertEqual(serializer.errors, {'email': ['Enter a valid email address.']}) +    def test_partial_update(self): +        """ +        Make sure that validate_email isn't called when partial=True and email +        isn't found in data. +        """ +        initial_data = { +            'email': 'tom@example.com', +            'content': 'A test comment', +            'created': datetime.datetime(2012, 1, 1) +        } + +        serializer = self.CommentSerializerWithFieldValidator(data=initial_data) +        self.assertEqual(serializer.is_valid(), True) +        instance = serializer.object + +        new_content = 'An *updated* test comment' +        partial_data = { +            'content': new_content +        } + +        serializer = self.CommentSerializerWithFieldValidator(instance=instance, +                                                              data=partial_data, +                                                              partial=True) +        self.assertEqual(serializer.is_valid(), True) +        instance = serializer.object +        self.assertEqual(instance.content, new_content) +  class PositiveIntegerAsChoiceTests(TestCase):      def test_positive_integer_in_json_is_correctly_parsed(self): @@ -516,6 +558,29 @@ class ModelValidationTests(TestCase):          self.assertFalse(second_serializer.is_valid())          self.assertEqual(second_serializer.errors,  {'title': ['Album with this Title already exists.']}) +    def test_foreign_key_is_null_with_partial(self): +        """ +        Test ModelSerializer validation with partial=True + +        Specifically test that a null foreign key does not pass validation +        """ +        album = Album(title='test') +        album.save() + +        class PhotoSerializer(serializers.ModelSerializer): +            class Meta: +                model = Photo + +        photo_serializer = PhotoSerializer(data={'description': 'test', 'album': album.pk}) +        self.assertTrue(photo_serializer.is_valid()) +        photo = photo_serializer.save() + +        # Updating only the album (foreign key) +        photo_serializer = PhotoSerializer(instance=photo, data={'album': ''}, partial=True) +        self.assertFalse(photo_serializer.is_valid()) +        self.assertTrue('album' in photo_serializer.errors) +        self.assertEqual(photo_serializer.errors['album'], photo_serializer.error_messages['required']) +      def test_foreign_key_with_partial(self):          """          Test ModelSerializer validation with partial=True @@ -1643,3 +1708,38 @@ class SerializerSupportsManyRelationships(TestCase):          serializer = SimpleSlugSourceModelSerializer(data={'text': 'foo', 'targets': [1, 2]})          self.assertTrue(serializer.is_valid())          self.assertEqual(serializer.data, {'text': 'foo', 'targets': [1, 2]}) + + +class TransformMethodsSerializer(serializers.Serializer): +    a = serializers.CharField() +    b_renamed = serializers.CharField(source='b') + +    def transform_a(self, obj, value): +        return value.lower() + +    def transform_b_renamed(self, obj, value): +        if value is not None: +            return 'and ' + value + + +class TestSerializerTransformMethods(TestCase): +    def setUp(self): +        self.s = TransformMethodsSerializer() + +    def test_transform_methods(self): +        self.assertEqual( +            self.s.to_native({'a': 'GREEN EGGS', 'b': 'HAM'}), +            { +                'a': 'green eggs', +                'b_renamed': 'and HAM', +            } +        ) + +    def test_missing_fields(self): +        self.assertEqual( +            self.s.to_native({'a': 'GREEN EGGS'}), +            { +                'a': 'green eggs', +                'b_renamed': None, +            } +        ) diff --git a/rest_framework/tests/test_serializer_empty.py b/rest_framework/tests/test_serializer_empty.py new file mode 100644 index 00000000..30cff361 --- /dev/null +++ b/rest_framework/tests/test_serializer_empty.py @@ -0,0 +1,15 @@ +from django.test import TestCase +from rest_framework import serializers + + +class EmptySerializerTestCase(TestCase): +    def test_empty_serializer(self): +        class FooBarSerializer(serializers.Serializer): +            foo = serializers.IntegerField() +            bar = serializers.SerializerMethodField('get_bar') + +            def get_bar(self, obj): +                return 'bar' + +        serializer = FooBarSerializer() +        self.assertEquals(serializer.data, {'foo': 0}) diff --git a/rest_framework/tests/test_serializer_nested.py b/rest_framework/tests/test_serializer_nested.py index 71d0e24b..7114a060 100644 --- a/rest_framework/tests/test_serializer_nested.py +++ b/rest_framework/tests/test_serializer_nested.py @@ -6,6 +6,7 @@ Doesn't cover model serializers.  from __future__ import unicode_literals  from django.test import TestCase  from rest_framework import serializers +from . import models  class WritableNestedSerializerBasicTests(TestCase): @@ -244,3 +245,104 @@ class WritableNestedSerializerObjectTests(TestCase):          serializer = self.AlbumSerializer(data=data, many=True)          self.assertEqual(serializer.is_valid(), True)          self.assertEqual(serializer.object, expected_object) + + +class ForeignKeyNestedSerializerUpdateTests(TestCase): +    def setUp(self): +        class Artist(object): +            def __init__(self, name): +                self.name = name + +            def __eq__(self, other): +                return self.name == other.name + +        class Album(object): +            def __init__(self, name, artist): +                self.name, self.artist = name, artist + +            def __eq__(self, other): +                return self.name == other.name and self.artist == other.artist + +        class ArtistSerializer(serializers.Serializer): +            name = serializers.CharField() + +            def restore_object(self, attrs, instance=None): +                if instance: +                    instance.name = attrs['name'] +                else: +                    instance = Artist(attrs['name']) +                return instance + +        class AlbumSerializer(serializers.Serializer): +            name = serializers.CharField() +            by = ArtistSerializer(source='artist') + +            def restore_object(self, attrs, instance=None): +                if instance: +                    instance.name = attrs['name'] +                    instance.artist = attrs['artist'] +                else: +                    instance = Album(attrs['name'], attrs['artist']) +                return instance + +        self.Artist = Artist +        self.Album = Album +        self.AlbumSerializer = AlbumSerializer + +    def test_create_via_foreign_key_with_source(self): +        """ +        Check that we can both *create* and *update* into objects across +        ForeignKeys that have a `source` specified. +        Regression test for #1170 +        """ +        data = { +            'name': 'Discovery', +            'by': {'name': 'Daft Punk'}, +        } + +        expected = self.Album(artist=self.Artist('Daft Punk'), name='Discovery') + +        # create +        serializer = self.AlbumSerializer(data=data) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(serializer.object, expected) + +        # update +        original = self.Album(artist=self.Artist('The Bats'), name='Free All the Monsters') +        serializer = self.AlbumSerializer(instance=original, data=data) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(serializer.object, expected) + + +class NestedModelSerializerUpdateTests(TestCase): +    def test_second_nested_level(self): +        john = models.Person.objects.create(name="john") + +        post = john.blogpost_set.create(title="Test blog post") +        post.blogpostcomment_set.create(text="I hate this blog post") +        post.blogpostcomment_set.create(text="I love this blog post") + +        class BlogPostCommentSerializer(serializers.ModelSerializer): +            class Meta: +                model = models.BlogPostComment + +        class BlogPostSerializer(serializers.ModelSerializer): +            comments = BlogPostCommentSerializer(many=True, source='blogpostcomment_set') +            class Meta: +                model = models.BlogPost +                fields = ('id', 'title', 'comments') + +        class PersonSerializer(serializers.ModelSerializer): +            posts = BlogPostSerializer(many=True, source='blogpost_set') +            class Meta: +                model = models.Person +                fields = ('id', 'name', 'age', 'posts') + +        serialize = PersonSerializer(instance=john) +        deserialize = PersonSerializer(data=serialize.data, instance=john) +        self.assertTrue(deserialize.is_valid()) + +        result = deserialize.object +        result.save() +        self.assertEqual(result.id, john.id) + diff --git a/rest_framework/tests/test_status.py b/rest_framework/tests/test_status.py new file mode 100644 index 00000000..7b1bdae3 --- /dev/null +++ b/rest_framework/tests/test_status.py @@ -0,0 +1,33 @@ +from __future__ import unicode_literals +from django.test import TestCase +from rest_framework.status import ( +    is_informational, is_success, is_redirect, is_client_error, is_server_error +) + + +class TestStatus(TestCase): +    def test_status_categories(self): +        self.assertFalse(is_informational(99)) +        self.assertTrue(is_informational(100)) +        self.assertTrue(is_informational(199)) +        self.assertFalse(is_informational(200)) + +        self.assertFalse(is_success(199)) +        self.assertTrue(is_success(200)) +        self.assertTrue(is_success(299)) +        self.assertFalse(is_success(300)) + +        self.assertFalse(is_redirect(299)) +        self.assertTrue(is_redirect(300)) +        self.assertTrue(is_redirect(399)) +        self.assertFalse(is_redirect(400)) + +        self.assertFalse(is_client_error(399)) +        self.assertTrue(is_client_error(400)) +        self.assertTrue(is_client_error(499)) +        self.assertFalse(is_client_error(500)) + +        self.assertFalse(is_server_error(499)) +        self.assertTrue(is_server_error(500)) +        self.assertTrue(is_server_error(599)) +        self.assertFalse(is_server_error(600))
\ No newline at end of file diff --git a/rest_framework/urlpatterns.py b/rest_framework/urlpatterns.py index a62530c7..038e9ee3 100644 --- a/rest_framework/urlpatterns.py +++ b/rest_framework/urlpatterns.py @@ -57,6 +57,6 @@ def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None):              allowed_pattern = '(%s)' % '|'.join(allowed)          suffix_pattern = r'\.(?P<%s>%s)$' % (suffix_kwarg, allowed_pattern)      else: -        suffix_pattern = r'\.(?P<%s>[a-z]+)$' % suffix_kwarg +        suffix_pattern = r'\.(?P<%s>[a-z0-9]+)$' % suffix_kwarg      return apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required) diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py index 13a85550..229b0b28 100644 --- a/rest_framework/utils/encoders.py +++ b/rest_framework/utils/encoders.py @@ -45,6 +45,11 @@ class JSONEncoder(json.JSONEncoder):              return str(o)          elif hasattr(o, 'tolist'):              return o.tolist() +        elif hasattr(o, '__getitem__'): +            try: +                return dict(o) +            except: +                pass          elif hasattr(o, '__iter__'):              return [i for i in o]          return super(JSONEncoder, self).default(o) @@ -90,6 +95,9 @@ else:                      node.flow_style = best_style              return node +    SafeDumper.add_representer(decimal.Decimal, +            SafeDumper.represent_decimal) +      SafeDumper.add_representer(SortedDict,              yaml.representer.SafeRepresenter.represent_dict)      SafeDumper.add_representer(DictWithMetadata, diff --git a/rest_framework/views.py b/rest_framework/views.py index 853e6461..e863af6d 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -154,8 +154,8 @@ class APIView(View):          Returns a dict that is passed through to Parser.parse(),          as the `parser_context` keyword argument.          """ -        # Note: Additionally `request` will also be added to the context -        #       by the Request object. +        # Note: Additionally `request` and `encoding` will also be added +        #       to the context by the Request object.          return {              'view': self,              'args': getattr(self, 'args', ()), diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py index d91323f2..7eb29f99 100644 --- a/rest_framework/viewsets.py +++ b/rest_framework/viewsets.py @@ -9,7 +9,7 @@ Actions are only bound to methods at the point of instantiating the views.      user_detail = UserViewSet.as_view({'get': 'retrieve'})  Typically, rather than instantiate views from viewsets directly, you'll -regsiter the viewset with a router and let the URL conf be determined +register the viewset with a router and let the URL conf be determined  automatically.      router = DefaultRouter() | 
