diff options
Diffstat (limited to 'rest_framework/fields.py')
| -rw-r--r-- | rest_framework/fields.py | 794 | 
1 files changed, 594 insertions, 200 deletions
| diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 0c78b3fb..ca9c479f 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1,18 +1,25 @@  from django.conf import settings -from django.core import validators -from django.core.exceptions import ValidationError -from django.utils import timezone +from django.core.exceptions import ObjectDoesNotExist +from django.core.exceptions import ValidationError as DjangoValidationError +from django.core.validators import RegexValidator +from django.forms import ImageField as DjangoImageField +from django.utils import six, timezone  from django.utils.dateparse import parse_date, parse_datetime, parse_time  from django.utils.encoding import is_protected_type  from django.utils.translation import ugettext_lazy as _  from rest_framework import ISO_8601 -from rest_framework.compat import smart_text +from rest_framework.compat import ( +    smart_text, EmailValidator, MinValueValidator, MaxValueValidator, +    MinLengthValidator, MaxLengthValidator, URLValidator, OrderedDict +) +from rest_framework.exceptions import ValidationError  from rest_framework.settings import api_settings  from rest_framework.utils import html, representation, humanize_datetime +import copy  import datetime  import decimal  import inspect -import warnings +import re  class empty: @@ -49,13 +56,20 @@ def get_attribute(instance, attrs):      Also accepts either attribute lookup on objects or dictionary lookups.      """      for attr in attrs: +        if instance is None: +            # Break out early if we get `None` at any point in a nested lookup. +            return None          try:              instance = getattr(instance, attr) +        except ObjectDoesNotExist: +            return None          except AttributeError as exc:              try:                  return instance[attr] -            except (KeyError, TypeError): +            except (KeyError, TypeError, AttributeError):                  raise exc +        if is_simple_callable(instance): +            instance = instance()      return instance @@ -80,14 +94,48 @@ def set_value(dictionary, keys, value):      dictionary[keys[-1]] = value +class CreateOnlyDefault: +    """ +    This class may be used to provide default values that are only used +    for create operations, but that do not return any value for update +    operations. +    """ +    def __init__(self, default): +        self.default = default + +    def set_context(self, serializer_field): +        self.is_update = serializer_field.parent.instance is not None + +    def __call__(self): +        if self.is_update: +            raise SkipField() +        if callable(self.default): +            return self.default() +        return self.default + +    def __repr__(self): +        return '%s(%s)' % (self.__class__.__name__, repr(self.default)) + + +class CurrentUserDefault: +    def set_context(self, serializer_field): +        self.user = serializer_field.context['request'].user + +    def __call__(self): +        return self.user + +    def __repr__(self): +        return '%s()' % self.__class__.__name__ + +  class SkipField(Exception):      pass  NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`'  NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`' -NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`'  NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`' +USE_READONLYFIELD = 'Field(read_only=True) should be ReadOnlyField'  MISSING_ERROR_MESSAGE = (      'ValidationError raised by `{class_name}`, but error key `{key}` does '      'not exist in the `error_messages` dictionary.' @@ -98,14 +146,17 @@ class Field(object):      _creation_counter = 0      default_error_messages = { -        'required': _('This field is required.') +        'required': _('This field is required.'), +        'null': _('This field may not be null.')      }      default_validators = [] +    default_empty_html = empty +    initial = None      def __init__(self, read_only=False, write_only=False, -                 required=None, default=empty, initial=None, source=None, +                 required=None, default=empty, initial=empty, source=None,                   label=None, help_text=None, style=None, -                 error_messages=None, validators=[]): +                 error_messages=None, validators=None, allow_null=False):          self._creation_counter = Field._creation_counter          Field._creation_counter += 1 @@ -116,19 +167,29 @@ class Field(object):          # Some combinations of keyword arguments do not make sense.          assert not (read_only and write_only), NOT_READ_ONLY_WRITE_ONLY          assert not (read_only and required), NOT_READ_ONLY_REQUIRED -        assert not (read_only and default is not empty), NOT_READ_ONLY_DEFAULT          assert not (required and default is not empty), NOT_REQUIRED_DEFAULT +        assert not (read_only and self.__class__ == Field), USE_READONLYFIELD          self.read_only = read_only          self.write_only = write_only          self.required = required          self.default = default          self.source = source -        self.initial = initial +        self.initial = self.initial if (initial is empty) else initial          self.label = label          self.help_text = help_text          self.style = {} if style is None else style -        self.validators = validators or self.default_validators[:] +        self.allow_null = allow_null + +        if allow_null and self.default_empty_html is empty: +            self.default_empty_html = None + +        if validators is not None: +            self.validators = validators[:] + +        # These are set up by `.bind()` when the field is added to a serializer. +        self.field_name = None +        self.parent = None          # Collect default error message from self and parent classes          messages = {} @@ -137,26 +198,26 @@ class Field(object):          messages.update(error_messages or {})          self.error_messages = messages -    def __new__(cls, *args, **kwargs): +    def bind(self, field_name, parent):          """ -        When a field is instantiated, we store the arguments that were used, -        so that we can present a helpful representation of the object. +        Initializes the field name and parent for the field instance. +        Called when a field is added to the parent serializer instance.          """ -        instance = super(Field, cls).__new__(cls) -        instance._args = args -        instance._kwargs = kwargs -        return instance -    def bind(self, field_name, parent, root): -        """ -        Setup the context for the field instance. -        """ +        # In order to enforce a consistent style, we error if a redundant +        # 'source' argument has been used. For example: +        # my_field = serializer.CharField(source='my_field') +        assert self.source != field_name, ( +            "It is redundant to specify `source='%s'` on field '%s' in " +            "serializer '%s', because it is the same as the field name. " +            "Remove the `source` keyword argument." % +            (field_name, self.__class__.__name__, parent.__class__.__name__) +        ) +          self.field_name = field_name          self.parent = parent -        self.root = root -        self.context = parent.context -        # `self.label` should deafult to being based on the field name. +        # `self.label` should default to being based on the field name.          if self.label is None:              self.label = field_name.replace('_', ' ').capitalize() @@ -171,24 +232,48 @@ class Field(object):          else:              self.source_attrs = self.source.split('.') +    # .validators is a lazily loaded property, that gets its default +    # value from `get_validators`. +    @property +    def validators(self): +        if not hasattr(self, '_validators'): +            self._validators = self.get_validators() +        return self._validators + +    @validators.setter +    def validators(self, validators): +        self._validators = validators + +    def get_validators(self): +        return self.default_validators[:] +      def get_initial(self):          """ -        Return a value to use when the field is being returned as a primative +        Return a value to use when the field is being returned as a primitive          value, without any object instance.          """          return self.initial      def get_value(self, dictionary):          """ -        Given the *incoming* primative data, return the value for this field +        Given the *incoming* primitive data, return the value for this field          that should be validated and transformed to a native value.          """ +        if html.is_html_input(dictionary): +            # HTML forms will represent empty fields as '', and cannot +            # represent None or False values directly. +            if self.field_name not in dictionary: +                if getattr(self.root, 'partial', False): +                    return empty +                return self.default_empty_html +            ret = dictionary[self.field_name] +            return self.default_empty_html if (ret == '') else ret          return dictionary.get(self.field_name, empty)      def get_attribute(self, instance):          """ -        Given the *outgoing* object instance, return the value for this field -        that should be returned as a primative value. +        Given the *outgoing* object instance, return the primitive value +        that should be used for this field.          """          return get_attribute(instance, self.source_attrs) @@ -203,47 +288,74 @@ class Field(object):          """          if self.default is empty:              raise SkipField() +        if callable(self.default): +            if hasattr(self.default, 'set_context'): +                self.default.set_context(self) +            return self.default()          return self.default      def run_validation(self, data=empty):          """          Validate a simple representation and return the internal value. -        The provided data may be `empty` if no representation was included. -        May return `empty` if the field should not be included in the +        The provided data may be `empty` if no representation was included +        in the input. + +        May raise `SkipField` if the field should not be included in the          validated data.          """ +        if self.read_only: +            return self.get_default() +          if data is empty: +            if getattr(self.root, 'partial', False): +                raise SkipField()              if self.required:                  self.fail('required')              return self.get_default() +        if data is None: +            if not self.allow_null: +                self.fail('null') +            return None +          value = self.to_internal_value(data)          self.run_validators(value)          return value      def run_validators(self, value): -        if value in (None, '', [], (), {}): -            return - +        """ +        Test the given value against all the validators on the field, +        and either raise a `ValidationError` or simply return. +        """          errors = []          for validator in self.validators: +            if hasattr(validator, 'set_context'): +                validator.set_context(self) +              try:                  validator(value)              except ValidationError as exc: +                # If the validation error contains a mapping of fields to +                # errors then simply raise it immediately rather than +                # attempting to accumulate a list of errors. +                if isinstance(exc.detail, dict): +                    raise +                errors.extend(exc.detail) +            except DjangoValidationError as exc:                  errors.extend(exc.messages)          if errors:              raise ValidationError(errors)      def to_internal_value(self, data):          """ -        Transform the *incoming* primative data into a native value. +        Transform the *incoming* primitive data into a native value.          """          raise NotImplementedError('to_internal_value() must be implemented.')      def to_representation(self, value):          """ -        Transform the *outgoing* native value into primative data. +        Transform the *outgoing* native value into primitive data.          """          raise NotImplementedError('to_representation() must be implemented.') @@ -257,9 +369,59 @@ class Field(object):              class_name = self.__class__.__name__              msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key)              raise AssertionError(msg) -        raise ValidationError(msg.format(**kwargs)) +        message_string = msg.format(**kwargs) +        raise ValidationError(message_string) + +    @property +    def root(self): +        """ +        Returns the top-level serializer for this field. +        """ +        root = self +        while root.parent is not None: +            root = root.parent +        return root + +    @property +    def context(self): +        """ +        Returns the context as passed to the root serializer on initialization. +        """ +        return getattr(self.root, '_context', {}) + +    def __new__(cls, *args, **kwargs): +        """ +        When a field is instantiated, we store the arguments that were used, +        so that we can present a helpful representation of the object. +        """ +        instance = super(Field, cls).__new__(cls) +        instance._args = args +        instance._kwargs = kwargs +        return instance + +    def __deepcopy__(self, memo): +        """ +        When cloning fields we instantiate using the arguments it was +        originally created with, rather than copying the complete state. +        """ +        args = copy.deepcopy(self._args) +        kwargs = dict(self._kwargs) +        # Bit ugly, but we need to special case 'validators' as Django's +        # RegexValidator does not support deepcopy. +        # We treat validator callables as immutable objects. +        # See https://github.com/tomchristie/django-rest-framework/issues/1954 +        validators = kwargs.pop('validators', None) +        kwargs = copy.deepcopy(kwargs) +        if validators is not None: +            kwargs['validators'] = validators +        return self.__class__(*args, **kwargs)      def __repr__(self): +        """ +        Fields are represented using their initial calling arguments. +        This allows us to create descriptive representations for serializer +        instances that show all the declared fields on the serializer. +        """          return representation.field_repr(self) @@ -269,25 +431,55 @@ class BooleanField(Field):      default_error_messages = {          'invalid': _('`{input}` is not a valid boolean.')      } +    default_empty_html = False +    initial = False      TRUE_VALUES = set(('t', 'T', 'true', 'True', 'TRUE', '1', 1, True))      FALSE_VALUES = set(('f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False)) -    def get_value(self, dictionary): -        if html.is_html_input(dictionary): -            # HTML forms do not send a `False` value on an empty checkbox, -            # so we override the default empty value to be False. -            return dictionary.get(self.field_name, False) -        return dictionary.get(self.field_name, empty) +    def __init__(self, **kwargs): +        assert 'allow_null' not in kwargs, '`allow_null` is not a valid option. Use `NullBooleanField` instead.' +        super(BooleanField, self).__init__(**kwargs) + +    def to_internal_value(self, data): +        if data in self.TRUE_VALUES: +            return True +        elif data in self.FALSE_VALUES: +            return False +        self.fail('invalid', input=data) + +    def to_representation(self, value): +        if value in self.TRUE_VALUES: +            return True +        elif value in self.FALSE_VALUES: +            return False +        return bool(value) + + +class NullBooleanField(Field): +    default_error_messages = { +        'invalid': _('`{input}` is not a valid boolean.') +    } +    initial = None +    TRUE_VALUES = set(('t', 'T', 'true', 'True', 'TRUE', '1', 1, True)) +    FALSE_VALUES = set(('f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False)) +    NULL_VALUES = set(('n', 'N', 'null', 'Null', 'NULL', '', None)) + +    def __init__(self, **kwargs): +        assert 'allow_null' not in kwargs, '`allow_null` is not a valid option.' +        kwargs['allow_null'] = True +        super(NullBooleanField, self).__init__(**kwargs)      def to_internal_value(self, data):          if data in self.TRUE_VALUES:              return True          elif data in self.FALSE_VALUES:              return False +        elif data in self.NULL_VALUES: +            return None          self.fail('invalid', input=data)      def to_representation(self, value): -        if value is None: +        if value in self.NULL_VALUES:              return None          if value in self.TRUE_VALUES:              return True @@ -300,136 +492,174 @@ class BooleanField(Field):  class CharField(Field):      default_error_messages = { -        'blank': _('This field may not be blank.') +        'blank': _('This field may not be blank.'), +        'max_length': _('Ensure this field has no more than {max_length} characters.'), +        'min_length': _('Ensure this field has no more than {min_length} characters.')      } +    initial = '' +    coerce_blank_to_null = False +    default_empty_html = ''      def __init__(self, **kwargs):          self.allow_blank = kwargs.pop('allow_blank', False) -        self.max_length = kwargs.pop('max_length', None) -        self.min_length = kwargs.pop('min_length', None) +        max_length = kwargs.pop('max_length', None) +        min_length = kwargs.pop('min_length', None)          super(CharField, self).__init__(**kwargs) +        if max_length is not None: +            message = self.error_messages['max_length'].format(max_length=max_length) +            self.validators.append(MaxLengthValidator(max_length, message=message)) +        if min_length is not None: +            message = self.error_messages['min_length'].format(min_length=min_length) +            self.validators.append(MinLengthValidator(min_length, message=message)) + +    def run_validation(self, data=empty): +        # Test for the empty string here so that it does not get validated, +        # and so that subclasses do not need to handle it explicitly +        # inside the `to_internal_value()` method. +        if data == '': +            if not self.allow_blank: +                self.fail('blank') +            return '' +        return super(CharField, self).run_validation(data)      def to_internal_value(self, data): -        if data == '' and not self.allow_blank: -            self.fail('blank') -        if data is None: -            return None -        return str(data) +        return six.text_type(data)      def to_representation(self, value): -        if value is None: -            return None -        return str(value) +        return six.text_type(value)  class EmailField(CharField):      default_error_messages = {          'invalid': _('Enter a valid email address.')      } -    default_validators = [validators.validate_email] + +    def __init__(self, **kwargs): +        super(EmailField, self).__init__(**kwargs) +        validator = EmailValidator(message=self.error_messages['invalid']) +        self.validators.append(validator)      def to_internal_value(self, data): -        if data == '' and not self.allow_blank: -            self.fail('blank') -        if data is None: -            return None -        return str(data).strip() +        return six.text_type(data).strip()      def to_representation(self, value): -        if value is None: -            return None -        return str(value).strip() +        return six.text_type(value).strip()  class RegexField(CharField): +    default_error_messages = { +        'invalid': _('This value does not match the required pattern.') +    } +      def __init__(self, regex, **kwargs): -        kwargs['validators'] = ( -            [validators.RegexValidator(regex)] + -            kwargs.get('validators', []) -        )          super(RegexField, self).__init__(**kwargs) +        validator = RegexValidator(regex, message=self.error_messages['invalid']) +        self.validators.append(validator)  class SlugField(CharField):      default_error_messages = {          'invalid': _("Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens.")      } -    default_validators = [validators.validate_slug] + +    def __init__(self, **kwargs): +        super(SlugField, self).__init__(**kwargs) +        slug_regex = re.compile(r'^[-a-zA-Z0-9_]+$') +        validator = RegexValidator(slug_regex, message=self.error_messages['invalid']) +        self.validators.append(validator)  class URLField(CharField):      default_error_messages = {          'invalid': _("Enter a valid URL.")      } -    default_validators = [validators.URLValidator()] + +    def __init__(self, **kwargs): +        super(URLField, self).__init__(**kwargs) +        validator = URLValidator(message=self.error_messages['invalid']) +        self.validators.append(validator)  # Number types...  class IntegerField(Field):      default_error_messages = { -        'invalid': _('A valid integer is required.') +        'invalid': _('A valid integer is required.'), +        'max_value': _('Ensure this value is less than or equal to {max_value}.'), +        'min_value': _('Ensure this value is greater than or equal to {min_value}.'), +        'max_string_length': _('String value too large')      } +    MAX_STRING_LENGTH = 1000  # Guard against malicious string inputs.      def __init__(self, **kwargs):          max_value = kwargs.pop('max_value', None)          min_value = kwargs.pop('min_value', None)          super(IntegerField, self).__init__(**kwargs)          if max_value is not None: -            self.validators.append(validators.MaxValueValidator(max_value)) +            message = self.error_messages['max_value'].format(max_value=max_value) +            self.validators.append(MaxValueValidator(max_value, message=message))          if min_value is not None: -            self.validators.append(validators.MinValueValidator(min_value)) +            message = self.error_messages['min_value'].format(min_value=min_value) +            self.validators.append(MinValueValidator(min_value, message=message))      def to_internal_value(self, data): +        if isinstance(data, six.text_type) and len(data) > self.MAX_STRING_LENGTH: +            self.fail('max_string_length') +          try: -            data = int(str(data)) +            data = int(data)          except (ValueError, TypeError):              self.fail('invalid')          return data      def to_representation(self, value): -        if value is None: -            return None          return int(value)  class FloatField(Field):      default_error_messages = { -        'invalid': _("'%s' value must be a float."), +        'invalid': _("A valid number is required."), +        'max_value': _('Ensure this value is less than or equal to {max_value}.'), +        'min_value': _('Ensure this value is greater than or equal to {min_value}.'), +        'max_string_length': _('String value too large')      } +    MAX_STRING_LENGTH = 1000  # Guard against malicious string inputs.      def __init__(self, **kwargs):          max_value = kwargs.pop('max_value', None)          min_value = kwargs.pop('min_value', None)          super(FloatField, self).__init__(**kwargs)          if max_value is not None: -            self.validators.append(validators.MaxValueValidator(max_value)) +            message = self.error_messages['max_value'].format(max_value=max_value) +            self.validators.append(MaxValueValidator(max_value, message=message))          if min_value is not None: -            self.validators.append(validators.MinValueValidator(min_value)) +            message = self.error_messages['min_value'].format(min_value=min_value) +            self.validators.append(MinValueValidator(min_value, message=message)) -    def to_internal_value(self, value): -        if value is None: -            return None -        return float(value) +    def to_internal_value(self, data): +        if isinstance(data, six.text_type) and len(data) > self.MAX_STRING_LENGTH: +            self.fail('max_string_length') -    def to_representation(self, value): -        if value is None: -            return None          try: -            return float(value) +            return float(data)          except (TypeError, ValueError): -            self.fail('invalid', value=value) +            self.fail('invalid') + +    def to_representation(self, value): +        return float(value)  class DecimalField(Field):      default_error_messages = { -        'invalid': _('Enter a number.'), +        'invalid': _('A valid number is required.'),          'max_value': _('Ensure this value is less than or equal to {max_value}.'),          'min_value': _('Ensure this value is greater than or equal to {min_value}.'),          'max_digits': _('Ensure that there are no more than {max_digits} digits in total.'),          'max_decimal_places': _('Ensure that there are no more than {max_decimal_places} decimal places.'), -        'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.') +        'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.'), +        'max_string_length': _('String value too large')      } +    MAX_STRING_LENGTH = 1000  # Guard against malicious string inputs.      coerce_to_string = api_settings.COERCE_DECIMAL_TO_STRING @@ -439,23 +669,25 @@ class DecimalField(Field):          self.coerce_to_string = coerce_to_string if (coerce_to_string is not None) else self.coerce_to_string          super(DecimalField, self).__init__(**kwargs)          if max_value is not None: -            self.validators.append(validators.MaxValueValidator(max_value)) +            message = self.error_messages['max_value'].format(max_value=max_value) +            self.validators.append(MaxValueValidator(max_value, message=message))          if min_value is not None: -            self.validators.append(validators.MinValueValidator(min_value)) +            message = self.error_messages['min_value'].format(min_value=min_value) +            self.validators.append(MinValueValidator(min_value, message=message)) -    def to_internal_value(self, value): +    def to_internal_value(self, data):          """          Validates that the input is a decimal number. Returns a Decimal          instance. Returns None for empty values. Ensures that there are no more          than max_digits in the number, and no more than decimal_places digits          after the decimal point.          """ -        if value in (None, ''): -            return None +        data = smart_text(data).strip() +        if len(data) > self.MAX_STRING_LENGTH: +            self.fail('max_string_length') -        value = smart_text(value).strip()          try: -            value = decimal.Decimal(value) +            value = decimal.Decimal(data)          except decimal.DecimalException:              self.fail('invalid') @@ -485,125 +717,116 @@ class DecimalField(Field):          if self.decimal_places is not None and decimals > self.decimal_places:              self.fail('max_decimal_places', max_decimal_places=self.decimal_places)          if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places): -            self.fail('max_whole_digits', max_while_digits=self.max_digits - self.decimal_places) +            self.fail('max_whole_digits', max_whole_digits=self.max_digits - self.decimal_places)          return value      def to_representation(self, value): -        if isinstance(value, decimal.Decimal): -            context = decimal.getcontext().copy() -            context.prec = self.max_digits -            quantized = value.quantize( -                decimal.Decimal('.1') ** self.decimal_places, -                context=context -            ) -            if not self.coerce_to_string: -                return quantized -            return '{0:f}'.format(quantized) - +        if not isinstance(value, decimal.Decimal): +            value = decimal.Decimal(six.text_type(value).strip()) + +        context = decimal.getcontext().copy() +        context.prec = self.max_digits +        quantized = value.quantize( +            decimal.Decimal('.1') ** self.decimal_places, +            context=context +        )          if not self.coerce_to_string: -            return value -        return '%.*f' % (self.max_decimal_places, value) +            return quantized +        return '{0:f}'.format(quantized)  # Date & time fields... -class DateField(Field): +class DateTimeField(Field):      default_error_messages = { -        'invalid': _('Date has wrong format. Use one of these formats instead: {format}'), +        'invalid': _('Datetime has wrong format. Use one of these formats instead: {format}'), +        'date': _('Expected a datetime but got a date.'),      } -    format = api_settings.DATE_FORMAT -    input_formats = api_settings.DATE_INPUT_FORMATS +    format = api_settings.DATETIME_FORMAT +    input_formats = api_settings.DATETIME_INPUT_FORMATS +    default_timezone = timezone.get_default_timezone() if settings.USE_TZ else None -    def __init__(self, format=None, input_formats=None, *args, **kwargs): -        self.format = format if format is not None else self.format +    def __init__(self, format=empty, input_formats=None, default_timezone=None, *args, **kwargs): +        self.format = format if format is not empty else self.format          self.input_formats = input_formats if input_formats is not None else self.input_formats -        super(DateField, self).__init__(*args, **kwargs) +        self.default_timezone = default_timezone if default_timezone is not None else self.default_timezone +        super(DateTimeField, self).__init__(*args, **kwargs) + +    def enforce_timezone(self, value): +        """ +        When `self.default_timezone` is `None`, always return naive datetimes. +        When `self.default_timezone` is not `None`, always return aware datetimes. +        """ +        if (self.default_timezone is not None) and not timezone.is_aware(value): +            return timezone.make_aware(value, self.default_timezone) +        elif (self.default_timezone is None) and timezone.is_aware(value): +            return timezone.make_naive(value, timezone.UTC()) +        return value      def to_internal_value(self, value): -        if value in (None, ''): -            return None +        if isinstance(value, datetime.date) and not isinstance(value, datetime.datetime): +            self.fail('date')          if isinstance(value, datetime.datetime): -            if timezone and settings.USE_TZ and timezone.is_aware(value): -                # Convert aware datetimes to the default time zone -                # before casting them to dates (#17742). -                default_timezone = timezone.get_default_timezone() -                value = timezone.make_naive(value, default_timezone) -            return value.date() - -        if isinstance(value, datetime.date): -            return value +            return self.enforce_timezone(value)          for format in self.input_formats:              if format.lower() == ISO_8601:                  try: -                    parsed = parse_date(value) +                    parsed = parse_datetime(value)                  except (ValueError, TypeError):                      pass                  else:                      if parsed is not None: -                        return parsed +                        return self.enforce_timezone(parsed)              else:                  try:                      parsed = datetime.datetime.strptime(value, format)                  except (ValueError, TypeError):                      pass                  else: -                    return parsed.date() +                    return self.enforce_timezone(parsed) -        humanized_format = humanize_datetime.date_formats(self.input_formats) +        humanized_format = humanize_datetime.datetime_formats(self.input_formats)          self.fail('invalid', format=humanized_format)      def to_representation(self, value): -        if value is None or self.format is None: +        if self.format is None:              return value -        if isinstance(value, datetime.datetime): -            value = value.date() -          if self.format.lower() == ISO_8601: -            return value.isoformat() +            value = value.isoformat() +            if value.endswith('+00:00'): +                value = value[:-6] + 'Z' +            return value          return value.strftime(self.format) -class DateTimeField(Field): +class DateField(Field):      default_error_messages = { -        'invalid': _('Datetime has wrong format. Use one of these formats instead: {format}'), +        'invalid': _('Date has wrong format. Use one of these formats instead: {format}'), +        'datetime': _('Expected a date but got a datetime.'),      } -    format = api_settings.DATETIME_FORMAT -    input_formats = api_settings.DATETIME_INPUT_FORMATS +    format = api_settings.DATE_FORMAT +    input_formats = api_settings.DATE_INPUT_FORMATS -    def __init__(self, format=None, input_formats=None, *args, **kwargs): -        self.format = format if format is not None else self.format +    def __init__(self, format=empty, input_formats=None, *args, **kwargs): +        self.format = format if format is not empty else self.format          self.input_formats = input_formats if input_formats is not None else self.input_formats -        super(DateTimeField, self).__init__(*args, **kwargs) +        super(DateField, self).__init__(*args, **kwargs)      def to_internal_value(self, value): -        if value in (None, ''): -            return None -          if isinstance(value, datetime.datetime): -            return value +            self.fail('datetime')          if isinstance(value, datetime.date): -            value = datetime.datetime(value.year, value.month, value.day) -            if settings.USE_TZ: -                # For backwards compatibility, interpret naive datetimes in -                # local time. This won't work during DST change, but we can't -                # do much about it, so we let the exceptions percolate up the -                # call stack. -                warnings.warn("DateTimeField received a naive datetime (%s)" -                              " while time zone support is active." % value, -                              RuntimeWarning) -                default_timezone = timezone.get_default_timezone() -                value = timezone.make_aware(value, default_timezone)              return value          for format in self.input_formats:              if format.lower() == ISO_8601:                  try: -                    parsed = parse_datetime(value) +                    parsed = parse_date(value)                  except (ValueError, TypeError):                      pass                  else: @@ -615,20 +838,26 @@ class DateTimeField(Field):                  except (ValueError, TypeError):                      pass                  else: -                    return parsed +                    return parsed.date() -        humanized_format = humanize_datetime.datetime_formats(self.input_formats) +        humanized_format = humanize_datetime.date_formats(self.input_formats)          self.fail('invalid', format=humanized_format)      def to_representation(self, value): -        if value is None or self.format is None: +        if self.format is None:              return value +        # Applying a `DateField` to a datetime value is almost always +        # not a sensible thing to do, as it means naively dropping +        # any explicit or implicit timezone info. +        assert not isinstance(value, datetime.datetime), ( +            'Expected a `date`, but got a `datetime`. Refusing to coerce, ' +            'as this may mean losing timezone information. Use a custom ' +            'read-only field and deal with timezone issues explicitly.' +        ) +          if self.format.lower() == ISO_8601: -            ret = value.isoformat() -            if ret.endswith('+00:00'): -                ret = ret[:-6] + 'Z' -            return ret +            return value.isoformat()          return value.strftime(self.format) @@ -639,15 +868,12 @@ class TimeField(Field):      format = api_settings.TIME_FORMAT      input_formats = api_settings.TIME_INPUT_FORMATS -    def __init__(self, format=None, input_formats=None, *args, **kwargs): -        self.format = format if format is not None else self.format +    def __init__(self, format=empty, input_formats=None, *args, **kwargs): +        self.format = format if format is not empty else self.format          self.input_formats = input_formats if input_formats is not None else self.input_formats          super(TimeField, self).__init__(*args, **kwargs) -    def from_native(self, value): -        if value in (None, ''): -            return None - +    def to_internal_value(self, value):          if isinstance(value, datetime.time):              return value @@ -672,11 +898,17 @@ class TimeField(Field):          self.fail('invalid', format=humanized_format)      def to_representation(self, value): -        if value is None or self.format is None: +        if self.format is None:              return value -        if isinstance(value, datetime.datetime): -            value = value.time() +        # Applying a `TimeField` to a datetime value is almost always +        # not a sensible thing to do, as it means naively dropping +        # any explicit or implicit timezone info. +        assert not isinstance(value, datetime.datetime), ( +            'Expected a `time`, but got a `datetime`. Refusing to coerce, ' +            'as this may mean losing timezone information. Use a custom ' +            'read-only field and deal with timezone issues explicitly.' +        )          if self.format.lower() == ISO_8601:              return value.isoformat() @@ -699,58 +931,171 @@ class ChoiceField(Field):              for item in choices          ]          if all(pairs): -            self.choices = dict([(key, display_value) for key, display_value in choices]) +            self.choices = OrderedDict([(key, display_value) for key, display_value in choices])          else: -            self.choices = dict([(item, item) for item in choices]) +            self.choices = OrderedDict([(item, item) for item in choices])          # Map the string representation of choices to the underlying value.          # Allows us to deal with eg. integer choices while supporting either          # integer or string input, but still get the correct datatype out.          self.choice_strings_to_values = dict([ -            (str(key), key) for key in self.choices.keys() +            (six.text_type(key), key) for key in self.choices.keys()          ])          super(ChoiceField, self).__init__(**kwargs)      def to_internal_value(self, data):          try: -            return self.choice_strings_to_values[str(data)] +            return self.choice_strings_to_values[six.text_type(data)]          except KeyError:              self.fail('invalid_choice', input=data)      def to_representation(self, value): -        return value +        if value in ('', None): +            return value +        return self.choice_strings_to_values[six.text_type(value)]  class MultipleChoiceField(ChoiceField):      default_error_messages = {          'invalid_choice': _('`{input}` is not a valid choice.'), -        'not_a_list': _('Expected a list of items but got type `{input_type}`') +        'not_a_list': _('Expected a list of items but got type `{input_type}`.')      } +    default_empty_html = [] + +    def get_value(self, dictionary): +        # We override the default field access in order to support +        # lists in HTML forms. +        if html.is_html_input(dictionary): +            return dictionary.getlist(self.field_name) +        return dictionary.get(self.field_name, empty)      def to_internal_value(self, data): -        if not hasattr(data, '__iter__'): +        if isinstance(data, type('')) or not hasattr(data, '__iter__'):              self.fail('not_a_list', input_type=type(data).__name__) +          return set([              super(MultipleChoiceField, self).to_internal_value(item)              for item in data          ])      def to_representation(self, value): -        return value +        return set([ +            self.choice_strings_to_values[six.text_type(item)] for item in value +        ])  # File types...  class FileField(Field): -    pass  # TODO +    default_error_messages = { +        'required': _("No file was submitted."), +        'invalid': _("The submitted data was not a file. Check the encoding type on the form."), +        'no_name': _("No filename could be determined."), +        'empty': _("The submitted file is empty."), +        'max_length': _('Ensure this filename has at most {max_length} characters (it has {length}).'), +    } +    use_url = api_settings.UPLOADED_FILES_USE_URL + +    def __init__(self, *args, **kwargs): +        self.max_length = kwargs.pop('max_length', None) +        self.allow_empty_file = kwargs.pop('allow_empty_file', False) +        self.use_url = kwargs.pop('use_url', self.use_url) +        super(FileField, self).__init__(*args, **kwargs) + +    def to_internal_value(self, data): +        try: +            # `UploadedFile` objects should have name and size attributes. +            file_name = data.name +            file_size = data.size +        except AttributeError: +            self.fail('invalid') + +        if not file_name: +            self.fail('no_name') +        if not self.allow_empty_file and not file_size: +            self.fail('empty') +        if self.max_length and len(file_name) > self.max_length: +            self.fail('max_length', max_length=self.max_length, length=len(file_name)) + +        return data + +    def to_representation(self, value): +        if self.use_url: +            if not value: +                return None +            url = value.url +            request = self.context.get('request', None) +            if request is not None: +                return request.build_absolute_uri(url) +            return url +        return value.name + + +class ImageField(FileField): +    default_error_messages = { +        'invalid_image': _( +            'Upload a valid image. The file you uploaded was either not an ' +            'image or a corrupted image.' +        ), +    } + +    def __init__(self, *args, **kwargs): +        self._DjangoImageField = kwargs.pop('_DjangoImageField', DjangoImageField) +        super(ImageField, self).__init__(*args, **kwargs) + +    def to_internal_value(self, data): +        # Image validation is a bit grungy, so we'll just outright +        # defer to Django's implementation so we don't need to +        # consider it, or treat PIL as a test dependency. +        file_object = super(ImageField, self).to_internal_value(data) +        django_field = self._DjangoImageField() +        django_field.error_messages = self.error_messages +        django_field.to_python(file_object) +        return file_object + + +# Composite field types... + +class ListField(Field): +    child = None +    initial = [] +    default_error_messages = { +        'not_a_list': _('Expected a list of items but got type `{input_type}`') +    } + +    def __init__(self, *args, **kwargs): +        self.child = kwargs.pop('child', copy.deepcopy(self.child)) +        assert self.child is not None, '`child` is a required argument.' +        assert not inspect.isclass(self.child), '`child` has not been instantiated.' +        super(ListField, self).__init__(*args, **kwargs) +        self.child.bind(field_name='', parent=self) + +    def get_value(self, dictionary): +        # We override the default field access in order to support +        # lists in HTML forms. +        if html.is_html_input(dictionary): +            return html.parse_html_list(dictionary, prefix=self.field_name) +        return dictionary.get(self.field_name, empty) +    def to_internal_value(self, data): +        """ +        List of dicts of native values <- List of dicts of primitive datatypes. +        """ +        if html.is_html_input(data): +            data = html.parse_html_list(data) +        if isinstance(data, type('')) or not hasattr(data, '__iter__'): +            self.fail('not_a_list', input_type=type(data).__name__) +        return [self.child.run_validation(item) for item in data] -class ImageField(Field): -    pass  # TODO +    def to_representation(self, data): +        """ +        List of object instances -> List of dicts of primitive datatypes. +        """ +        return [self.child.to_representation(item) for item in data] -# Advanced field types... +# Miscellaneous field types...  class ReadOnlyField(Field):      """ @@ -770,11 +1115,31 @@ class ReadOnlyField(Field):          super(ReadOnlyField, self).__init__(**kwargs)      def to_representation(self, value): -        if is_simple_callable(value): -            return value()          return value +class HiddenField(Field): +    """ +    A hidden field does not take input from the user, or present any output, +    but it does populate a field in `validated_data`, based on its default +    value. This is particularly useful when we have a `unique_for_date` +    constraint on a pair of fields, as we need some way to include the date in +    the validated data. +    """ +    def __init__(self, **kwargs): +        assert 'default' in kwargs, 'default is a required argument.' +        kwargs['write_only'] = True +        super(HiddenField, self).__init__(**kwargs) + +    def get_value(self, dictionary): +        # We always use the default value for `HiddenField`. +        # User input is never provided or accepted. +        return empty + +    def to_internal_value(self, data): +        return data + +  class SerializerMethodField(Field):      """      A read-only field that get its representation from calling a method on the @@ -790,17 +1155,32 @@ class SerializerMethodField(Field):          def get_extra_info(self, obj):              return ...  # Calculate some data to return.      """ -    def __init__(self, method_attr=None, **kwargs): -        self.method_attr = method_attr +    def __init__(self, method_name=None, **kwargs): +        self.method_name = method_name          kwargs['source'] = '*'          kwargs['read_only'] = True          super(SerializerMethodField, self).__init__(**kwargs) +    def bind(self, field_name, parent): +        # In order to enforce a consistent style, we error if a redundant +        # 'method_name' argument has been used. For example: +        # my_field = serializer.CharField(source='my_field') +        default_method_name = 'get_{field_name}'.format(field_name=field_name) +        assert self.method_name != default_method_name, ( +            "It is redundant to specify `%s` on SerializerMethodField '%s' in " +            "serializer '%s', because it is the same as the default method name. " +            "Remove the `method_name` argument." % +            (self.method_name, field_name, parent.__class__.__name__) +        ) + +        # The method name should default to `get_{field_name}`. +        if self.method_name is None: +            self.method_name = default_method_name + +        super(SerializerMethodField, self).bind(field_name, parent) +      def to_representation(self, value): -        method_attr = self.method_attr -        if method_attr is None: -            method_attr = 'get_{field_name}'.format(field_name=self.field_name) -        method = getattr(self.parent, method_attr) +        method = getattr(self.parent, self.method_name)          return method(value) @@ -811,10 +1191,19 @@ class ModelField(Field):      This is used by `ModelSerializer` when dealing with custom model fields,      that do not have a serializer field to be mapped to.      """ +    default_error_messages = { +        'max_length': _('Ensure this field has no more than {max_length} characters.'), +    } +      def __init__(self, model_field, **kwargs):          self.model_field = model_field -        kwargs['source'] = '*' +        # The `max_length` option is supported by Django's base `Field` class, +        # so we'd better support it here. +        max_length = kwargs.pop('max_length', None)          super(ModelField, self).__init__(**kwargs) +        if max_length is not None: +            message = self.error_messages['max_length'].format(max_length=max_length) +            self.validators.append(MaxLengthValidator(max_length, message=message))      def to_internal_value(self, data):          rel = getattr(self.model_field, 'rel', None) @@ -822,6 +1211,11 @@ class ModelField(Field):              return rel.to._meta.get_field(rel.field_name).to_python(data)          return self.model_field.to_python(data) +    def get_attribute(self, obj): +        # We pass the object instance onto `to_representation`, +        # not just the field attribute. +        return obj +      def to_representation(self, obj):          value = self.model_field._get_val_from_obj(obj)          if is_protected_type(value): | 
