diff options
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/parsers.py | 20 | ||||
| -rw-r--r-- | rest_framework/serializers.py | 28 | ||||
| -rw-r--r-- | rest_framework/utils/formatting.py | 6 |
3 files changed, 44 insertions, 10 deletions
diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index aa4fd3f1..c287908d 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -11,7 +11,7 @@ from django.http import QueryDict from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser from django.http.multipartparser import MultiPartParserError, parse_header, ChunkIter from django.utils import six -from rest_framework.compat import etree, yaml, force_text +from rest_framework.compat import etree, yaml, force_text, urlparse from rest_framework.exceptions import ParseError from rest_framework import renderers import json @@ -290,6 +290,22 @@ class FileUploadParser(BaseParser): try: meta = parser_context['request'].META disposition = parse_header(meta['HTTP_CONTENT_DISPOSITION'].encode('utf-8')) - return force_text(disposition[1]['filename']) + filename_parm = disposition[1] + if 'filename*' in filename_parm: + return self.get_encoded_filename(filename_parm) + return force_text(filename_parm['filename']) except (AttributeError, KeyError): pass + + def get_encoded_filename(self, filename_parm): + """ + Handle encoded filenames per RFC6266. See also: + http://tools.ietf.org/html/rfc2231#section-4 + """ + encoded_filename = force_text(filename_parm['filename*']) + try: + charset, lang, filename = encoded_filename.split('\'', 2) + filename = urlparse.unquote(filename) + except (ValueError, LookupError): + filename = force_text(filename_parm['filename']) + return filename diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index be8ad3f2..b3db3582 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -625,6 +625,20 @@ class ModelSerializerOptions(SerializerOptions): self.write_only_fields = getattr(meta, 'write_only_fields', ()) +def _get_class_mapping(mapping, obj): + """ + Takes a dictionary with classes as keys, and an object. + Traverses the object's inheritance hierarchy in method + resolution order, and returns the first matching value + from the dictionary or None. + + """ + return next( + (mapping[cls] for cls in inspect.getmro(obj.__class__) if cls in mapping), + None + ) + + class ModelSerializer(Serializer): """ A serializer that deals with model instances and querysets. @@ -899,15 +913,17 @@ class ModelSerializer(Serializer): models.URLField: ['max_length'], } - if model_field.__class__ in attribute_dict: - attributes = attribute_dict[model_field.__class__] + attributes = _get_class_mapping(attribute_dict, model_field) + if attributes: for attribute in attributes: kwargs.update({attribute: getattr(model_field, attribute)}) - try: - return self.field_mapping[model_field.__class__](**kwargs) - except KeyError: - return ModelField(model_field=model_field, **kwargs) + serializer_field_class = _get_class_mapping( + self.field_mapping, model_field) + + if serializer_field_class: + return serializer_field_class(**kwargs) + return ModelField(model_field=model_field, **kwargs) def get_validation_exclusions(self, instance=None): """ diff --git a/rest_framework/utils/formatting.py b/rest_framework/utils/formatting.py index 6d53aed1..470af51b 100644 --- a/rest_framework/utils/formatting.py +++ b/rest_framework/utils/formatting.py @@ -2,11 +2,12 @@ Utility functions to return a formatted name and description for a given view. """ from __future__ import unicode_literals +import re from django.utils.html import escape from django.utils.safestring import mark_safe -from rest_framework.compat import apply_markdown -import re + +from rest_framework.compat import apply_markdown, force_text def remove_trailing_string(content, trailing): @@ -28,6 +29,7 @@ def dedent(content): as it fails to dedent multiline docstrings that include unindented text on the initial line. """ + content = force_text(content) whitespace_counts = [len(line) - len(line.lstrip(' ')) for line in content.splitlines()[1:] if line.lstrip()] |
