diff options
| author | Tom Christie | 2013-08-30 09:28:33 +0100 | 
|---|---|---|
| committer | Tom Christie | 2013-08-30 09:28:33 +0100 | 
| commit | 9a5b2eefa92dede844ab94d049093e91ac98af5b (patch) | |
| tree | faf389e2f8c8296aeaa486ab97ed0be9113cc2ba /rest_framework | |
| parent | bf07b8e616bd92e4ae3c2c09b198181d7075e6bd (diff) | |
| parent | f3ab0b2b1d5734314dbe3cdd13cd7c4f0531bf7d (diff) | |
| download | django-rest-framework-9a5b2eefa92dede844ab94d049093e91ac98af5b.tar.bz2 | |
Merge master
Diffstat (limited to 'rest_framework')
23 files changed, 878 insertions, 277 deletions
| diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 518ba41a..b3a9b0df 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -301,7 +301,10 @@ class WritableField(Field):          try:              if self.use_files:                  files = files or {} -                native = files[field_name] +                try: +                    native = files[field_name] +                except KeyError: +                    native = data[field_name]              else:                  native = data[field_name]          except KeyError: @@ -504,6 +507,11 @@ class ChoiceField(WritableField):                      return True          return False +    def from_native(self, value): +        if value in validators.EMPTY_VALUES: +            return None +        return super(ChoiceField, self).from_native(value) +  class EmailField(CharField):      type_name = 'EmailField' diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 8e6b8e26..851f8474 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -14,13 +14,15 @@ from rest_framework.settings import api_settings  import warnings -def strict_positive_int(integer_string): +def strict_positive_int(integer_string, cutoff=None):      """      Cast a string to a strictly positive integer.      """      ret = int(integer_string)      if ret <= 0:          raise ValueError() +    if cutoff: +        ret = min(ret, cutoff)      return ret  def get_object_or_404(queryset, **filter_kwargs): @@ -56,6 +58,7 @@ class GenericAPIView(views.APIView):      # Pagination settings      paginate_by = api_settings.PAGINATE_BY      paginate_by_param = api_settings.PAGINATE_BY_PARAM +    max_paginate_by = api_settings.MAX_PAGINATE_BY      pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS      page_kwarg = 'page' @@ -205,9 +208,11 @@ class GenericAPIView(views.APIView):                            DeprecationWarning, stacklevel=2)          if self.paginate_by_param: -            query_params = self.request.QUERY_PARAMS              try: -                return strict_positive_int(query_params[self.paginate_by_param]) +                return strict_positive_int( +                    self.request.QUERY_PARAMS[self.paginate_by_param], +                    cutoff=self.max_paginate_by +                )              except (KeyError, ValueError):                  pass diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 679dfa6c..2c85d157 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -142,11 +142,16 @@ class UpdateModelMixin(object):          try:              return self.get_object()          except Http404: -            # If this is a PUT-as-create operation, we need to ensure that -            # we have relevant permissions, as if this was a POST request. -            # This will either raise a PermissionDenied exception, -            # or simply return None -            self.check_permissions(clone_request(self.request, 'POST')) +            if self.request.method == 'PUT': +                # For PUT-as-create operation, we need to ensure that we have +                # relevant permissions, as if this was a POST request.  This +                # will either raise a PermissionDenied exception, or simply +                # return None. +                self.check_permissions(clone_request(self.request, 'POST')) +            else: +                # PATCH requests where the object does not exist should still +                # return a 404 response. +                raise      def pre_save(self, obj):          """ diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index 96bfac84..98fc0341 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -10,9 +10,9 @@ from django.core.files.uploadhandler import StopFutureHandlers  from django.http import QueryDict  from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser  from django.http.multipartparser import MultiPartParserError, parse_header, ChunkIter -from rest_framework.compat import yaml, etree +from rest_framework.compat import etree, six, yaml  from rest_framework.exceptions import ParseError -from rest_framework.compat import six +from rest_framework import renderers  import json  import datetime  import decimal @@ -47,6 +47,7 @@ class JSONParser(BaseParser):      """      media_type = 'application/json' +    renderer_class = renderers.UnicodeJSONRenderer      def parse(self, stream, media_type=None, parser_context=None):          """ @@ -121,7 +122,8 @@ class MultiPartParser(BaseParser):          parser_context = parser_context or {}          request = parser_context['request']          encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET) -        meta = request.META +        meta = request.META.copy() +        meta['CONTENT_TYPE'] = media_type          upload_handlers = request.upload_handlers          try: @@ -129,7 +131,7 @@ class MultiPartParser(BaseParser):              data, files = parser.parse()              return DataAndFiles(data, files)          except MultiPartParserError as exc: -            raise ParseError('Multipart form parse error - %s' % six.u(exc)) +            raise ParseError('Multipart form parse error - %s' % str(exc))  class XMLParser(BaseParser): diff --git a/rest_framework/relations.py b/rest_framework/relations.py index f1f7dea7..417925b5 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -126,9 +126,9 @@ class RelatedField(WritableField):              value = obj              for component in source.split('.'): -                value = get_component(value, component)                  if value is None:                      break +                value = get_component(value, component)          except ObjectDoesNotExist:              return None @@ -236,6 +236,8 @@ class PrimaryKeyRelatedField(RelatedField):                  source = self.source or field_name                  queryset = obj                  for component in source.split('.'): +                    if queryset is None: +                        return []                      queryset = get_component(queryset, component)              # Forward relationship @@ -556,8 +558,13 @@ class HyperlinkedIdentityField(Field):          May raise a `NoReverseMatch` if the `view_name` and `lookup_field`          attributes are not configured to correctly match the URL conf.          """ -        lookup_field = getattr(obj, self.lookup_field) +        lookup_field = getattr(obj, self.lookup_field, None)          kwargs = {self.lookup_field: lookup_field} + +        # Handle unsaved object case +        if lookup_field is None: +            return None +          try:              return reverse(view_name, kwargs=kwargs, request=request, format=format)          except NoReverseMatch: diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 1006e26c..fca67eee 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -21,10 +21,10 @@ from rest_framework.compat import six  from rest_framework.compat import smart_text  from rest_framework.compat import yaml  from rest_framework.settings import api_settings -from rest_framework.request import clone_request +from rest_framework.request import is_form_media_type, override_method  from rest_framework.utils import encoders  from rest_framework.utils.breadcrumbs import get_breadcrumbs -from rest_framework import exceptions, parsers, status, VERSION +from rest_framework import exceptions, status, VERSION  class BaseRenderer(object): @@ -36,6 +36,7 @@ class BaseRenderer(object):      media_type = None      format = None      charset = 'utf-8' +    render_style = 'text'      def render(self, data, accepted_media_type=None, renderer_context=None):          raise NotImplemented('Renderer class requires .render() to be implemented') @@ -51,16 +52,17 @@ class JSONRenderer(BaseRenderer):      format = 'json'      encoder_class = encoders.JSONEncoder      ensure_ascii = True -    charset = 'utf-8' -    # Note that JSON encodings must be utf-8, utf-16 or utf-32. +    charset = None +    # JSON is a binary encoding, that can be encoded as utf-8, utf-16 or utf-32.      # See: http://www.ietf.org/rfc/rfc4627.txt +    # Also: http://lucumr.pocoo.org/2013/7/19/application-mimetypes-and-encodings/      def render(self, data, accepted_media_type=None, renderer_context=None):          """          Render `data` into JSON.          """          if data is None: -            return '' +            return bytes()          # If 'indent' is provided in the context, then pretty print the result.          # E.g. If we're being called by the BrowsableAPIRenderer. @@ -85,13 +87,12 @@ class JSONRenderer(BaseRenderer):          # and may (or may not) be unicode.          # On python 3.x json.dumps() returns unicode strings.          if isinstance(ret, six.text_type): -            return bytes(ret.encode(self.charset)) +            return bytes(ret.encode('utf-8'))          return ret  class UnicodeJSONRenderer(JSONRenderer):      ensure_ascii = False -    charset = 'utf-8'      """      Renderer which serializes to JSON.      Does *not* apply JSON's character escaping for non-ascii characters. @@ -108,6 +109,7 @@ class JSONPRenderer(JSONRenderer):      format = 'jsonp'      callback_parameter = 'callback'      default_callback = 'callback' +    charset = 'utf-8'      def get_callback(self, renderer_context):          """ @@ -316,6 +318,90 @@ class StaticHTMLRenderer(TemplateHTMLRenderer):          return data +class HTMLFormRenderer(BaseRenderer): +    """ +    Renderers serializer data into an HTML form. + +    If the serializer was instantiated without an object then this will +    return an HTML form not bound to any object, +    otherwise it will return an HTML form with the appropriate initial data +    populated from the object. + +    Note that rendering of field and form errors is not currently supported. +    """ +    media_type = 'text/html' +    format = 'form' +    template = 'rest_framework/form.html' +    charset = 'utf-8' + +    def data_to_form_fields(self, data): +        fields = {} +        for key, val in data.fields.items(): +            if getattr(val, 'read_only', True): +                # Don't include read-only fields. +                continue + +            if getattr(val, 'fields', None): +                # Nested data not supported by HTML forms. +                continue + +            kwargs = {} +            kwargs['required'] = val.required + +            #if getattr(v, 'queryset', None): +            #    kwargs['queryset'] = v.queryset + +            if getattr(val, 'choices', None) is not None: +                kwargs['choices'] = val.choices + +            if getattr(val, 'regex', None) is not None: +                kwargs['regex'] = val.regex + +            if getattr(val, 'widget', None): +                widget = copy.deepcopy(val.widget) +                kwargs['widget'] = widget + +            if getattr(val, 'default', None) is not None: +                kwargs['initial'] = val.default + +            if getattr(val, 'label', None) is not None: +                kwargs['label'] = val.label + +            if getattr(val, 'help_text', None) is not None: +                kwargs['help_text'] = val.help_text + +            fields[key] = val.form_field_class(**kwargs) + +        return fields + +    def render(self, data, accepted_media_type=None, renderer_context=None): +        """ +        Render serializer data and return an HTML form, as a string. +        """ +        # The HTMLFormRenderer currently uses something of a hack to render +        # the content, by translating each of the serializer fields into +        # an html form field, creating a dynamic form using those fields, +        # and then rendering that form. + +        # This isn't strictly neccessary, as we could render the serilizer +        # fields to HTML directly.  The implementation is historical and will +        # likely change at some point. + +        self.renderer_context = renderer_context or {} +        request = renderer_context['request'] + +        # Creating an on the fly form see: +        # http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python +        fields = self.data_to_form_fields(data) +        DynamicForm = type(str('DynamicForm'), (forms.Form,), fields) +        data = None if data.empty else data + +        template = loader.get_template(self.template) +        context = RequestContext(request, {'form': DynamicForm(data)}) + +        return template.render(context) + +  class BrowsableAPIRenderer(BaseRenderer):      """      HTML renderer used to self-document the API. @@ -324,6 +410,7 @@ class BrowsableAPIRenderer(BaseRenderer):      format = 'api'      template = 'rest_framework/api.html'      charset = 'utf-8' +    form_renderer_class = HTMLFormRenderer      def get_default_renderer(self, view):          """ @@ -348,7 +435,10 @@ class BrowsableAPIRenderer(BaseRenderer):          renderer_context['indent'] = 4          content = renderer.render(data, accepted_media_type, renderer_context) -        if renderer.charset is None: +        render_style = getattr(renderer, 'render_style', 'text') +        assert render_style in ['text', 'binary'], 'Expected .render_style ' \ +            '"text" or "binary", but got "%s"' % render_style +        if render_style == 'binary':              return '[%d bytes of binary content]' % len(content)          return content @@ -371,130 +461,99 @@ class BrowsableAPIRenderer(BaseRenderer):              return False  # Doesn't have permissions          return True -    def serializer_to_form_fields(self, serializer): -        fields = {} -        for k, v in serializer.get_fields().items(): -            if getattr(v, 'read_only', True): -                continue - -            kwargs = {} -            kwargs['required'] = v.required - -            #if getattr(v, 'queryset', None): -            #    kwargs['queryset'] = v.queryset - -            if getattr(v, 'choices', None) is not None: -                kwargs['choices'] = v.choices - -            if getattr(v, 'regex', None) is not None: -                kwargs['regex'] = v.regex - -            if getattr(v, 'widget', None): -                widget = copy.deepcopy(v.widget) -                kwargs['widget'] = widget - -            if getattr(v, 'default', None) is not None: -                kwargs['initial'] = v.default - -            if getattr(v, 'label', None) is not None: -                kwargs['label'] = v.label - -            if getattr(v, 'help_text', None) is not None: -                kwargs['help_text'] = v.help_text - -            fields[k] = v.form_field_class(**kwargs) - -        return fields - -    def _get_form(self, view, method, request): -        # We need to impersonate a request with the correct method, -        # so that eg. any dynamic get_serializer_class methods return the -        # correct form for each method. -        restore = view.request -        request = clone_request(request, method) -        view.request = request -        try: -            return self.get_form(view, method, request) -        finally: -            view.request = restore - -    def _get_raw_data_form(self, view, method, request, media_types): -        # We need to impersonate a request with the correct method, -        # so that eg. any dynamic get_serializer_class methods return the -        # correct form for each method. -        restore = view.request -        request = clone_request(request, method) -        view.request = request -        try: -            return self.get_raw_data_form(view, method, request, media_types) -        finally: -            view.request = restore - -    def get_form(self, view, method, request): +    def get_rendered_html_form(self, view, method, request):          """ -        Get a form, possibly bound to either the input or output data. -        In the absence on of the Resource having an associated form then -        provide a form that can be used to submit arbitrary content. +        Return a string representing a rendered HTML form, possibly bound to +        either the input or output data. + +        In the absence of the View having an associated form then return None.          """ -        obj = getattr(view, 'object', None) -        if not self.show_form_for_method(view, method, request, obj): -            return +        with override_method(view, request, method) as request: +            obj = getattr(view, 'object', None) +            if not self.show_form_for_method(view, method, request, obj): +                return -        if method in ('DELETE', 'OPTIONS'): -            return True  # Don't actually need to return a form +            if method in ('DELETE', 'OPTIONS'): +                return True  # Don't actually need to return a form -        if not getattr(view, 'get_serializer', None) or not parsers.FormParser in view.parser_classes: -            return +            if (not getattr(view, 'get_serializer', None) +                or not any(is_form_media_type(parser.media_type) for parser in view.parser_classes)): +                return -        serializer = view.get_serializer(instance=obj) -        fields = self.serializer_to_form_fields(serializer) +            serializer = view.get_serializer(instance=obj) -        # Creating an on the fly form see: -        # http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python -        OnTheFlyForm = type(str("OnTheFlyForm"), (forms.Form,), fields) -        data = (obj is not None) and serializer.data or None -        form_instance = OnTheFlyForm(data) -        return form_instance +            data = serializer.data +            form_renderer = self.form_renderer_class() +            return form_renderer.render(data, self.accepted_media_type, self.renderer_context) -    def get_raw_data_form(self, view, method, request, media_types): +    def get_raw_data_form(self, view, method, request):          """          Returns a form that allows for arbitrary content types to be tunneled          via standard HTML forms.          (Which are typically application/x-www-form-urlencoded)          """ - -        # If we're not using content overloading there's no point in supplying a generic form, -        # as the view won't treat the form's value as the content of the request. -        if not (api_settings.FORM_CONTENT_OVERRIDE -                and api_settings.FORM_CONTENTTYPE_OVERRIDE): -            return None - -        # Check permissions -        obj = getattr(view, 'object', None) -        if not self.show_form_for_method(view, method, request, obj): -            return - -        content_type_field = api_settings.FORM_CONTENTTYPE_OVERRIDE -        content_field = api_settings.FORM_CONTENT_OVERRIDE -        choices = [(media_type, media_type) for media_type in media_types] -        initial = media_types[0] - -        # NB. http://jacobian.org/writing/dynamic-form-generation/ -        class GenericContentForm(forms.Form): -            def __init__(self): -                super(GenericContentForm, self).__init__() - -                self.fields[content_type_field] = forms.ChoiceField( -                    label='Media type', -                    choices=choices, -                    initial=initial -                ) -                self.fields[content_field] = forms.CharField( -                    label='Content', -                    widget=forms.Textarea -                ) - -        return GenericContentForm() +        with override_method(view, request, method) as request: +            # If we're not using content overloading there's no point in +            # supplying a generic form, as the view won't treat the form's +            # value as the content of the request. +            if not (api_settings.FORM_CONTENT_OVERRIDE +                    and api_settings.FORM_CONTENTTYPE_OVERRIDE): +                return None + +            # Check permissions +            obj = getattr(view, 'object', None) +            if not self.show_form_for_method(view, method, request, obj): +                return + +            # If possible, serialize the initial content for the generic form +            default_parser = view.parser_classes[0] +            renderer_class = getattr(default_parser, 'renderer_class', None) +            if (hasattr(view, 'get_serializer') and renderer_class): +                # View has a serializer defined and parser class has a +                # corresponding renderer that can be used to render the data. + +                # Get a read-only version of the serializer +                serializer = view.get_serializer(instance=obj) +                if obj is None: +                    for name, field in serializer.fields.items(): +                        if getattr(field, 'read_only', None): +                            del serializer.fields[name] + +                # Render the raw data content +                renderer = renderer_class() +                accepted = self.accepted_media_type +                context = self.renderer_context.copy() +                context['indent'] = 4 +                content = renderer.render(serializer.data, accepted, context) +            else: +                content = None + +            # Generate a generic form that includes a content type field, +            # and a content field. +            content_type_field = api_settings.FORM_CONTENTTYPE_OVERRIDE +            content_field = api_settings.FORM_CONTENT_OVERRIDE + +            media_types = [parser.media_type for parser in view.parser_classes] +            choices = [(media_type, media_type) for media_type in media_types] +            initial = media_types[0] + +            # NB. http://jacobian.org/writing/dynamic-form-generation/ +            class GenericContentForm(forms.Form): +                def __init__(self): +                    super(GenericContentForm, self).__init__() + +                    self.fields[content_type_field] = forms.ChoiceField( +                        label='Media type', +                        choices=choices, +                        initial=initial +                    ) +                    self.fields[content_field] = forms.CharField( +                        label='Content', +                        widget=forms.Textarea, +                        initial=content +                    ) + +            return GenericContentForm()      def get_name(self, view):          return view.get_view_name() @@ -509,26 +568,25 @@ class BrowsableAPIRenderer(BaseRenderer):          """          Render the HTML for the browsable API representation.          """ -        accepted_media_type = accepted_media_type or '' -        renderer_context = renderer_context or {} +        self.accepted_media_type = accepted_media_type or '' +        self.renderer_context = renderer_context or {}          view = renderer_context['view']          request = renderer_context['request']          response = renderer_context['response'] -        media_types = [parser.media_type for parser in view.parser_classes]          renderer = self.get_default_renderer(view)          content = self.get_content(renderer, data, accepted_media_type, renderer_context) -        put_form = self._get_form(view, 'PUT', request) -        post_form = self._get_form(view, 'POST', request) -        patch_form = self._get_form(view, 'PATCH', request) -        delete_form = self._get_form(view, 'DELETE', request) -        options_form = self._get_form(view, 'OPTIONS', request) +        put_form = self.get_rendered_html_form(view, 'PUT', request) +        post_form = self.get_rendered_html_form(view, 'POST', request) +        patch_form = self.get_rendered_html_form(view, 'PATCH', request) +        delete_form = self.get_rendered_html_form(view, 'DELETE', request) +        options_form = self.get_rendered_html_form(view, 'OPTIONS', request) -        raw_data_put_form = self._get_raw_data_form(view, 'PUT', request, media_types) -        raw_data_post_form = self._get_raw_data_form(view, 'POST', request, media_types) -        raw_data_patch_form = self._get_raw_data_form(view, 'PATCH', request, media_types) +        raw_data_put_form = self.get_raw_data_form(view, 'PUT', request) +        raw_data_post_form = self.get_raw_data_form(view, 'POST', request) +        raw_data_patch_form = self.get_raw_data_form(view, 'PATCH', request)          raw_data_put_or_patch_form = raw_data_put_form or raw_data_patch_form          name = self.get_name(view) @@ -581,3 +639,4 @@ class MultiPartRenderer(BaseRenderer):      def render(self, data, accepted_media_type=None, renderer_context=None):          return encode_multipart(self.BOUNDARY, data) + diff --git a/rest_framework/request.py b/rest_framework/request.py index 919716f4..977d4d96 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -28,6 +28,29 @@ def is_form_media_type(media_type):              base_media_type == 'multipart/form-data') +class override_method(object): +    """ +    A context manager that temporarily overrides the method on a request, +    additionally setting the `view.request` attribute. + +    Usage: + +        with override_method(view, request, 'POST') as request: +            ... # Do stuff with `view` and `request` +    """ +    def __init__(self, view, request, method): +        self.view = view +        self.request = request +        self.method = method + +    def __enter__(self): +        self.view.request = clone_request(self.request, self.method) +        return self.view.request + +    def __exit__(self, *args, **kwarg): +        self.view.request = self.request + +  class Empty(object):      """      Placeholder for unset attributes. diff --git a/rest_framework/routers.py b/rest_framework/routers.py index b761ba9a..1c7a8158 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -213,7 +213,11 @@ class SimpleRouter(BaseRouter):          Given a viewset, return the portion of URL regex that is used          to match against a single instance.          """ -        base_regex = '(?P<{lookup_field}>[^/]+)' +        if self.trailing_slash: +            base_regex = '(?P<{lookup_field}>[^/]+)' +        else: +            # Don't consume `.json` style suffixes +            base_regex = '(?P<{lookup_field}>[^/.]+)'          lookup_field = getattr(viewset, 'lookup_field', 'pk')          return base_regex.format(lookup_field=lookup_field) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index b3850157..f1775762 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -31,6 +31,9 @@ from rest_framework.relations import *  from rest_framework.fields import * +class RelationsList(list): +    _deleted = [] +  class NestedValidationError(ValidationError):      """      The default ValidationError behavior is to stringify each item in the list @@ -160,7 +163,6 @@ class BaseSerializer(WritableField):          self._data = None          self._files = None          self._errors = None -        self._deleted = None          if many and instance is not None and not hasattr(instance, '__iter__'):              raise ValueError('instance should be a queryset or other iterable with many=True') @@ -297,7 +299,8 @@ class BaseSerializer(WritableField):          Serialize objects -> primitives.          """          ret = self._dict_class() -        ret.fields = {} +        ret.fields = self._dict_class() +        ret.empty = obj is None          for field_name, field in self.fields.items():              field.initialize(parent=self, field_name=field_name) @@ -330,14 +333,15 @@ class BaseSerializer(WritableField):          if self.source == '*':              return self.to_native(obj) +        # Get the raw field value          try:              source = self.source or field_name              value = obj              for component in source.split('.'): -                value = get_component(value, component)                  if value is None:                      break +                value = get_component(value, component)          except ObjectDoesNotExist:              return None @@ -372,6 +376,7 @@ class BaseSerializer(WritableField):          # Set the serializer object if it exists          obj = getattr(self.parent.object, field_name) if self.parent.object else None +        obj = obj.all() if is_simple_callable(getattr(obj, 'all', None)) else obj          if self.source == '*':              if value: @@ -385,7 +390,8 @@ class BaseSerializer(WritableField):                      'data': value,                      'context': self.context,                      'partial': self.partial, -                    'many': self.many +                    'many': self.many, +                    'allow_add_remove': self.allow_add_remove                  }                  serializer = self.__class__(**kwargs) @@ -418,8 +424,17 @@ class BaseSerializer(WritableField):          if self._errors is None:              data, files = self.init_data, self.init_files -            if self.many: -                ret = [] +            if self.many is not None: +                many = self.many +            else: +                many = hasattr(data, '__iter__') and not isinstance(data, (Page, dict, six.text_type)) +                if many: +                    warnings.warn('Implict list/queryset serialization is deprecated. ' +                                  'Use the `many=True` flag when instantiating the serializer.', +                                  DeprecationWarning, stacklevel=3) + +            if many: +                ret = RelationsList()                  errors = []                  update = self.object is not None @@ -446,8 +461,8 @@ class BaseSerializer(WritableField):                          ret.append(self.from_native(item, None))                          errors.append(self._errors) -                    if update: -                        self._deleted = identity_to_objects.values() +                    if update and self.allow_add_remove: +                        ret._deleted = identity_to_objects.values()                      self._errors = any(errors) and errors or []                  else: @@ -490,12 +505,12 @@ class BaseSerializer(WritableField):          """          if isinstance(self.object, list):              [self.save_object(item, **kwargs) for item in self.object] + +            if self.object._deleted: +                [self.delete_object(item) for item in self.object._deleted]          else:              self.save_object(self.object, **kwargs) -        if self.allow_add_remove and self._deleted: -            [self.delete_object(item) for item in self._deleted] -          return self.object      def metadata(self): @@ -771,9 +786,12 @@ class ModelSerializer(Serializer):          cls = self.opts.model          opts = get_concrete_model(cls)._meta          exclusions = [field.name for field in opts.fields + opts.many_to_many] +          for field_name, field in self.fields.items():              field_name = field.source or field_name -            if field_name in exclusions and not field.read_only: +            if field_name in exclusions \ +                and not field.read_only \ +                and not isinstance(field, Serializer):                  exclusions.remove(field_name)          return exclusions @@ -799,6 +817,7 @@ class ModelSerializer(Serializer):          """          m2m_data = {}          related_data = {} +        nested_forward_relations = {}          meta = self.opts.model._meta          # Reverse fk or one-to-one relations @@ -818,6 +837,12 @@ class ModelSerializer(Serializer):              if field.name in attrs:                  m2m_data[field.name] = attrs.pop(field.name) +        # Nested forward relations - These need to be marked so we can save +        # them before saving the parent model instance. +        for field_name in attrs.keys(): +            if isinstance(self.fields.get(field_name, None), Serializer): +                nested_forward_relations[field_name] = attrs[field_name] +          # Update an existing instance...          if instance is not None:              for key, val in attrs.items(): @@ -833,6 +858,7 @@ class ModelSerializer(Serializer):          # at the point of save.          instance._related_data = related_data          instance._m2m_data = m2m_data +        instance._nested_forward_relations = nested_forward_relations          return instance @@ -848,6 +874,14 @@ class ModelSerializer(Serializer):          """          Save the deserialized object and return it.          """ +        if getattr(obj, '_nested_forward_relations', None): +            # Nested relationships need to be saved before we can save the +            # parent instance. +            for field_name, sub_object in obj._nested_forward_relations.items(): +                if sub_object: +                    self.save_object(sub_object) +                setattr(obj, field_name, sub_object) +          obj.save(**kwargs)          if getattr(obj, '_m2m_data', None): @@ -857,7 +891,25 @@ class ModelSerializer(Serializer):          if getattr(obj, '_related_data', None):              for accessor_name, related in obj._related_data.items(): -                setattr(obj, accessor_name, related) +                if isinstance(related, RelationsList): +                    # Nested reverse fk relationship +                    for related_item in related: +                        fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name +                        setattr(related_item, fk_field, obj) +                        self.save_object(related_item) + +                    # Delete any removed objects +                    if related._deleted: +                        [self.delete_object(item) for item in related._deleted] + +                elif isinstance(related, models.Model): +                    # Nested reverse one-one relationship +                    fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name +                    setattr(related, fk_field, obj) +                    self.save_object(related) +                else: +                    # Reverse FK or reverse one-one +                    setattr(obj, accessor_name, related)              del(obj._related_data) @@ -879,6 +931,7 @@ class HyperlinkedModelSerializer(ModelSerializer):      _options_class = HyperlinkedModelSerializerOptions      _default_view_name = '%(model_name)s-detail'      _hyperlink_field_class = HyperlinkedRelatedField +    _hyperlink_identify_field_class = HyperlinkedIdentityField      def get_default_fields(self):          fields = super(HyperlinkedModelSerializer, self).get_default_fields() @@ -887,7 +940,7 @@ class HyperlinkedModelSerializer(ModelSerializer):              self.opts.view_name = self._get_default_view_name(self.opts.model)          if 'url' not in fields: -            url_field = HyperlinkedIdentityField( +            url_field = self._hyperlink_identify_field_class(                  view_name=self.opts.view_name,                  lookup_field=self.opts.lookup_field              ) diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 7d25e513..8c084751 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -48,7 +48,6 @@ DEFAULTS = {      ),      'DEFAULT_THROTTLE_CLASSES': (      ), -      'DEFAULT_CONTENT_NEGOTIATION_CLASS':          'rest_framework.negotiation.DefaultContentNegotiation', @@ -68,15 +67,16 @@ DEFAULTS = {      # Pagination      'PAGINATE_BY': None,      'PAGINATE_BY_PARAM': None, - -    # View configuration -    'VIEW_NAME_FUNCTION': 'rest_framework.views.get_view_name', -    'VIEW_DESCRIPTION_FUNCTION': 'rest_framework.views.get_view_description', +    'MAX_PAGINATE_BY': None,      # Authentication      'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',      'UNAUTHENTICATED_TOKEN': None, +    # View configuration +    'VIEW_NAME_FUNCTION': 'rest_framework.views.get_view_name', +    'VIEW_DESCRIPTION_FUNCTION': 'rest_framework.views.get_view_description', +      # Testing      'TEST_REQUEST_RENDERER_CLASSES': (          'rest_framework.renderers.MultiPartRenderer', diff --git a/rest_framework/static/rest_framework/js/default.js b/rest_framework/static/rest_framework/js/default.js index c74829d7..bcb1964d 100644 --- a/rest_framework/static/rest_framework/js/default.js +++ b/rest_framework/static/rest_framework/js/default.js @@ -1,13 +1,56 @@ +function getCookie(c_name) +{ +    // From http://www.w3schools.com/js/js_cookies.asp +    var c_value = document.cookie; +    var c_start = c_value.indexOf(" " + c_name + "="); +    if (c_start == -1) { +        c_start = c_value.indexOf(c_name + "="); +    } +    if (c_start == -1) { +        c_value = null; +    } else { +        c_start = c_value.indexOf("=", c_start) + 1; +        var c_end = c_value.indexOf(";", c_start); +        if (c_end == -1) { +            c_end = c_value.length; +        } +        c_value = unescape(c_value.substring(c_start,c_end)); +    } +    return c_value; +} + +// JSON highlighting.  prettyPrint(); +// Bootstrap tooltips.  $('.js-tooltip').tooltip({      delay: 1000  }); +// Deal with rounded tab styling after tab clicks.  $('a[data-toggle="tab"]:first').on('shown', function (e) {      $(e.target).parents('.tabbable').addClass('first-tab-active');  });  $('a[data-toggle="tab"]:not(:first)').on('shown', function (e) {      $(e.target).parents('.tabbable').removeClass('first-tab-active');  }); -$('.form-switcher a:first').tab('show'); + +$('a[data-toggle="tab"]').click(function(){ +    document.cookie="tabstyle=" + this.name + "; path=/"; +}); + +// Store tab preference in cookies & display appropriate tab on load. +var selectedTab = null; +var selectedTabName = getCookie('tabstyle'); + +if (selectedTabName) { +    selectedTab = $('.form-switcher a[name=' + selectedTabName + ']'); +} + +if (selectedTab && selectedTab.length > 0) { +    // Display whichever tab is selected. +    selectedTab.tab('show'); +} else { +    // If no tab selected, display rightmost tab. +    $('.form-switcher a:first').tab('show'); +} diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html index 51f9c291..aa90e90c 100644 --- a/rest_framework/templates/rest_framework/base.html +++ b/rest_framework/templates/rest_framework/base.html @@ -128,17 +128,17 @@                  <div {% if post_form %}class="tabbable"{% endif %}>                      {% if post_form %}                      <ul class="nav nav-tabs form-switcher"> -                        <li><a href="#object-form" data-toggle="tab">HTML form</a></li> -                        <li><a href="#generic-content-form" data-toggle="tab">Raw data</a></li> +                        <li><a name='html-tab' href="#object-form" data-toggle="tab">HTML form</a></li> +                        <li><a name='raw-tab' href="#generic-content-form" data-toggle="tab">Raw data</a></li>                      </ul>                      {% endif %}                      <div class="well tab-content">                          {% if post_form %}                          <div class="tab-pane" id="object-form">                              {% with form=post_form %} -                            <form action="{{ request.get_full_path }}" method="POST" {% if form.is_multipart %}enctype="multipart/form-data"{% endif %} class="form-horizontal"> +                            <form action="{{ request.get_full_path }}" method="POST" enctype="multipart/form-data" class="form-horizontal">                                  <fieldset> -                                    {% include "rest_framework/form.html" %} +                                    {{ post_form }}                                      <div class="form-actions">                                          <button class="btn btn-primary" title="Make a POST request on the {{ name }} resource">POST</button>                                      </div> @@ -167,23 +167,21 @@                  <div {% if put_form %}class="tabbable"{% endif %}>                      {% if put_form %}                      <ul class="nav nav-tabs form-switcher"> -                        <li><a href="#object-form" data-toggle="tab">HTML form</a></li> -                        <li><a href="#generic-content-form" data-toggle="tab">Raw data</a></li> +                        <li><a name='html-tab' href="#object-form" data-toggle="tab">HTML form</a></li> +                        <li><a  name='raw-tab' href="#generic-content-form" data-toggle="tab">Raw data</a></li>                      </ul>                      {% endif %}                      <div class="well tab-content">                          {% if put_form %}                          <div class="tab-pane" id="object-form"> -                            {% with form=put_form %} -                            <form action="{{ request.get_full_path }}" method="POST" {% if form.is_multipart %}enctype="multipart/form-data"{% endif %} class="form-horizontal"> +                            <form action="{{ request.get_full_path }}" method="POST" enctype="multipart/form-data" class="form-horizontal">                                  <fieldset> -                                    {% include "rest_framework/form.html" %} +                                    {{ put_form }}                                      <div class="form-actions">                                          <button class="btn btn-primary js-tooltip" name="{{ api_settings.FORM_METHOD_OVERRIDE }}" value="PUT" title="Make a PUT request on the {{ name }} resource">PUT</button>                                      </div>                                  </fieldset>                              </form> -                            {% endwith %}                          </div>                          {% endif %}                          <div {% if put_form %}class="tab-pane"{% endif %} id="generic-content-form"> diff --git a/rest_framework/test.py b/rest_framework/test.py index a18f5a29..234d10a4 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -134,6 +134,8 @@ class APIClient(APIRequestFactory, DjangoClient):          """          self.handler._force_user = user          self.handler._force_token = token +        if user is None: +            self.logout()  # Also clear any possible session info if required      def request(self, **kwargs):          # Ensure that any credentials set get added to every request. diff --git a/rest_framework/tests/test_fields.py b/rest_framework/tests/test_fields.py index ebccba7d..34fbab9c 100644 --- a/rest_framework/tests/test_fields.py +++ b/rest_framework/tests/test_fields.py @@ -688,6 +688,14 @@ class ChoiceFieldTests(TestCase):          f = serializers.ChoiceField(required=False, choices=self.SAMPLE_CHOICES)          self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + self.SAMPLE_CHOICES) +    def test_from_native_empty(self): +        """ +        Make sure from_native() returns None on empty param. +        """ +        f = serializers.ChoiceField(choices=self.SAMPLE_CHOICES) +        result = f.from_native('') +        self.assertEqual(result, None) +  class EmailFieldTests(TestCase):      """ diff --git a/rest_framework/tests/test_files.py b/rest_framework/tests/test_files.py index 487046ac..c13c38b8 100644 --- a/rest_framework/tests/test_files.py +++ b/rest_framework/tests/test_files.py @@ -7,13 +7,13 @@ import datetime  class UploadedFile(object): -    def __init__(self, file, created=None): +    def __init__(self, file=None, created=None):          self.file = file          self.created = created or datetime.datetime.now()  class UploadedFileSerializer(serializers.Serializer): -    file = serializers.FileField() +    file = serializers.FileField(required=False)      created = serializers.DateTimeField()      def restore_object(self, attrs, instance=None): @@ -47,5 +47,36 @@ class FileSerializerTests(TestCase):          now = datetime.datetime.now()          serializer = UploadedFileSerializer(data={'created': now}) +        self.assertTrue(serializer.is_valid()) +        self.assertEqual(serializer.object.created, now) +        self.assertIsNone(serializer.object.file) + +    def test_remove_with_empty_string(self): +        """ +        Passing empty string as data should cause file to be removed + +        Test for: +        https://github.com/tomchristie/django-rest-framework/issues/937 +        """ +        now = datetime.datetime.now() +        file = BytesIO(six.b('stuff')) +        file.name = 'stuff.txt' +        file.size = len(file.getvalue()) + +        uploaded_file = UploadedFile(file=file, created=now) + +        serializer = UploadedFileSerializer(instance=uploaded_file, data={'created': now, 'file': ''}) +        self.assertTrue(serializer.is_valid()) +        self.assertEqual(serializer.object.created, uploaded_file.created) +        self.assertIsNone(serializer.object.file) + +    def test_validation_error_with_non_file(self): +        """ +        Passing non-files should raise a validation error. +        """ +        now = datetime.datetime.now() +        errmsg = 'No file was submitted. Check the encoding type on the form.' + +        serializer = UploadedFileSerializer(data={'created': now, 'file': 'abc'})          self.assertFalse(serializer.is_valid()) -        self.assertIn('file', serializer.errors) +        self.assertEqual(serializer.errors, {'file': [errmsg]}) diff --git a/rest_framework/tests/test_generics.py b/rest_framework/tests/test_generics.py index 1550880b..7a87d389 100644 --- a/rest_framework/tests/test_generics.py +++ b/rest_framework/tests/test_generics.py @@ -338,6 +338,17 @@ class TestInstanceView(TestCase):          new_obj = SlugBasedModel.objects.get(slug='test_slug')          self.assertEqual(new_obj.text, 'foobar') +    def test_patch_cannot_create_an_object(self): +        """ +        PATCH requests should not be able to create objects. +        """ +        data = {'text': 'foobar'} +        request = factory.patch('/999', data, format='json') +        with self.assertNumQueries(1): +            response = self.view(request, pk=999).render() +        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) +        self.assertFalse(self.objects.filter(id=999).exists()) +  class TestOverriddenGetObject(TestCase):      """ diff --git a/rest_framework/tests/test_pagination.py b/rest_framework/tests/test_pagination.py index 85d4640e..4170d4b6 100644 --- a/rest_framework/tests/test_pagination.py +++ b/rest_framework/tests/test_pagination.py @@ -42,6 +42,16 @@ class PaginateByParamView(generics.ListAPIView):      paginate_by_param = 'page_size' +class MaxPaginateByView(generics.ListAPIView): +    """ +    View for testing custom max_paginate_by usage +    """ +    model = BasicModel +    paginate_by = 3 +    max_paginate_by = 5 +    paginate_by_param = 'page_size' + +  class IntegrationTestPagination(TestCase):      """      Integration tests for paginated list views. @@ -313,6 +323,43 @@ class TestCustomPaginateByParam(TestCase):          self.assertEqual(response.data['results'], self.data[:5]) +class TestMaxPaginateByParam(TestCase): +    """ +    Tests for list views with max_paginate_by kwarg +    """ + +    def setUp(self): +        """ +        Create 13 BasicModel instances. +        """ +        for i in range(13): +            BasicModel(text=i).save() +        self.objects = BasicModel.objects +        self.data = [ +            {'id': obj.id, 'text': obj.text} +            for obj in self.objects.all() +        ] +        self.view = MaxPaginateByView.as_view() + +    def test_max_paginate_by(self): +        """ +        If max_paginate_by is set, it should limit page size for the view. +        """ +        request = factory.get('/?page_size=10') +        response = self.view(request).render() +        self.assertEqual(response.data['count'], 13) +        self.assertEqual(response.data['results'], self.data[:5]) + +    def test_max_paginate_by_without_page_size_param(self): +        """ +        If max_paginate_by is set, but client does not specifiy page_size, +        standard `paginate_by` behavior should be used. +        """ +        request = factory.get('/') +        response = self.view(request).render() +        self.assertEqual(response.data['results'], self.data[:3]) + +  ### Tests for context in pagination serializers  class CustomField(serializers.Field): diff --git a/rest_framework/tests/test_relations_nested.py b/rest_framework/tests/test_relations_nested.py index f6d006b3..d393b0c3 100644 --- a/rest_framework/tests/test_relations_nested.py +++ b/rest_framework/tests/test_relations_nested.py @@ -1,107 +1,328 @@  from __future__ import unicode_literals +from django.db import models  from django.test import TestCase  from rest_framework import serializers -from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource -class ForeignKeySourceSerializer(serializers.ModelSerializer): -    class Meta: -        model = ForeignKeySource -        fields = ('id', 'name', 'target') -        depth = 1 +class OneToOneTarget(models.Model): +    name = models.CharField(max_length=100) -class ForeignKeyTargetSerializer(serializers.ModelSerializer): -    class Meta: -        model = ForeignKeyTarget -        fields = ('id', 'name', 'sources') -        depth = 1 +class OneToOneSource(models.Model): +    name = models.CharField(max_length=100) +    target = models.OneToOneField(OneToOneTarget, related_name='source', +                                  null=True, blank=True) -class NullableForeignKeySourceSerializer(serializers.ModelSerializer): -    class Meta: -        model = NullableForeignKeySource -        fields = ('id', 'name', 'target') -        depth = 1 +class OneToManyTarget(models.Model): +    name = models.CharField(max_length=100) -class NullableOneToOneTargetSerializer(serializers.ModelSerializer): -    class Meta: -        model = OneToOneTarget -        fields = ('id', 'name', 'nullable_source') -        depth = 1 +class OneToManySource(models.Model): +    name = models.CharField(max_length=100) +    target = models.ForeignKey(OneToManyTarget, related_name='sources') -class ReverseForeignKeyTests(TestCase): +class ReverseNestedOneToOneTests(TestCase):      def setUp(self): -        target = ForeignKeyTarget(name='target-1') -        target.save() -        new_target = ForeignKeyTarget(name='target-2') -        new_target.save() +        class OneToOneSourceSerializer(serializers.ModelSerializer): +            class Meta: +                model = OneToOneSource +                fields = ('id', 'name') + +        class OneToOneTargetSerializer(serializers.ModelSerializer): +            source = OneToOneSourceSerializer() + +            class Meta: +                model = OneToOneTarget +                fields = ('id', 'name', 'source') + +        self.Serializer = OneToOneTargetSerializer +          for idx in range(1, 4): -            source = ForeignKeySource(name='source-%d' % idx, target=target) +            target = OneToOneTarget(name='target-%d' % idx) +            target.save() +            source = OneToOneSource(name='source-%d' % idx, target=target)              source.save() -    def test_foreign_key_retrieve(self): -        queryset = ForeignKeySource.objects.all() -        serializer = ForeignKeySourceSerializer(queryset, many=True) +    def test_one_to_one_retrieve(self): +        queryset = OneToOneTarget.objects.all() +        serializer = self.Serializer(queryset, many=True)          expected = [ -            {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, -            {'id': 2, 'name': 'source-2', 'target': {'id': 1, 'name': 'target-1'}}, -            {'id': 3, 'name': 'source-3', 'target': {'id': 1, 'name': 'target-1'}}, +            {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, +            {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, +            {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}}          ]          self.assertEqual(serializer.data, expected) -    def test_reverse_foreign_key_retrieve(self): -        queryset = ForeignKeyTarget.objects.all() -        serializer = ForeignKeyTargetSerializer(queryset, many=True) +    def test_one_to_one_create(self): +        data = {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}} +        serializer = self.Serializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'target-4') + +        # Ensure (target 4, target_source 4, source 4) are added, and +        # everything else is as expected. +        queryset = OneToOneTarget.objects.all() +        serializer = self.Serializer(queryset, many=True)          expected = [ -            {'id': 1, 'name': 'target-1', 'sources': [ -                {'id': 1, 'name': 'source-1', 'target': 1}, -                {'id': 2, 'name': 'source-2', 'target': 1}, -                {'id': 3, 'name': 'source-3', 'target': 1}, -            ]}, -            {'id': 2, 'name': 'target-2', 'sources': [ -            ]} +            {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, +            {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, +            {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}}, +            {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}}          ]          self.assertEqual(serializer.data, expected) +    def test_one_to_one_create_with_invalid_data(self): +        data = {'id': 4, 'name': 'target-4', 'source': {'id': 4}} +        serializer = self.Serializer(data=data) +        self.assertFalse(serializer.is_valid()) +        self.assertEqual(serializer.errors, {'source': [{'name': ['This field is required.']}]}) -class NestedNullableForeignKeyTests(TestCase): +    def test_one_to_one_update(self): +        data = {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}} +        instance = OneToOneTarget.objects.get(pk=3) +        serializer = self.Serializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'target-3-updated') + +        # Ensure (target 3, target_source 3, source 3) are updated, +        # and everything else is as expected. +        queryset = OneToOneTarget.objects.all() +        serializer = self.Serializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, +            {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, +            {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}} +        ] +        self.assertEqual(serializer.data, expected) + + +class ForwardNestedOneToOneTests(TestCase):      def setUp(self): -        target = ForeignKeyTarget(name='target-1') -        target.save() +        class OneToOneTargetSerializer(serializers.ModelSerializer): +            class Meta: +                model = OneToOneTarget +                fields = ('id', 'name') + +        class OneToOneSourceSerializer(serializers.ModelSerializer): +            target = OneToOneTargetSerializer() + +            class Meta: +                model = OneToOneSource +                fields = ('id', 'name', 'target') + +        self.Serializer = OneToOneSourceSerializer +          for idx in range(1, 4): -            if idx == 3: -                target = None -            source = NullableForeignKeySource(name='source-%d' % idx, target=target) +            target = OneToOneTarget(name='target-%d' % idx) +            target.save() +            source = OneToOneSource(name='source-%d' % idx, target=target)              source.save() -    def test_foreign_key_retrieve_with_null(self): -        queryset = NullableForeignKeySource.objects.all() -        serializer = NullableForeignKeySourceSerializer(queryset, many=True) +    def test_one_to_one_retrieve(self): +        queryset = OneToOneSource.objects.all() +        serializer = self.Serializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, +            {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, +            {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_one_to_one_create(self): +        data = {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}} +        serializer = self.Serializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'source-4') + +        # Ensure (target 4, target_source 4, source 4) are added, and +        # everything else is as expected. +        queryset = OneToOneSource.objects.all() +        serializer = self.Serializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, +            {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, +            {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}}, +            {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_one_to_one_create_with_invalid_data(self): +        data = {'id': 4, 'name': 'source-4', 'target': {'id': 4}} +        serializer = self.Serializer(data=data) +        self.assertFalse(serializer.is_valid()) +        self.assertEqual(serializer.errors, {'target': [{'name': ['This field is required.']}]}) + +    def test_one_to_one_update(self): +        data = {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}} +        instance = OneToOneSource.objects.get(pk=3) +        serializer = self.Serializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'source-3-updated') + +        # Ensure (target 3, target_source 3, source 3) are updated, +        # and everything else is as expected. +        queryset = OneToOneSource.objects.all() +        serializer = self.Serializer(queryset, many=True)          expected = [              {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, -            {'id': 2, 'name': 'source-2', 'target': {'id': 1, 'name': 'target-1'}}, -            {'id': 3, 'name': 'source-3', 'target': None}, +            {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, +            {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}}          ]          self.assertEqual(serializer.data, expected) +    def test_one_to_one_update_to_null(self): +        data = {'id': 3, 'name': 'source-3-updated', 'target': None} +        instance = OneToOneSource.objects.get(pk=3) +        serializer = self.Serializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() -class NestedNullableOneToOneTests(TestCase): +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'source-3-updated') +        self.assertEqual(obj.target, None) + +        queryset = OneToOneSource.objects.all() +        serializer = self.Serializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, +            {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, +            {'id': 3, 'name': 'source-3-updated', 'target': None} +        ] +        self.assertEqual(serializer.data, expected) + +    # TODO: Nullable 1-1 tests +    # def test_one_to_one_delete(self): +    #     data = {'id': 3, 'name': 'target-3', 'target_source': None} +    #     instance = OneToOneTarget.objects.get(pk=3) +    #     serializer = self.Serializer(instance, data=data) +    #     self.assertTrue(serializer.is_valid()) +    #     serializer.save() + +    #     # Ensure (target_source 3, source 3) are deleted, +    #     # and everything else is as expected. +    #     queryset = OneToOneTarget.objects.all() +    #     serializer = self.Serializer(queryset) +    #     expected = [ +    #         {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, +    #         {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, +    #         {'id': 3, 'name': 'target-3', 'source': None} +    #     ] +    #     self.assertEqual(serializer.data, expected) + + +class ReverseNestedOneToManyTests(TestCase):      def setUp(self): -        target = OneToOneTarget(name='target-1') +        class OneToManySourceSerializer(serializers.ModelSerializer): +            class Meta: +                model = OneToManySource +                fields = ('id', 'name') + +        class OneToManyTargetSerializer(serializers.ModelSerializer): +            sources = OneToManySourceSerializer(many=True, allow_add_remove=True) + +            class Meta: +                model = OneToManyTarget +                fields = ('id', 'name', 'sources') + +        self.Serializer = OneToManyTargetSerializer + +        target = OneToManyTarget(name='target-1')          target.save() -        new_target = OneToOneTarget(name='target-2') -        new_target.save() -        source = NullableOneToOneSource(name='source-1', target=target) -        source.save() +        for idx in range(1, 4): +            source = OneToManySource(name='source-%d' % idx, target=target) +            source.save() -    def test_reverse_foreign_key_retrieve_with_null(self): -        queryset = OneToOneTarget.objects.all() -        serializer = NullableOneToOneTargetSerializer(queryset, many=True) +    def test_one_to_many_retrieve(self): +        queryset = OneToManyTarget.objects.all() +        serializer = self.Serializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, +                                                      {'id': 2, 'name': 'source-2'}, +                                                      {'id': 3, 'name': 'source-3'}]}, +        ] +        self.assertEqual(serializer.data, expected) + +    def test_one_to_many_create(self): +        data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, +                                                         {'id': 2, 'name': 'source-2'}, +                                                         {'id': 3, 'name': 'source-3'}, +                                                         {'id': 4, 'name': 'source-4'}]} +        instance = OneToManyTarget.objects.get(pk=1) +        serializer = self.Serializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'target-1') + +        # Ensure source 4 is added, and everything else is as +        # expected. +        queryset = OneToManyTarget.objects.all() +        serializer = self.Serializer(queryset, many=True)          expected = [ -            {'id': 1, 'name': 'target-1', 'nullable_source': {'id': 1, 'name': 'source-1', 'target': 1}}, -            {'id': 2, 'name': 'target-2', 'nullable_source': None}, +            {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, +                                                      {'id': 2, 'name': 'source-2'}, +                                                      {'id': 3, 'name': 'source-3'}, +                                                      {'id': 4, 'name': 'source-4'}]} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_one_to_many_create_with_invalid_data(self): +        data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, +                                                         {'id': 2, 'name': 'source-2'}, +                                                         {'id': 3, 'name': 'source-3'}, +                                                         {'id': 4}]} +        serializer = self.Serializer(data=data) +        self.assertFalse(serializer.is_valid()) +        self.assertEqual(serializer.errors, {'sources': [{}, {}, {}, {'name': ['This field is required.']}]}) + +    def test_one_to_many_update(self): +        data = {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'}, +                                                                 {'id': 2, 'name': 'source-2'}, +                                                                 {'id': 3, 'name': 'source-3'}]} +        instance = OneToManyTarget.objects.get(pk=1) +        serializer = self.Serializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'target-1-updated') + +        # Ensure (target 1, source 1) are updated, +        # and everything else is as expected. +        queryset = OneToManyTarget.objects.all() +        serializer = self.Serializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'}, +                                                              {'id': 2, 'name': 'source-2'}, +                                                              {'id': 3, 'name': 'source-3'}]} + +        ] +        self.assertEqual(serializer.data, expected) + +    def test_one_to_many_delete(self): +        data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, +                                                         {'id': 3, 'name': 'source-3'}]} +        instance = OneToManyTarget.objects.get(pk=1) +        serializer = self.Serializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        serializer.save() + +        # Ensure source 2 is deleted, and everything else is as +        # expected. +        queryset = OneToManyTarget.objects.all() +        serializer = self.Serializer(queryset, many=True) +        expected = [ +            {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, +                                                      {'id': 3, 'name': 'source-3'}]} +          ]          self.assertEqual(serializer.data, expected) diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py index c3597e38..3f456fef 100644 --- a/rest_framework/tests/test_routers.py +++ b/rest_framework/tests/test_routers.py @@ -146,7 +146,7 @@ class TestTrailingSlashRemoved(TestCase):          self.urls = self.router.urls      def test_urls_can_have_trailing_slash_removed(self): -        expected = ['^notes$', '^notes/(?P<pk>[^/]+)$'] +        expected = ['^notes$', '^notes/(?P<pk>[^/.]+)$']          for idx in range(len(expected)):              self.assertEqual(expected[idx], self.urls[idx].regex.pattern) diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py index 49d45fc2..48b8956b 100644 --- a/rest_framework/tests/test_testing.py +++ b/rest_framework/tests/test_testing.py @@ -17,8 +17,18 @@ def view(request):      }) +@api_view(['GET', 'POST']) +def session_view(request): +    active_session = request.session.get('active_session', False) +    request.session['active_session'] = True +    return Response({ +        'active_session': active_session +    }) + +  urlpatterns = patterns('',      url(r'^view/$', view), +    url(r'^session-view/$', session_view),  ) @@ -46,6 +56,26 @@ class TestAPITestClient(TestCase):          response = self.client.get('/view/')          self.assertEqual(response.data['user'], 'example') +    def test_force_authenticate_with_sessions(self): +        """ +        Setting `.force_authenticate()` forcibly authenticates each request. +        """ +        user = User.objects.create_user('example', 'example@example.com') +        self.client.force_authenticate(user) + +        # First request does not yet have an active session +        response = self.client.get('/session-view/') +        self.assertEqual(response.data['active_session'], False) + +        # Subsequant requests have an active session +        response = self.client.get('/session-view/') +        self.assertEqual(response.data['active_session'], True) + +        # Force authenticating as `None` should also logout the user session. +        self.client.force_authenticate(None) +        response = self.client.get('/session-view/') +        self.assertEqual(response.data['active_session'], False) +      def test_csrf_exempt_by_default(self):          """          By default, the test client is CSRF exempt. diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py index 65b45593..a946d837 100644 --- a/rest_framework/throttling.py +++ b/rest_framework/throttling.py @@ -2,7 +2,7 @@  Provides various throttling policies.  """  from __future__ import unicode_literals -from django.core.cache import cache +from django.core.cache import cache as default_cache  from django.core.exceptions import ImproperlyConfigured  from rest_framework.settings import api_settings  import time @@ -39,6 +39,7 @@ class SimpleRateThrottle(BaseThrottle):      Previous request information used for throttling is stored in the cache.      """ +    cache = default_cache      timer = time.time      cache_format = 'throtte_%(scope)s_%(ident)s'      scope = None @@ -99,7 +100,7 @@ class SimpleRateThrottle(BaseThrottle):          if self.key is None:              return True -        self.history = cache.get(self.key, []) +        self.history = self.cache.get(self.key, [])          self.now = self.timer()          # Drop any requests from the history which have now passed the @@ -116,7 +117,7 @@ class SimpleRateThrottle(BaseThrottle):          into the cache.          """          self.history.insert(0, self.now) -        cache.set(self.key, self.history, self.duration) +        self.cache.set(self.key, self.history, self.duration)          return True      def throttle_failure(self): @@ -151,7 +152,9 @@ class AnonRateThrottle(SimpleRateThrottle):          if request.user.is_authenticated():              return None  # Only throttle unauthenticated requests. -        ident = request.META.get('REMOTE_ADDR', None) +        ident = request.META.get('HTTP_X_FORWARDED_FOR') +        if ident is None: +            ident = request.META.get('REMOTE_ADDR')          return self.cache_format % {              'scope': self.scope, diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py index 0384faba..e6690d17 100644 --- a/rest_framework/utils/breadcrumbs.py +++ b/rest_framework/utils/breadcrumbs.py @@ -8,8 +8,11 @@ def get_breadcrumbs(url):      tuple of (name, url).      """ +    from rest_framework.settings import api_settings      from rest_framework.views import APIView +    view_name_func = api_settings.VIEW_NAME_FUNCTION +      def breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen):          """          Add tuples of (name, url) to the breadcrumbs list, @@ -28,8 +31,8 @@ def get_breadcrumbs(url):                  # Don't list the same view twice in a row.                  # Probably an optional trailing slash.                  if not seen or seen[-1] != view: -                    instance = view.cls() -                    name = instance.get_view_name() +                    suffix = getattr(view, 'suffix', None) +                    name = view_name_func(cls, suffix)                      breadcrumbs_list.insert(0, (name, prefix + url))                      seen.append(view) diff --git a/rest_framework/views.py b/rest_framework/views.py index 727a9f95..4cff0422 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -15,8 +15,14 @@ from rest_framework.settings import api_settings  from rest_framework.utils import formatting -def get_view_name(cls, suffix=None): -    name = cls.__name__ +def get_view_name(view_cls, suffix=None): +    """ +    Given a view class, return a textual name to represent the view. +    This name is used in the browsable API, and in OPTIONS responses. + +    This function is the default for the `VIEW_NAME_FUNCTION` setting. +    """ +    name = view_cls.__name__      name = formatting.remove_trailing_string(name, 'View')      name = formatting.remove_trailing_string(name, 'ViewSet')      name = formatting.camelcase_to_spaces(name) @@ -25,17 +31,56 @@ def get_view_name(cls, suffix=None):      return name -def get_view_description(cls, html=False): -    description = cls.__doc__ or '' +def get_view_description(view_cls, html=False): +    """ +    Given a view class, return a textual description to represent the view. +    This name is used in the browsable API, and in OPTIONS responses. + +    This function is the default for the `VIEW_DESCRIPTION_FUNCTION` setting. +    """ +    description = view_cls.__doc__ or ''      description = formatting.dedent(smart_text(description))      if html:          return formatting.markup_description(description)      return description +def exception_handler(exc): +    """ +    Returns the response that should be used for any given exception. + +    By default we handle the REST framework `APIException`, and also +    Django's builtin `Http404` and `PermissionDenied` exceptions. + +    Any unhandled exceptions may return `None`, which will cause a 500 error +    to be raised. +    """ +    if isinstance(exc, exceptions.APIException): +        headers = {} +        if getattr(exc, 'auth_header', None): +            headers['WWW-Authenticate'] = exc.auth_header +        if getattr(exc, 'wait', None): +            headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait + +        return Response({'detail': exc.detail}, +                        status=exc.status_code, +                        headers=headers) + +    elif isinstance(exc, Http404): +        return Response({'detail': 'Not found'}, +                        status=status.HTTP_404_NOT_FOUND) + +    elif isinstance(exc, PermissionDenied): +        return Response({'detail': 'Permission denied'}, +                        status=status.HTTP_403_FORBIDDEN) + +    # Note: Unhandled exceptions will raise a 500 error. +    return None + +  class APIView(View): -    settings = api_settings +    # The following policies may be set at either globally, or per-view.      renderer_classes = api_settings.DEFAULT_RENDERER_CLASSES      parser_classes = api_settings.DEFAULT_PARSER_CLASSES      authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES @@ -43,6 +88,9 @@ class APIView(View):      permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES      content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS +    # Allow dependancy injection of other settings to make testing easier. +    settings = api_settings +      @classmethod      def as_view(cls, **initkwargs):          """ @@ -133,7 +181,7 @@ class APIView(View):          Return the view name, as used in OPTIONS responses and in the          browsable API.          """ -        func = api_settings.VIEW_NAME_FUNCTION +        func = self.settings.VIEW_NAME_FUNCTION          return func(self.__class__, getattr(self, 'suffix', None))      def get_view_description(self, html=False): @@ -141,7 +189,7 @@ class APIView(View):          Return some descriptive text for the view, as used in OPTIONS responses          and in the browsable API.          """ -        func = api_settings.VIEW_DESCRIPTION_FUNCTION +        func = self.settings.VIEW_DESCRIPTION_FUNCTION          return func(self.__class__, html)      # API policy instantiation methods @@ -303,33 +351,23 @@ class APIView(View):          Handle any exception that occurs, by returning an appropriate response,          or re-raising the error.          """ -        if isinstance(exc, exceptions.Throttled) and exc.wait is not None: -            # Throttle wait header -            self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait -          if isinstance(exc, (exceptions.NotAuthenticated,                              exceptions.AuthenticationFailed)):              # WWW-Authenticate header for 401 responses, else coerce to 403              auth_header = self.get_authenticate_header(self.request)              if auth_header: -                self.headers['WWW-Authenticate'] = auth_header +                exc.auth_header = auth_header              else:                  exc.status_code = status.HTTP_403_FORBIDDEN -        if isinstance(exc, exceptions.APIException): -            return Response({'detail': exc.detail}, -                            status=exc.status_code, -                            exception=True) -        elif isinstance(exc, Http404): -            return Response({'detail': 'Not found'}, -                            status=status.HTTP_404_NOT_FOUND, -                            exception=True) -        elif isinstance(exc, PermissionDenied): -            return Response({'detail': 'Permission denied'}, -                            status=status.HTTP_403_FORBIDDEN, -                            exception=True) -        raise +        response = exception_handler(exc) + +        if response is None: +            raise + +        response.exception = True +        return response      # Note: session based authentication is explicitly CSRF validated,      # all other authentication is CSRF exempt. | 
