aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/fields.py
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework/fields.py')
-rw-r--r--rest_framework/fields.py794
1 files changed, 594 insertions, 200 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 0c78b3fb..ca9c479f 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -1,18 +1,25 @@
from django.conf import settings
-from django.core import validators
-from django.core.exceptions import ValidationError
-from django.utils import timezone
+from django.core.exceptions import ObjectDoesNotExist
+from django.core.exceptions import ValidationError as DjangoValidationError
+from django.core.validators import RegexValidator
+from django.forms import ImageField as DjangoImageField
+from django.utils import six, timezone
from django.utils.dateparse import parse_date, parse_datetime, parse_time
from django.utils.encoding import is_protected_type
from django.utils.translation import ugettext_lazy as _
from rest_framework import ISO_8601
-from rest_framework.compat import smart_text
+from rest_framework.compat import (
+ smart_text, EmailValidator, MinValueValidator, MaxValueValidator,
+ MinLengthValidator, MaxLengthValidator, URLValidator, OrderedDict
+)
+from rest_framework.exceptions import ValidationError
from rest_framework.settings import api_settings
from rest_framework.utils import html, representation, humanize_datetime
+import copy
import datetime
import decimal
import inspect
-import warnings
+import re
class empty:
@@ -49,13 +56,20 @@ def get_attribute(instance, attrs):
Also accepts either attribute lookup on objects or dictionary lookups.
"""
for attr in attrs:
+ if instance is None:
+ # Break out early if we get `None` at any point in a nested lookup.
+ return None
try:
instance = getattr(instance, attr)
+ except ObjectDoesNotExist:
+ return None
except AttributeError as exc:
try:
return instance[attr]
- except (KeyError, TypeError):
+ except (KeyError, TypeError, AttributeError):
raise exc
+ if is_simple_callable(instance):
+ instance = instance()
return instance
@@ -80,14 +94,48 @@ def set_value(dictionary, keys, value):
dictionary[keys[-1]] = value
+class CreateOnlyDefault:
+ """
+ This class may be used to provide default values that are only used
+ for create operations, but that do not return any value for update
+ operations.
+ """
+ def __init__(self, default):
+ self.default = default
+
+ def set_context(self, serializer_field):
+ self.is_update = serializer_field.parent.instance is not None
+
+ def __call__(self):
+ if self.is_update:
+ raise SkipField()
+ if callable(self.default):
+ return self.default()
+ return self.default
+
+ def __repr__(self):
+ return '%s(%s)' % (self.__class__.__name__, repr(self.default))
+
+
+class CurrentUserDefault:
+ def set_context(self, serializer_field):
+ self.user = serializer_field.context['request'].user
+
+ def __call__(self):
+ return self.user
+
+ def __repr__(self):
+ return '%s()' % self.__class__.__name__
+
+
class SkipField(Exception):
pass
NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`'
NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`'
-NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`'
NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`'
+USE_READONLYFIELD = 'Field(read_only=True) should be ReadOnlyField'
MISSING_ERROR_MESSAGE = (
'ValidationError raised by `{class_name}`, but error key `{key}` does '
'not exist in the `error_messages` dictionary.'
@@ -98,14 +146,17 @@ class Field(object):
_creation_counter = 0
default_error_messages = {
- 'required': _('This field is required.')
+ 'required': _('This field is required.'),
+ 'null': _('This field may not be null.')
}
default_validators = []
+ default_empty_html = empty
+ initial = None
def __init__(self, read_only=False, write_only=False,
- required=None, default=empty, initial=None, source=None,
+ required=None, default=empty, initial=empty, source=None,
label=None, help_text=None, style=None,
- error_messages=None, validators=[]):
+ error_messages=None, validators=None, allow_null=False):
self._creation_counter = Field._creation_counter
Field._creation_counter += 1
@@ -116,19 +167,29 @@ class Field(object):
# Some combinations of keyword arguments do not make sense.
assert not (read_only and write_only), NOT_READ_ONLY_WRITE_ONLY
assert not (read_only and required), NOT_READ_ONLY_REQUIRED
- assert not (read_only and default is not empty), NOT_READ_ONLY_DEFAULT
assert not (required and default is not empty), NOT_REQUIRED_DEFAULT
+ assert not (read_only and self.__class__ == Field), USE_READONLYFIELD
self.read_only = read_only
self.write_only = write_only
self.required = required
self.default = default
self.source = source
- self.initial = initial
+ self.initial = self.initial if (initial is empty) else initial
self.label = label
self.help_text = help_text
self.style = {} if style is None else style
- self.validators = validators or self.default_validators[:]
+ self.allow_null = allow_null
+
+ if allow_null and self.default_empty_html is empty:
+ self.default_empty_html = None
+
+ if validators is not None:
+ self.validators = validators[:]
+
+ # These are set up by `.bind()` when the field is added to a serializer.
+ self.field_name = None
+ self.parent = None
# Collect default error message from self and parent classes
messages = {}
@@ -137,26 +198,26 @@ class Field(object):
messages.update(error_messages or {})
self.error_messages = messages
- def __new__(cls, *args, **kwargs):
+ def bind(self, field_name, parent):
"""
- When a field is instantiated, we store the arguments that were used,
- so that we can present a helpful representation of the object.
+ Initializes the field name and parent for the field instance.
+ Called when a field is added to the parent serializer instance.
"""
- instance = super(Field, cls).__new__(cls)
- instance._args = args
- instance._kwargs = kwargs
- return instance
- def bind(self, field_name, parent, root):
- """
- Setup the context for the field instance.
- """
+ # In order to enforce a consistent style, we error if a redundant
+ # 'source' argument has been used. For example:
+ # my_field = serializer.CharField(source='my_field')
+ assert self.source != field_name, (
+ "It is redundant to specify `source='%s'` on field '%s' in "
+ "serializer '%s', because it is the same as the field name. "
+ "Remove the `source` keyword argument." %
+ (field_name, self.__class__.__name__, parent.__class__.__name__)
+ )
+
self.field_name = field_name
self.parent = parent
- self.root = root
- self.context = parent.context
- # `self.label` should deafult to being based on the field name.
+ # `self.label` should default to being based on the field name.
if self.label is None:
self.label = field_name.replace('_', ' ').capitalize()
@@ -171,24 +232,48 @@ class Field(object):
else:
self.source_attrs = self.source.split('.')
+ # .validators is a lazily loaded property, that gets its default
+ # value from `get_validators`.
+ @property
+ def validators(self):
+ if not hasattr(self, '_validators'):
+ self._validators = self.get_validators()
+ return self._validators
+
+ @validators.setter
+ def validators(self, validators):
+ self._validators = validators
+
+ def get_validators(self):
+ return self.default_validators[:]
+
def get_initial(self):
"""
- Return a value to use when the field is being returned as a primative
+ Return a value to use when the field is being returned as a primitive
value, without any object instance.
"""
return self.initial
def get_value(self, dictionary):
"""
- Given the *incoming* primative data, return the value for this field
+ Given the *incoming* primitive data, return the value for this field
that should be validated and transformed to a native value.
"""
+ if html.is_html_input(dictionary):
+ # HTML forms will represent empty fields as '', and cannot
+ # represent None or False values directly.
+ if self.field_name not in dictionary:
+ if getattr(self.root, 'partial', False):
+ return empty
+ return self.default_empty_html
+ ret = dictionary[self.field_name]
+ return self.default_empty_html if (ret == '') else ret
return dictionary.get(self.field_name, empty)
def get_attribute(self, instance):
"""
- Given the *outgoing* object instance, return the value for this field
- that should be returned as a primative value.
+ Given the *outgoing* object instance, return the primitive value
+ that should be used for this field.
"""
return get_attribute(instance, self.source_attrs)
@@ -203,47 +288,74 @@ class Field(object):
"""
if self.default is empty:
raise SkipField()
+ if callable(self.default):
+ if hasattr(self.default, 'set_context'):
+ self.default.set_context(self)
+ return self.default()
return self.default
def run_validation(self, data=empty):
"""
Validate a simple representation and return the internal value.
- The provided data may be `empty` if no representation was included.
- May return `empty` if the field should not be included in the
+ The provided data may be `empty` if no representation was included
+ in the input.
+
+ May raise `SkipField` if the field should not be included in the
validated data.
"""
+ if self.read_only:
+ return self.get_default()
+
if data is empty:
+ if getattr(self.root, 'partial', False):
+ raise SkipField()
if self.required:
self.fail('required')
return self.get_default()
+ if data is None:
+ if not self.allow_null:
+ self.fail('null')
+ return None
+
value = self.to_internal_value(data)
self.run_validators(value)
return value
def run_validators(self, value):
- if value in (None, '', [], (), {}):
- return
-
+ """
+ Test the given value against all the validators on the field,
+ and either raise a `ValidationError` or simply return.
+ """
errors = []
for validator in self.validators:
+ if hasattr(validator, 'set_context'):
+ validator.set_context(self)
+
try:
validator(value)
except ValidationError as exc:
+ # If the validation error contains a mapping of fields to
+ # errors then simply raise it immediately rather than
+ # attempting to accumulate a list of errors.
+ if isinstance(exc.detail, dict):
+ raise
+ errors.extend(exc.detail)
+ except DjangoValidationError as exc:
errors.extend(exc.messages)
if errors:
raise ValidationError(errors)
def to_internal_value(self, data):
"""
- Transform the *incoming* primative data into a native value.
+ Transform the *incoming* primitive data into a native value.
"""
raise NotImplementedError('to_internal_value() must be implemented.')
def to_representation(self, value):
"""
- Transform the *outgoing* native value into primative data.
+ Transform the *outgoing* native value into primitive data.
"""
raise NotImplementedError('to_representation() must be implemented.')
@@ -257,9 +369,59 @@ class Field(object):
class_name = self.__class__.__name__
msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key)
raise AssertionError(msg)
- raise ValidationError(msg.format(**kwargs))
+ message_string = msg.format(**kwargs)
+ raise ValidationError(message_string)
+
+ @property
+ def root(self):
+ """
+ Returns the top-level serializer for this field.
+ """
+ root = self
+ while root.parent is not None:
+ root = root.parent
+ return root
+
+ @property
+ def context(self):
+ """
+ Returns the context as passed to the root serializer on initialization.
+ """
+ return getattr(self.root, '_context', {})
+
+ def __new__(cls, *args, **kwargs):
+ """
+ When a field is instantiated, we store the arguments that were used,
+ so that we can present a helpful representation of the object.
+ """
+ instance = super(Field, cls).__new__(cls)
+ instance._args = args
+ instance._kwargs = kwargs
+ return instance
+
+ def __deepcopy__(self, memo):
+ """
+ When cloning fields we instantiate using the arguments it was
+ originally created with, rather than copying the complete state.
+ """
+ args = copy.deepcopy(self._args)
+ kwargs = dict(self._kwargs)
+ # Bit ugly, but we need to special case 'validators' as Django's
+ # RegexValidator does not support deepcopy.
+ # We treat validator callables as immutable objects.
+ # See https://github.com/tomchristie/django-rest-framework/issues/1954
+ validators = kwargs.pop('validators', None)
+ kwargs = copy.deepcopy(kwargs)
+ if validators is not None:
+ kwargs['validators'] = validators
+ return self.__class__(*args, **kwargs)
def __repr__(self):
+ """
+ Fields are represented using their initial calling arguments.
+ This allows us to create descriptive representations for serializer
+ instances that show all the declared fields on the serializer.
+ """
return representation.field_repr(self)
@@ -269,25 +431,55 @@ class BooleanField(Field):
default_error_messages = {
'invalid': _('`{input}` is not a valid boolean.')
}
+ default_empty_html = False
+ initial = False
TRUE_VALUES = set(('t', 'T', 'true', 'True', 'TRUE', '1', 1, True))
FALSE_VALUES = set(('f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False))
- def get_value(self, dictionary):
- if html.is_html_input(dictionary):
- # HTML forms do not send a `False` value on an empty checkbox,
- # so we override the default empty value to be False.
- return dictionary.get(self.field_name, False)
- return dictionary.get(self.field_name, empty)
+ def __init__(self, **kwargs):
+ assert 'allow_null' not in kwargs, '`allow_null` is not a valid option. Use `NullBooleanField` instead.'
+ super(BooleanField, self).__init__(**kwargs)
+
+ def to_internal_value(self, data):
+ if data in self.TRUE_VALUES:
+ return True
+ elif data in self.FALSE_VALUES:
+ return False
+ self.fail('invalid', input=data)
+
+ def to_representation(self, value):
+ if value in self.TRUE_VALUES:
+ return True
+ elif value in self.FALSE_VALUES:
+ return False
+ return bool(value)
+
+
+class NullBooleanField(Field):
+ default_error_messages = {
+ 'invalid': _('`{input}` is not a valid boolean.')
+ }
+ initial = None
+ TRUE_VALUES = set(('t', 'T', 'true', 'True', 'TRUE', '1', 1, True))
+ FALSE_VALUES = set(('f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False))
+ NULL_VALUES = set(('n', 'N', 'null', 'Null', 'NULL', '', None))
+
+ def __init__(self, **kwargs):
+ assert 'allow_null' not in kwargs, '`allow_null` is not a valid option.'
+ kwargs['allow_null'] = True
+ super(NullBooleanField, self).__init__(**kwargs)
def to_internal_value(self, data):
if data in self.TRUE_VALUES:
return True
elif data in self.FALSE_VALUES:
return False
+ elif data in self.NULL_VALUES:
+ return None
self.fail('invalid', input=data)
def to_representation(self, value):
- if value is None:
+ if value in self.NULL_VALUES:
return None
if value in self.TRUE_VALUES:
return True
@@ -300,136 +492,174 @@ class BooleanField(Field):
class CharField(Field):
default_error_messages = {
- 'blank': _('This field may not be blank.')
+ 'blank': _('This field may not be blank.'),
+ 'max_length': _('Ensure this field has no more than {max_length} characters.'),
+ 'min_length': _('Ensure this field has no more than {min_length} characters.')
}
+ initial = ''
+ coerce_blank_to_null = False
+ default_empty_html = ''
def __init__(self, **kwargs):
self.allow_blank = kwargs.pop('allow_blank', False)
- self.max_length = kwargs.pop('max_length', None)
- self.min_length = kwargs.pop('min_length', None)
+ max_length = kwargs.pop('max_length', None)
+ min_length = kwargs.pop('min_length', None)
super(CharField, self).__init__(**kwargs)
+ if max_length is not None:
+ message = self.error_messages['max_length'].format(max_length=max_length)
+ self.validators.append(MaxLengthValidator(max_length, message=message))
+ if min_length is not None:
+ message = self.error_messages['min_length'].format(min_length=min_length)
+ self.validators.append(MinLengthValidator(min_length, message=message))
+
+ def run_validation(self, data=empty):
+ # Test for the empty string here so that it does not get validated,
+ # and so that subclasses do not need to handle it explicitly
+ # inside the `to_internal_value()` method.
+ if data == '':
+ if not self.allow_blank:
+ self.fail('blank')
+ return ''
+ return super(CharField, self).run_validation(data)
def to_internal_value(self, data):
- if data == '' and not self.allow_blank:
- self.fail('blank')
- if data is None:
- return None
- return str(data)
+ return six.text_type(data)
def to_representation(self, value):
- if value is None:
- return None
- return str(value)
+ return six.text_type(value)
class EmailField(CharField):
default_error_messages = {
'invalid': _('Enter a valid email address.')
}
- default_validators = [validators.validate_email]
+
+ def __init__(self, **kwargs):
+ super(EmailField, self).__init__(**kwargs)
+ validator = EmailValidator(message=self.error_messages['invalid'])
+ self.validators.append(validator)
def to_internal_value(self, data):
- if data == '' and not self.allow_blank:
- self.fail('blank')
- if data is None:
- return None
- return str(data).strip()
+ return six.text_type(data).strip()
def to_representation(self, value):
- if value is None:
- return None
- return str(value).strip()
+ return six.text_type(value).strip()
class RegexField(CharField):
+ default_error_messages = {
+ 'invalid': _('This value does not match the required pattern.')
+ }
+
def __init__(self, regex, **kwargs):
- kwargs['validators'] = (
- [validators.RegexValidator(regex)] +
- kwargs.get('validators', [])
- )
super(RegexField, self).__init__(**kwargs)
+ validator = RegexValidator(regex, message=self.error_messages['invalid'])
+ self.validators.append(validator)
class SlugField(CharField):
default_error_messages = {
'invalid': _("Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens.")
}
- default_validators = [validators.validate_slug]
+
+ def __init__(self, **kwargs):
+ super(SlugField, self).__init__(**kwargs)
+ slug_regex = re.compile(r'^[-a-zA-Z0-9_]+$')
+ validator = RegexValidator(slug_regex, message=self.error_messages['invalid'])
+ self.validators.append(validator)
class URLField(CharField):
default_error_messages = {
'invalid': _("Enter a valid URL.")
}
- default_validators = [validators.URLValidator()]
+
+ def __init__(self, **kwargs):
+ super(URLField, self).__init__(**kwargs)
+ validator = URLValidator(message=self.error_messages['invalid'])
+ self.validators.append(validator)
# Number types...
class IntegerField(Field):
default_error_messages = {
- 'invalid': _('A valid integer is required.')
+ 'invalid': _('A valid integer is required.'),
+ 'max_value': _('Ensure this value is less than or equal to {max_value}.'),
+ 'min_value': _('Ensure this value is greater than or equal to {min_value}.'),
+ 'max_string_length': _('String value too large')
}
+ MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs.
def __init__(self, **kwargs):
max_value = kwargs.pop('max_value', None)
min_value = kwargs.pop('min_value', None)
super(IntegerField, self).__init__(**kwargs)
if max_value is not None:
- self.validators.append(validators.MaxValueValidator(max_value))
+ message = self.error_messages['max_value'].format(max_value=max_value)
+ self.validators.append(MaxValueValidator(max_value, message=message))
if min_value is not None:
- self.validators.append(validators.MinValueValidator(min_value))
+ message = self.error_messages['min_value'].format(min_value=min_value)
+ self.validators.append(MinValueValidator(min_value, message=message))
def to_internal_value(self, data):
+ if isinstance(data, six.text_type) and len(data) > self.MAX_STRING_LENGTH:
+ self.fail('max_string_length')
+
try:
- data = int(str(data))
+ data = int(data)
except (ValueError, TypeError):
self.fail('invalid')
return data
def to_representation(self, value):
- if value is None:
- return None
return int(value)
class FloatField(Field):
default_error_messages = {
- 'invalid': _("'%s' value must be a float."),
+ 'invalid': _("A valid number is required."),
+ 'max_value': _('Ensure this value is less than or equal to {max_value}.'),
+ 'min_value': _('Ensure this value is greater than or equal to {min_value}.'),
+ 'max_string_length': _('String value too large')
}
+ MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs.
def __init__(self, **kwargs):
max_value = kwargs.pop('max_value', None)
min_value = kwargs.pop('min_value', None)
super(FloatField, self).__init__(**kwargs)
if max_value is not None:
- self.validators.append(validators.MaxValueValidator(max_value))
+ message = self.error_messages['max_value'].format(max_value=max_value)
+ self.validators.append(MaxValueValidator(max_value, message=message))
if min_value is not None:
- self.validators.append(validators.MinValueValidator(min_value))
+ message = self.error_messages['min_value'].format(min_value=min_value)
+ self.validators.append(MinValueValidator(min_value, message=message))
- def to_internal_value(self, value):
- if value is None:
- return None
- return float(value)
+ def to_internal_value(self, data):
+ if isinstance(data, six.text_type) and len(data) > self.MAX_STRING_LENGTH:
+ self.fail('max_string_length')
- def to_representation(self, value):
- if value is None:
- return None
try:
- return float(value)
+ return float(data)
except (TypeError, ValueError):
- self.fail('invalid', value=value)
+ self.fail('invalid')
+
+ def to_representation(self, value):
+ return float(value)
class DecimalField(Field):
default_error_messages = {
- 'invalid': _('Enter a number.'),
+ 'invalid': _('A valid number is required.'),
'max_value': _('Ensure this value is less than or equal to {max_value}.'),
'min_value': _('Ensure this value is greater than or equal to {min_value}.'),
'max_digits': _('Ensure that there are no more than {max_digits} digits in total.'),
'max_decimal_places': _('Ensure that there are no more than {max_decimal_places} decimal places.'),
- 'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.')
+ 'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.'),
+ 'max_string_length': _('String value too large')
}
+ MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs.
coerce_to_string = api_settings.COERCE_DECIMAL_TO_STRING
@@ -439,23 +669,25 @@ class DecimalField(Field):
self.coerce_to_string = coerce_to_string if (coerce_to_string is not None) else self.coerce_to_string
super(DecimalField, self).__init__(**kwargs)
if max_value is not None:
- self.validators.append(validators.MaxValueValidator(max_value))
+ message = self.error_messages['max_value'].format(max_value=max_value)
+ self.validators.append(MaxValueValidator(max_value, message=message))
if min_value is not None:
- self.validators.append(validators.MinValueValidator(min_value))
+ message = self.error_messages['min_value'].format(min_value=min_value)
+ self.validators.append(MinValueValidator(min_value, message=message))
- def to_internal_value(self, value):
+ def to_internal_value(self, data):
"""
Validates that the input is a decimal number. Returns a Decimal
instance. Returns None for empty values. Ensures that there are no more
than max_digits in the number, and no more than decimal_places digits
after the decimal point.
"""
- if value in (None, ''):
- return None
+ data = smart_text(data).strip()
+ if len(data) > self.MAX_STRING_LENGTH:
+ self.fail('max_string_length')
- value = smart_text(value).strip()
try:
- value = decimal.Decimal(value)
+ value = decimal.Decimal(data)
except decimal.DecimalException:
self.fail('invalid')
@@ -485,125 +717,116 @@ class DecimalField(Field):
if self.decimal_places is not None and decimals > self.decimal_places:
self.fail('max_decimal_places', max_decimal_places=self.decimal_places)
if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places):
- self.fail('max_whole_digits', max_while_digits=self.max_digits - self.decimal_places)
+ self.fail('max_whole_digits', max_whole_digits=self.max_digits - self.decimal_places)
return value
def to_representation(self, value):
- if isinstance(value, decimal.Decimal):
- context = decimal.getcontext().copy()
- context.prec = self.max_digits
- quantized = value.quantize(
- decimal.Decimal('.1') ** self.decimal_places,
- context=context
- )
- if not self.coerce_to_string:
- return quantized
- return '{0:f}'.format(quantized)
-
+ if not isinstance(value, decimal.Decimal):
+ value = decimal.Decimal(six.text_type(value).strip())
+
+ context = decimal.getcontext().copy()
+ context.prec = self.max_digits
+ quantized = value.quantize(
+ decimal.Decimal('.1') ** self.decimal_places,
+ context=context
+ )
if not self.coerce_to_string:
- return value
- return '%.*f' % (self.max_decimal_places, value)
+ return quantized
+ return '{0:f}'.format(quantized)
# Date & time fields...
-class DateField(Field):
+class DateTimeField(Field):
default_error_messages = {
- 'invalid': _('Date has wrong format. Use one of these formats instead: {format}'),
+ 'invalid': _('Datetime has wrong format. Use one of these formats instead: {format}'),
+ 'date': _('Expected a datetime but got a date.'),
}
- format = api_settings.DATE_FORMAT
- input_formats = api_settings.DATE_INPUT_FORMATS
+ format = api_settings.DATETIME_FORMAT
+ input_formats = api_settings.DATETIME_INPUT_FORMATS
+ default_timezone = timezone.get_default_timezone() if settings.USE_TZ else None
- def __init__(self, format=None, input_formats=None, *args, **kwargs):
- self.format = format if format is not None else self.format
+ def __init__(self, format=empty, input_formats=None, default_timezone=None, *args, **kwargs):
+ self.format = format if format is not empty else self.format
self.input_formats = input_formats if input_formats is not None else self.input_formats
- super(DateField, self).__init__(*args, **kwargs)
+ self.default_timezone = default_timezone if default_timezone is not None else self.default_timezone
+ super(DateTimeField, self).__init__(*args, **kwargs)
+
+ def enforce_timezone(self, value):
+ """
+ When `self.default_timezone` is `None`, always return naive datetimes.
+ When `self.default_timezone` is not `None`, always return aware datetimes.
+ """
+ if (self.default_timezone is not None) and not timezone.is_aware(value):
+ return timezone.make_aware(value, self.default_timezone)
+ elif (self.default_timezone is None) and timezone.is_aware(value):
+ return timezone.make_naive(value, timezone.UTC())
+ return value
def to_internal_value(self, value):
- if value in (None, ''):
- return None
+ if isinstance(value, datetime.date) and not isinstance(value, datetime.datetime):
+ self.fail('date')
if isinstance(value, datetime.datetime):
- if timezone and settings.USE_TZ and timezone.is_aware(value):
- # Convert aware datetimes to the default time zone
- # before casting them to dates (#17742).
- default_timezone = timezone.get_default_timezone()
- value = timezone.make_naive(value, default_timezone)
- return value.date()
-
- if isinstance(value, datetime.date):
- return value
+ return self.enforce_timezone(value)
for format in self.input_formats:
if format.lower() == ISO_8601:
try:
- parsed = parse_date(value)
+ parsed = parse_datetime(value)
except (ValueError, TypeError):
pass
else:
if parsed is not None:
- return parsed
+ return self.enforce_timezone(parsed)
else:
try:
parsed = datetime.datetime.strptime(value, format)
except (ValueError, TypeError):
pass
else:
- return parsed.date()
+ return self.enforce_timezone(parsed)
- humanized_format = humanize_datetime.date_formats(self.input_formats)
+ humanized_format = humanize_datetime.datetime_formats(self.input_formats)
self.fail('invalid', format=humanized_format)
def to_representation(self, value):
- if value is None or self.format is None:
+ if self.format is None:
return value
- if isinstance(value, datetime.datetime):
- value = value.date()
-
if self.format.lower() == ISO_8601:
- return value.isoformat()
+ value = value.isoformat()
+ if value.endswith('+00:00'):
+ value = value[:-6] + 'Z'
+ return value
return value.strftime(self.format)
-class DateTimeField(Field):
+class DateField(Field):
default_error_messages = {
- 'invalid': _('Datetime has wrong format. Use one of these formats instead: {format}'),
+ 'invalid': _('Date has wrong format. Use one of these formats instead: {format}'),
+ 'datetime': _('Expected a date but got a datetime.'),
}
- format = api_settings.DATETIME_FORMAT
- input_formats = api_settings.DATETIME_INPUT_FORMATS
+ format = api_settings.DATE_FORMAT
+ input_formats = api_settings.DATE_INPUT_FORMATS
- def __init__(self, format=None, input_formats=None, *args, **kwargs):
- self.format = format if format is not None else self.format
+ def __init__(self, format=empty, input_formats=None, *args, **kwargs):
+ self.format = format if format is not empty else self.format
self.input_formats = input_formats if input_formats is not None else self.input_formats
- super(DateTimeField, self).__init__(*args, **kwargs)
+ super(DateField, self).__init__(*args, **kwargs)
def to_internal_value(self, value):
- if value in (None, ''):
- return None
-
if isinstance(value, datetime.datetime):
- return value
+ self.fail('datetime')
if isinstance(value, datetime.date):
- value = datetime.datetime(value.year, value.month, value.day)
- if settings.USE_TZ:
- # For backwards compatibility, interpret naive datetimes in
- # local time. This won't work during DST change, but we can't
- # do much about it, so we let the exceptions percolate up the
- # call stack.
- warnings.warn("DateTimeField received a naive datetime (%s)"
- " while time zone support is active." % value,
- RuntimeWarning)
- default_timezone = timezone.get_default_timezone()
- value = timezone.make_aware(value, default_timezone)
return value
for format in self.input_formats:
if format.lower() == ISO_8601:
try:
- parsed = parse_datetime(value)
+ parsed = parse_date(value)
except (ValueError, TypeError):
pass
else:
@@ -615,20 +838,26 @@ class DateTimeField(Field):
except (ValueError, TypeError):
pass
else:
- return parsed
+ return parsed.date()
- humanized_format = humanize_datetime.datetime_formats(self.input_formats)
+ humanized_format = humanize_datetime.date_formats(self.input_formats)
self.fail('invalid', format=humanized_format)
def to_representation(self, value):
- if value is None or self.format is None:
+ if self.format is None:
return value
+ # Applying a `DateField` to a datetime value is almost always
+ # not a sensible thing to do, as it means naively dropping
+ # any explicit or implicit timezone info.
+ assert not isinstance(value, datetime.datetime), (
+ 'Expected a `date`, but got a `datetime`. Refusing to coerce, '
+ 'as this may mean losing timezone information. Use a custom '
+ 'read-only field and deal with timezone issues explicitly.'
+ )
+
if self.format.lower() == ISO_8601:
- ret = value.isoformat()
- if ret.endswith('+00:00'):
- ret = ret[:-6] + 'Z'
- return ret
+ return value.isoformat()
return value.strftime(self.format)
@@ -639,15 +868,12 @@ class TimeField(Field):
format = api_settings.TIME_FORMAT
input_formats = api_settings.TIME_INPUT_FORMATS
- def __init__(self, format=None, input_formats=None, *args, **kwargs):
- self.format = format if format is not None else self.format
+ def __init__(self, format=empty, input_formats=None, *args, **kwargs):
+ self.format = format if format is not empty else self.format
self.input_formats = input_formats if input_formats is not None else self.input_formats
super(TimeField, self).__init__(*args, **kwargs)
- def from_native(self, value):
- if value in (None, ''):
- return None
-
+ def to_internal_value(self, value):
if isinstance(value, datetime.time):
return value
@@ -672,11 +898,17 @@ class TimeField(Field):
self.fail('invalid', format=humanized_format)
def to_representation(self, value):
- if value is None or self.format is None:
+ if self.format is None:
return value
- if isinstance(value, datetime.datetime):
- value = value.time()
+ # Applying a `TimeField` to a datetime value is almost always
+ # not a sensible thing to do, as it means naively dropping
+ # any explicit or implicit timezone info.
+ assert not isinstance(value, datetime.datetime), (
+ 'Expected a `time`, but got a `datetime`. Refusing to coerce, '
+ 'as this may mean losing timezone information. Use a custom '
+ 'read-only field and deal with timezone issues explicitly.'
+ )
if self.format.lower() == ISO_8601:
return value.isoformat()
@@ -699,58 +931,171 @@ class ChoiceField(Field):
for item in choices
]
if all(pairs):
- self.choices = dict([(key, display_value) for key, display_value in choices])
+ self.choices = OrderedDict([(key, display_value) for key, display_value in choices])
else:
- self.choices = dict([(item, item) for item in choices])
+ self.choices = OrderedDict([(item, item) for item in choices])
# Map the string representation of choices to the underlying value.
# Allows us to deal with eg. integer choices while supporting either
# integer or string input, but still get the correct datatype out.
self.choice_strings_to_values = dict([
- (str(key), key) for key in self.choices.keys()
+ (six.text_type(key), key) for key in self.choices.keys()
])
super(ChoiceField, self).__init__(**kwargs)
def to_internal_value(self, data):
try:
- return self.choice_strings_to_values[str(data)]
+ return self.choice_strings_to_values[six.text_type(data)]
except KeyError:
self.fail('invalid_choice', input=data)
def to_representation(self, value):
- return value
+ if value in ('', None):
+ return value
+ return self.choice_strings_to_values[six.text_type(value)]
class MultipleChoiceField(ChoiceField):
default_error_messages = {
'invalid_choice': _('`{input}` is not a valid choice.'),
- 'not_a_list': _('Expected a list of items but got type `{input_type}`')
+ 'not_a_list': _('Expected a list of items but got type `{input_type}`.')
}
+ default_empty_html = []
+
+ def get_value(self, dictionary):
+ # We override the default field access in order to support
+ # lists in HTML forms.
+ if html.is_html_input(dictionary):
+ return dictionary.getlist(self.field_name)
+ return dictionary.get(self.field_name, empty)
def to_internal_value(self, data):
- if not hasattr(data, '__iter__'):
+ if isinstance(data, type('')) or not hasattr(data, '__iter__'):
self.fail('not_a_list', input_type=type(data).__name__)
+
return set([
super(MultipleChoiceField, self).to_internal_value(item)
for item in data
])
def to_representation(self, value):
- return value
+ return set([
+ self.choice_strings_to_values[six.text_type(item)] for item in value
+ ])
# File types...
class FileField(Field):
- pass # TODO
+ default_error_messages = {
+ 'required': _("No file was submitted."),
+ 'invalid': _("The submitted data was not a file. Check the encoding type on the form."),
+ 'no_name': _("No filename could be determined."),
+ 'empty': _("The submitted file is empty."),
+ 'max_length': _('Ensure this filename has at most {max_length} characters (it has {length}).'),
+ }
+ use_url = api_settings.UPLOADED_FILES_USE_URL
+
+ def __init__(self, *args, **kwargs):
+ self.max_length = kwargs.pop('max_length', None)
+ self.allow_empty_file = kwargs.pop('allow_empty_file', False)
+ self.use_url = kwargs.pop('use_url', self.use_url)
+ super(FileField, self).__init__(*args, **kwargs)
+
+ def to_internal_value(self, data):
+ try:
+ # `UploadedFile` objects should have name and size attributes.
+ file_name = data.name
+ file_size = data.size
+ except AttributeError:
+ self.fail('invalid')
+
+ if not file_name:
+ self.fail('no_name')
+ if not self.allow_empty_file and not file_size:
+ self.fail('empty')
+ if self.max_length and len(file_name) > self.max_length:
+ self.fail('max_length', max_length=self.max_length, length=len(file_name))
+
+ return data
+
+ def to_representation(self, value):
+ if self.use_url:
+ if not value:
+ return None
+ url = value.url
+ request = self.context.get('request', None)
+ if request is not None:
+ return request.build_absolute_uri(url)
+ return url
+ return value.name
+
+
+class ImageField(FileField):
+ default_error_messages = {
+ 'invalid_image': _(
+ 'Upload a valid image. The file you uploaded was either not an '
+ 'image or a corrupted image.'
+ ),
+ }
+
+ def __init__(self, *args, **kwargs):
+ self._DjangoImageField = kwargs.pop('_DjangoImageField', DjangoImageField)
+ super(ImageField, self).__init__(*args, **kwargs)
+
+ def to_internal_value(self, data):
+ # Image validation is a bit grungy, so we'll just outright
+ # defer to Django's implementation so we don't need to
+ # consider it, or treat PIL as a test dependency.
+ file_object = super(ImageField, self).to_internal_value(data)
+ django_field = self._DjangoImageField()
+ django_field.error_messages = self.error_messages
+ django_field.to_python(file_object)
+ return file_object
+
+
+# Composite field types...
+
+class ListField(Field):
+ child = None
+ initial = []
+ default_error_messages = {
+ 'not_a_list': _('Expected a list of items but got type `{input_type}`')
+ }
+
+ def __init__(self, *args, **kwargs):
+ self.child = kwargs.pop('child', copy.deepcopy(self.child))
+ assert self.child is not None, '`child` is a required argument.'
+ assert not inspect.isclass(self.child), '`child` has not been instantiated.'
+ super(ListField, self).__init__(*args, **kwargs)
+ self.child.bind(field_name='', parent=self)
+
+ def get_value(self, dictionary):
+ # We override the default field access in order to support
+ # lists in HTML forms.
+ if html.is_html_input(dictionary):
+ return html.parse_html_list(dictionary, prefix=self.field_name)
+ return dictionary.get(self.field_name, empty)
+ def to_internal_value(self, data):
+ """
+ List of dicts of native values <- List of dicts of primitive datatypes.
+ """
+ if html.is_html_input(data):
+ data = html.parse_html_list(data)
+ if isinstance(data, type('')) or not hasattr(data, '__iter__'):
+ self.fail('not_a_list', input_type=type(data).__name__)
+ return [self.child.run_validation(item) for item in data]
-class ImageField(Field):
- pass # TODO
+ def to_representation(self, data):
+ """
+ List of object instances -> List of dicts of primitive datatypes.
+ """
+ return [self.child.to_representation(item) for item in data]
-# Advanced field types...
+# Miscellaneous field types...
class ReadOnlyField(Field):
"""
@@ -770,11 +1115,31 @@ class ReadOnlyField(Field):
super(ReadOnlyField, self).__init__(**kwargs)
def to_representation(self, value):
- if is_simple_callable(value):
- return value()
return value
+class HiddenField(Field):
+ """
+ A hidden field does not take input from the user, or present any output,
+ but it does populate a field in `validated_data`, based on its default
+ value. This is particularly useful when we have a `unique_for_date`
+ constraint on a pair of fields, as we need some way to include the date in
+ the validated data.
+ """
+ def __init__(self, **kwargs):
+ assert 'default' in kwargs, 'default is a required argument.'
+ kwargs['write_only'] = True
+ super(HiddenField, self).__init__(**kwargs)
+
+ def get_value(self, dictionary):
+ # We always use the default value for `HiddenField`.
+ # User input is never provided or accepted.
+ return empty
+
+ def to_internal_value(self, data):
+ return data
+
+
class SerializerMethodField(Field):
"""
A read-only field that get its representation from calling a method on the
@@ -790,17 +1155,32 @@ class SerializerMethodField(Field):
def get_extra_info(self, obj):
return ... # Calculate some data to return.
"""
- def __init__(self, method_attr=None, **kwargs):
- self.method_attr = method_attr
+ def __init__(self, method_name=None, **kwargs):
+ self.method_name = method_name
kwargs['source'] = '*'
kwargs['read_only'] = True
super(SerializerMethodField, self).__init__(**kwargs)
+ def bind(self, field_name, parent):
+ # In order to enforce a consistent style, we error if a redundant
+ # 'method_name' argument has been used. For example:
+ # my_field = serializer.CharField(source='my_field')
+ default_method_name = 'get_{field_name}'.format(field_name=field_name)
+ assert self.method_name != default_method_name, (
+ "It is redundant to specify `%s` on SerializerMethodField '%s' in "
+ "serializer '%s', because it is the same as the default method name. "
+ "Remove the `method_name` argument." %
+ (self.method_name, field_name, parent.__class__.__name__)
+ )
+
+ # The method name should default to `get_{field_name}`.
+ if self.method_name is None:
+ self.method_name = default_method_name
+
+ super(SerializerMethodField, self).bind(field_name, parent)
+
def to_representation(self, value):
- method_attr = self.method_attr
- if method_attr is None:
- method_attr = 'get_{field_name}'.format(field_name=self.field_name)
- method = getattr(self.parent, method_attr)
+ method = getattr(self.parent, self.method_name)
return method(value)
@@ -811,10 +1191,19 @@ class ModelField(Field):
This is used by `ModelSerializer` when dealing with custom model fields,
that do not have a serializer field to be mapped to.
"""
+ default_error_messages = {
+ 'max_length': _('Ensure this field has no more than {max_length} characters.'),
+ }
+
def __init__(self, model_field, **kwargs):
self.model_field = model_field
- kwargs['source'] = '*'
+ # The `max_length` option is supported by Django's base `Field` class,
+ # so we'd better support it here.
+ max_length = kwargs.pop('max_length', None)
super(ModelField, self).__init__(**kwargs)
+ if max_length is not None:
+ message = self.error_messages['max_length'].format(max_length=max_length)
+ self.validators.append(MaxLengthValidator(max_length, message=message))
def to_internal_value(self, data):
rel = getattr(self.model_field, 'rel', None)
@@ -822,6 +1211,11 @@ class ModelField(Field):
return rel.to._meta.get_field(rel.field_name).to_python(data)
return self.model_field.to_python(data)
+ def get_attribute(self, obj):
+ # We pass the object instance onto `to_representation`,
+ # not just the field attribute.
+ return obj
+
def to_representation(self, obj):
value = self.model_field._get_val_from_obj(obj)
if is_protected_type(value):