aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMarko Tibold2012-11-13 23:26:17 +0100
committerMarko Tibold2012-11-13 23:26:17 +0100
commit5443dd5f3c5f75cd1524eb26c6d5b53df3594f9b (patch)
tree90b6338de44f12090a21315b82f008c2137f9759
parent44e9749e36d31f811db7dc0998c7b8d1e35a784d (diff)
downloaddjango-rest-framework-5443dd5f3c5f75cd1524eb26c6d5b53df3594f9b.tar.bz2
Added a FileField and an ImageField (copied from django.forms.fields).
Adjusted generics, mixins and serializers to take a `files` arg where applicable.
-rw-r--r--rest_framework/fields.py91
-rw-r--r--rest_framework/generics.py3
-rw-r--r--rest_framework/mixins.py4
-rw-r--r--rest_framework/serializers.py21
4 files changed, 108 insertions, 11 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 4c206426..9cd84c0d 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -904,3 +904,94 @@ class FloatField(WritableField):
except (TypeError, ValueError):
msg = self.error_messages['invalid'] % value
raise ValidationError(msg)
+
+
+class FileField(WritableField):
+ type_name = 'FileField'
+
+ default_error_messages = {
+ 'invalid': _("No file was submitted. Check the encoding type on the form."),
+ 'missing': _("No file was submitted."),
+ 'empty': _("The submitted file is empty."),
+ 'max_length': _('Ensure this filename has at most %(max)d characters (it has %(length)d).'),
+ 'contradiction': _('Please either submit a file or check the clear checkbox, not both.')
+ }
+
+ def __init__(self, *args, **kwargs):
+ self.max_length = kwargs.pop('max_length', None)
+ self.allow_empty_file = kwargs.pop('allow_empty_file', False)
+ super(FileField, self).__init__(*args, **kwargs)
+
+ def from_native(self, data):
+ if data in validators.EMPTY_VALUES:
+ return None
+
+ # UploadedFile objects should have name and size attributes.
+ try:
+ file_name = data.name
+ file_size = data.size
+ except AttributeError:
+ raise ValidationError(self.error_messages['invalid'])
+
+ if self.max_length is not None and len(file_name) > self.max_length:
+ error_values = {'max': self.max_length, 'length': len(file_name)}
+ raise ValidationError(self.error_messages['max_length'] % error_values)
+ if not file_name:
+ raise ValidationError(self.error_messages['invalid'])
+ if not self.allow_empty_file and not file_size:
+ raise ValidationError(self.error_messages['empty'])
+
+ return data
+
+ def to_native(self, value):
+ """
+ No need to return anything, the file can be accessed form its url.
+ """
+ return
+
+
+class ImageField(FileField):
+ default_error_messages = {
+ 'invalid_image': _("Upload a valid image. The file you uploaded was either not an image or a corrupted image."),
+ }
+
+ def from_native(self, data):
+ """
+ Checks that the file-upload field data contains a valid image (GIF, JPG,
+ PNG, possibly others -- whatever the Python Imaging Library supports).
+ """
+ f = super(ImageField, self).from_native(data)
+ if f is None:
+ return None
+
+ # Try to import PIL in either of the two ways it can end up installed.
+ try:
+ from PIL import Image
+ except ImportError:
+ import Image
+
+ # We need to get a file object for PIL. We might have a path or we might
+ # have to read the data into memory.
+ if hasattr(data, 'temporary_file_path'):
+ file = data.temporary_file_path()
+ else:
+ if hasattr(data, 'read'):
+ file = BytesIO(data.read())
+ else:
+ file = BytesIO(data['content'])
+
+ try:
+ # load() could spot a truncated JPEG, but it loads the entire
+ # image in memory, which is a DoS vector. See #3848 and #18520.
+ # verify() must be called immediately after the constructor.
+ Image.open(file).verify()
+ except ImportError:
+ # Under PyPy, it is possible to import PIL. However, the underlying
+ # _imaging C module isn't available, so an ImportError will be
+ # raised. Catch and re-raise.
+ raise
+ except Exception: # Python Imaging Library doesn't recognize it as an image
+ raise ValidationError(self.error_messages['invalid_image'])
+ if hasattr(f, 'seek') and callable(f.seek):
+ f.seek(0)
+ return f
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index ebd06e45..d47c39cd 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -44,11 +44,10 @@ class GenericAPIView(views.APIView):
return serializer_class
def get_serializer(self, instance=None, data=None, files=None):
- # TODO: add support for files
# TODO: add support for seperate serializer/deserializer
serializer_class = self.get_serializer_class()
context = self.get_serializer_context()
- return serializer_class(instance, data=data, context=context)
+ return serializer_class(instance, data=data, files=files, context=context)
class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView):
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index c3625a88..991f4c50 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -15,7 +15,7 @@ class CreateModelMixin(object):
Should be mixed in with any `BaseView`.
"""
def create(self, request, *args, **kwargs):
- serializer = self.get_serializer(data=request.DATA)
+ serializer = self.get_serializer(data=request.DATA, files=request.FILES)
if serializer.is_valid():
self.pre_save(serializer.object)
self.object = serializer.save()
@@ -80,7 +80,7 @@ class UpdateModelMixin(object):
self.object = None
success_status = status.HTTP_201_CREATED
- serializer = self.get_serializer(self.object, data=request.DATA)
+ serializer = self.get_serializer(self.object, data=request.DATA, files=request.FILES)
if serializer.is_valid():
self.pre_save(serializer.object)
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 46d4765e..a46432a9 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -91,7 +91,7 @@ class BaseSerializer(Field):
_options_class = SerializerOptions
_dict_class = SortedDictWithMetadata # Set to unsorted dict for backwards compatability with unsorted implementations.
- def __init__(self, instance=None, data=None, context=None, **kwargs):
+ def __init__(self, instance=None, data=None, files=None, context=None, **kwargs):
super(BaseSerializer, self).__init__(**kwargs)
self.opts = self._options_class(self.Meta)
self.fields = copy.deepcopy(self.base_fields)
@@ -101,9 +101,11 @@ class BaseSerializer(Field):
self.context = context or {}
self.init_data = data
+ self.init_files = files
self.object = instance
self._data = None
+ self._files = None
self._errors = None
#####
@@ -187,7 +189,7 @@ class BaseSerializer(Field):
ret.fields[key] = field
return ret
- def restore_fields(self, data):
+ def restore_fields(self, data, files):
"""
Core of deserialization, together with `restore_object`.
Converts a dictionary of data into a dictionary of deserialized fields.
@@ -196,7 +198,10 @@ class BaseSerializer(Field):
reverted_data = {}
for field_name, field in fields.items():
try:
- field.field_from_native(data, field_name, reverted_data)
+ if isinstance(field, (FileField, ImageField)):
+ field.field_from_native(files, field_name, reverted_data)
+ else:
+ field.field_from_native(data, field_name, reverted_data)
except ValidationError as err:
self._errors[field_name] = list(err.messages)
@@ -250,7 +255,7 @@ class BaseSerializer(Field):
return [self.convert_object(item) for item in obj]
return self.convert_object(obj)
- def from_native(self, data):
+ def from_native(self, data, files):
"""
Deserialize primatives -> objects.
"""
@@ -259,8 +264,8 @@ class BaseSerializer(Field):
return (self.from_native(item) for item in data)
self._errors = {}
- if data is not None:
- attrs = self.restore_fields(data)
+ if data is not None or files is not None:
+ attrs = self.restore_fields(data, files)
attrs = self.perform_validation(attrs)
else:
self._errors['non_field_errors'] = ['No input provided']
@@ -288,7 +293,7 @@ class BaseSerializer(Field):
setting self.object if no errors occurred.
"""
if self._errors is None:
- obj = self.from_native(self.init_data)
+ obj = self.from_native(self.init_data, self.init_files)
if not self._errors:
self.object = obj
return self._errors
@@ -440,6 +445,8 @@ class ModelSerializer(Serializer):
models.TextField: CharField,
models.CommaSeparatedIntegerField: CharField,
models.BooleanField: BooleanField,
+ models.FileField: FileField,
+ models.ImageField: ImageField,
}
try:
return field_mapping[model_field.__class__](**kwargs)