diff options
Diffstat (limited to 'rest_framework/fields.py')
| -rw-r--r-- | rest_framework/fields.py | 296 |
1 files changed, 240 insertions, 56 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 86c3a837..c83ee5ec 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1,7 +1,13 @@ +""" +Serializer fields perform validation on incoming data. + +They are very similar to Django's form fields. +""" from __future__ import unicode_literals import copy import datetime +from decimal import Decimal, DecimalException import inspect import re import warnings @@ -13,26 +19,29 @@ from django import forms from django.forms import widgets from django.utils.encoding import is_protected_type from django.utils.translation import ugettext_lazy as _ -from rest_framework.compat import parse_date, parse_datetime -from rest_framework.compat import timezone + +from rest_framework import ISO_8601 +from rest_framework.compat import timezone, parse_date, parse_datetime, parse_time from rest_framework.compat import BytesIO from rest_framework.compat import six from rest_framework.compat import smart_text -from rest_framework.compat import parse_time +from rest_framework.settings import api_settings def is_simple_callable(obj): """ True if the object is a callable that takes no arguments. """ - try: - args, _, _, defaults = inspect.getargspec(obj) - except TypeError: + function = inspect.isfunction(obj) + method = inspect.ismethod(obj) + + if not (function or method): return False - else: - len_args = len(args) if inspect.isfunction(obj) else len(args) - 1 - len_defaults = len(defaults) if defaults else 0 - return len_args <= len_defaults + + args, _, _, defaults = inspect.getargspec(obj) + len_args = len(args) if function else len(args) - 1 + len_defaults = len(defaults) if defaults else 0 + return len_args <= len_defaults def get_component(obj, attr_name): @@ -50,6 +59,46 @@ def get_component(obj, attr_name): return val +def readable_datetime_formats(formats): + format = ', '.join(formats).replace(ISO_8601, 'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HHMM|-HHMM|Z]') + return humanize_strptime(format) + + +def readable_date_formats(formats): + format = ', '.join(formats).replace(ISO_8601, 'YYYY[-MM[-DD]]') + return humanize_strptime(format) + + +def readable_time_formats(formats): + format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]') + return humanize_strptime(format) + + +def humanize_strptime(format_string): + # Note that we're missing some of the locale specific mappings that + # don't really make sense. + mapping = { + "%Y": "YYYY", + "%y": "YY", + "%m": "MM", + "%b": "[Jan-Dec]", + "%B": "[January-December]", + "%d": "DD", + "%H": "hh", + "%I": "hh", # Requires '%p' to differentiate from '%H'. + "%M": "mm", + "%S": "ss", + "%f": "uuuuuu", + "%a": "[Mon-Sun]", + "%A": "[Monday-Sunday]", + "%p": "[AM|PM]", + "%z": "[+HHMM|-HHMM]" + } + for key, val in mapping.items(): + format_string = format_string.replace(key, val) + return format_string + + class Field(object): read_only = True creation_counter = 0 @@ -151,9 +200,9 @@ class WritableField(Field): # 'blank' is to be deprecated in favor of 'required' if blank is not None: - warnings.warn('The `blank` keyword argument is due to deprecated. ' + warnings.warn('The `blank` keyword argument is deprecated. ' 'Use the `required` keyword argument instead.', - PendingDeprecationWarning, stacklevel=2) + DeprecationWarning, stacklevel=2) required = not(blank) super(WritableField, self).__init__(source=source) @@ -447,12 +496,16 @@ class DateField(WritableField): form_field_class = forms.DateField default_error_messages = { - 'invalid': _("'%s' value has an invalid date format. It must be " - "in YYYY-MM-DD format."), - 'invalid_date': _("'%s' value has the correct format (YYYY-MM-DD) " - "but it is an invalid date."), + 'invalid': _("Date has wrong format. Use one of these formats instead: %s"), } empty = None + input_formats = api_settings.DATE_INPUT_FORMATS + format = api_settings.DATE_FORMAT + + def __init__(self, input_formats=None, format=None, *args, **kwargs): + self.input_formats = input_formats if input_formats is not None else self.input_formats + self.format = format if format is not None else self.format + super(DateField, self).__init__(*args, **kwargs) def from_native(self, value): if value in validators.EMPTY_VALUES: @@ -468,17 +521,37 @@ class DateField(WritableField): if isinstance(value, datetime.date): return value - try: - parsed = parse_date(value) - if parsed is not None: - return parsed - except (ValueError, TypeError): - msg = self.error_messages['invalid_date'] % value - raise ValidationError(msg) + for format in self.input_formats: + if format.lower() == ISO_8601: + try: + parsed = parse_date(value) + except (ValueError, TypeError): + pass + else: + if parsed is not None: + return parsed + else: + try: + parsed = datetime.datetime.strptime(value, format) + except (ValueError, TypeError): + pass + else: + return parsed.date() - msg = self.error_messages['invalid'] % value + msg = self.error_messages['invalid'] % readable_date_formats(self.input_formats) raise ValidationError(msg) + def to_native(self, value): + if value is None or self.format is None: + return value + + if isinstance(value, datetime.datetime): + value = value.date() + + if self.format.lower() == ISO_8601: + return value.isoformat() + return value.strftime(self.format) + class DateTimeField(WritableField): type_name = 'DateTimeField' @@ -486,15 +559,16 @@ class DateTimeField(WritableField): form_field_class = forms.DateTimeField default_error_messages = { - 'invalid': _("'%s' value has an invalid format. It must be in " - "YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ] format."), - 'invalid_date': _("'%s' value has the correct format " - "(YYYY-MM-DD) but it is an invalid date."), - 'invalid_datetime': _("'%s' value has the correct format " - "(YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ]) " - "but it is an invalid date/time."), + 'invalid': _("Datetime has wrong format. Use one of these formats instead: %s"), } empty = None + input_formats = api_settings.DATETIME_INPUT_FORMATS + format = api_settings.DATETIME_FORMAT + + def __init__(self, input_formats=None, format=None, *args, **kwargs): + self.input_formats = input_formats if input_formats is not None else self.input_formats + self.format = format if format is not None else self.format + super(DateTimeField, self).__init__(*args, **kwargs) def from_native(self, value): if value in validators.EMPTY_VALUES: @@ -516,25 +590,37 @@ class DateTimeField(WritableField): value = timezone.make_aware(value, default_timezone) return value - try: - parsed = parse_datetime(value) - if parsed is not None: - return parsed - except (ValueError, TypeError): - msg = self.error_messages['invalid_datetime'] % value - raise ValidationError(msg) - - try: - parsed = parse_date(value) - if parsed is not None: - return datetime.datetime(parsed.year, parsed.month, parsed.day) - except (ValueError, TypeError): - msg = self.error_messages['invalid_date'] % value - raise ValidationError(msg) + for format in self.input_formats: + if format.lower() == ISO_8601: + try: + parsed = parse_datetime(value) + except (ValueError, TypeError): + pass + else: + if parsed is not None: + return parsed + else: + try: + parsed = datetime.datetime.strptime(value, format) + except (ValueError, TypeError): + pass + else: + return parsed - msg = self.error_messages['invalid'] % value + msg = self.error_messages['invalid'] % readable_datetime_formats(self.input_formats) raise ValidationError(msg) + def to_native(self, value): + if value is None or self.format is None: + return value + + if self.format.lower() == ISO_8601: + ret = value.isoformat() + if ret.endswith('+00:00'): + ret = ret[:-6] + 'Z' + return ret + return value.strftime(self.format) + class TimeField(WritableField): type_name = 'TimeField' @@ -542,10 +628,16 @@ class TimeField(WritableField): form_field_class = forms.TimeField default_error_messages = { - 'invalid': _("'%s' value has an invalid format. It must be a valid " - "time in the HH:MM[:ss[.uuuuuu]] format."), + 'invalid': _("Time has wrong format. Use one of these formats instead: %s"), } empty = None + input_formats = api_settings.TIME_INPUT_FORMATS + format = api_settings.TIME_FORMAT + + def __init__(self, input_formats=None, format=None, *args, **kwargs): + self.input_formats = input_formats if input_formats is not None else self.input_formats + self.format = format if format is not None else self.format + super(TimeField, self).__init__(*args, **kwargs) def from_native(self, value): if value in validators.EMPTY_VALUES: @@ -554,13 +646,36 @@ class TimeField(WritableField): if isinstance(value, datetime.time): return value - try: - parsed = parse_time(value) - assert parsed is not None - return parsed - except (ValueError, TypeError): - msg = self.error_messages['invalid'] % value - raise ValidationError(msg) + for format in self.input_formats: + if format.lower() == ISO_8601: + try: + parsed = parse_time(value) + except (ValueError, TypeError): + pass + else: + if parsed is not None: + return parsed + else: + try: + parsed = datetime.datetime.strptime(value, format) + except (ValueError, TypeError): + pass + else: + return parsed.time() + + msg = self.error_messages['invalid'] % readable_time_formats(self.input_formats) + raise ValidationError(msg) + + def to_native(self, value): + if value is None or self.format is None: + return value + + if isinstance(value, datetime.datetime): + value = value.time() + + if self.format.lower() == ISO_8601: + return value.isoformat() + return value.strftime(self.format) class IntegerField(WritableField): @@ -612,6 +727,75 @@ class FloatField(WritableField): raise ValidationError(msg) +class DecimalField(WritableField): + type_name = 'DecimalField' + form_field_class = forms.DecimalField + + default_error_messages = { + 'invalid': _('Enter a number.'), + 'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'), + 'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'), + 'max_digits': _('Ensure that there are no more than %s digits in total.'), + 'max_decimal_places': _('Ensure that there are no more than %s decimal places.'), + 'max_whole_digits': _('Ensure that there are no more than %s digits before the decimal point.') + } + + def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, *args, **kwargs): + self.max_value, self.min_value = max_value, min_value + self.max_digits, self.decimal_places = max_digits, decimal_places + super(DecimalField, self).__init__(*args, **kwargs) + + if max_value is not None: + self.validators.append(validators.MaxValueValidator(max_value)) + if min_value is not None: + self.validators.append(validators.MinValueValidator(min_value)) + + def from_native(self, value): + """ + Validates that the input is a decimal number. Returns a Decimal + instance. Returns None for empty values. Ensures that there are no more + than max_digits in the number, and no more than decimal_places digits + after the decimal point. + """ + if value in validators.EMPTY_VALUES: + return None + value = smart_text(value).strip() + try: + value = Decimal(value) + except DecimalException: + raise ValidationError(self.error_messages['invalid']) + return value + + def validate(self, value): + super(DecimalField, self).validate(value) + if value in validators.EMPTY_VALUES: + return + # Check for NaN, Inf and -Inf values. We can't compare directly for NaN, + # since it is never equal to itself. However, NaN is the only value that + # isn't equal to itself, so we can use this to identify NaN + if value != value or value == Decimal("Inf") or value == Decimal("-Inf"): + raise ValidationError(self.error_messages['invalid']) + sign, digittuple, exponent = value.as_tuple() + decimals = abs(exponent) + # digittuple doesn't include any leading zeros. + digits = len(digittuple) + if decimals > digits: + # We have leading zeros up to or past the decimal point. Count + # everything past the decimal point as a digit. We do not count + # 0 before the decimal point as a digit since that would mean + # we would not allow max_digits = decimal_places. + digits = decimals + whole_digits = digits - decimals + + if self.max_digits is not None and digits > self.max_digits: + raise ValidationError(self.error_messages['max_digits'] % self.max_digits) + if self.decimal_places is not None and decimals > self.decimal_places: + raise ValidationError(self.error_messages['max_decimal_places'] % self.decimal_places) + if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places): + raise ValidationError(self.error_messages['max_whole_digits'] % (self.max_digits - self.decimal_places)) + return value + + class FileField(WritableField): use_files = True type_name = 'FileField' |
