diff options
Diffstat (limited to 'rest_framework/fields.py')
| -rw-r--r-- | rest_framework/fields.py | 1713 |
1 files changed, 976 insertions, 737 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py index c0253f86..c40dc3fb 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1,34 +1,38 @@ -""" -Serializer fields perform validation on incoming data. - -They are very similar to Django's form fields. -""" from __future__ import unicode_literals - -import copy -import datetime -import inspect -import re -import warnings -from decimal import Decimal, DecimalException -from django import forms -from django.core import validators -from django.core.exceptions import ValidationError from django.conf import settings -from django.db.models.fields import BLANK_CHOICE_DASH -from django.http import QueryDict -from django.forms import widgets +from django.core.exceptions import ObjectDoesNotExist +from django.core.exceptions import ValidationError as DjangoValidationError +from django.core.validators import RegexValidator +from django.forms import ImageField as DjangoImageField from django.utils import six, timezone -from django.utils.encoding import is_protected_type -from django.utils.translation import ugettext_lazy as _ -from django.utils.datastructures import SortedDict from django.utils.dateparse import parse_date, parse_datetime, parse_time +from django.utils.encoding import is_protected_type, smart_text +from django.utils.translation import ugettext_lazy as _ from rest_framework import ISO_8601 from rest_framework.compat import ( - BytesIO, smart_text, - force_text, is_non_str_iterable + EmailValidator, MinValueValidator, MaxValueValidator, + MinLengthValidator, MaxLengthValidator, URLValidator, OrderedDict, + unicode_repr, unicode_to_repr ) +from rest_framework.exceptions import ValidationError from rest_framework.settings import api_settings +from rest_framework.utils import html, representation, humanize_datetime +import collections +import copy +import datetime +import decimal +import inspect +import re + + +class empty: + """ + This class is used to represent no data being provided for a given input + or output value. + + It is required because `None` may be a valid input or output value. + """ + pass def is_simple_callable(obj): @@ -47,683 +51,835 @@ def is_simple_callable(obj): return len_args <= len_defaults -def get_component(obj, attr_name): +def get_attribute(instance, attrs): """ - Given an object, and an attribute name, - return that attribute on the object. + Similar to Python's built in `getattr(instance, attr)`, + but takes a list of nested attributes, instead of a single attribute. + + Also accepts either attribute lookup on objects or dictionary lookups. """ - if isinstance(obj, dict): - val = obj.get(attr_name) - else: - val = getattr(obj, attr_name) - - if is_simple_callable(val): - return val() - return val - - -def readable_datetime_formats(formats): - format = ', '.join(formats).replace( - ISO_8601, - 'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]' - ) - return humanize_strptime(format) - - -def readable_date_formats(formats): - format = ', '.join(formats).replace(ISO_8601, 'YYYY[-MM[-DD]]') - return humanize_strptime(format) - - -def readable_time_formats(formats): - format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]') - return humanize_strptime(format) - - -def humanize_strptime(format_string): - # Note that we're missing some of the locale specific mappings that - # don't really make sense. - mapping = { - "%Y": "YYYY", - "%y": "YY", - "%m": "MM", - "%b": "[Jan-Dec]", - "%B": "[January-December]", - "%d": "DD", - "%H": "hh", - "%I": "hh", # Requires '%p' to differentiate from '%H'. - "%M": "mm", - "%S": "ss", - "%f": "uuuuuu", - "%a": "[Mon-Sun]", - "%A": "[Monday-Sunday]", - "%p": "[AM|PM]", - "%z": "[+HHMM|-HHMM]" - } - for key, val in mapping.items(): - format_string = format_string.replace(key, val) - return format_string + for attr in attrs: + if instance is None: + # Break out early if we get `None` at any point in a nested lookup. + return None + try: + if isinstance(instance, collections.Mapping): + instance = instance[attr] + else: + instance = getattr(instance, attr) + except ObjectDoesNotExist: + return None + if is_simple_callable(instance): + instance = instance() + return instance -def strip_multiple_choice_msg(help_text): +def set_value(dictionary, keys, value): """ - Remove the 'Hold down "control" ...' message that is Django enforces in - select multiple fields on ModelForms. (Required for 1.5 and earlier) + Similar to Python's built in `dictionary[key] = value`, + but takes a list of nested keys instead of a single key. - See https://code.djangoproject.com/ticket/9321 + set_value({'a': 1}, [], {'b': 2}) -> {'a': 1, 'b': 2} + set_value({'a': 1}, ['x'], 2) -> {'a': 1, 'x': 2} + set_value({'a': 1}, ['x', 'y'], 2) -> {'a': 1, 'x': {'y': 2}} """ - multiple_choice_msg = _(' Hold down "Control", or "Command" on a Mac, to select more than one.') - multiple_choice_msg = force_text(multiple_choice_msg) + if not keys: + dictionary.update(value) + return - return help_text.replace(multiple_choice_msg, '') + for key in keys[:-1]: + if key not in dictionary: + dictionary[key] = {} + dictionary = dictionary[key] + dictionary[keys[-1]] = value -class Field(object): - read_only = True - creation_counter = 0 - empty = '' - type_name = None - partial = False - use_files = False - form_field_class = forms.CharField - type_label = 'field' - widget = None - - def __init__(self, source=None, label=None, help_text=None): - self.parent = None - self.creation_counter = Field.creation_counter - Field.creation_counter += 1 +class CreateOnlyDefault: + """ + This class may be used to provide default values that are only used + for create operations, but that do not return any value for update + operations. + """ + def __init__(self, default): + self.default = default - self.source = source + def set_context(self, serializer_field): + self.is_update = serializer_field.parent.instance is not None - if label is not None: - self.label = smart_text(label) - else: - self.label = None + def __call__(self): + if self.is_update: + raise SkipField() + if callable(self.default): + return self.default() + return self.default - if help_text is not None: - self.help_text = strip_multiple_choice_msg(smart_text(help_text)) - else: - self.help_text = None + def __repr__(self): + return unicode_to_repr( + '%s(%s)' % (self.__class__.__name__, unicode_repr(self.default)) + ) - self._errors = [] - self._value = None - self._name = None - @property - def errors(self): - return self._errors +class CurrentUserDefault: + def set_context(self, serializer_field): + self.user = serializer_field.context['request'].user - def widget_html(self): - if not self.widget: - return '' + def __call__(self): + return self.user - attrs = {} - if 'id' not in self.widget.attrs: - attrs['id'] = self._name + def __repr__(self): + return unicode_to_repr('%s()' % self.__class__.__name__) - return self.widget.render(self._name, self._value, attrs=attrs) - def label_tag(self): - return '<label for="%s">%s:</label>' % (self._name, self.label) +class SkipField(Exception): + pass - def initialize(self, parent, field_name): - """ - Called to set up a field prior to field_to_native or field_from_native. - parent - The parent serializer. - field_name - The name of the field being initialized. - """ - self.parent = parent - self.root = parent.root or parent - self.context = self.root.context - self.partial = self.root.partial - if self.partial: - self.required = False +NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`' +NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`' +NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`' +USE_READONLYFIELD = 'Field(read_only=True) should be ReadOnlyField' +MISSING_ERROR_MESSAGE = ( + 'ValidationError raised by `{class_name}`, but error key `{key}` does ' + 'not exist in the `error_messages` dictionary.' +) - def field_from_native(self, data, files, field_name, into): - """ - Given a dictionary and a field name, updates the dictionary `into`, - with the field and it's deserialized value. - """ - return - def field_to_native(self, obj, field_name): - """ - Given an object and a field name, returns the value that should be - serialized for that field. - """ - if obj is None: - return self.empty +class Field(object): + _creation_counter = 0 - if self.source == '*': - return self.to_native(obj) + default_error_messages = { + 'required': _('This field is required.'), + 'null': _('This field may not be null.') + } + default_validators = [] + default_empty_html = empty + initial = None - source = self.source or field_name - value = obj + def __init__(self, read_only=False, write_only=False, + required=None, default=empty, initial=empty, source=None, + label=None, help_text=None, style=None, + error_messages=None, validators=None, allow_null=False): + self._creation_counter = Field._creation_counter + Field._creation_counter += 1 - for component in source.split('.'): - value = get_component(value, component) - if value is None: - break + # If `required` is unset, then use `True` unless a default is provided. + if required is None: + required = default is empty and not read_only - return self.to_native(value) + # Some combinations of keyword arguments do not make sense. + assert not (read_only and write_only), NOT_READ_ONLY_WRITE_ONLY + assert not (read_only and required), NOT_READ_ONLY_REQUIRED + assert not (required and default is not empty), NOT_REQUIRED_DEFAULT + assert not (read_only and self.__class__ == Field), USE_READONLYFIELD - def to_native(self, value): - """ - Converts the field's value into it's simple representation. - """ - if is_simple_callable(value): - value = value() + self.read_only = read_only + self.write_only = write_only + self.required = required + self.default = default + self.source = source + self.initial = self.initial if (initial is empty) else initial + self.label = label + self.help_text = help_text + self.style = {} if style is None else style + self.allow_null = allow_null + + if self.default_empty_html is not empty: + if not required: + self.default_empty_html = empty + elif default is not empty: + self.default_empty_html = default + + if validators is not None: + self.validators = validators[:] + + # These are set up by `.bind()` when the field is added to a serializer. + self.field_name = None + self.parent = None - if is_protected_type(value): - return value - elif (is_non_str_iterable(value) and - not isinstance(value, (dict, six.string_types))): - return [self.to_native(item) for item in value] - elif isinstance(value, dict): - # Make sure we preserve field ordering, if it exists - ret = SortedDict() - for key, val in value.items(): - ret[key] = self.to_native(val) - return ret - return force_text(value) - - def attributes(self): + # Collect default error message from self and parent classes + messages = {} + for cls in reversed(self.__class__.__mro__): + messages.update(getattr(cls, 'default_error_messages', {})) + messages.update(error_messages or {}) + self.error_messages = messages + + def bind(self, field_name, parent): """ - Returns a dictionary of attributes to be used when serializing to xml. + Initializes the field name and parent for the field instance. + Called when a field is added to the parent serializer instance. """ - if self.type_name: - return {'type': self.type_name} - return {} - - def metadata(self): - metadata = SortedDict() - metadata['type'] = self.type_label - metadata['required'] = getattr(self, 'required', False) - optional_attrs = ['read_only', 'label', 'help_text', - 'min_length', 'max_length'] - for attr in optional_attrs: - value = getattr(self, attr, None) - if value is not None and value != '': - metadata[attr] = force_text(value, strings_only=True) - return metadata - - -class WritableField(Field): - """ - Base for read/write fields. - """ - write_only = False - default_validators = [] - default_error_messages = { - 'required': _('This field is required.'), - 'invalid': _('Invalid value.'), - } - widget = widgets.TextInput - default = None - def __init__(self, source=None, label=None, help_text=None, - read_only=False, write_only=False, required=None, - validators=[], error_messages=None, widget=None, - default=None, blank=None): + # In order to enforce a consistent style, we error if a redundant + # 'source' argument has been used. For example: + # my_field = serializer.CharField(source='my_field') + assert self.source != field_name, ( + "It is redundant to specify `source='%s'` on field '%s' in " + "serializer '%s', because it is the same as the field name. " + "Remove the `source` keyword argument." % + (field_name, self.__class__.__name__, parent.__class__.__name__) + ) - super(WritableField, self).__init__(source=source, label=label, help_text=help_text) + self.field_name = field_name + self.parent = parent - self.read_only = read_only - self.write_only = write_only + # `self.label` should default to being based on the field name. + if self.label is None: + self.label = field_name.replace('_', ' ').capitalize() - assert not (read_only and write_only), "Cannot set read_only=True and write_only=True" + # self.source should default to being the same as the field name. + if self.source is None: + self.source = field_name - if required is None: - self.required = not(read_only) + # self.source_attrs is a list of attributes that need to be looked up + # when serializing the instance, or populating the validated data. + if self.source == '*': + self.source_attrs = [] else: - assert not (read_only and required), "Cannot set required=True and read_only=True" - self.required = required + self.source_attrs = self.source.split('.') - messages = {} - for c in reversed(self.__class__.__mro__): - messages.update(getattr(c, 'default_error_messages', {})) - messages.update(error_messages or {}) - self.error_messages = messages + # .validators is a lazily loaded property, that gets its default + # value from `get_validators`. + @property + def validators(self): + if not hasattr(self, '_validators'): + self._validators = self.get_validators() + return self._validators - self.validators = self.default_validators + validators - self.default = default if default is not None else self.default + @validators.setter + def validators(self, validators): + self._validators = validators - # Widgets are only used for HTML forms. - widget = widget or self.widget - if isinstance(widget, type): - widget = widget() - self.widget = widget + def get_validators(self): + return self.default_validators[:] - def __deepcopy__(self, memo): - result = copy.copy(self) - memo[id(self)] = result - result.validators = self.validators[:] - return result + def get_initial(self): + """ + Return a value to use when the field is being returned as a primitive + value, without any object instance. + """ + return self.initial - def get_default_value(self): - if is_simple_callable(self.default): + def get_value(self, dictionary): + """ + Given the *incoming* primitive data, return the value for this field + that should be validated and transformed to a native value. + """ + if html.is_html_input(dictionary): + # HTML forms will represent empty fields as '', and cannot + # represent None or False values directly. + if self.field_name not in dictionary: + if getattr(self.root, 'partial', False): + return empty + return self.default_empty_html + ret = dictionary[self.field_name] + return self.default_empty_html if (ret == '') else ret + return dictionary.get(self.field_name, empty) + + def get_attribute(self, instance): + """ + Given the *outgoing* object instance, return the primitive value + that should be used for this field. + """ + try: + return get_attribute(instance, self.source_attrs) + except (KeyError, AttributeError) as exc: + msg = ( + 'Got {exc_type} when attempting to get a value for field ' + '`{field}` on serializer `{serializer}`.\nThe serializer ' + 'field might be named incorrectly and not match ' + 'any attribute or key on the `{instance}` instance.\n' + 'Original exception text was: {exc}.'.format( + exc_type=type(exc).__name__, + field=self.field_name, + serializer=self.parent.__class__.__name__, + instance=instance.__class__.__name__, + exc=exc + ) + ) + raise type(exc)(msg) + + def get_default(self): + """ + Return the default value to use when validating data if no input + is provided for this field. + + If a default has not been set for this field then this will simply + return `empty`, indicating that no value should be set in the + validated data for this field. + """ + if self.default is empty: + raise SkipField() + if callable(self.default): + if hasattr(self.default, 'set_context'): + self.default.set_context(self) return self.default() return self.default - def validate(self, value): - if value in validators.EMPTY_VALUES and self.required: - raise ValidationError(self.error_messages['required']) + def validate_empty_values(self, data): + """ + Validate empty values, and either: + + * Raise `ValidationError`, indicating invalid data. + * Raise `SkipField`, indicating that the field should be ignored. + * Return (True, data), indicating an empty value that should be + returned without any furhter validation being applied. + * Return (False, data), indicating a non-empty value, that should + have validation applied as normal. + """ + if self.read_only: + return (True, self.get_default()) + + if data is empty: + if getattr(self.root, 'partial', False): + raise SkipField() + if self.required: + self.fail('required') + return (True, self.get_default()) + + if data is None: + if not self.allow_null: + self.fail('null') + return (True, None) + + return (False, data) + + def run_validation(self, data=empty): + """ + Validate a simple representation and return the internal value. + + The provided data may be `empty` if no representation was included + in the input. + + May raise `SkipField` if the field should not be included in the + validated data. + """ + (is_empty_value, data) = self.validate_empty_values(data) + if is_empty_value: + return data + value = self.to_internal_value(data) + self.run_validators(value) + return value def run_validators(self, value): - if value in validators.EMPTY_VALUES: - return + """ + Test the given value against all the validators on the field, + and either raise a `ValidationError` or simply return. + """ errors = [] - for v in self.validators: + for validator in self.validators: + if hasattr(validator, 'set_context'): + validator.set_context(self) + try: - v(value) - except ValidationError as e: - if hasattr(e, 'code') and e.code in self.error_messages: - message = self.error_messages[e.code] - if e.params: - message = message % e.params - errors.append(message) - else: - errors.extend(e.messages) + validator(value) + except ValidationError as exc: + # If the validation error contains a mapping of fields to + # errors then simply raise it immediately rather than + # attempting to accumulate a list of errors. + if isinstance(exc.detail, dict): + raise + errors.extend(exc.detail) + except DjangoValidationError as exc: + errors.extend(exc.messages) if errors: raise ValidationError(errors) - def field_to_native(self, obj, field_name): - if self.write_only: - return None - return super(WritableField, self).field_to_native(obj, field_name) + def to_internal_value(self, data): + """ + Transform the *incoming* primitive data into a native value. + """ + raise NotImplementedError( + '{cls}.to_internal_value() must be implemented.'.format( + cls=self.__class__.__name__ + ) + ) - def field_from_native(self, data, files, field_name, into): + def to_representation(self, value): """ - Given a dictionary and a field name, updates the dictionary `into`, - with the field and it's deserialized value. + Transform the *outgoing* native value into primitive data. """ - if self.read_only: - return + raise NotImplementedError( + '{cls}.to_representation() must be implemented.\n' + 'If you are upgrading from REST framework version 2 ' + 'you might want `ReadOnlyField`.'.format( + cls=self.__class__.__name__ + ) + ) + def fail(self, key, **kwargs): + """ + A helper method that simply raises a validation error. + """ try: - data = data or {} - if self.use_files: - files = files or {} - try: - native = files[field_name] - except KeyError: - native = data[field_name] - else: - native = data[field_name] + msg = self.error_messages[key] except KeyError: - if self.default is not None and not self.partial: - # Note: partial updates shouldn't set defaults - native = self.get_default_value() - else: - if self.required: - raise ValidationError(self.error_messages['required']) - return + class_name = self.__class__.__name__ + msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key) + raise AssertionError(msg) + message_string = msg.format(**kwargs) + raise ValidationError(message_string) - value = self.from_native(native) - if self.source == '*': - if value: - into.update(value) - else: - self.validate(value) - self.run_validators(value) - into[self.source or field_name] = value + @property + def root(self): + """ + Returns the top-level serializer for this field. + """ + root = self + while root.parent is not None: + root = root.parent + return root - def from_native(self, value): + @property + def context(self): """ - Reverts a simple representation back to the field's value. + Returns the context as passed to the root serializer on initialization. """ - return value + return getattr(self.root, '_context', {}) + def __new__(cls, *args, **kwargs): + """ + When a field is instantiated, we store the arguments that were used, + so that we can present a helpful representation of the object. + """ + instance = super(Field, cls).__new__(cls) + instance._args = args + instance._kwargs = kwargs + return instance -class ModelField(WritableField): - """ - A generic field that can be used against an arbitrary model field. - """ - def __init__(self, *args, **kwargs): - try: - self.model_field = kwargs.pop('model_field') - except KeyError: - raise ValueError("ModelField requires 'model_field' kwarg") - - self.min_length = kwargs.pop('min_length', - getattr(self.model_field, 'min_length', None)) - self.max_length = kwargs.pop('max_length', - getattr(self.model_field, 'max_length', None)) - self.min_value = kwargs.pop('min_value', - getattr(self.model_field, 'min_value', None)) - self.max_value = kwargs.pop('max_value', - getattr(self.model_field, 'max_value', None)) - - super(ModelField, self).__init__(*args, **kwargs) - - if self.min_length is not None: - self.validators.append(validators.MinLengthValidator(self.min_length)) - if self.max_length is not None: - self.validators.append(validators.MaxLengthValidator(self.max_length)) - if self.min_value is not None: - self.validators.append(validators.MinValueValidator(self.min_value)) - if self.max_value is not None: - self.validators.append(validators.MaxValueValidator(self.max_value)) - - def from_native(self, value): - rel = getattr(self.model_field, "rel", None) - if rel is not None: - return rel.to._meta.get_field(rel.field_name).to_python(value) - else: - return self.model_field.to_python(value) + def __deepcopy__(self, memo): + """ + When cloning fields we instantiate using the arguments it was + originally created with, rather than copying the complete state. + """ + args = copy.deepcopy(self._args) + kwargs = dict(self._kwargs) + # Bit ugly, but we need to special case 'validators' as Django's + # RegexValidator does not support deepcopy. + # We treat validator callables as immutable objects. + # See https://github.com/tomchristie/django-rest-framework/issues/1954 + validators = kwargs.pop('validators', None) + kwargs = copy.deepcopy(kwargs) + if validators is not None: + kwargs['validators'] = validators + return self.__class__(*args, **kwargs) + + def __repr__(self): + """ + Fields are represented using their initial calling arguments. + This allows us to create descriptive representations for serializer + instances that show all the declared fields on the serializer. + """ + return unicode_to_repr(representation.field_repr(self)) - def field_to_native(self, obj, field_name): - value = self.model_field._get_val_from_obj(obj) - if is_protected_type(value): - return value - return self.model_field.value_to_string(obj) - def attributes(self): - return { - "type": self.model_field.get_internal_type() - } +# Boolean types... + +class BooleanField(Field): + default_error_messages = { + 'invalid': _('`{input}` is not a valid boolean.') + } + default_empty_html = False + initial = False + TRUE_VALUES = set(('t', 'T', 'true', 'True', 'TRUE', '1', 1, True)) + FALSE_VALUES = set(('f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False)) + + def __init__(self, **kwargs): + assert 'allow_null' not in kwargs, '`allow_null` is not a valid option. Use `NullBooleanField` instead.' + super(BooleanField, self).__init__(**kwargs) + + def to_internal_value(self, data): + if data in self.TRUE_VALUES: + return True + elif data in self.FALSE_VALUES: + return False + self.fail('invalid', input=data) + def to_representation(self, value): + if value in self.TRUE_VALUES: + return True + elif value in self.FALSE_VALUES: + return False + return bool(value) -# Typed Fields -class BooleanField(WritableField): - type_name = 'BooleanField' - type_label = 'boolean' - form_field_class = forms.BooleanField - widget = widgets.CheckboxInput +class NullBooleanField(Field): default_error_messages = { - 'invalid': _("'%s' value must be either True or False."), + 'invalid': _('`{input}` is not a valid boolean.') } - empty = False + initial = None + TRUE_VALUES = set(('t', 'T', 'true', 'True', 'TRUE', '1', 1, True)) + FALSE_VALUES = set(('f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False)) + NULL_VALUES = set(('n', 'N', 'null', 'Null', 'NULL', '', None)) - def field_from_native(self, data, files, field_name, into): - # HTML checkboxes do not explicitly represent unchecked as `False` - # we deal with that here... - if isinstance(data, QueryDict) and self.default is None: - self.default = False + def __init__(self, **kwargs): + assert 'allow_null' not in kwargs, '`allow_null` is not a valid option.' + kwargs['allow_null'] = True + super(NullBooleanField, self).__init__(**kwargs) - return super(BooleanField, self).field_from_native( - data, files, field_name, into - ) + def to_internal_value(self, data): + if data in self.TRUE_VALUES: + return True + elif data in self.FALSE_VALUES: + return False + elif data in self.NULL_VALUES: + return None + self.fail('invalid', input=data) - def from_native(self, value): - if value in ('true', 't', 'True', '1'): + def to_representation(self, value): + if value in self.NULL_VALUES: + return None + if value in self.TRUE_VALUES: return True - if value in ('false', 'f', 'False', '0'): + elif value in self.FALSE_VALUES: return False return bool(value) -class CharField(WritableField): - type_name = 'CharField' - type_label = 'string' - form_field_class = forms.CharField +# String types... - def __init__(self, max_length=None, min_length=None, allow_none=False, *args, **kwargs): - self.max_length, self.min_length = max_length, min_length - self.allow_none = allow_none - super(CharField, self).__init__(*args, **kwargs) - if min_length is not None: - self.validators.append(validators.MinLengthValidator(min_length)) - if max_length is not None: - self.validators.append(validators.MaxLengthValidator(max_length)) +class CharField(Field): + default_error_messages = { + 'blank': _('This field may not be blank.'), + 'max_length': _('Ensure this field has no more than {max_length} characters.'), + 'min_length': _('Ensure this field has at least {min_length} characters.') + } + initial = '' + coerce_blank_to_null = False + default_empty_html = '' - def from_native(self, value): - if isinstance(value, six.string_types): - return value + def __init__(self, **kwargs): + self.allow_blank = kwargs.pop('allow_blank', False) + max_length = kwargs.pop('max_length', None) + min_length = kwargs.pop('min_length', None) + super(CharField, self).__init__(**kwargs) + if max_length is not None: + message = self.error_messages['max_length'].format(max_length=max_length) + self.validators.append(MaxLengthValidator(max_length, message=message)) + if min_length is not None: + message = self.error_messages['min_length'].format(min_length=min_length) + self.validators.append(MinLengthValidator(min_length, message=message)) + + if self.allow_null and (not self.allow_blank) and (self.default is empty): + # HTML input cannot represent `None` values, so we need to + # forcibly coerce empty HTML values to `None` if `allow_null=True`. + self.default_empty_html = None + + def run_validation(self, data=empty): + # Test for the empty string here so that it does not get validated, + # and so that subclasses do not need to handle it explicitly + # inside the `to_internal_value()` method. + if data == '': + if not self.allow_blank: + self.fail('blank') + return '' + return super(CharField, self).run_validation(data) - if value is None: - if not self.allow_none: - return '' - else: - # Return None explicitly because smart_text(None) == 'None'. See #1834 for details - return None + def to_internal_value(self, data): + return six.text_type(data) - return smart_text(value) + def to_representation(self, value): + return six.text_type(value) -class URLField(CharField): - type_name = 'URLField' - type_label = 'url' +class EmailField(CharField): + default_error_messages = { + 'invalid': _('Enter a valid email address.') + } def __init__(self, **kwargs): - if 'validators' not in kwargs: - kwargs['validators'] = [validators.URLValidator()] - super(URLField, self).__init__(**kwargs) + super(EmailField, self).__init__(**kwargs) + validator = EmailValidator(message=self.error_messages['invalid']) + self.validators.append(validator) + def to_internal_value(self, data): + return six.text_type(data).strip() + + def to_representation(self, value): + return six.text_type(value).strip() -class SlugField(CharField): - type_name = 'SlugField' - type_label = 'slug' - form_field_class = forms.SlugField +class RegexField(CharField): default_error_messages = { - 'invalid': _("Enter a valid 'slug' consisting of letters, numbers," - " underscores or hyphens."), + 'invalid': _('This value does not match the required pattern.') } - default_validators = [validators.validate_slug] - def __init__(self, *args, **kwargs): - super(SlugField, self).__init__(*args, **kwargs) + def __init__(self, regex, **kwargs): + super(RegexField, self).__init__(**kwargs) + validator = RegexValidator(regex, message=self.error_messages['invalid']) + self.validators.append(validator) -class ChoiceField(WritableField): - type_name = 'ChoiceField' - type_label = 'choice' - form_field_class = forms.ChoiceField - widget = widgets.Select +class SlugField(CharField): default_error_messages = { - 'invalid_choice': _('Select a valid choice. %(value)s is not one of ' - 'the available choices.'), + 'invalid': _("Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens.") } - def __init__(self, choices=(), blank_display_value=None, *args, **kwargs): - self.empty = kwargs.pop('empty', '') - super(ChoiceField, self).__init__(*args, **kwargs) - self.choices = choices - if not self.required: - if blank_display_value is None: - blank_choice = BLANK_CHOICE_DASH - else: - blank_choice = [('', blank_display_value)] - self.choices = blank_choice + self.choices + def __init__(self, **kwargs): + super(SlugField, self).__init__(**kwargs) + slug_regex = re.compile(r'^[-a-zA-Z0-9_]+$') + validator = RegexValidator(slug_regex, message=self.error_messages['invalid']) + self.validators.append(validator) - def _get_choices(self): - return self._choices - def _set_choices(self, value): - # Setting choices also sets the choices on the widget. - # choices can be any iterable, but we call list() on it because - # it will be consumed more than once. - self._choices = self.widget.choices = list(value) +class URLField(CharField): + default_error_messages = { + 'invalid': _("Enter a valid URL.") + } - choices = property(_get_choices, _set_choices) + def __init__(self, **kwargs): + super(URLField, self).__init__(**kwargs) + validator = URLValidator(message=self.error_messages['invalid']) + self.validators.append(validator) + + +# Number types... + +class IntegerField(Field): + default_error_messages = { + 'invalid': _('A valid integer is required.'), + 'max_value': _('Ensure this value is less than or equal to {max_value}.'), + 'min_value': _('Ensure this value is greater than or equal to {min_value}.'), + 'max_string_length': _('String value too large') + } + MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs. - def metadata(self): - data = super(ChoiceField, self).metadata() - data['choices'] = [{'value': v, 'display_name': n} for v, n in self.choices] + def __init__(self, **kwargs): + max_value = kwargs.pop('max_value', None) + min_value = kwargs.pop('min_value', None) + super(IntegerField, self).__init__(**kwargs) + if max_value is not None: + message = self.error_messages['max_value'].format(max_value=max_value) + self.validators.append(MaxValueValidator(max_value, message=message)) + if min_value is not None: + message = self.error_messages['min_value'].format(min_value=min_value) + self.validators.append(MinValueValidator(min_value, message=message)) + + def to_internal_value(self, data): + if isinstance(data, six.text_type) and len(data) > self.MAX_STRING_LENGTH: + self.fail('max_string_length') + + try: + data = int(data) + except (ValueError, TypeError): + self.fail('invalid') return data - def validate(self, value): - """ - Validates that the input is in self.choices. - """ - super(ChoiceField, self).validate(value) - if value and not self.valid_value(value): - raise ValidationError(self.error_messages['invalid_choice'] % {'value': value}) + def to_representation(self, value): + return int(value) - def valid_value(self, value): - """ - Check to see if the provided value is a valid choice. - """ - for k, v in self.choices: - if isinstance(v, (list, tuple)): - # This is an optgroup, so look inside the group for options - for k2, v2 in v: - if value == smart_text(k2) or value == k2: - return True - else: - if value == smart_text(k) or value == k: - return True - return False - def from_native(self, value): - value = super(ChoiceField, self).from_native(value) - if value == self.empty or value in validators.EMPTY_VALUES: - return self.empty - return value +class FloatField(Field): + default_error_messages = { + 'invalid': _("A valid number is required."), + 'max_value': _('Ensure this value is less than or equal to {max_value}.'), + 'min_value': _('Ensure this value is greater than or equal to {min_value}.'), + 'max_string_length': _('String value too large') + } + MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs. + def __init__(self, **kwargs): + max_value = kwargs.pop('max_value', None) + min_value = kwargs.pop('min_value', None) + super(FloatField, self).__init__(**kwargs) + if max_value is not None: + message = self.error_messages['max_value'].format(max_value=max_value) + self.validators.append(MaxValueValidator(max_value, message=message)) + if min_value is not None: + message = self.error_messages['min_value'].format(min_value=min_value) + self.validators.append(MinValueValidator(min_value, message=message)) + + def to_internal_value(self, data): + if isinstance(data, six.text_type) and len(data) > self.MAX_STRING_LENGTH: + self.fail('max_string_length') + + try: + return float(data) + except (TypeError, ValueError): + self.fail('invalid') + + def to_representation(self, value): + return float(value) -class EmailField(CharField): - type_name = 'EmailField' - type_label = 'email' - form_field_class = forms.EmailField +class DecimalField(Field): default_error_messages = { - 'invalid': _('Enter a valid email address.'), + 'invalid': _('A valid number is required.'), + 'max_value': _('Ensure this value is less than or equal to {max_value}.'), + 'min_value': _('Ensure this value is greater than or equal to {min_value}.'), + 'max_digits': _('Ensure that there are no more than {max_digits} digits in total.'), + 'max_decimal_places': _('Ensure that there are no more than {max_decimal_places} decimal places.'), + 'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.'), + 'max_string_length': _('String value too large') } - default_validators = [validators.validate_email] + MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs. - def from_native(self, value): - ret = super(EmailField, self).from_native(value) - if ret is None: - return None - return ret.strip() + coerce_to_string = api_settings.COERCE_DECIMAL_TO_STRING + def __init__(self, max_digits, decimal_places, coerce_to_string=None, max_value=None, min_value=None, **kwargs): + self.max_digits = max_digits + self.decimal_places = decimal_places + self.coerce_to_string = coerce_to_string if (coerce_to_string is not None) else self.coerce_to_string + super(DecimalField, self).__init__(**kwargs) + if max_value is not None: + message = self.error_messages['max_value'].format(max_value=max_value) + self.validators.append(MaxValueValidator(max_value, message=message)) + if min_value is not None: + message = self.error_messages['min_value'].format(min_value=min_value) + self.validators.append(MinValueValidator(min_value, message=message)) -class RegexField(CharField): - type_name = 'RegexField' - type_label = 'regex' - form_field_class = forms.RegexField + def to_internal_value(self, data): + """ + Validates that the input is a decimal number. Returns a Decimal + instance. Returns None for empty values. Ensures that there are no more + than max_digits in the number, and no more than decimal_places digits + after the decimal point. + """ + data = smart_text(data).strip() + if len(data) > self.MAX_STRING_LENGTH: + self.fail('max_string_length') - def __init__(self, regex, max_length=None, min_length=None, *args, **kwargs): - super(RegexField, self).__init__(max_length, min_length, *args, **kwargs) - self.regex = regex + try: + value = decimal.Decimal(data) + except decimal.DecimalException: + self.fail('invalid') - def _get_regex(self): - return self._regex + # Check for NaN. It is the only value that isn't equal to itself, + # so we can use this to identify NaN values. + if value != value: + self.fail('invalid') - def _set_regex(self, regex): - if isinstance(regex, six.string_types): - regex = re.compile(regex) - self._regex = regex - if hasattr(self, '_regex_validator') and self._regex_validator in self.validators: - self.validators.remove(self._regex_validator) - self._regex_validator = validators.RegexValidator(regex=regex) - self.validators.append(self._regex_validator) + # Check for infinity and negative infinity. + if value in (decimal.Decimal('Inf'), decimal.Decimal('-Inf')): + self.fail('invalid') - regex = property(_get_regex, _set_regex) + sign, digittuple, exponent = value.as_tuple() + decimals = abs(exponent) + # digittuple doesn't include any leading zeros. + digits = len(digittuple) + if decimals > digits: + # We have leading zeros up to or past the decimal point. Count + # everything past the decimal point as a digit. We do not count + # 0 before the decimal point as a digit since that would mean + # we would not allow max_digits = decimal_places. + digits = decimals + whole_digits = digits - decimals + + if self.max_digits is not None and digits > self.max_digits: + self.fail('max_digits', max_digits=self.max_digits) + if self.decimal_places is not None and decimals > self.decimal_places: + self.fail('max_decimal_places', max_decimal_places=self.decimal_places) + if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places): + self.fail('max_whole_digits', max_whole_digits=self.max_digits - self.decimal_places) + + return value + + def to_representation(self, value): + if not isinstance(value, decimal.Decimal): + value = decimal.Decimal(six.text_type(value).strip()) + + context = decimal.getcontext().copy() + context.prec = self.max_digits + quantized = value.quantize( + decimal.Decimal('.1') ** self.decimal_places, + context=context + ) + if not self.coerce_to_string: + return quantized + return '{0:f}'.format(quantized) -class DateField(WritableField): - type_name = 'DateField' - type_label = 'date' - widget = widgets.DateInput - form_field_class = forms.DateField +# Date & time fields... +class DateTimeField(Field): default_error_messages = { - 'invalid': _("Date has wrong format. Use one of these formats instead: %s"), + 'invalid': _('Datetime has wrong format. Use one of these formats instead: {format}'), + 'date': _('Expected a datetime but got a date.'), } - empty = None - input_formats = api_settings.DATE_INPUT_FORMATS - format = api_settings.DATE_FORMAT + format = api_settings.DATETIME_FORMAT + input_formats = api_settings.DATETIME_INPUT_FORMATS + default_timezone = timezone.get_default_timezone() if settings.USE_TZ else None - def __init__(self, input_formats=None, format=None, *args, **kwargs): + def __init__(self, format=empty, input_formats=None, default_timezone=None, *args, **kwargs): + self.format = format if format is not empty else self.format self.input_formats = input_formats if input_formats is not None else self.input_formats - self.format = format if format is not None else self.format - super(DateField, self).__init__(*args, **kwargs) + self.default_timezone = default_timezone if default_timezone is not None else self.default_timezone + super(DateTimeField, self).__init__(*args, **kwargs) - def from_native(self, value): - if value in validators.EMPTY_VALUES: - return None + def enforce_timezone(self, value): + """ + When `self.default_timezone` is `None`, always return naive datetimes. + When `self.default_timezone` is not `None`, always return aware datetimes. + """ + if (self.default_timezone is not None) and not timezone.is_aware(value): + return timezone.make_aware(value, self.default_timezone) + elif (self.default_timezone is None) and timezone.is_aware(value): + return timezone.make_naive(value, timezone.UTC()) + return value + + def to_internal_value(self, value): + if isinstance(value, datetime.date) and not isinstance(value, datetime.datetime): + self.fail('date') if isinstance(value, datetime.datetime): - if timezone and settings.USE_TZ and timezone.is_aware(value): - # Convert aware datetimes to the default time zone - # before casting them to dates (#17742). - default_timezone = timezone.get_default_timezone() - value = timezone.make_naive(value, default_timezone) - return value.date() - if isinstance(value, datetime.date): - return value + return self.enforce_timezone(value) for format in self.input_formats: if format.lower() == ISO_8601: try: - parsed = parse_date(value) + parsed = parse_datetime(value) except (ValueError, TypeError): pass else: if parsed is not None: - return parsed + return self.enforce_timezone(parsed) else: try: parsed = datetime.datetime.strptime(value, format) except (ValueError, TypeError): pass else: - return parsed.date() + return self.enforce_timezone(parsed) - msg = self.error_messages['invalid'] % readable_date_formats(self.input_formats) - raise ValidationError(msg) + humanized_format = humanize_datetime.datetime_formats(self.input_formats) + self.fail('invalid', format=humanized_format) - def to_native(self, value): - if value is None or self.format is None: + def to_representation(self, value): + if self.format is None: return value - if isinstance(value, datetime.datetime): - value = value.date() - if self.format.lower() == ISO_8601: - return value.isoformat() + value = value.isoformat() + if value.endswith('+00:00'): + value = value[:-6] + 'Z' + return value return value.strftime(self.format) -class DateTimeField(WritableField): - type_name = 'DateTimeField' - type_label = 'datetime' - widget = widgets.DateTimeInput - form_field_class = forms.DateTimeField - +class DateField(Field): default_error_messages = { - 'invalid': _("Datetime has wrong format. Use one of these formats instead: %s"), + 'invalid': _('Date has wrong format. Use one of these formats instead: {format}'), + 'datetime': _('Expected a date but got a datetime.'), } - empty = None - input_formats = api_settings.DATETIME_INPUT_FORMATS - format = api_settings.DATETIME_FORMAT + format = api_settings.DATE_FORMAT + input_formats = api_settings.DATE_INPUT_FORMATS - def __init__(self, input_formats=None, format=None, *args, **kwargs): + def __init__(self, format=empty, input_formats=None, *args, **kwargs): + self.format = format if format is not empty else self.format self.input_formats = input_formats if input_formats is not None else self.input_formats - self.format = format if format is not None else self.format - super(DateTimeField, self).__init__(*args, **kwargs) - - def from_native(self, value): - if value in validators.EMPTY_VALUES: - return None + super(DateField, self).__init__(*args, **kwargs) + def to_internal_value(self, value): if isinstance(value, datetime.datetime): - return value + self.fail('datetime') + if isinstance(value, datetime.date): - value = datetime.datetime(value.year, value.month, value.day) - if settings.USE_TZ: - # For backwards compatibility, interpret naive datetimes in - # local time. This won't work during DST change, but we can't - # do much about it, so we let the exceptions percolate up the - # call stack. - warnings.warn("DateTimeField received a naive datetime (%s)" - " while time zone support is active." % value, - RuntimeWarning) - default_timezone = timezone.get_default_timezone() - value = timezone.make_aware(value, default_timezone) return value for format in self.input_formats: if format.lower() == ISO_8601: try: - parsed = parse_datetime(value) + parsed = parse_date(value) except (ValueError, TypeError): pass else: @@ -735,45 +891,42 @@ class DateTimeField(WritableField): except (ValueError, TypeError): pass else: - return parsed + return parsed.date() - msg = self.error_messages['invalid'] % readable_datetime_formats(self.input_formats) - raise ValidationError(msg) + humanized_format = humanize_datetime.date_formats(self.input_formats) + self.fail('invalid', format=humanized_format) - def to_native(self, value): - if value is None or self.format is None: + def to_representation(self, value): + if self.format is None: return value + # Applying a `DateField` to a datetime value is almost always + # not a sensible thing to do, as it means naively dropping + # any explicit or implicit timezone info. + assert not isinstance(value, datetime.datetime), ( + 'Expected a `date`, but got a `datetime`. Refusing to coerce, ' + 'as this may mean losing timezone information. Use a custom ' + 'read-only field and deal with timezone issues explicitly.' + ) + if self.format.lower() == ISO_8601: - ret = value.isoformat() - if ret.endswith('+00:00'): - ret = ret[:-6] + 'Z' - return ret + return value.isoformat() return value.strftime(self.format) -class TimeField(WritableField): - type_name = 'TimeField' - type_label = 'time' - widget = widgets.TimeInput - form_field_class = forms.TimeField - +class TimeField(Field): default_error_messages = { - 'invalid': _("Time has wrong format. Use one of these formats instead: %s"), + 'invalid': _('Time has wrong format. Use one of these formats instead: {format}'), } - empty = None - input_formats = api_settings.TIME_INPUT_FORMATS format = api_settings.TIME_FORMAT + input_formats = api_settings.TIME_INPUT_FORMATS - def __init__(self, input_formats=None, format=None, *args, **kwargs): + def __init__(self, format=empty, input_formats=None, *args, **kwargs): + self.format = format if format is not empty else self.format self.input_formats = input_formats if input_formats is not None else self.input_formats - self.format = format if format is not None else self.format super(TimeField, self).__init__(*args, **kwargs) - def from_native(self, value): - if value in validators.EMPTY_VALUES: - return None - + def to_internal_value(self, value): if isinstance(value, datetime.time): return value @@ -794,249 +947,335 @@ class TimeField(WritableField): else: return parsed.time() - msg = self.error_messages['invalid'] % readable_time_formats(self.input_formats) - raise ValidationError(msg) + humanized_format = humanize_datetime.time_formats(self.input_formats) + self.fail('invalid', format=humanized_format) - def to_native(self, value): - if value is None or self.format is None: + def to_representation(self, value): + if self.format is None: return value - if isinstance(value, datetime.datetime): - value = value.time() + # Applying a `TimeField` to a datetime value is almost always + # not a sensible thing to do, as it means naively dropping + # any explicit or implicit timezone info. + assert not isinstance(value, datetime.datetime), ( + 'Expected a `time`, but got a `datetime`. Refusing to coerce, ' + 'as this may mean losing timezone information. Use a custom ' + 'read-only field and deal with timezone issues explicitly.' + ) if self.format.lower() == ISO_8601: return value.isoformat() return value.strftime(self.format) -class IntegerField(WritableField): - type_name = 'IntegerField' - type_label = 'integer' - form_field_class = forms.IntegerField - empty = 0 +# Choice types... +class ChoiceField(Field): default_error_messages = { - 'invalid': _('Enter a whole number.'), - 'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'), - 'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'), + 'invalid_choice': _('`{input}` is not a valid choice.') } - def __init__(self, max_value=None, min_value=None, *args, **kwargs): - self.max_value, self.min_value = max_value, min_value - super(IntegerField, self).__init__(*args, **kwargs) - - if max_value is not None: - self.validators.append(validators.MaxValueValidator(max_value)) - if min_value is not None: - self.validators.append(validators.MinValueValidator(min_value)) - - def from_native(self, value): - if value in validators.EMPTY_VALUES: - return None - - try: - value = int(str(value)) - except (ValueError, TypeError): - raise ValidationError(self.error_messages['invalid']) - return value + def __init__(self, choices, **kwargs): + # Allow either single or paired choices style: + # choices = [1, 2, 3] + # choices = [(1, 'First'), (2, 'Second'), (3, 'Third')] + pairs = [ + isinstance(item, (list, tuple)) and len(item) == 2 + for item in choices + ] + if all(pairs): + self.choices = OrderedDict([(key, display_value) for key, display_value in choices]) + else: + self.choices = OrderedDict([(item, item) for item in choices]) + # Map the string representation of choices to the underlying value. + # Allows us to deal with eg. integer choices while supporting either + # integer or string input, but still get the correct datatype out. + self.choice_strings_to_values = dict([ + (six.text_type(key), key) for key in self.choices.keys() + ]) -class FloatField(WritableField): - type_name = 'FloatField' - type_label = 'float' - form_field_class = forms.FloatField - empty = 0 + self.allow_blank = kwargs.pop('allow_blank', False) - default_error_messages = { - 'invalid': _("'%s' value must be a float."), - } + super(ChoiceField, self).__init__(**kwargs) - def from_native(self, value): - if value in validators.EMPTY_VALUES: - return None + def to_internal_value(self, data): + if data == '' and self.allow_blank: + return '' try: - return float(value) - except (TypeError, ValueError): - msg = self.error_messages['invalid'] % value - raise ValidationError(msg) + return self.choice_strings_to_values[six.text_type(data)] + except KeyError: + self.fail('invalid_choice', input=data) + def to_representation(self, value): + if value in ('', None): + return value + return self.choice_strings_to_values[six.text_type(value)] -class DecimalField(WritableField): - type_name = 'DecimalField' - type_label = 'decimal' - form_field_class = forms.DecimalField - empty = Decimal('0') +class MultipleChoiceField(ChoiceField): default_error_messages = { - 'invalid': _('Enter a number.'), - 'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'), - 'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'), - 'max_digits': _('Ensure that there are no more than %s digits in total.'), - 'max_decimal_places': _('Ensure that there are no more than %s decimal places.'), - 'max_whole_digits': _('Ensure that there are no more than %s digits before the decimal point.') + 'invalid_choice': _('`{input}` is not a valid choice.'), + 'not_a_list': _('Expected a list of items but got type `{input_type}`.') } + default_empty_html = [] - def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, *args, **kwargs): - self.max_value, self.min_value = max_value, min_value - self.max_digits, self.decimal_places = max_digits, decimal_places - super(DecimalField, self).__init__(*args, **kwargs) + def get_value(self, dictionary): + # We override the default field access in order to support + # lists in HTML forms. + if html.is_html_input(dictionary): + return dictionary.getlist(self.field_name) + return dictionary.get(self.field_name, empty) - if max_value is not None: - self.validators.append(validators.MaxValueValidator(max_value)) - if min_value is not None: - self.validators.append(validators.MinValueValidator(min_value)) - - def from_native(self, value): - """ - Validates that the input is a decimal number. Returns a Decimal - instance. Returns None for empty values. Ensures that there are no more - than max_digits in the number, and no more than decimal_places digits - after the decimal point. - """ - if value in validators.EMPTY_VALUES: - return None - value = smart_text(value).strip() - try: - value = Decimal(value) - except DecimalException: - raise ValidationError(self.error_messages['invalid']) - return value + def to_internal_value(self, data): + if isinstance(data, type('')) or not hasattr(data, '__iter__'): + self.fail('not_a_list', input_type=type(data).__name__) - def validate(self, value): - super(DecimalField, self).validate(value) - if value in validators.EMPTY_VALUES: - return - # Check for NaN, Inf and -Inf values. We can't compare directly for NaN, - # since it is never equal to itself. However, NaN is the only value that - # isn't equal to itself, so we can use this to identify NaN - if value != value or value == Decimal("Inf") or value == Decimal("-Inf"): - raise ValidationError(self.error_messages['invalid']) - sign, digittuple, exponent = value.as_tuple() - decimals = abs(exponent) - # digittuple doesn't include any leading zeros. - digits = len(digittuple) - if decimals > digits: - # We have leading zeros up to or past the decimal point. Count - # everything past the decimal point as a digit. We do not count - # 0 before the decimal point as a digit since that would mean - # we would not allow max_digits = decimal_places. - digits = decimals - whole_digits = digits - decimals + return set([ + super(MultipleChoiceField, self).to_internal_value(item) + for item in data + ]) - if self.max_digits is not None and digits > self.max_digits: - raise ValidationError(self.error_messages['max_digits'] % self.max_digits) - if self.decimal_places is not None and decimals > self.decimal_places: - raise ValidationError(self.error_messages['max_decimal_places'] % self.decimal_places) - if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places): - raise ValidationError(self.error_messages['max_whole_digits'] % (self.max_digits - self.decimal_places)) - return value + def to_representation(self, value): + return set([ + self.choice_strings_to_values[six.text_type(item)] for item in value + ]) -class FileField(WritableField): - use_files = True - type_name = 'FileField' - type_label = 'file upload' - form_field_class = forms.FileField - widget = widgets.FileInput +# File types... +class FileField(Field): default_error_messages = { - 'invalid': _("No file was submitted. Check the encoding type on the form."), - 'missing': _("No file was submitted."), + 'required': _("No file was submitted."), + 'invalid': _("The submitted data was not a file. Check the encoding type on the form."), + 'no_name': _("No filename could be determined."), 'empty': _("The submitted file is empty."), - 'max_length': _('Ensure this filename has at most %(max)d characters (it has %(length)d).'), - 'contradiction': _('Please either submit a file or check the clear checkbox, not both.') + 'max_length': _('Ensure this filename has at most {max_length} characters (it has {length}).'), } + use_url = api_settings.UPLOADED_FILES_USE_URL def __init__(self, *args, **kwargs): self.max_length = kwargs.pop('max_length', None) self.allow_empty_file = kwargs.pop('allow_empty_file', False) + self.use_url = kwargs.pop('use_url', self.use_url) super(FileField, self).__init__(*args, **kwargs) - def from_native(self, data): - if data in validators.EMPTY_VALUES: - return None - - # UploadedFile objects should have name and size attributes. + def to_internal_value(self, data): try: + # `UploadedFile` objects should have name and size attributes. file_name = data.name file_size = data.size except AttributeError: - raise ValidationError(self.error_messages['invalid']) + self.fail('invalid') - if self.max_length is not None and len(file_name) > self.max_length: - error_values = {'max': self.max_length, 'length': len(file_name)} - raise ValidationError(self.error_messages['max_length'] % error_values) if not file_name: - raise ValidationError(self.error_messages['invalid']) + self.fail('no_name') if not self.allow_empty_file and not file_size: - raise ValidationError(self.error_messages['empty']) + self.fail('empty') + if self.max_length and len(file_name) > self.max_length: + self.fail('max_length', max_length=self.max_length, length=len(file_name)) return data - def to_native(self, value): + def to_representation(self, value): + if self.use_url: + if not value: + return None + url = value.url + request = self.context.get('request', None) + if request is not None: + return request.build_absolute_uri(url) + return url return value.name class ImageField(FileField): - use_files = True - type_name = 'ImageField' - type_label = 'image upload' - form_field_class = forms.ImageField + default_error_messages = { + 'invalid_image': _( + 'Upload a valid image. The file you uploaded was either not an ' + 'image or a corrupted image.' + ), + } + def __init__(self, *args, **kwargs): + self._DjangoImageField = kwargs.pop('_DjangoImageField', DjangoImageField) + super(ImageField, self).__init__(*args, **kwargs) + + def to_internal_value(self, data): + # Image validation is a bit grungy, so we'll just outright + # defer to Django's implementation so we don't need to + # consider it, or treat PIL as a test dependency. + file_object = super(ImageField, self).to_internal_value(data) + django_field = self._DjangoImageField() + django_field.error_messages = self.error_messages + django_field.to_python(file_object) + return file_object + + +# Composite field types... + +class ListField(Field): + child = None + initial = [] default_error_messages = { - 'invalid_image': _("Upload a valid image. The file you uploaded was " - "either not an image or a corrupted image."), + 'not_a_list': _('Expected a list of items but got type `{input_type}`') } - def from_native(self, data): + def __init__(self, *args, **kwargs): + self.child = kwargs.pop('child', copy.deepcopy(self.child)) + assert self.child is not None, '`child` is a required argument.' + assert not inspect.isclass(self.child), '`child` has not been instantiated.' + super(ListField, self).__init__(*args, **kwargs) + self.child.bind(field_name='', parent=self) + + def get_value(self, dictionary): + # We override the default field access in order to support + # lists in HTML forms. + if html.is_html_input(dictionary): + return html.parse_html_list(dictionary, prefix=self.field_name) + return dictionary.get(self.field_name, empty) + + def to_internal_value(self, data): """ - Checks that the file-upload field data contains a valid image (GIF, JPG, - PNG, possibly others -- whatever the Python Imaging Library supports). + List of dicts of native values <- List of dicts of primitive datatypes. """ - f = super(ImageField, self).from_native(data) - if f is None: - return None + if html.is_html_input(data): + data = html.parse_html_list(data) + if isinstance(data, type('')) or not hasattr(data, '__iter__'): + self.fail('not_a_list', input_type=type(data).__name__) + return [self.child.run_validation(item) for item in data] - from rest_framework.compat import Image - assert Image is not None, 'Either Pillow or PIL must be installed for ImageField support.' + def to_representation(self, data): + """ + List of object instances -> List of dicts of primitive datatypes. + """ + return [self.child.to_representation(item) for item in data] - # We need to get a file object for PIL. We might have a path or we might - # have to read the data into memory. - if hasattr(data, 'temporary_file_path'): - file = data.temporary_file_path() - else: - if hasattr(data, 'read'): - file = BytesIO(data.read()) - else: - file = BytesIO(data['content']) - try: - # load() could spot a truncated JPEG, but it loads the entire - # image in memory, which is a DoS vector. See #3848 and #18520. - # verify() must be called immediately after the constructor. - Image.open(file).verify() - except ImportError: - # Under PyPy, it is possible to import PIL. However, the underlying - # _imaging C module isn't available, so an ImportError will be - # raised. Catch and re-raise. - raise - except Exception: # Python Imaging Library doesn't recognize it as an image - raise ValidationError(self.error_messages['invalid_image']) - if hasattr(f, 'seek') and callable(f.seek): - f.seek(0) - return f +# Miscellaneous field types... +class ReadOnlyField(Field): + """ + A read-only field that simply returns the field value. -class SerializerMethodField(Field): + If the field is a method with no parameters, the method will be called + and it's return value used as the representation. + + For example, the following would call `get_expiry_date()` on the object: + + class ExampleSerializer(self): + expiry_date = ReadOnlyField(source='get_expiry_date') """ - A field that gets its value by calling a method on the serializer it's attached to. + + def __init__(self, **kwargs): + kwargs['read_only'] = True + super(ReadOnlyField, self).__init__(**kwargs) + + def to_representation(self, value): + return value + + +class HiddenField(Field): + """ + A hidden field does not take input from the user, or present any output, + but it does populate a field in `validated_data`, based on its default + value. This is particularly useful when we have a `unique_for_date` + constraint on a pair of fields, as we need some way to include the date in + the validated data. + """ + def __init__(self, **kwargs): + assert 'default' in kwargs, 'default is a required argument.' + kwargs['write_only'] = True + super(HiddenField, self).__init__(**kwargs) + + def get_value(self, dictionary): + # We always use the default value for `HiddenField`. + # User input is never provided or accepted. + return empty + + def to_internal_value(self, data): + return data + + +class SerializerMethodField(Field): """ + A read-only field that get its representation from calling a method on the + parent serializer class. The method called will be of the form + "get_{field_name}", and should take a single argument, which is the + object being serialized. - def __init__(self, method_name, *args, **kwargs): + For example: + + class ExampleSerializer(self): + extra_info = SerializerMethodField() + + def get_extra_info(self, obj): + return ... # Calculate some data to return. + """ + def __init__(self, method_name=None, **kwargs): self.method_name = method_name - super(SerializerMethodField, self).__init__(*args, **kwargs) + kwargs['source'] = '*' + kwargs['read_only'] = True + super(SerializerMethodField, self).__init__(**kwargs) + + def bind(self, field_name, parent): + # In order to enforce a consistent style, we error if a redundant + # 'method_name' argument has been used. For example: + # my_field = serializer.CharField(source='my_field') + default_method_name = 'get_{field_name}'.format(field_name=field_name) + assert self.method_name != default_method_name, ( + "It is redundant to specify `%s` on SerializerMethodField '%s' in " + "serializer '%s', because it is the same as the default method name. " + "Remove the `method_name` argument." % + (self.method_name, field_name, parent.__class__.__name__) + ) + + # The method name should default to `get_{field_name}`. + if self.method_name is None: + self.method_name = default_method_name + + super(SerializerMethodField, self).bind(field_name, parent) + + def to_representation(self, value): + method = getattr(self.parent, self.method_name) + return method(value) + + +class ModelField(Field): + """ + A generic field that can be used against an arbitrary model field. + + This is used by `ModelSerializer` when dealing with custom model fields, + that do not have a serializer field to be mapped to. + """ + default_error_messages = { + 'max_length': _('Ensure this field has no more than {max_length} characters.'), + } + + def __init__(self, model_field, **kwargs): + self.model_field = model_field + # The `max_length` option is supported by Django's base `Field` class, + # so we'd better support it here. + max_length = kwargs.pop('max_length', None) + super(ModelField, self).__init__(**kwargs) + if max_length is not None: + message = self.error_messages['max_length'].format(max_length=max_length) + self.validators.append(MaxLengthValidator(max_length, message=message)) + + def to_internal_value(self, data): + rel = getattr(self.model_field, 'rel', None) + if rel is not None: + return rel.to._meta.get_field(rel.field_name).to_python(data) + return self.model_field.to_python(data) + + def get_attribute(self, obj): + # We pass the object instance onto `to_representation`, + # not just the field attribute. + return obj - def field_to_native(self, obj, field_name): - value = getattr(self.parent, self.method_name)(obj) - return self.to_native(value) + def to_representation(self, obj): + value = self.model_field._get_val_from_obj(obj) + if is_protected_type(value): + return value + return self.model_field.value_to_string(obj) |
