aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/fields.py
diff options
context:
space:
mode:
authorTom Christie2013-01-15 17:53:24 +0000
committerTom Christie2013-01-15 17:53:24 +0000
commit71e55cc4f6300959398f7aef4a8d91b6a6a2af57 (patch)
tree68c2080034263d897741da33cbc5e09746006257 /rest_framework/fields.py
parent52847a215d4e8de88e81d9ae79ce8bee9a36a9a2 (diff)
parente1076cfb49b6293aa837cf7bdb4c11988892c598 (diff)
downloaddjango-rest-framework-71e55cc4f6300959398f7aef4a8d91b6a6a2af57.tar.bz2
Merge with latest master
Diffstat (limited to 'rest_framework/fields.py')
-rw-r--r--rest_framework/fields.py455
1 files changed, 235 insertions, 220 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 6ed37823..998911e1 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -1,16 +1,18 @@
import copy
import datetime
import inspect
+import re
import warnings
+from io import BytesIO
+
from django.core import validators
-from django.core.exceptions import ObjectDoesNotExist, ValidationError
-from django.core.urlresolvers import resolve
+from django.core.exceptions import ValidationError
from django.conf import settings
+from django import forms
from django.forms import widgets
from django.utils.encoding import is_protected_type, smart_unicode
from django.utils.translation import ugettext_lazy as _
-from rest_framework.reverse import reverse
from rest_framework.compat import parse_date, parse_datetime
from rest_framework.compat import timezone
@@ -26,9 +28,12 @@ def is_simple_callable(obj):
class Field(object):
+ read_only = True
creation_counter = 0
empty = ''
type_name = None
+ _use_files = None
+ form_field_class = forms.CharField
def __init__(self, source=None):
self.parent = None
@@ -38,18 +43,20 @@ class Field(object):
self.source = source
- def initialize(self, parent):
+ def initialize(self, parent, field_name):
"""
Called to set up a field prior to field_to_native or field_from_native.
parent - The parent serializer.
- model_field - The model field this field corrosponds to, if one exists.
+ model_field - The model field this field corresponds to, if one exists.
"""
self.parent = parent
self.root = parent.root or parent
self.context = self.root.context
+ if self.root.partial:
+ self.required = False
- def field_from_native(self, data, field_name, into):
+ def field_from_native(self, data, files, field_name, into):
"""
Given a dictionary and a field name, updates the dictionary `into`,
with the field and it's deserialized value.
@@ -88,6 +95,8 @@ class Field(object):
return value
elif hasattr(value, '__iter__') and not isinstance(value, (dict, basestring)):
return [self.to_native(item) for item in value]
+ elif isinstance(value, dict):
+ return dict(map(self.to_native, (k, v)) for k, v in value.items())
return smart_unicode(value)
def attributes(self):
@@ -111,17 +120,17 @@ class WritableField(Field):
widget = widgets.TextInput
default = None
- def __init__(self, source=None, readonly=False, required=None,
+ def __init__(self, source=None, read_only=False, required=None,
validators=[], error_messages=None, widget=None,
- default=None):
+ default=None, blank=None):
super(WritableField, self).__init__(source=source)
- self.readonly = readonly
+ self.read_only = read_only
if required is None:
- self.required = not(readonly)
+ self.required = not(read_only)
else:
- assert not readonly, "Cannot set required=True and readonly=True"
+ assert not (read_only and required), "Cannot set required=True and read_only=True"
self.required = required
messages = {}
@@ -131,7 +140,8 @@ class WritableField(Field):
self.error_messages = messages
self.validators = self.default_validators + validators
- self.default = default or self.default
+ self.default = default if default is not None else self.default
+ self.blank = blank
# Widgets are ony used for HTML forms.
widget = widget or self.widget
@@ -161,18 +171,23 @@ class WritableField(Field):
if errors:
raise ValidationError(errors)
- def field_from_native(self, data, field_name, into):
+ def field_from_native(self, data, files, field_name, into):
"""
Given a dictionary and a field name, updates the dictionary `into`,
with the field and it's deserialized value.
"""
- if self.readonly:
+ if self.read_only:
return
try:
- native = data[field_name]
+ if self._use_files:
+ files = files or {}
+ native = files[field_name]
+ else:
+ native = data[field_name]
except KeyError:
- if self.default is not None:
+ if self.default is not None and not self.root.partial:
+ # Note: partial updates shouldn't set defaults
native = self.default
else:
if self.required:
@@ -197,21 +212,32 @@ class WritableField(Field):
class ModelField(WritableField):
"""
- A generic field that can be used against an arbirtrary model field.
+ A generic field that can be used against an arbitrary model field.
"""
def __init__(self, *args, **kwargs):
try:
self.model_field = kwargs.pop('model_field')
except:
raise ValueError("ModelField requires 'model_field' kwarg")
+
+ self.min_length = kwargs.pop('min_length',
+ getattr(self.model_field, 'min_length', None))
+ self.max_length = kwargs.pop('max_length',
+ getattr(self.model_field, 'max_length', None))
+
super(ModelField, self).__init__(*args, **kwargs)
+ if self.min_length is not None:
+ self.validators.append(validators.MinLengthValidator(self.min_length))
+ if self.max_length is not None:
+ self.validators.append(validators.MaxLengthValidator(self.max_length))
+
def from_native(self, value):
- try:
- rel = self.model_field.rel
- except:
+ rel = getattr(self.model_field, "rel", None)
+ if rel is not None:
+ return rel.to._meta.get_field(rel.field_name).to_python(value)
+ else:
return self.model_field.to_python(value)
- return rel.to._meta.get_field(rel.field_name).to_python(value)
def field_to_native(self, obj, field_name):
value = self.model_field._get_val_from_obj(obj)
@@ -224,200 +250,12 @@ class ModelField(WritableField):
"type": self.model_field.get_internal_type()
}
-##### Relational fields #####
-
-
-class RelatedField(WritableField):
- """
- Base class for related model fields.
- """
- def __init__(self, *args, **kwargs):
- self.queryset = kwargs.pop('queryset', None)
- super(RelatedField, self).__init__(*args, **kwargs)
-
- def field_to_native(self, obj, field_name):
- value = getattr(obj, self.source or field_name)
- return self.to_native(value)
-
- def field_from_native(self, data, field_name, into):
- if self.readonly:
- return
-
- value = data.get(field_name)
- into[(self.source or field_name) + '_id'] = self.from_native(value)
-
-
-class ManyRelatedMixin(object):
- """
- Mixin to convert a related field to a many related field.
- """
- def field_to_native(self, obj, field_name):
- value = getattr(obj, self.source or field_name)
- return [self.to_native(item) for item in value.all()]
-
- def field_from_native(self, data, field_name, into):
- if self.readonly:
- return
-
- try:
- # Form data
- value = data.getlist(self.source or field_name)
- except:
- # Non-form data
- value = data.get(self.source or field_name)
- else:
- if value == ['']:
- value = []
- into[field_name] = [self.from_native(item) for item in value]
-
-
-class ManyRelatedField(ManyRelatedMixin, RelatedField):
- """
- Base class for related model managers.
- """
- pass
-
-
-### PrimaryKey relationships
-
-class PrimaryKeyRelatedField(RelatedField):
- """
- Serializes a related field or related object to a pk value.
- """
-
- def to_native(self, pk):
- return pk
-
- def field_to_native(self, obj, field_name):
- try:
- # Prefer obj.serializable_value for performance reasons
- pk = obj.serializable_value(self.source or field_name)
- except AttributeError:
- # RelatedObject (reverse relationship)
- obj = getattr(obj, self.source or field_name)
- return self.to_native(obj.pk)
- # Forward relationship
- return self.to_native(pk)
-
-
-class ManyPrimaryKeyRelatedField(ManyRelatedField):
- """
- Serializes a to-many related field or related manager to a pk value.
- """
- def to_native(self, pk):
- return pk
-
- def field_to_native(self, obj, field_name):
- try:
- # Prefer obj.serializable_value for performance reasons
- queryset = obj.serializable_value(self.source or field_name)
- except AttributeError:
- # RelatedManager (reverse relationship)
- queryset = getattr(obj, self.source or field_name)
- return [self.to_native(item.pk) for item in queryset.all()]
- # Forward relationship
- return [self.to_native(item.pk) for item in queryset.all()]
-
-
-### Hyperlinked relationships
-
-class HyperlinkedRelatedField(RelatedField):
- pk_url_kwarg = 'pk'
- slug_url_kwarg = 'slug'
- slug_field = 'slug'
-
- def __init__(self, *args, **kwargs):
- try:
- self.view_name = kwargs.pop('view_name')
- except:
- raise ValueError("Hyperlinked field requires 'view_name' kwarg")
- super(HyperlinkedRelatedField, self).__init__(*args, **kwargs)
-
- def to_native(self, obj):
- view_name = self.view_name
- request = self.context.get('request', None)
- kwargs = {self.pk_url_kwarg: obj.pk}
- try:
- return reverse(view_name, kwargs=kwargs, request=request)
- except:
- pass
-
- slug = getattr(obj, self.slug_field, None)
-
- if not slug:
- raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name)
-
- kwargs = {self.slug_url_kwarg: slug}
- try:
- return reverse(self.view_name, kwargs=kwargs, request=request)
- except:
- pass
-
- kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug}
- try:
- return reverse(self.view_name, kwargs=kwargs, request=request)
- except:
- pass
-
- raise ValidationError('Could not resolve URL for field using view name "%s"', view_name)
-
- def from_native(self, value):
- # Convert URL -> model instance pk
- # TODO: Use values_list
- try:
- match = resolve(value)
- except:
- raise ValidationError('Invalid hyperlink - No URL match')
-
- if match.url_name != self.view_name:
- raise ValidationError('Invalid hyperlink - Incorrect URL match')
-
- pk = match.kwargs.get(self.pk_url_kwarg, None)
- slug = match.kwargs.get(self.slug_url_kwarg, None)
-
- # Try explicit primary key.
- if pk is not None:
- return pk
- # Next, try looking up by slug.
- elif slug is not None:
- slug_field = self.get_slug_field()
- queryset = self.queryset.filter(**{slug_field: slug})
- # If none of those are defined, it's an error.
- else:
- raise ValidationError('Invalid hyperlink')
-
- try:
- obj = queryset.get()
- except ObjectDoesNotExist:
- raise ValidationError('Invalid hyperlink - object does not exist.')
- return obj.pk
-
-
-class ManyHyperlinkedRelatedField(ManyRelatedMixin, HyperlinkedRelatedField):
- pass
-
-
-class HyperlinkedIdentityField(Field):
- """
- A field that represents the model's identity using a hyperlink.
- """
- def __init__(self, *args, **kwargs):
- # TODO: Make this mandatory, and have the HyperlinkedModelSerializer
- # set it on-the-fly
- self.view_name = kwargs.pop('view_name', None)
- super(HyperlinkedIdentityField, self).__init__(*args, **kwargs)
-
- def field_to_native(self, obj, field_name):
- request = self.context.get('request', None)
- view_name = self.view_name or self.parent.opts.view_name
- view_kwargs = {'pk': obj.pk}
- return reverse(view_name, kwargs=view_kwargs, request=request)
-
##### Typed Fields #####
class BooleanField(WritableField):
type_name = 'BooleanField'
+ form_field_class = forms.BooleanField
widget = widgets.CheckboxInput
default_error_messages = {
'invalid': _(u"'%s' value must be either True or False."),
@@ -430,15 +268,16 @@ class BooleanField(WritableField):
default = False
def from_native(self, value):
- if value in ('t', 'True', '1'):
+ if value in ('true', 't', 'True', '1'):
return True
- if value in ('f', 'False', '0'):
+ if value in ('false', 'f', 'False', '0'):
return False
return bool(value)
class CharField(WritableField):
type_name = 'CharField'
+ form_field_class = forms.CharField
def __init__(self, max_length=None, min_length=None, *args, **kwargs):
self.max_length, self.min_length = max_length, min_length
@@ -448,14 +287,42 @@ class CharField(WritableField):
if max_length is not None:
self.validators.append(validators.MaxLengthValidator(max_length))
+ def validate(self, value):
+ """
+ Validates that the value is supplied (if required).
+ """
+ # if empty string and allow blank
+ if self.blank and not value:
+ return
+ else:
+ super(CharField, self).validate(value)
+
def from_native(self, value):
if isinstance(value, basestring) or value is None:
return value
return smart_unicode(value)
+class URLField(CharField):
+ type_name = 'URLField'
+
+ def __init__(self, **kwargs):
+ kwargs['max_length'] = kwargs.get('max_length', 200)
+ kwargs['validators'] = [validators.URLValidator()]
+ super(URLField, self).__init__(**kwargs)
+
+
+class SlugField(CharField):
+ type_name = 'SlugField'
+
+ def __init__(self, *args, **kwargs):
+ kwargs['max_length'] = kwargs.get('max_length', 50)
+ super(SlugField, self).__init__(*args, **kwargs)
+
+
class ChoiceField(WritableField):
type_name = 'ChoiceField'
+ form_field_class = forms.ChoiceField
widget = widgets.Select
default_error_messages = {
'invalid_choice': _('Select a valid choice. %(value)s is not one of the available choices.'),
@@ -495,13 +362,14 @@ class ChoiceField(WritableField):
if value == smart_unicode(k2):
return True
else:
- if value == smart_unicode(k):
+ if value == smart_unicode(k) or value == k:
return True
return False
class EmailField(CharField):
type_name = 'EmailField'
+ form_field_class = forms.EmailField
default_error_messages = {
'invalid': _('Enter a valid e-mail address.'),
@@ -509,7 +377,10 @@ class EmailField(CharField):
default_validators = [validators.validate_email]
def from_native(self, value):
- return super(EmailField, self).from_native(value).strip()
+ ret = super(EmailField, self).from_native(value)
+ if ret is None:
+ return None
+ return ret.strip()
def __deepcopy__(self, memo):
result = copy.copy(self)
@@ -519,8 +390,39 @@ class EmailField(CharField):
return result
+class RegexField(CharField):
+ type_name = 'RegexField'
+ form_field_class = forms.RegexField
+
+ def __init__(self, regex, max_length=None, min_length=None, *args, **kwargs):
+ super(RegexField, self).__init__(max_length, min_length, *args, **kwargs)
+ self.regex = regex
+
+ def _get_regex(self):
+ return self._regex
+
+ def _set_regex(self, regex):
+ if isinstance(regex, basestring):
+ regex = re.compile(regex)
+ self._regex = regex
+ if hasattr(self, '_regex_validator') and self._regex_validator in self.validators:
+ self.validators.remove(self._regex_validator)
+ self._regex_validator = validators.RegexValidator(regex=regex)
+ self.validators.append(self._regex_validator)
+
+ regex = property(_get_regex, _set_regex)
+
+ def __deepcopy__(self, memo):
+ result = copy.copy(self)
+ memo[id(self)] = result
+ result.validators = self.validators[:]
+ return result
+
+
class DateField(WritableField):
type_name = 'DateField'
+ widget = widgets.DateInput
+ form_field_class = forms.DateField
default_error_messages = {
'invalid': _(u"'%s' value has an invalid date format. It must be "
@@ -531,8 +433,9 @@ class DateField(WritableField):
empty = None
def from_native(self, value):
- if value is None:
- return value
+ if value in validators.EMPTY_VALUES:
+ return None
+
if isinstance(value, datetime.datetime):
if timezone and settings.USE_TZ and timezone.is_aware(value):
# Convert aware datetimes to the default time zone
@@ -557,6 +460,8 @@ class DateField(WritableField):
class DateTimeField(WritableField):
type_name = 'DateTimeField'
+ widget = widgets.DateTimeInput
+ form_field_class = forms.DateTimeField
default_error_messages = {
'invalid': _(u"'%s' value has an invalid format. It must be in "
@@ -570,8 +475,9 @@ class DateTimeField(WritableField):
empty = None
def from_native(self, value):
- if value is None:
- return value
+ if value in validators.EMPTY_VALUES:
+ return None
+
if isinstance(value, datetime.datetime):
return value
if isinstance(value, datetime.date):
@@ -610,6 +516,7 @@ class DateTimeField(WritableField):
class IntegerField(WritableField):
type_name = 'IntegerField'
+ form_field_class = forms.IntegerField
default_error_messages = {
'invalid': _('Enter a whole number.'),
@@ -629,6 +536,7 @@ class IntegerField(WritableField):
def from_native(self, value):
if value in validators.EMPTY_VALUES:
return None
+
try:
value = int(str(value))
except (ValueError, TypeError):
@@ -638,16 +546,123 @@ class IntegerField(WritableField):
class FloatField(WritableField):
type_name = 'FloatField'
+ form_field_class = forms.FloatField
default_error_messages = {
'invalid': _("'%s' value must be a float."),
}
def from_native(self, value):
- if value is None:
- return value
+ if value in validators.EMPTY_VALUES:
+ return None
+
try:
return float(value)
except (TypeError, ValueError):
msg = self.error_messages['invalid'] % value
raise ValidationError(msg)
+
+
+class FileField(WritableField):
+ _use_files = True
+ type_name = 'FileField'
+ form_field_class = forms.FileField
+ widget = widgets.FileInput
+
+ default_error_messages = {
+ 'invalid': _("No file was submitted. Check the encoding type on the form."),
+ 'missing': _("No file was submitted."),
+ 'empty': _("The submitted file is empty."),
+ 'max_length': _('Ensure this filename has at most %(max)d characters (it has %(length)d).'),
+ 'contradiction': _('Please either submit a file or check the clear checkbox, not both.')
+ }
+
+ def __init__(self, *args, **kwargs):
+ self.max_length = kwargs.pop('max_length', None)
+ self.allow_empty_file = kwargs.pop('allow_empty_file', False)
+ super(FileField, self).__init__(*args, **kwargs)
+
+ def from_native(self, data):
+ if data in validators.EMPTY_VALUES:
+ return None
+
+ # UploadedFile objects should have name and size attributes.
+ try:
+ file_name = data.name
+ file_size = data.size
+ except AttributeError:
+ raise ValidationError(self.error_messages['invalid'])
+
+ if self.max_length is not None and len(file_name) > self.max_length:
+ error_values = {'max': self.max_length, 'length': len(file_name)}
+ raise ValidationError(self.error_messages['max_length'] % error_values)
+ if not file_name:
+ raise ValidationError(self.error_messages['invalid'])
+ if not self.allow_empty_file and not file_size:
+ raise ValidationError(self.error_messages['empty'])
+
+ return data
+
+ def to_native(self, value):
+ return value.name
+
+
+class ImageField(FileField):
+ _use_files = True
+ form_field_class = forms.ImageField
+
+ default_error_messages = {
+ 'invalid_image': _("Upload a valid image. The file you uploaded was either not an image or a corrupted image."),
+ }
+
+ def from_native(self, data):
+ """
+ Checks that the file-upload field data contains a valid image (GIF, JPG,
+ PNG, possibly others -- whatever the Python Imaging Library supports).
+ """
+ f = super(ImageField, self).from_native(data)
+ if f is None:
+ return None
+
+ from compat import Image
+ assert Image is not None, 'PIL must be installed for ImageField support'
+
+ # We need to get a file object for PIL. We might have a path or we might
+ # have to read the data into memory.
+ if hasattr(data, 'temporary_file_path'):
+ file = data.temporary_file_path()
+ else:
+ if hasattr(data, 'read'):
+ file = BytesIO(data.read())
+ else:
+ file = BytesIO(data['content'])
+
+ try:
+ # load() could spot a truncated JPEG, but it loads the entire
+ # image in memory, which is a DoS vector. See #3848 and #18520.
+ # verify() must be called immediately after the constructor.
+ Image.open(file).verify()
+ except ImportError:
+ # Under PyPy, it is possible to import PIL. However, the underlying
+ # _imaging C module isn't available, so an ImportError will be
+ # raised. Catch and re-raise.
+ raise
+ except Exception: # Python Imaging Library doesn't recognize it as an image
+ raise ValidationError(self.error_messages['invalid_image'])
+ if hasattr(f, 'seek') and callable(f.seek):
+ f.seek(0)
+ return f
+
+
+class SerializerMethodField(Field):
+ """
+ A field that gets its value by calling a method on the serializer it's attached to.
+ """
+
+ def __init__(self, method_name):
+ self.method_name = method_name
+ super(SerializerMethodField, self).__init__()
+
+ def field_to_native(self, obj, field_name):
+ value = getattr(self.parent, self.method_name)(obj)
+ return self.to_native(value)