diff options
| author | Tom Christie | 2014-09-09 17:46:28 +0100 |
|---|---|---|
| committer | Tom Christie | 2014-09-09 17:46:28 +0100 |
| commit | b1c07670ca65084c5fef2bbb63d1f4163763014b (patch) | |
| tree | 4f08654d698990d97fe275d8dbbbcc1164524086 /rest_framework | |
| parent | 21980b800d04a1d82a6003823abfdf4ab80ae979 (diff) | |
| download | django-rest-framework-b1c07670ca65084c5fef2bbb63d1f4163763014b.tar.bz2 | |
Fleshing out serializer fields
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/fields.py | 591 | ||||
| -rw-r--r-- | rest_framework/serializers.py | 380 | ||||
| -rw-r--r-- | rest_framework/utils/humanize_datetime.py | 47 | ||||
| -rw-r--r-- | rest_framework/utils/modelinfo.py | 97 | ||||
| -rw-r--r-- | rest_framework/utils/representation.py | 72 |
5 files changed, 872 insertions, 315 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 250c0579..043a44ed 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1,8 +1,18 @@ +from django.conf import settings from django.core import validators from django.core.exceptions import ValidationError +from django.utils import timezone +from django.utils.dateparse import parse_date, parse_datetime, parse_time from django.utils.encoding import is_protected_type -from rest_framework.utils import html +from django.utils.translation import ugettext_lazy as _ +from rest_framework import ISO_8601 +from rest_framework.compat import smart_text +from rest_framework.settings import api_settings +from rest_framework.utils import html, representation, humanize_datetime +import datetime +import decimal import inspect +import warnings class empty: @@ -71,22 +81,22 @@ class SkipField(Exception): pass +NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`' +NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`' +NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`' +NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`' +MISSING_ERROR_MESSAGE = ( + 'ValidationError raised by `{class_name}`, but error key `{key}` does ' + 'not exist in the `error_messages` dictionary.' +) + + class Field(object): _creation_counter = 0 - MESSAGES = { - 'required': 'This field is required.' + default_error_messages = { + 'required': _('This field is required.') } - - _NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`' - _NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`' - _NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`' - _NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`' - _MISSING_ERROR_MESSAGE = ( - 'ValidationError raised by `{class_name}`, but error key `{key}` does ' - 'not exist in the `MESSAGES` dictionary.' - ) - default_validators = [] def __init__(self, read_only=False, write_only=False, @@ -100,10 +110,10 @@ class Field(object): required = default is empty and not read_only # Some combinations of keyword arguments do not make sense. - assert not (read_only and write_only), self._NOT_READ_ONLY_WRITE_ONLY - assert not (read_only and required), self._NOT_READ_ONLY_REQUIRED - assert not (read_only and default is not empty), self._NOT_READ_ONLY_DEFAULT - assert not (required and default is not empty), self._NOT_REQUIRED_DEFAULT + assert not (read_only and write_only), NOT_READ_ONLY_WRITE_ONLY + assert not (read_only and required), NOT_READ_ONLY_REQUIRED + assert not (read_only and default is not empty), NOT_READ_ONLY_DEFAULT + assert not (required and default is not empty), NOT_REQUIRED_DEFAULT self.read_only = read_only self.write_only = write_only @@ -113,7 +123,14 @@ class Field(object): self.initial = initial self.label = label self.style = {} if style is None else style - self.validators = self.default_validators + validators + self.validators = validators or self.default_validators[:] + + # Collect default error message from self and parent classes + messages = {} + for cls in reversed(self.__class__.__mro__): + messages.update(getattr(cls, 'default_error_messages', {})) + messages.update(error_messages or {}) + self.error_messages = messages def bind(self, field_name, parent, root): """ @@ -186,12 +203,14 @@ class Field(object): self.fail('required') return self.get_default() - self.run_validators(data) - return self.to_native(data) + value = self.to_native(data) + self.run_validators(value) + return value def run_validators(self, value): if value in validators.EMPTY_VALUES: return + errors = [] for validator in self.validators: try: @@ -218,33 +237,32 @@ class Field(object): A helper method that simply raises a validation error. """ try: - raise ValidationError(self.MESSAGES[key].format(**kwargs)) + msg = self.error_messages[key] except KeyError: class_name = self.__class__.__name__ - msg = self._MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key) + msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key) raise AssertionError(msg) + raise ValidationError(msg.format(**kwargs)) def __new__(cls, *args, **kwargs): + """ + When a field is instantiated, we store the arguments that were used, + so that we can present a helpful representation of the object. + """ instance = super(Field, cls).__new__(cls) instance._args = args instance._kwargs = kwargs return instance def __repr__(self): - arg_string = ', '.join([repr(val) for val in self._args]) - kwarg_string = ', '.join([ - '%s=%s' % (key, repr(val)) for key, val in self._kwargs.items() - ]) - if arg_string and kwarg_string: - arg_string += ', ' - class_name = self.__class__.__name__ - return "%s(%s%s)" % (class_name, arg_string, kwarg_string) + return representation.field_repr(self) +# Boolean types... + class BooleanField(Field): - MESSAGES = { - 'required': 'This field is required.', - 'invalid_value': '`{input}` is not a valid boolean.' + default_error_messages = { + 'invalid': _('`{input}` is not a valid boolean.') } TRUE_VALUES = {'t', 'T', 'true', 'True', 'TRUE', '1', 1, True} FALSE_VALUES = {'f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False} @@ -261,13 +279,23 @@ class BooleanField(Field): return True elif data in self.FALSE_VALUES: return False - self.fail('invalid_value', input=data) + self.fail('invalid', input=data) + + def to_primative(self, value): + if value is None: + return None + if value in self.TRUE_VALUES: + return True + elif value in self.FALSE_VALUES: + return False + return bool(value) +# String types... + class CharField(Field): - MESSAGES = { - 'required': 'This field is required.', - 'blank': 'This field may not be blank.' + default_error_messages = { + 'blank': _('This field may not be blank.') } def __init__(self, **kwargs): @@ -281,19 +309,364 @@ class CharField(Field): self.fail('blank') return str(data) + def to_primative(self, value): + if value is None: + return None + return str(value) -class ChoiceField(Field): - MESSAGES = { - 'required': 'This field is required.', - 'invalid_choice': '`{input}` is not a valid choice.' + +class EmailField(CharField): + default_error_messages = { + 'invalid': _('Enter a valid email address.') + } + default_validators = [validators.validate_email] + + def to_native(self, data): + ret = super(EmailField, self).to_native(data) + if ret is None: + return None + return ret.strip() + + def to_primative(self, value): + ret = super(EmailField, self).to_primative(value) + if ret is None: + return None + return ret.strip() + + +class RegexField(CharField): + def __init__(self, regex, **kwargs): + kwargs['validators'] = ( + [validators.RegexValidator(regex)] + + kwargs.get('validators', []) + ) + super(RegexField, self).__init__(**kwargs) + + +class SlugField(CharField): + default_error_messages = { + 'invalid': _("Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens.") + } + default_validators = [validators.validate_slug] + + +class URLField(CharField): + default_error_messages = { + 'invalid': _("Enter a valid URL.") + } + default_validators = [validators.URLValidator()] + + +# Number types... + +class IntegerField(Field): + default_error_messages = { + 'invalid': _('A valid integer is required.') + } + + def __init__(self, **kwargs): + max_value = kwargs.pop('max_value', None) + min_value = kwargs.pop('min_value', None) + super(IntegerField, self).__init__(**kwargs) + if max_value is not None: + self.validators.append(validators.MaxValueValidator(max_value)) + if min_value is not None: + self.validators.append(validators.MinValueValidator(min_value)) + print self.__class__.__name__, self.validators + + def to_native(self, data): + try: + data = int(str(data)) + except (ValueError, TypeError): + self.fail('invalid') + return data + + def to_primative(self, value): + if value is None: + return None + return int(value) + + +class FloatField(Field): + default_error_messages = { + 'invalid': _("'%s' value must be a float."), } - coerce_to_type = str def __init__(self, **kwargs): - choices = kwargs.pop('choices') + max_value = kwargs.pop('max_value', None) + min_value = kwargs.pop('min_value', None) + super(FloatField, self).__init__(**kwargs) + if max_value is not None: + self.validators.append(validators.MaxValueValidator(max_value)) + if min_value is not None: + self.validators.append(validators.MinValueValidator(min_value)) + + def to_primative(self, value): + if value is None: + return None + try: + return float(value) + except (TypeError, ValueError): + self.fail('invalid', value=value) + + def to_native(self, value): + if value is None: + return None + return float(value) + + +class DecimalField(Field): + default_error_messages = { + 'invalid': _('Enter a number.'), + 'max_value': _('Ensure this value is less than or equal to {max_value}.'), + 'min_value': _('Ensure this value is greater than or equal to {min_value}.'), + 'max_digits': _('Ensure that there are no more than {max_digits} digits in total.'), + 'max_decimal_places': _('Ensure that there are no more than {max_decimal_places} decimal places.'), + 'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.') + } + + def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, **kwargs): + self.max_value, self.min_value = max_value, min_value + self.max_digits, self.max_decimal_places = max_digits, decimal_places + super(DecimalField, self).__init__(**kwargs) + if max_value is not None: + self.validators.append(validators.MaxValueValidator(max_value)) + if min_value is not None: + self.validators.append(validators.MinValueValidator(min_value)) + + def from_native(self, value): + """ + Validates that the input is a decimal number. Returns a Decimal + instance. Returns None for empty values. Ensures that there are no more + than max_digits in the number, and no more than decimal_places digits + after the decimal point. + """ + if value in validators.EMPTY_VALUES: + return None + + value = smart_text(value).strip() + try: + value = decimal.Decimal(value) + except decimal.DecimalException: + self.fail('invalid') + + # Check for NaN. It is the only value that isn't equal to itself, + # so we can use this to identify NaN values. + if value != value: + self.fail('invalid') + + # Check for infinity and negative infinity. + if value in (decimal.Decimal('Inf'), decimal.Decimal('-Inf')): + self.fail('invalid') + + sign, digittuple, exponent = value.as_tuple() + decimals = abs(exponent) + # digittuple doesn't include any leading zeros. + digits = len(digittuple) + if decimals > digits: + # We have leading zeros up to or past the decimal point. Count + # everything past the decimal point as a digit. We do not count + # 0 before the decimal point as a digit since that would mean + # we would not allow max_digits = decimal_places. + digits = decimals + whole_digits = digits - decimals + + if self.max_digits is not None and digits > self.max_digits: + self.fail('max_digits', max_digits=self.max_digits) + if self.decimal_places is not None and decimals > self.decimal_places: + self.fail('max_decimal_places', max_decimal_places=self.max_decimal_places) + if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places): + self.fail('max_whole_digits', max_while_digits=self.max_digits - self.decimal_places) + + return value + + +# Date & time fields... + +class DateField(Field): + default_error_messages = { + 'invalid': _("Date has wrong format. Use one of these formats instead: %s"), + } + input_formats = api_settings.DATE_INPUT_FORMATS + format = api_settings.DATE_FORMAT + + def __init__(self, input_formats=None, format=None, *args, **kwargs): + self.input_formats = input_formats if input_formats is not None else self.input_formats + self.format = format if format is not None else self.format + super(DateField, self).__init__(*args, **kwargs) + + def from_native(self, value): + if value in validators.EMPTY_VALUES: + return None + + if isinstance(value, datetime.datetime): + if timezone and settings.USE_TZ and timezone.is_aware(value): + # Convert aware datetimes to the default time zone + # before casting them to dates (#17742). + default_timezone = timezone.get_default_timezone() + value = timezone.make_naive(value, default_timezone) + return value.date() + if isinstance(value, datetime.date): + return value + + for format in self.input_formats: + if format.lower() == ISO_8601: + try: + parsed = parse_date(value) + except (ValueError, TypeError): + pass + else: + if parsed is not None: + return parsed + else: + try: + parsed = datetime.datetime.strptime(value, format) + except (ValueError, TypeError): + pass + else: + return parsed.date() + + humanized_format = humanize_datetime.date_formats(self.input_formats) + msg = self.error_messages['invalid'] % humanized_format + raise ValidationError(msg) + + def to_primative(self, value): + if value is None or self.format is None: + return value + + if isinstance(value, datetime.datetime): + value = value.date() + + if self.format.lower() == ISO_8601: + return value.isoformat() + return value.strftime(self.format) + + +class DateTimeField(Field): + default_error_messages = { + 'invalid': _("Datetime has wrong format. Use one of these formats instead: %s"), + } + input_formats = api_settings.DATETIME_INPUT_FORMATS + format = api_settings.DATETIME_FORMAT + + def __init__(self, input_formats=None, format=None, *args, **kwargs): + self.input_formats = input_formats if input_formats is not None else self.input_formats + self.format = format if format is not None else self.format + super(DateTimeField, self).__init__(*args, **kwargs) + + def from_native(self, value): + if value in validators.EMPTY_VALUES: + return None + + if isinstance(value, datetime.datetime): + return value + if isinstance(value, datetime.date): + value = datetime.datetime(value.year, value.month, value.day) + if settings.USE_TZ: + # For backwards compatibility, interpret naive datetimes in + # local time. This won't work during DST change, but we can't + # do much about it, so we let the exceptions percolate up the + # call stack. + warnings.warn("DateTimeField received a naive datetime (%s)" + " while time zone support is active." % value, + RuntimeWarning) + default_timezone = timezone.get_default_timezone() + value = timezone.make_aware(value, default_timezone) + return value + + for format in self.input_formats: + if format.lower() == ISO_8601: + try: + parsed = parse_datetime(value) + except (ValueError, TypeError): + pass + else: + if parsed is not None: + return parsed + else: + try: + parsed = datetime.datetime.strptime(value, format) + except (ValueError, TypeError): + pass + else: + return parsed + + humanized_format = humanize_datetime.datetime_formats(self.input_formats) + msg = self.error_messages['invalid'] % humanized_format + raise ValidationError(msg) + + def to_primative(self, value): + if value is None or self.format is None: + return value + + if self.format.lower() == ISO_8601: + ret = value.isoformat() + if ret.endswith('+00:00'): + ret = ret[:-6] + 'Z' + return ret + return value.strftime(self.format) - assert choices, '`choices` argument is required and may not be empty' +class TimeField(Field): + default_error_messages = { + 'invalid': _("Time has wrong format. Use one of these formats instead: %s"), + } + input_formats = api_settings.TIME_INPUT_FORMATS + format = api_settings.TIME_FORMAT + + def __init__(self, input_formats=None, format=None, *args, **kwargs): + self.input_formats = input_formats if input_formats is not None else self.input_formats + self.format = format if format is not None else self.format + super(TimeField, self).__init__(*args, **kwargs) + + def from_native(self, value): + if value in validators.EMPTY_VALUES: + return None + + if isinstance(value, datetime.time): + return value + + for format in self.input_formats: + if format.lower() == ISO_8601: + try: + parsed = parse_time(value) + except (ValueError, TypeError): + pass + else: + if parsed is not None: + return parsed + else: + try: + parsed = datetime.datetime.strptime(value, format) + except (ValueError, TypeError): + pass + else: + return parsed.time() + + humanized_format = humanize_datetime.time_formats(self.input_formats) + msg = self.error_messages['invalid'] % humanized_format + raise ValidationError(msg) + + def to_primative(self, value): + if value is None or self.format is None: + return value + + if isinstance(value, datetime.datetime): + value = value.time() + + if self.format.lower() == ISO_8601: + return value.isoformat() + return value.strftime(self.format) + + +# Choice types... + +class ChoiceField(Field): + default_error_messages = { + 'invalid_choice': _('`{input}` is not a valid choice.') + } + + def __init__(self, choices, **kwargs): # Allow either single or paired choices style: # choices = [1, 2, 3] # choices = [(1, 'First'), (2, 'Second'), (3, 'Third')] @@ -321,12 +694,14 @@ class ChoiceField(Field): except KeyError: self.fail('invalid_choice', input=data) + def to_primative(self, value): + return value + class MultipleChoiceField(ChoiceField): - MESSAGES = { - 'required': 'This field is required.', - 'invalid_choice': '`{input}` is not a valid choice.', - 'not_a_list': 'Expected a list of items but got type `{input_type}`' + default_error_messages = { + 'invalid_choice': _('`{input}` is not a valid choice.'), + 'not_a_list': _('Expected a list of items but got type `{input_type}`') } def to_native(self, data): @@ -337,72 +712,42 @@ class MultipleChoiceField(ChoiceField): for item in data ]) - -class IntegerField(Field): - MESSAGES = { - 'required': 'This field is required.', - 'invalid_integer': 'A valid integer is required.' - } - - def __init__(self, **kwargs): - max_value = kwargs.pop('max_value', None) - min_value = kwargs.pop('min_value', None) - super(IntegerField, self).__init__(**kwargs) - if max_value is not None: - self.validators.append(validators.MaxValueValidator(max_value)) - if min_value is not None: - self.validators.append(validators.MinValueValidator(min_value)) - - def to_native(self, data): - try: - data = int(str(data)) - except (ValueError, TypeError): - self.fail('invalid_integer') - return data - def to_primative(self, value): - if value is None: - return None - return int(value) + return value -class EmailField(CharField): +# File types... + +class FileField(Field): pass # TODO -class URLField(CharField): +class ImageField(Field): pass # TODO -class RegexField(CharField): - def __init__(self, **kwargs): - self.regex = kwargs.pop('regex') - super(CharField, self).__init__(**kwargs) - +# Advanced field types... -class DateField(CharField): - def __init__(self, **kwargs): - self.input_formats = kwargs.pop('input_formats', None) - super(DateField, self).__init__(**kwargs) +class ReadOnlyField(Field): + """ + A read-only field that simply returns the field value. + If the field is a method with no parameters, the method will be called + and it's return value used as the representation. -class TimeField(CharField): - def __init__(self, **kwargs): - self.input_formats = kwargs.pop('input_formats', None) - super(TimeField, self).__init__(**kwargs) + For example, the following would call `get_expiry_date()` on the object: + class ExampleSerializer(self): + expiry_date = ReadOnlyField(source='get_expiry_date') + """ -class DateTimeField(CharField): def __init__(self, **kwargs): - self.input_formats = kwargs.pop('input_formats', None) - super(DateTimeField, self).__init__(**kwargs) - - -class FileField(Field): - pass # TODO + kwargs['read_only'] = True + super(ReadOnlyField, self).__init__(**kwargs) + def to_native(self, data): + raise NotImplemented('.to_native() not supported.') -class ReadOnlyField(Field): def to_primative(self, value): if is_simple_callable(value): return value() @@ -410,11 +755,28 @@ class ReadOnlyField(Field): class MethodField(Field): + """ + A read-only field that get its representation from calling a method on the + parent serializer class. The method called will be of the form + "get_{field_name}", and should take a single argument, which is the + object being serialized. + + For example: + + class ExampleSerializer(self): + extra_info = MethodField() + + def get_extra_info(self, obj): + return ... # Calculate some data to return. + """ def __init__(self, **kwargs): kwargs['source'] = '*' kwargs['read_only'] = True super(MethodField, self).__init__(**kwargs) + def to_native(self, data): + raise NotImplemented('.to_native() not supported.') + def to_primative(self, value): attr = 'get_{field_name}'.format(field_name=self.field_name) method = getattr(self.parent, attr) @@ -424,35 +786,14 @@ class MethodField(Field): class ModelField(Field): """ A generic field that can be used against an arbitrary model field. - """ - def __init__(self, *args, **kwargs): - try: - self.model_field = kwargs.pop('model_field') - except KeyError: - raise ValueError("ModelField requires 'model_field' kwarg") - - self.min_length = kwargs.pop('min_length', - getattr(self.model_field, 'min_length', None)) - self.max_length = kwargs.pop('max_length', - getattr(self.model_field, 'max_length', None)) - self.min_value = kwargs.pop('min_value', - getattr(self.model_field, 'min_value', None)) - self.max_value = kwargs.pop('max_value', - getattr(self.model_field, 'max_value', None)) - - super(ModelField, self).__init__(*args, **kwargs) - - if self.min_length is not None: - self.validators.append(validators.MinLengthValidator(self.min_length)) - if self.max_length is not None: - self.validators.append(validators.MaxLengthValidator(self.max_length)) - if self.min_value is not None: - self.validators.append(validators.MinValueValidator(self.min_value)) - if self.max_value is not None: - self.validators.append(validators.MaxValueValidator(self.max_value)) - def get_attribute(self, instance): - return get_attribute(instance, self.source_attrs[:-1]) + This is used by `ModelSerializer` when dealing with custom model fields, + that do not have a serializer field to be mapped to. + """ + def __init__(self, model_field, **kwargs): + self.model_field = model_field + kwargs['source'] = '*' + super(ModelField, self).__init__(**kwargs) def to_native(self, data): rel = getattr(self.model_field, 'rel', None) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 93226d32..8ca28387 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -10,15 +10,15 @@ python primitives. 2. The process of marshalling between python primitives and request and response content is handled by parsers and renderers. """ +from django.core import validators from django.core.exceptions import ValidationError from django.db import models from django.utils import six from collections import namedtuple, OrderedDict from rest_framework.fields import empty, set_value, Field, SkipField from rest_framework.settings import api_settings -from rest_framework.utils import html +from rest_framework.utils import html, modelinfo, representation import copy -import inspect # Note: We do the following so that users of the framework can use this style: # @@ -146,12 +146,10 @@ class SerializerMetaclass(type): class Serializer(BaseSerializer): def __new__(cls, *args, **kwargs): - many = kwargs.pop('many', False) - if many: - class DynamicListSerializer(ListSerializer): - child = cls() - return DynamicListSerializer(*args, **kwargs) - return super(Serializer, cls).__new__(cls) + if kwargs.pop('many', False): + kwargs['child'] = cls() + return ListSerializer(*args, **kwargs) + return super(Serializer, cls).__new__(cls, *args, **kwargs) def __init__(self, *args, **kwargs): self.context = kwargs.pop('context', {}) @@ -248,6 +246,9 @@ class Serializer(BaseSerializer): error = errors.get(field.field_name) yield FieldResult(field, value, error) + def __repr__(self): + return representation.serializer_repr(self, indent=1) + class ListSerializer(BaseSerializer): child = None @@ -299,26 +300,8 @@ class ListSerializer(BaseSerializer): self.instance = self.create(self.validated_data) return self.instance - -def _resolve_model(obj): - """ - Resolve supplied `obj` to a Django model class. - - `obj` must be a Django model class itself, or a string - representation of one. Useful in situtations like GH #1225 where - Django may not have resolved a string-based reference to a model in - another model's foreign key definition. - - String representations should have the format: - 'appname.ModelName' - """ - if isinstance(obj, six.string_types) and len(obj.split('.')) == 2: - app_name, model_name = obj.split('.') - return models.get_model(app_name, model_name) - elif inspect.isclass(obj) and issubclass(obj, models.Model): - return obj - else: - raise ValueError("{0} is not a Django model".format(obj)) + def __repr__(self): + return representation.list_repr(self, indent=1) class ModelSerializerOptions(object): @@ -334,24 +317,25 @@ class ModelSerializerOptions(object): class ModelSerializer(Serializer): field_mapping = { models.AutoField: IntegerField, - # models.FloatField: FloatField, + models.BigIntegerField: IntegerField, + models.BooleanField: BooleanField, + models.CharField: CharField, + models.CommaSeparatedIntegerField: CharField, + models.DateField: DateField, + models.DateTimeField: DateTimeField, + models.DecimalField: DecimalField, + models.EmailField: EmailField, + models.FileField: FileField, + models.FloatField: FloatField, models.IntegerField: IntegerField, + models.NullBooleanField: BooleanField, models.PositiveIntegerField: IntegerField, - models.SmallIntegerField: IntegerField, models.PositiveSmallIntegerField: IntegerField, - models.DateTimeField: DateTimeField, - models.DateField: DateField, + models.SlugField: SlugField, + models.SmallIntegerField: IntegerField, + models.TextField: CharField, models.TimeField: TimeField, - # models.DecimalField: DecimalField, - models.EmailField: EmailField, - models.CharField: CharField, models.URLField: URLField, - # models.SlugField: SlugField, - models.TextField: CharField, - models.CommaSeparatedIntegerField: CharField, - models.BooleanField: BooleanField, - models.NullBooleanField: BooleanField, - models.FileField: FileField, # models.ImageField: ImageField, } @@ -392,85 +376,31 @@ class ModelSerializer(Serializer): """ Return all the fields that should be serialized for the model. """ - cls = self.opts.model - opts = cls._meta.concrete_model._meta + info = modelinfo.get_field_info(self.opts.model) ret = OrderedDict() - nested = bool(self.opts.depth) - # Deal with adding the primary key field - pk_field = opts.pk - while pk_field.rel and pk_field.rel.parent_link: - # If model is a child via multitable inheritance, use parent's pk - pk_field = pk_field.rel.to._meta.pk - - serializer_pk_field = self.get_pk_field(pk_field) + serializer_pk_field = self.get_pk_field(info.pk) if serializer_pk_field: - ret[pk_field.name] = serializer_pk_field - - # Deal with forward relationships - forward_rels = [field for field in opts.fields if field.serialize] - forward_rels += [field for field in opts.many_to_many if field.serialize] - - for model_field in forward_rels: - has_through_model = False - - if model_field.rel: - to_many = isinstance(model_field, - models.fields.related.ManyToManyField) - related_model = _resolve_model(model_field.rel.to) - - if to_many and not model_field.rel.through._meta.auto_created: - has_through_model = True + ret[info.pk.name] = serializer_pk_field - if model_field.rel and nested: - field = self.get_nested_field(model_field, related_model, to_many) - elif model_field.rel: - field = self.get_related_field(model_field, related_model, to_many) - else: - field = self.get_field(model_field) - - if field: - if has_through_model: - field.read_only = True + # Regular fields + for field_name, field in info.fields.items(): + ret[field_name] = self.get_field(field) - ret[model_field.name] = field - - # Deal with reverse relationships - if not self.opts.fields: - reverse_rels = [] - else: - # Reverse relationships are only included if they are explicitly - # present in the `fields` option on the serializer - reverse_rels = opts.get_all_related_objects() - reverse_rels += opts.get_all_related_many_to_many_objects() - - for relation in reverse_rels: - accessor_name = relation.get_accessor_name() - if not self.opts.fields or accessor_name not in self.opts.fields: - continue - related_model = relation.model - to_many = relation.field.rel.multiple - has_through_model = False - is_m2m = isinstance(relation.field, - models.fields.related.ManyToManyField) - - if ( - is_m2m and - hasattr(relation.field.rel, 'through') and - not relation.field.rel.through._meta.auto_created - ): - has_through_model = True - - if nested: - field = self.get_nested_field(None, related_model, to_many) + # Forward relations + for field_name, relation_info in info.forward_relations.items(): + if self.opts.depth: + ret[field_name] = self.get_nested_field(*relation_info) else: - field = self.get_related_field(None, related_model, to_many) + ret[field_name] = self.get_related_field(*relation_info) - if field: - if has_through_model: - field.read_only = True - - ret[accessor_name] = field + # Reverse relations + for accessor_name, relation_info in info.reverse_relations.items(): + if accessor_name in self.opts.fields: + if self.opts.depth: + ret[field_name] = self.get_nested_field(*relation_info) + else: + ret[field_name] = self.get_related_field(*relation_info) return ret @@ -480,7 +410,7 @@ class ModelSerializer(Serializer): """ return self.get_field(model_field) - def get_nested_field(self, model_field, related_model, to_many): + def get_nested_field(self, model_field, related_model, to_many, has_through_model): """ Creates a default instance of a nested relational field. @@ -491,59 +421,148 @@ class ModelSerializer(Serializer): model = related_model depth = self.opts.depth - 1 - return NestedModelSerializer(many=to_many) + kwargs = {'read_only': True} + if to_many: + kwargs['many'] = True + return NestedModelSerializer(**kwargs) - def get_related_field(self, model_field, related_model, to_many): + def get_related_field(self, model_field, related_model, to_many, has_through_model): """ Creates a default instance of a flat relational field. Note that model_field will be `None` for reverse relationships. """ - # TODO: filter queryset using: - # .using(db).complex_filter(self.rel.limit_choices_to) + kwargs = { + 'queryset': related_model._default_manager, + } - kwargs = {} - # 'queryset': related_model._default_manager, - # 'many': to_many - # } + if to_many: + kwargs['many'] = True + + if has_through_model: + kwargs['read_only'] = True + kwargs.pop('queryset', None) if model_field: - kwargs['required'] = not(model_field.null or model_field.blank) + if model_field.null or model_field.blank: + kwargs['required'] = False # if model_field.help_text is not None: # kwargs['help_text'] = model_field.help_text if model_field.verbose_name is not None: kwargs['label'] = model_field.verbose_name if not model_field.editable: kwargs['read_only'] = True - if model_field.verbose_name is not None: - kwargs['label'] = model_field.verbose_name + kwargs.pop('queryset', None) - return IntegerField(**kwargs) - # TODO: return PrimaryKeyRelatedField(**kwargs) + return PrimaryKeyRelatedField(**kwargs) def get_field(self, model_field): """ Creates a default instance of a basic non-relational field. """ kwargs = {} + validator_kwarg = model_field.validators if model_field.null or model_field.blank: kwargs['required'] = False + if model_field.verbose_name is not None: + kwargs['label'] = model_field.verbose_name + if isinstance(model_field, models.AutoField) or not model_field.editable: kwargs['read_only'] = True + # Read only implies that the field is not required. + # We have a cleaner repr on the instance if we don't set it. + kwargs.pop('required', None) if model_field.has_default(): kwargs['default'] = model_field.get_default() - - if issubclass(model_field.__class__, models.TextField): - kwargs['widget'] = widgets.Textarea - - if model_field.verbose_name is not None: - kwargs['label'] = model_field.verbose_name - - if model_field.validators is not None: - kwargs['validators'] = model_field.validators + # Having a default implies that the field is not required. + # We have a cleaner repr on the instance if we don't set it. + kwargs.pop('required', None) + + # Ensure that max_length is passed explicitly as a keyword arg, + # rather than as a validator. + max_length = getattr(model_field, 'max_length', None) + if max_length is not None: + kwargs['max_length'] = max_length + validator_kwarg = [ + validator for validator in validator_kwarg + if not isinstance(validator, validators.MaxLengthValidator) + ] + + # Ensure that min_length is passed explicitly as a keyword arg, + # rather than as a validator. + min_length = getattr(model_field, 'min_length', None) + if min_length is not None: + kwargs['min_length'] = min_length + validator_kwarg = [ + validator for validator in validator_kwarg + if not isinstance(validator, validators.MinLengthValidator) + ] + + # Ensure that max_value is passed explicitly as a keyword arg, + # rather than as a validator. + max_value = next(( + validator.limit_value for validator in validator_kwarg + if isinstance(validator, validators.MaxValueValidator) + ), None) + if max_value is not None: + kwargs['max_value'] = max_value + validator_kwarg = [ + validator for validator in validator_kwarg + if not isinstance(validator, validators.MaxValueValidator) + ] + + # Ensure that max_value is passed explicitly as a keyword arg, + # rather than as a validator. + min_value = next(( + validator.limit_value for validator in validator_kwarg + if isinstance(validator, validators.MinValueValidator) + ), None) + if min_value is not None: + kwargs['min_value'] = min_value + validator_kwarg = [ + validator for validator in validator_kwarg + if not isinstance(validator, validators.MinValueValidator) + ] + + # URLField does not need to include the URLValidator argument, + # as it is explicitly added in. + if isinstance(model_field, models.URLField): + validator_kwarg = [ + validator for validator in validator_kwarg + if not isinstance(validator, validators.URLValidator) + ] + + # EmailField does not need to include the validate_email argument, + # as it is explicitly added in. + if isinstance(model_field, models.EmailField): + validator_kwarg = [ + validator for validator in validator_kwarg + if validator is not validators.validate_email + ] + + # SlugField do not need to include the 'validate_slug' argument, + if isinstance(model_field, models.SlugField): + validator_kwarg = [ + validator for validator in validator_kwarg + if validator is not validators.validate_slug + ] + + max_digits = getattr(model_field, 'max_digits', None) + if max_digits is not None: + kwargs['max_digits'] = max_digits + + decimal_places = getattr(model_field, 'decimal_places', None) + if decimal_places is not None: + kwargs['decimal_places'] = decimal_places + + if validator_kwarg: + kwargs['validators'] = validator_kwarg + + # if issubclass(model_field.__class__, models.TextField): + # kwargs['widget'] = widgets.Textarea # if model_field.help_text is not None: # kwargs['help_text'] = model_field.help_text @@ -555,31 +574,10 @@ class ModelSerializer(Serializer): kwargs['empty'] = None return ChoiceField(**kwargs) - # put this below the ChoiceField because min_value isn't a valid initializer - if issubclass(model_field.__class__, models.PositiveIntegerField) or \ - issubclass(model_field.__class__, models.PositiveSmallIntegerField): - kwargs['min_value'] = 0 - if model_field.null and \ issubclass(model_field.__class__, (models.CharField, models.TextField)): kwargs['allow_none'] = True - # attribute_dict = { - # models.CharField: ['max_length'], - # models.CommaSeparatedIntegerField: ['max_length'], - # models.DecimalField: ['max_digits', 'decimal_places'], - # models.EmailField: ['max_length'], - # models.FileField: ['max_length'], - # models.ImageField: ['max_length'], - # models.SlugField: ['max_length'], - # models.URLField: ['max_length'], - # } - - # if model_field.__class__ in attribute_dict: - # attributes = attribute_dict[model_field.__class__] - # for attribute in attributes: - # kwargs.update({attribute: getattr(model_field, attribute)}) - try: return self.field_mapping[model_field.__class__](**kwargs) except KeyError: @@ -594,28 +592,21 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions): super(HyperlinkedModelSerializerOptions, self).__init__(meta) self.view_name = getattr(meta, 'view_name', None) self.lookup_field = getattr(meta, 'lookup_field', None) - self.url_field_name = getattr(meta, 'url_field_name', api_settings.URL_FIELD_NAME) class HyperlinkedModelSerializer(ModelSerializer): _options_class = HyperlinkedModelSerializerOptions - _default_view_name = '%(model_name)s-detail' - _hyperlink_field_class = HyperlinkedRelatedField - _hyperlink_identify_field_class = HyperlinkedIdentityField def get_default_fields(self): fields = super(HyperlinkedModelSerializer, self).get_default_fields() if self.opts.view_name is None: - self.opts.view_name = self._get_default_view_name(self.opts.model) + self.opts.view_name = self.get_default_view_name(self.opts.model) - if self.opts.url_field_name not in fields: - url_field = self._hyperlink_identify_field_class( - view_name=self.opts.view_name, - lookup_field=self.opts.lookup_field - ) + url_field_name = api_settings.URL_FIELD_NAME + if url_field_name not in fields: ret = fields.__class__() - ret[self.opts.url_field_name] = url_field + ret[url_field_name] = self.get_url_field() ret.update(fields) fields = ret @@ -625,39 +616,48 @@ class HyperlinkedModelSerializer(ModelSerializer): if self.opts.fields and model_field.name in self.opts.fields: return self.get_field(model_field) - def get_related_field(self, model_field, related_model, to_many): + def get_url_field(self): + kwargs = { + 'view_name': self.get_default_view_name(self.opts.model) + } + if self.opts.lookup_field: + kwargs['lookup_field'] = self.opts.lookup_field + return HyperlinkedIdentityField(**kwargs) + + def get_related_field(self, model_field, related_model, to_many, has_through_model): """ Creates a default instance of a flat relational field. """ - # TODO: filter queryset using: - # .using(db).complex_filter(self.rel.limit_choices_to) - # kwargs = { - # 'queryset': related_model._default_manager, - # 'view_name': self._get_default_view_name(related_model), - # 'many': to_many - # } - kwargs = {} + kwargs = { + 'queryset': related_model._default_manager, + 'view_name': self.get_default_view_name(related_model), + } + + if to_many: + kwargs['many'] = True + + if has_through_model: + kwargs['read_only'] = True + kwargs.pop('queryset', None) if model_field: - kwargs['required'] = not(model_field.null or model_field.blank) + if model_field.null or model_field.blank: + kwargs['required'] = False # if model_field.help_text is not None: # kwargs['help_text'] = model_field.help_text if model_field.verbose_name is not None: kwargs['label'] = model_field.verbose_name + if not model_field.editable: + kwargs['read_only'] = True + kwargs.pop('queryset', None) - return IntegerField(**kwargs) - # if self.opts.lookup_field: - # kwargs['lookup_field'] = self.opts.lookup_field - - # return self._hyperlink_field_class(**kwargs) + return HyperlinkedRelatedField(**kwargs) - def _get_default_view_name(self, model): + def get_default_view_name(self, model): """ - Return the view name to use if 'view_name' is not specified in 'Meta' + Return the view name to use for related models. """ - model_meta = model._meta - format_kwargs = { - 'app_label': model_meta.app_label, - 'model_name': model_meta.object_name.lower() + return '%(model_name)s-detail' % { + 'app_label': model._meta.app_label, + 'model_name': model._meta.object_name.lower() } - return self._default_view_name % format_kwargs diff --git a/rest_framework/utils/humanize_datetime.py b/rest_framework/utils/humanize_datetime.py new file mode 100644 index 00000000..649f2abc --- /dev/null +++ b/rest_framework/utils/humanize_datetime.py @@ -0,0 +1,47 @@ +""" +Helper functions that convert strftime formats into more readable representations. +""" +from rest_framework import ISO_8601 + + +def datetime_formats(formats): + format = ', '.join(formats).replace( + ISO_8601, + 'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]' + ) + return humanize_strptime(format) + + +def date_formats(formats): + format = ', '.join(formats).replace(ISO_8601, 'YYYY[-MM[-DD]]') + return humanize_strptime(format) + + +def time_formats(formats): + format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]') + return humanize_strptime(format) + + +def humanize_strptime(format_string): + # Note that we're missing some of the locale specific mappings that + # don't really make sense. + mapping = { + "%Y": "YYYY", + "%y": "YY", + "%m": "MM", + "%b": "[Jan-Dec]", + "%B": "[January-December]", + "%d": "DD", + "%H": "hh", + "%I": "hh", # Requires '%p' to differentiate from '%H'. + "%M": "mm", + "%S": "ss", + "%f": "uuuuuu", + "%a": "[Mon-Sun]", + "%A": "[Monday-Sunday]", + "%p": "[AM|PM]", + "%z": "[+HHMM|-HHMM]" + } + for key, val in mapping.items(): + format_string = format_string.replace(key, val) + return format_string diff --git a/rest_framework/utils/modelinfo.py b/rest_framework/utils/modelinfo.py new file mode 100644 index 00000000..c0513886 --- /dev/null +++ b/rest_framework/utils/modelinfo.py @@ -0,0 +1,97 @@ +""" +Helper functions for returning the field information that is associated +with a model class. +""" +from collections import namedtuple, OrderedDict +from django.db import models +from django.utils import six +import inspect + +FieldInfo = namedtuple('FieldResult', ['pk', 'fields', 'forward_relations', 'reverse_relations']) +RelationInfo = namedtuple('RelationInfo', ['field', 'related', 'to_many', 'has_through_model']) + + +def _resolve_model(obj): + """ + Resolve supplied `obj` to a Django model class. + + `obj` must be a Django model class itself, or a string + representation of one. Useful in situtations like GH #1225 where + Django may not have resolved a string-based reference to a model in + another model's foreign key definition. + + String representations should have the format: + 'appname.ModelName' + """ + if isinstance(obj, six.string_types) and len(obj.split('.')) == 2: + app_name, model_name = obj.split('.') + return models.get_model(app_name, model_name) + elif inspect.isclass(obj) and issubclass(obj, models.Model): + return obj + raise ValueError("{0} is not a Django model".format(obj)) + + +def get_field_info(model): + """ + Given a model class, returns a `FieldInfo` instance containing metadata + about the various field types on the model. + """ + opts = model._meta.concrete_model._meta + + # Deal with the primary key. + pk = opts.pk + while pk.rel and pk.rel.parent_link: + # If model is a child via multitable inheritance, use parent's pk. + pk = pk.rel.to._meta.pk + + # Deal with regular fields. + fields = OrderedDict() + for field in [field for field in opts.fields if field.serialize and not field.rel]: + fields[field.name] = field + + # Deal with forward relationships. + forward_relations = OrderedDict() + for field in [field for field in opts.fields if field.serialize and field.rel]: + forward_relations[field.name] = RelationInfo( + field=field, + related=_resolve_model(field.rel.to), + to_many=False, + has_through_model=False + ) + + # Deal with forward many-to-many relationships. + for field in [field for field in opts.many_to_many if field.serialize]: + forward_relations[field.name] = RelationInfo( + field=field, + related=_resolve_model(field.rel.to), + to_many=True, + has_through_model=( + not field.rel.through._meta.auto_created + ) + ) + + # Deal with reverse relationships. + reverse_relations = OrderedDict() + for relation in opts.get_all_related_objects(): + accessor_name = relation.get_accessor_name() + reverse_relations[accessor_name] = RelationInfo( + field=None, + related=relation.model, + to_many=relation.field.rel.multiple, + has_through_model=False + ) + + # Deal with reverse many-to-many relationships. + for relation in opts.get_all_related_many_to_many_objects(): + accessor_name = relation.get_accessor_name() + reverse_relations[accessor_name] = RelationInfo( + field=None, + related=relation.model, + to_many=True, + has_through_model=( + hasattr(relation.field.rel, 'through') and + not relation.field.rel.through._meta.auto_created + ) + ) + + return FieldInfo(pk, fields, forward_relations, reverse_relations) diff --git a/rest_framework/utils/representation.py b/rest_framework/utils/representation.py new file mode 100644 index 00000000..1de21597 --- /dev/null +++ b/rest_framework/utils/representation.py @@ -0,0 +1,72 @@ +""" +Helper functions for creating user-friendly representations +of serializer classes and serializer fields. +""" +import re + + +def smart_repr(value): + value = repr(value) + + # Representations like u'help text' + # should simply be presented as 'help text' + if value.startswith("u'") and value.endswith("'"): + return value[1:] + + # Representations like + # <django.core.validators.RegexValidator object at 0x1047af050> + # Should be presented as + # <django.core.validators.RegexValidator object> + value = re.sub(' at 0x[0-9a-f]{8,10}>', '>', value) + + return value + + +def field_repr(field, force_many=False): + kwargs = field._kwargs + if force_many: + kwargs = kwargs.copy() + kwargs['many'] = True + kwargs.pop('child', None) + + arg_string = ', '.join([smart_repr(val) for val in field._args]) + kwarg_string = ', '.join([ + '%s=%s' % (key, smart_repr(val)) + for key, val in sorted(kwargs.items()) + ]) + if arg_string and kwarg_string: + arg_string += ', ' + + if force_many: + class_name = force_many.__class__.__name__ + else: + class_name = field.__class__.__name__ + + return "%s(%s%s)" % (class_name, arg_string, kwarg_string) + + +def serializer_repr(serializer, indent, force_many=None): + ret = field_repr(serializer, force_many) + ':' + indent_str = ' ' * indent + + if force_many: + fields = force_many.fields + else: + fields = serializer.fields + + for field_name, field in fields.items(): + ret += '\n' + indent_str + field_name + ' = ' + if hasattr(field, 'fields'): + ret += serializer_repr(field, indent + 1) + elif hasattr(field, 'child'): + ret += list_repr(field, indent + 1) + else: + ret += field_repr(field) + return ret + + +def list_repr(serializer, indent): + child = serializer.child + if hasattr(child, 'fields'): + return serializer_repr(serializer, indent, force_many=child) + return field_repr(serializer) |
