diff options
| author | Tom Christie | 2014-09-09 17:46:28 +0100 | 
|---|---|---|
| committer | Tom Christie | 2014-09-09 17:46:28 +0100 | 
| commit | b1c07670ca65084c5fef2bbb63d1f4163763014b (patch) | |
| tree | 4f08654d698990d97fe275d8dbbbcc1164524086 | |
| parent | 21980b800d04a1d82a6003823abfdf4ab80ae979 (diff) | |
| download | django-rest-framework-b1c07670ca65084c5fef2bbb63d1f4163763014b.tar.bz2 | |
Fleshing out serializer fields
| -rw-r--r-- | rest_framework/fields.py | 591 | ||||
| -rw-r--r-- | rest_framework/serializers.py | 380 | ||||
| -rw-r--r-- | rest_framework/utils/humanize_datetime.py | 47 | ||||
| -rw-r--r-- | rest_framework/utils/modelinfo.py | 97 | ||||
| -rw-r--r-- | rest_framework/utils/representation.py | 72 | ||||
| -rw-r--r-- | tests/test_model_field_mappings.py | 160 | ||||
| -rw-r--r-- | tests/test_modelinfo.py (renamed from tests/test_serializers.py) | 2 | ||||
| -rw-r--r-- | tests/test_relations.py | 12 | ||||
| -rw-r--r-- | tests/test_serializer_empty.py | 2 | 
9 files changed, 1040 insertions, 323 deletions
| diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 250c0579..043a44ed 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1,8 +1,18 @@ +from django.conf import settings  from django.core import validators  from django.core.exceptions import ValidationError +from django.utils import timezone +from django.utils.dateparse import parse_date, parse_datetime, parse_time  from django.utils.encoding import is_protected_type -from rest_framework.utils import html +from django.utils.translation import ugettext_lazy as _ +from rest_framework import ISO_8601 +from rest_framework.compat import smart_text +from rest_framework.settings import api_settings +from rest_framework.utils import html, representation, humanize_datetime +import datetime +import decimal  import inspect +import warnings  class empty: @@ -71,22 +81,22 @@ class SkipField(Exception):      pass +NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`' +NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`' +NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`' +NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`' +MISSING_ERROR_MESSAGE = ( +    'ValidationError raised by `{class_name}`, but error key `{key}` does ' +    'not exist in the `error_messages` dictionary.' +) + +  class Field(object):      _creation_counter = 0 -    MESSAGES = { -        'required': 'This field is required.' +    default_error_messages = { +        'required': _('This field is required.')      } - -    _NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`' -    _NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`' -    _NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`' -    _NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`' -    _MISSING_ERROR_MESSAGE = ( -        'ValidationError raised by `{class_name}`, but error key `{key}` does ' -        'not exist in the `MESSAGES` dictionary.' -    ) -      default_validators = []      def __init__(self, read_only=False, write_only=False, @@ -100,10 +110,10 @@ class Field(object):              required = default is empty and not read_only          # Some combinations of keyword arguments do not make sense. -        assert not (read_only and write_only), self._NOT_READ_ONLY_WRITE_ONLY -        assert not (read_only and required), self._NOT_READ_ONLY_REQUIRED -        assert not (read_only and default is not empty), self._NOT_READ_ONLY_DEFAULT -        assert not (required and default is not empty), self._NOT_REQUIRED_DEFAULT +        assert not (read_only and write_only), NOT_READ_ONLY_WRITE_ONLY +        assert not (read_only and required), NOT_READ_ONLY_REQUIRED +        assert not (read_only and default is not empty), NOT_READ_ONLY_DEFAULT +        assert not (required and default is not empty), NOT_REQUIRED_DEFAULT          self.read_only = read_only          self.write_only = write_only @@ -113,7 +123,14 @@ class Field(object):          self.initial = initial          self.label = label          self.style = {} if style is None else style -        self.validators = self.default_validators + validators +        self.validators = validators or self.default_validators[:] + +        # Collect default error message from self and parent classes +        messages = {} +        for cls in reversed(self.__class__.__mro__): +            messages.update(getattr(cls, 'default_error_messages', {})) +        messages.update(error_messages or {}) +        self.error_messages = messages      def bind(self, field_name, parent, root):          """ @@ -186,12 +203,14 @@ class Field(object):                  self.fail('required')              return self.get_default() -        self.run_validators(data) -        return self.to_native(data) +        value = self.to_native(data) +        self.run_validators(value) +        return value      def run_validators(self, value):          if value in validators.EMPTY_VALUES:              return +          errors = []          for validator in self.validators:              try: @@ -218,33 +237,32 @@ class Field(object):          A helper method that simply raises a validation error.          """          try: -            raise ValidationError(self.MESSAGES[key].format(**kwargs)) +            msg = self.error_messages[key]          except KeyError:              class_name = self.__class__.__name__ -            msg = self._MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key) +            msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key)              raise AssertionError(msg) +        raise ValidationError(msg.format(**kwargs))      def __new__(cls, *args, **kwargs): +        """ +        When a field is instantiated, we store the arguments that were used, +        so that we can present a helpful representation of the object. +        """          instance = super(Field, cls).__new__(cls)          instance._args = args          instance._kwargs = kwargs          return instance      def __repr__(self): -        arg_string = ', '.join([repr(val) for val in self._args]) -        kwarg_string = ', '.join([ -            '%s=%s' % (key, repr(val)) for key, val in self._kwargs.items() -        ]) -        if arg_string and kwarg_string: -            arg_string += ', ' -        class_name = self.__class__.__name__ -        return "%s(%s%s)" % (class_name, arg_string, kwarg_string) +        return representation.field_repr(self) +# Boolean types... +  class BooleanField(Field): -    MESSAGES = { -        'required': 'This field is required.', -        'invalid_value': '`{input}` is not a valid boolean.' +    default_error_messages = { +        'invalid': _('`{input}` is not a valid boolean.')      }      TRUE_VALUES = {'t', 'T', 'true', 'True', 'TRUE', '1', 1, True}      FALSE_VALUES = {'f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False} @@ -261,13 +279,23 @@ class BooleanField(Field):              return True          elif data in self.FALSE_VALUES:              return False -        self.fail('invalid_value', input=data) +        self.fail('invalid', input=data) + +    def to_primative(self, value): +        if value is None: +            return None +        if value in self.TRUE_VALUES: +            return True +        elif value in self.FALSE_VALUES: +            return False +        return bool(value) +# String types... +  class CharField(Field): -    MESSAGES = { -        'required': 'This field is required.', -        'blank': 'This field may not be blank.' +    default_error_messages = { +        'blank': _('This field may not be blank.')      }      def __init__(self, **kwargs): @@ -281,19 +309,364 @@ class CharField(Field):              self.fail('blank')          return str(data) +    def to_primative(self, value): +        if value is None: +            return None +        return str(value) -class ChoiceField(Field): -    MESSAGES = { -        'required': 'This field is required.', -        'invalid_choice': '`{input}` is not a valid choice.' + +class EmailField(CharField): +    default_error_messages = { +        'invalid': _('Enter a valid email address.') +    } +    default_validators = [validators.validate_email] + +    def to_native(self, data): +        ret = super(EmailField, self).to_native(data) +        if ret is None: +            return None +        return ret.strip() + +    def to_primative(self, value): +        ret = super(EmailField, self).to_primative(value) +        if ret is None: +            return None +        return ret.strip() + + +class RegexField(CharField): +    def __init__(self, regex, **kwargs): +        kwargs['validators'] = ( +            [validators.RegexValidator(regex)] + +            kwargs.get('validators', []) +        ) +        super(RegexField, self).__init__(**kwargs) + + +class SlugField(CharField): +    default_error_messages = { +        'invalid': _("Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens.") +    } +    default_validators = [validators.validate_slug] + + +class URLField(CharField): +    default_error_messages = { +        'invalid': _("Enter a valid URL.") +    } +    default_validators = [validators.URLValidator()] + + +# Number types... + +class IntegerField(Field): +    default_error_messages = { +        'invalid': _('A valid integer is required.') +    } + +    def __init__(self, **kwargs): +        max_value = kwargs.pop('max_value', None) +        min_value = kwargs.pop('min_value', None) +        super(IntegerField, self).__init__(**kwargs) +        if max_value is not None: +            self.validators.append(validators.MaxValueValidator(max_value)) +        if min_value is not None: +            self.validators.append(validators.MinValueValidator(min_value)) +        print self.__class__.__name__, self.validators + +    def to_native(self, data): +        try: +            data = int(str(data)) +        except (ValueError, TypeError): +            self.fail('invalid') +        return data + +    def to_primative(self, value): +        if value is None: +            return None +        return int(value) + + +class FloatField(Field): +    default_error_messages = { +        'invalid': _("'%s' value must be a float."),      } -    coerce_to_type = str      def __init__(self, **kwargs): -        choices = kwargs.pop('choices') +        max_value = kwargs.pop('max_value', None) +        min_value = kwargs.pop('min_value', None) +        super(FloatField, self).__init__(**kwargs) +        if max_value is not None: +            self.validators.append(validators.MaxValueValidator(max_value)) +        if min_value is not None: +            self.validators.append(validators.MinValueValidator(min_value)) + +    def to_primative(self, value): +        if value is None: +            return None +        try: +            return float(value) +        except (TypeError, ValueError): +            self.fail('invalid', value=value) + +    def to_native(self, value): +        if value is None: +            return None +        return float(value) + + +class DecimalField(Field): +    default_error_messages = { +        'invalid': _('Enter a number.'), +        'max_value': _('Ensure this value is less than or equal to {max_value}.'), +        'min_value': _('Ensure this value is greater than or equal to {min_value}.'), +        'max_digits': _('Ensure that there are no more than {max_digits} digits in total.'), +        'max_decimal_places': _('Ensure that there are no more than {max_decimal_places} decimal places.'), +        'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.') +    } + +    def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, **kwargs): +        self.max_value, self.min_value = max_value, min_value +        self.max_digits, self.max_decimal_places = max_digits, decimal_places +        super(DecimalField, self).__init__(**kwargs) +        if max_value is not None: +            self.validators.append(validators.MaxValueValidator(max_value)) +        if min_value is not None: +            self.validators.append(validators.MinValueValidator(min_value)) + +    def from_native(self, value): +        """ +        Validates that the input is a decimal number. Returns a Decimal +        instance. Returns None for empty values. Ensures that there are no more +        than max_digits in the number, and no more than decimal_places digits +        after the decimal point. +        """ +        if value in validators.EMPTY_VALUES: +            return None + +        value = smart_text(value).strip() +        try: +            value = decimal.Decimal(value) +        except decimal.DecimalException: +            self.fail('invalid') + +        # Check for NaN. It is the only value that isn't equal to itself, +        # so we can use this to identify NaN values. +        if value != value: +            self.fail('invalid') + +        # Check for infinity and negative infinity. +        if value in (decimal.Decimal('Inf'), decimal.Decimal('-Inf')): +            self.fail('invalid') + +        sign, digittuple, exponent = value.as_tuple() +        decimals = abs(exponent) +        # digittuple doesn't include any leading zeros. +        digits = len(digittuple) +        if decimals > digits: +            # We have leading zeros up to or past the decimal point.  Count +            # everything past the decimal point as a digit.  We do not count +            # 0 before the decimal point as a digit since that would mean +            # we would not allow max_digits = decimal_places. +            digits = decimals +        whole_digits = digits - decimals + +        if self.max_digits is not None and digits > self.max_digits: +            self.fail('max_digits', max_digits=self.max_digits) +        if self.decimal_places is not None and decimals > self.decimal_places: +            self.fail('max_decimal_places', max_decimal_places=self.max_decimal_places) +        if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places): +            self.fail('max_whole_digits', max_while_digits=self.max_digits - self.decimal_places) + +        return value + + +# Date & time fields... + +class DateField(Field): +    default_error_messages = { +        'invalid': _("Date has wrong format. Use one of these formats instead: %s"), +    } +    input_formats = api_settings.DATE_INPUT_FORMATS +    format = api_settings.DATE_FORMAT + +    def __init__(self, input_formats=None, format=None, *args, **kwargs): +        self.input_formats = input_formats if input_formats is not None else self.input_formats +        self.format = format if format is not None else self.format +        super(DateField, self).__init__(*args, **kwargs) + +    def from_native(self, value): +        if value in validators.EMPTY_VALUES: +            return None + +        if isinstance(value, datetime.datetime): +            if timezone and settings.USE_TZ and timezone.is_aware(value): +                # Convert aware datetimes to the default time zone +                # before casting them to dates (#17742). +                default_timezone = timezone.get_default_timezone() +                value = timezone.make_naive(value, default_timezone) +            return value.date() +        if isinstance(value, datetime.date): +            return value + +        for format in self.input_formats: +            if format.lower() == ISO_8601: +                try: +                    parsed = parse_date(value) +                except (ValueError, TypeError): +                    pass +                else: +                    if parsed is not None: +                        return parsed +            else: +                try: +                    parsed = datetime.datetime.strptime(value, format) +                except (ValueError, TypeError): +                    pass +                else: +                    return parsed.date() + +        humanized_format = humanize_datetime.date_formats(self.input_formats) +        msg = self.error_messages['invalid'] % humanized_format +        raise ValidationError(msg) + +    def to_primative(self, value): +        if value is None or self.format is None: +            return value + +        if isinstance(value, datetime.datetime): +            value = value.date() + +        if self.format.lower() == ISO_8601: +            return value.isoformat() +        return value.strftime(self.format) + + +class DateTimeField(Field): +    default_error_messages = { +        'invalid': _("Datetime has wrong format. Use one of these formats instead: %s"), +    } +    input_formats = api_settings.DATETIME_INPUT_FORMATS +    format = api_settings.DATETIME_FORMAT + +    def __init__(self, input_formats=None, format=None, *args, **kwargs): +        self.input_formats = input_formats if input_formats is not None else self.input_formats +        self.format = format if format is not None else self.format +        super(DateTimeField, self).__init__(*args, **kwargs) + +    def from_native(self, value): +        if value in validators.EMPTY_VALUES: +            return None + +        if isinstance(value, datetime.datetime): +            return value +        if isinstance(value, datetime.date): +            value = datetime.datetime(value.year, value.month, value.day) +            if settings.USE_TZ: +                # For backwards compatibility, interpret naive datetimes in +                # local time. This won't work during DST change, but we can't +                # do much about it, so we let the exceptions percolate up the +                # call stack. +                warnings.warn("DateTimeField received a naive datetime (%s)" +                              " while time zone support is active." % value, +                              RuntimeWarning) +                default_timezone = timezone.get_default_timezone() +                value = timezone.make_aware(value, default_timezone) +            return value + +        for format in self.input_formats: +            if format.lower() == ISO_8601: +                try: +                    parsed = parse_datetime(value) +                except (ValueError, TypeError): +                    pass +                else: +                    if parsed is not None: +                        return parsed +            else: +                try: +                    parsed = datetime.datetime.strptime(value, format) +                except (ValueError, TypeError): +                    pass +                else: +                    return parsed + +        humanized_format = humanize_datetime.datetime_formats(self.input_formats) +        msg = self.error_messages['invalid'] % humanized_format +        raise ValidationError(msg) + +    def to_primative(self, value): +        if value is None or self.format is None: +            return value + +        if self.format.lower() == ISO_8601: +            ret = value.isoformat() +            if ret.endswith('+00:00'): +                ret = ret[:-6] + 'Z' +            return ret +        return value.strftime(self.format) -        assert choices, '`choices` argument is required and may not be empty' +class TimeField(Field): +    default_error_messages = { +        'invalid': _("Time has wrong format. Use one of these formats instead: %s"), +    } +    input_formats = api_settings.TIME_INPUT_FORMATS +    format = api_settings.TIME_FORMAT + +    def __init__(self, input_formats=None, format=None, *args, **kwargs): +        self.input_formats = input_formats if input_formats is not None else self.input_formats +        self.format = format if format is not None else self.format +        super(TimeField, self).__init__(*args, **kwargs) + +    def from_native(self, value): +        if value in validators.EMPTY_VALUES: +            return None + +        if isinstance(value, datetime.time): +            return value + +        for format in self.input_formats: +            if format.lower() == ISO_8601: +                try: +                    parsed = parse_time(value) +                except (ValueError, TypeError): +                    pass +                else: +                    if parsed is not None: +                        return parsed +            else: +                try: +                    parsed = datetime.datetime.strptime(value, format) +                except (ValueError, TypeError): +                    pass +                else: +                    return parsed.time() + +        humanized_format = humanize_datetime.time_formats(self.input_formats) +        msg = self.error_messages['invalid'] % humanized_format +        raise ValidationError(msg) + +    def to_primative(self, value): +        if value is None or self.format is None: +            return value + +        if isinstance(value, datetime.datetime): +            value = value.time() + +        if self.format.lower() == ISO_8601: +            return value.isoformat() +        return value.strftime(self.format) + + +# Choice types... + +class ChoiceField(Field): +    default_error_messages = { +        'invalid_choice': _('`{input}` is not a valid choice.') +    } + +    def __init__(self, choices, **kwargs):          # Allow either single or paired choices style:          # choices = [1, 2, 3]          # choices = [(1, 'First'), (2, 'Second'), (3, 'Third')] @@ -321,12 +694,14 @@ class ChoiceField(Field):          except KeyError:              self.fail('invalid_choice', input=data) +    def to_primative(self, value): +        return value +  class MultipleChoiceField(ChoiceField): -    MESSAGES = { -        'required': 'This field is required.', -        'invalid_choice': '`{input}` is not a valid choice.', -        'not_a_list': 'Expected a list of items but got type `{input_type}`' +    default_error_messages = { +        'invalid_choice': _('`{input}` is not a valid choice.'), +        'not_a_list': _('Expected a list of items but got type `{input_type}`')      }      def to_native(self, data): @@ -337,72 +712,42 @@ class MultipleChoiceField(ChoiceField):              for item in data          ]) - -class IntegerField(Field): -    MESSAGES = { -        'required': 'This field is required.', -        'invalid_integer': 'A valid integer is required.' -    } - -    def __init__(self, **kwargs): -        max_value = kwargs.pop('max_value', None) -        min_value = kwargs.pop('min_value', None) -        super(IntegerField, self).__init__(**kwargs) -        if max_value is not None: -            self.validators.append(validators.MaxValueValidator(max_value)) -        if min_value is not None: -            self.validators.append(validators.MinValueValidator(min_value)) - -    def to_native(self, data): -        try: -            data = int(str(data)) -        except (ValueError, TypeError): -            self.fail('invalid_integer') -        return data -      def to_primative(self, value): -        if value is None: -            return None -        return int(value) +        return value -class EmailField(CharField): +# File types... + +class FileField(Field):      pass  # TODO -class URLField(CharField): +class ImageField(Field):      pass  # TODO -class RegexField(CharField): -    def __init__(self, **kwargs): -        self.regex = kwargs.pop('regex') -        super(CharField, self).__init__(**kwargs) - +# Advanced field types... -class DateField(CharField): -    def __init__(self, **kwargs): -        self.input_formats = kwargs.pop('input_formats', None) -        super(DateField, self).__init__(**kwargs) +class ReadOnlyField(Field): +    """ +    A read-only field that simply returns the field value. +    If the field is a method with no parameters, the method will be called +    and it's return value used as the representation. -class TimeField(CharField): -    def __init__(self, **kwargs): -        self.input_formats = kwargs.pop('input_formats', None) -        super(TimeField, self).__init__(**kwargs) +    For example, the following would call `get_expiry_date()` on the object: +    class ExampleSerializer(self): +        expiry_date = ReadOnlyField(source='get_expiry_date') +    """ -class DateTimeField(CharField):      def __init__(self, **kwargs): -        self.input_formats = kwargs.pop('input_formats', None) -        super(DateTimeField, self).__init__(**kwargs) - - -class FileField(Field): -    pass  # TODO +        kwargs['read_only'] = True +        super(ReadOnlyField, self).__init__(**kwargs) +    def to_native(self, data): +        raise NotImplemented('.to_native() not supported.') -class ReadOnlyField(Field):      def to_primative(self, value):          if is_simple_callable(value):              return value() @@ -410,11 +755,28 @@ class ReadOnlyField(Field):  class MethodField(Field): +    """ +    A read-only field that get its representation from calling a method on the +    parent serializer class. The method called will be of the form +    "get_{field_name}", and should take a single argument, which is the +    object being serialized. + +    For example: + +    class ExampleSerializer(self): +        extra_info = MethodField() + +        def get_extra_info(self, obj): +            return ...  # Calculate some data to return. +    """      def __init__(self, **kwargs):          kwargs['source'] = '*'          kwargs['read_only'] = True          super(MethodField, self).__init__(**kwargs) +    def to_native(self, data): +        raise NotImplemented('.to_native() not supported.') +      def to_primative(self, value):          attr = 'get_{field_name}'.format(field_name=self.field_name)          method = getattr(self.parent, attr) @@ -424,35 +786,14 @@ class MethodField(Field):  class ModelField(Field):      """      A generic field that can be used against an arbitrary model field. -    """ -    def __init__(self, *args, **kwargs): -        try: -            self.model_field = kwargs.pop('model_field') -        except KeyError: -            raise ValueError("ModelField requires 'model_field' kwarg") - -        self.min_length = kwargs.pop('min_length', -                                     getattr(self.model_field, 'min_length', None)) -        self.max_length = kwargs.pop('max_length', -                                     getattr(self.model_field, 'max_length', None)) -        self.min_value = kwargs.pop('min_value', -                                    getattr(self.model_field, 'min_value', None)) -        self.max_value = kwargs.pop('max_value', -                                    getattr(self.model_field, 'max_value', None)) - -        super(ModelField, self).__init__(*args, **kwargs) - -        if self.min_length is not None: -            self.validators.append(validators.MinLengthValidator(self.min_length)) -        if self.max_length is not None: -            self.validators.append(validators.MaxLengthValidator(self.max_length)) -        if self.min_value is not None: -            self.validators.append(validators.MinValueValidator(self.min_value)) -        if self.max_value is not None: -            self.validators.append(validators.MaxValueValidator(self.max_value)) -    def get_attribute(self, instance): -        return get_attribute(instance, self.source_attrs[:-1]) +    This is used by `ModelSerializer` when dealing with custom model fields, +    that do not have a serializer field to be mapped to. +    """ +    def __init__(self, model_field, **kwargs): +        self.model_field = model_field +        kwargs['source'] = '*' +        super(ModelField, self).__init__(**kwargs)      def to_native(self, data):          rel = getattr(self.model_field, 'rel', None) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 93226d32..8ca28387 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -10,15 +10,15 @@ python primitives.  2. The process of marshalling between python primitives and request and  response content is handled by parsers and renderers.  """ +from django.core import validators  from django.core.exceptions import ValidationError  from django.db import models  from django.utils import six  from collections import namedtuple, OrderedDict  from rest_framework.fields import empty, set_value, Field, SkipField  from rest_framework.settings import api_settings -from rest_framework.utils import html +from rest_framework.utils import html, modelinfo, representation  import copy -import inspect  # Note: We do the following so that users of the framework can use this style:  # @@ -146,12 +146,10 @@ class SerializerMetaclass(type):  class Serializer(BaseSerializer):      def __new__(cls, *args, **kwargs): -        many = kwargs.pop('many', False) -        if many: -            class DynamicListSerializer(ListSerializer): -                child = cls() -            return DynamicListSerializer(*args, **kwargs) -        return super(Serializer, cls).__new__(cls) +        if kwargs.pop('many', False): +            kwargs['child'] = cls() +            return ListSerializer(*args, **kwargs) +        return super(Serializer, cls).__new__(cls, *args, **kwargs)      def __init__(self, *args, **kwargs):          self.context = kwargs.pop('context', {}) @@ -248,6 +246,9 @@ class Serializer(BaseSerializer):              error = errors.get(field.field_name)              yield FieldResult(field, value, error) +    def __repr__(self): +        return representation.serializer_repr(self, indent=1) +  class ListSerializer(BaseSerializer):      child = None @@ -299,26 +300,8 @@ class ListSerializer(BaseSerializer):          self.instance = self.create(self.validated_data)          return self.instance - -def _resolve_model(obj): -    """ -    Resolve supplied `obj` to a Django model class. - -    `obj` must be a Django model class itself, or a string -    representation of one.  Useful in situtations like GH #1225 where -    Django may not have resolved a string-based reference to a model in -    another model's foreign key definition. - -    String representations should have the format: -        'appname.ModelName' -    """ -    if isinstance(obj, six.string_types) and len(obj.split('.')) == 2: -        app_name, model_name = obj.split('.') -        return models.get_model(app_name, model_name) -    elif inspect.isclass(obj) and issubclass(obj, models.Model): -        return obj -    else: -        raise ValueError("{0} is not a Django model".format(obj)) +    def __repr__(self): +        return representation.list_repr(self, indent=1)  class ModelSerializerOptions(object): @@ -334,24 +317,25 @@ class ModelSerializerOptions(object):  class ModelSerializer(Serializer):      field_mapping = {          models.AutoField: IntegerField, -        # models.FloatField: FloatField, +        models.BigIntegerField: IntegerField, +        models.BooleanField: BooleanField, +        models.CharField: CharField, +        models.CommaSeparatedIntegerField: CharField, +        models.DateField: DateField, +        models.DateTimeField: DateTimeField, +        models.DecimalField: DecimalField, +        models.EmailField: EmailField, +        models.FileField: FileField, +        models.FloatField: FloatField,          models.IntegerField: IntegerField, +        models.NullBooleanField: BooleanField,          models.PositiveIntegerField: IntegerField, -        models.SmallIntegerField: IntegerField,          models.PositiveSmallIntegerField: IntegerField, -        models.DateTimeField: DateTimeField, -        models.DateField: DateField, +        models.SlugField: SlugField, +        models.SmallIntegerField: IntegerField, +        models.TextField: CharField,          models.TimeField: TimeField, -        # models.DecimalField: DecimalField, -        models.EmailField: EmailField, -        models.CharField: CharField,          models.URLField: URLField, -        # models.SlugField: SlugField, -        models.TextField: CharField, -        models.CommaSeparatedIntegerField: CharField, -        models.BooleanField: BooleanField, -        models.NullBooleanField: BooleanField, -        models.FileField: FileField,          # models.ImageField: ImageField,      } @@ -392,85 +376,31 @@ class ModelSerializer(Serializer):          """          Return all the fields that should be serialized for the model.          """ -        cls = self.opts.model -        opts = cls._meta.concrete_model._meta +        info = modelinfo.get_field_info(self.opts.model)          ret = OrderedDict() -        nested = bool(self.opts.depth) -        # Deal with adding the primary key field -        pk_field = opts.pk -        while pk_field.rel and pk_field.rel.parent_link: -            # If model is a child via multitable inheritance, use parent's pk -            pk_field = pk_field.rel.to._meta.pk - -        serializer_pk_field = self.get_pk_field(pk_field) +        serializer_pk_field = self.get_pk_field(info.pk)          if serializer_pk_field: -            ret[pk_field.name] = serializer_pk_field - -        # Deal with forward relationships -        forward_rels = [field for field in opts.fields if field.serialize] -        forward_rels += [field for field in opts.many_to_many if field.serialize] - -        for model_field in forward_rels: -            has_through_model = False - -            if model_field.rel: -                to_many = isinstance(model_field, -                                     models.fields.related.ManyToManyField) -                related_model = _resolve_model(model_field.rel.to) - -                if to_many and not model_field.rel.through._meta.auto_created: -                    has_through_model = True +            ret[info.pk.name] = serializer_pk_field -            if model_field.rel and nested: -                field = self.get_nested_field(model_field, related_model, to_many) -            elif model_field.rel: -                field = self.get_related_field(model_field, related_model, to_many) -            else: -                field = self.get_field(model_field) - -            if field: -                if has_through_model: -                    field.read_only = True +        # Regular fields +        for field_name, field in info.fields.items(): +            ret[field_name] = self.get_field(field) -                ret[model_field.name] = field - -        # Deal with reverse relationships -        if not self.opts.fields: -            reverse_rels = [] -        else: -            # Reverse relationships are only included if they are explicitly -            # present in the `fields` option on the serializer -            reverse_rels = opts.get_all_related_objects() -            reverse_rels += opts.get_all_related_many_to_many_objects() - -        for relation in reverse_rels: -            accessor_name = relation.get_accessor_name() -            if not self.opts.fields or accessor_name not in self.opts.fields: -                continue -            related_model = relation.model -            to_many = relation.field.rel.multiple -            has_through_model = False -            is_m2m = isinstance(relation.field, -                                models.fields.related.ManyToManyField) - -            if ( -                is_m2m and -                hasattr(relation.field.rel, 'through') and -                not relation.field.rel.through._meta.auto_created -            ): -                has_through_model = True - -            if nested: -                field = self.get_nested_field(None, related_model, to_many) +        # Forward relations +        for field_name, relation_info in info.forward_relations.items(): +            if self.opts.depth: +                ret[field_name] = self.get_nested_field(*relation_info)              else: -                field = self.get_related_field(None, related_model, to_many) +                ret[field_name] = self.get_related_field(*relation_info) -            if field: -                if has_through_model: -                    field.read_only = True - -                ret[accessor_name] = field +        # Reverse relations +        for accessor_name, relation_info in info.reverse_relations.items(): +            if accessor_name in self.opts.fields: +                if self.opts.depth: +                    ret[field_name] = self.get_nested_field(*relation_info) +                else: +                    ret[field_name] = self.get_related_field(*relation_info)          return ret @@ -480,7 +410,7 @@ class ModelSerializer(Serializer):          """          return self.get_field(model_field) -    def get_nested_field(self, model_field, related_model, to_many): +    def get_nested_field(self, model_field, related_model, to_many, has_through_model):          """          Creates a default instance of a nested relational field. @@ -491,59 +421,148 @@ class ModelSerializer(Serializer):                  model = related_model                  depth = self.opts.depth - 1 -        return NestedModelSerializer(many=to_many) +        kwargs = {'read_only': True} +        if to_many: +            kwargs['many'] = True +        return NestedModelSerializer(**kwargs) -    def get_related_field(self, model_field, related_model, to_many): +    def get_related_field(self, model_field, related_model, to_many, has_through_model):          """          Creates a default instance of a flat relational field.          Note that model_field will be `None` for reverse relationships.          """ -        # TODO: filter queryset using: -        # .using(db).complex_filter(self.rel.limit_choices_to) +        kwargs = { +            'queryset': related_model._default_manager, +        } -        kwargs = {} -        #     'queryset': related_model._default_manager, -        #     'many': to_many -        # } +        if to_many: +            kwargs['many'] = True + +        if has_through_model: +            kwargs['read_only'] = True +            kwargs.pop('queryset', None)          if model_field: -            kwargs['required'] = not(model_field.null or model_field.blank) +            if model_field.null or model_field.blank: +                kwargs['required'] = False              #  if model_field.help_text is not None:              #      kwargs['help_text'] = model_field.help_text              if model_field.verbose_name is not None:                  kwargs['label'] = model_field.verbose_name              if not model_field.editable:                  kwargs['read_only'] = True -            if model_field.verbose_name is not None: -                kwargs['label'] = model_field.verbose_name +                kwargs.pop('queryset', None) -        return IntegerField(**kwargs) -        # TODO: return PrimaryKeyRelatedField(**kwargs) +        return PrimaryKeyRelatedField(**kwargs)      def get_field(self, model_field):          """          Creates a default instance of a basic non-relational field.          """          kwargs = {} +        validator_kwarg = model_field.validators          if model_field.null or model_field.blank:              kwargs['required'] = False +        if model_field.verbose_name is not None: +            kwargs['label'] = model_field.verbose_name +          if isinstance(model_field, models.AutoField) or not model_field.editable:              kwargs['read_only'] = True +            # Read only implies that the field is not required. +            # We have a cleaner repr on the instance if we don't set it. +            kwargs.pop('required', None)          if model_field.has_default():              kwargs['default'] = model_field.get_default() - -        if issubclass(model_field.__class__, models.TextField): -            kwargs['widget'] = widgets.Textarea - -        if model_field.verbose_name is not None: -            kwargs['label'] = model_field.verbose_name - -        if model_field.validators is not None: -            kwargs['validators'] = model_field.validators +            # Having a default implies that the field is not required. +            # We have a cleaner repr on the instance if we don't set it. +            kwargs.pop('required', None) + +        # Ensure that max_length is passed explicitly as a keyword arg, +        # rather than as a validator. +        max_length = getattr(model_field, 'max_length', None) +        if max_length is not None: +            kwargs['max_length'] = max_length +            validator_kwarg = [ +                validator for validator in validator_kwarg +                if not isinstance(validator, validators.MaxLengthValidator) +            ] + +        # Ensure that min_length is passed explicitly as a keyword arg, +        # rather than as a validator. +        min_length = getattr(model_field, 'min_length', None) +        if min_length is not None: +            kwargs['min_length'] = min_length +            validator_kwarg = [ +                validator for validator in validator_kwarg +                if not isinstance(validator, validators.MinLengthValidator) +            ] + +        # Ensure that max_value is passed explicitly as a keyword arg, +        # rather than as a validator. +        max_value = next(( +            validator.limit_value for validator in validator_kwarg +            if isinstance(validator, validators.MaxValueValidator) +        ), None) +        if max_value is not None: +            kwargs['max_value'] = max_value +            validator_kwarg = [ +                validator for validator in validator_kwarg +                if not isinstance(validator, validators.MaxValueValidator) +            ] + +        # Ensure that max_value is passed explicitly as a keyword arg, +        # rather than as a validator. +        min_value = next(( +            validator.limit_value for validator in validator_kwarg +            if isinstance(validator, validators.MinValueValidator) +        ), None) +        if min_value is not None: +            kwargs['min_value'] = min_value +            validator_kwarg = [ +                validator for validator in validator_kwarg +                if not isinstance(validator, validators.MinValueValidator) +            ] + +        # URLField does not need to include the URLValidator argument, +        # as it is explicitly added in. +        if isinstance(model_field, models.URLField): +            validator_kwarg = [ +                validator for validator in validator_kwarg +                if not isinstance(validator, validators.URLValidator) +            ] + +        # EmailField does not need to include the validate_email argument, +        # as it is explicitly added in. +        if isinstance(model_field, models.EmailField): +            validator_kwarg = [ +                validator for validator in validator_kwarg +                if validator is not validators.validate_email +            ] + +        # SlugField do not need to include the 'validate_slug' argument, +        if isinstance(model_field, models.SlugField): +            validator_kwarg = [ +                validator for validator in validator_kwarg +                if validator is not validators.validate_slug +            ] + +        max_digits = getattr(model_field, 'max_digits', None) +        if max_digits is not None: +            kwargs['max_digits'] = max_digits + +        decimal_places = getattr(model_field, 'decimal_places', None) +        if decimal_places is not None: +            kwargs['decimal_places'] = decimal_places + +        if validator_kwarg: +            kwargs['validators'] = validator_kwarg + +        # if issubclass(model_field.__class__, models.TextField): +        #     kwargs['widget'] = widgets.Textarea          # if model_field.help_text is not None:          #     kwargs['help_text'] = model_field.help_text @@ -555,31 +574,10 @@ class ModelSerializer(Serializer):                  kwargs['empty'] = None              return ChoiceField(**kwargs) -        # put this below the ChoiceField because min_value isn't a valid initializer -        if issubclass(model_field.__class__, models.PositiveIntegerField) or \ -                issubclass(model_field.__class__, models.PositiveSmallIntegerField): -            kwargs['min_value'] = 0 -          if model_field.null and \                  issubclass(model_field.__class__, (models.CharField, models.TextField)):              kwargs['allow_none'] = True -        # attribute_dict = { -        #     models.CharField: ['max_length'], -        #     models.CommaSeparatedIntegerField: ['max_length'], -        #     models.DecimalField: ['max_digits', 'decimal_places'], -        #     models.EmailField: ['max_length'], -        #     models.FileField: ['max_length'], -        #     models.ImageField: ['max_length'], -        #     models.SlugField: ['max_length'], -        #     models.URLField: ['max_length'], -        # } - -        # if model_field.__class__ in attribute_dict: -        #     attributes = attribute_dict[model_field.__class__] -        #     for attribute in attributes: -        #         kwargs.update({attribute: getattr(model_field, attribute)}) -          try:              return self.field_mapping[model_field.__class__](**kwargs)          except KeyError: @@ -594,28 +592,21 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions):          super(HyperlinkedModelSerializerOptions, self).__init__(meta)          self.view_name = getattr(meta, 'view_name', None)          self.lookup_field = getattr(meta, 'lookup_field', None) -        self.url_field_name = getattr(meta, 'url_field_name', api_settings.URL_FIELD_NAME)  class HyperlinkedModelSerializer(ModelSerializer):      _options_class = HyperlinkedModelSerializerOptions -    _default_view_name = '%(model_name)s-detail' -    _hyperlink_field_class = HyperlinkedRelatedField -    _hyperlink_identify_field_class = HyperlinkedIdentityField      def get_default_fields(self):          fields = super(HyperlinkedModelSerializer, self).get_default_fields()          if self.opts.view_name is None: -            self.opts.view_name = self._get_default_view_name(self.opts.model) +            self.opts.view_name = self.get_default_view_name(self.opts.model) -        if self.opts.url_field_name not in fields: -            url_field = self._hyperlink_identify_field_class( -                view_name=self.opts.view_name, -                lookup_field=self.opts.lookup_field -            ) +        url_field_name = api_settings.URL_FIELD_NAME +        if url_field_name not in fields:              ret = fields.__class__() -            ret[self.opts.url_field_name] = url_field +            ret[url_field_name] = self.get_url_field()              ret.update(fields)              fields = ret @@ -625,39 +616,48 @@ class HyperlinkedModelSerializer(ModelSerializer):          if self.opts.fields and model_field.name in self.opts.fields:              return self.get_field(model_field) -    def get_related_field(self, model_field, related_model, to_many): +    def get_url_field(self): +        kwargs = { +            'view_name': self.get_default_view_name(self.opts.model) +        } +        if self.opts.lookup_field: +            kwargs['lookup_field'] = self.opts.lookup_field +        return HyperlinkedIdentityField(**kwargs) + +    def get_related_field(self, model_field, related_model, to_many, has_through_model):          """          Creates a default instance of a flat relational field.          """ -        # TODO: filter queryset using: -        # .using(db).complex_filter(self.rel.limit_choices_to) -        # kwargs = { -        #     'queryset': related_model._default_manager, -        #     'view_name': self._get_default_view_name(related_model), -        #     'many': to_many -        # } -        kwargs = {} +        kwargs = { +            'queryset': related_model._default_manager, +            'view_name': self.get_default_view_name(related_model), +        } + +        if to_many: +            kwargs['many'] = True + +        if has_through_model: +            kwargs['read_only'] = True +            kwargs.pop('queryset', None)          if model_field: -            kwargs['required'] = not(model_field.null or model_field.blank) +            if model_field.null or model_field.blank: +                kwargs['required'] = False              # if model_field.help_text is not None:              #     kwargs['help_text'] = model_field.help_text              if model_field.verbose_name is not None:                  kwargs['label'] = model_field.verbose_name +            if not model_field.editable: +                kwargs['read_only'] = True +                kwargs.pop('queryset', None) -        return IntegerField(**kwargs) -        # if self.opts.lookup_field: -        #     kwargs['lookup_field'] = self.opts.lookup_field - -        # return self._hyperlink_field_class(**kwargs) +        return HyperlinkedRelatedField(**kwargs) -    def _get_default_view_name(self, model): +    def get_default_view_name(self, model):          """ -        Return the view name to use if 'view_name' is not specified in 'Meta' +        Return the view name to use for related models.          """ -        model_meta = model._meta -        format_kwargs = { -            'app_label': model_meta.app_label, -            'model_name': model_meta.object_name.lower() +        return '%(model_name)s-detail' % { +            'app_label': model._meta.app_label, +            'model_name': model._meta.object_name.lower()          } -        return self._default_view_name % format_kwargs diff --git a/rest_framework/utils/humanize_datetime.py b/rest_framework/utils/humanize_datetime.py new file mode 100644 index 00000000..649f2abc --- /dev/null +++ b/rest_framework/utils/humanize_datetime.py @@ -0,0 +1,47 @@ +""" +Helper functions that convert strftime formats into more readable representations. +""" +from rest_framework import ISO_8601 + + +def datetime_formats(formats): +    format = ', '.join(formats).replace( +        ISO_8601, +        'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]' +    ) +    return humanize_strptime(format) + + +def date_formats(formats): +    format = ', '.join(formats).replace(ISO_8601, 'YYYY[-MM[-DD]]') +    return humanize_strptime(format) + + +def time_formats(formats): +    format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]') +    return humanize_strptime(format) + + +def humanize_strptime(format_string): +    # Note that we're missing some of the locale specific mappings that +    # don't really make sense. +    mapping = { +        "%Y": "YYYY", +        "%y": "YY", +        "%m": "MM", +        "%b": "[Jan-Dec]", +        "%B": "[January-December]", +        "%d": "DD", +        "%H": "hh", +        "%I": "hh",  # Requires '%p' to differentiate from '%H'. +        "%M": "mm", +        "%S": "ss", +        "%f": "uuuuuu", +        "%a": "[Mon-Sun]", +        "%A": "[Monday-Sunday]", +        "%p": "[AM|PM]", +        "%z": "[+HHMM|-HHMM]" +    } +    for key, val in mapping.items(): +        format_string = format_string.replace(key, val) +    return format_string diff --git a/rest_framework/utils/modelinfo.py b/rest_framework/utils/modelinfo.py new file mode 100644 index 00000000..c0513886 --- /dev/null +++ b/rest_framework/utils/modelinfo.py @@ -0,0 +1,97 @@ +""" +Helper functions for returning the field information that is associated +with a model class. +""" +from collections import namedtuple, OrderedDict +from django.db import models +from django.utils import six +import inspect + +FieldInfo = namedtuple('FieldResult', ['pk', 'fields', 'forward_relations', 'reverse_relations']) +RelationInfo = namedtuple('RelationInfo', ['field', 'related', 'to_many', 'has_through_model']) + + +def _resolve_model(obj): +    """ +    Resolve supplied `obj` to a Django model class. + +    `obj` must be a Django model class itself, or a string +    representation of one.  Useful in situtations like GH #1225 where +    Django may not have resolved a string-based reference to a model in +    another model's foreign key definition. + +    String representations should have the format: +        'appname.ModelName' +    """ +    if isinstance(obj, six.string_types) and len(obj.split('.')) == 2: +        app_name, model_name = obj.split('.') +        return models.get_model(app_name, model_name) +    elif inspect.isclass(obj) and issubclass(obj, models.Model): +        return obj +    raise ValueError("{0} is not a Django model".format(obj)) + + +def get_field_info(model): +    """ +    Given a model class, returns a `FieldInfo` instance containing metadata +    about the various field types on the model. +    """ +    opts = model._meta.concrete_model._meta + +    # Deal with the primary key. +    pk = opts.pk +    while pk.rel and pk.rel.parent_link: +        # If model is a child via multitable inheritance, use parent's pk. +        pk = pk.rel.to._meta.pk + +    # Deal with regular fields. +    fields = OrderedDict() +    for field in [field for field in opts.fields if field.serialize and not field.rel]: +        fields[field.name] = field + +    # Deal with forward relationships. +    forward_relations = OrderedDict() +    for field in [field for field in opts.fields if field.serialize and field.rel]: +        forward_relations[field.name] = RelationInfo( +            field=field, +            related=_resolve_model(field.rel.to), +            to_many=False, +            has_through_model=False +        ) + +    # Deal with forward many-to-many relationships. +    for field in [field for field in opts.many_to_many if field.serialize]: +        forward_relations[field.name] = RelationInfo( +            field=field, +            related=_resolve_model(field.rel.to), +            to_many=True, +            has_through_model=( +                not field.rel.through._meta.auto_created +            ) +        ) + +    # Deal with reverse relationships. +    reverse_relations = OrderedDict() +    for relation in opts.get_all_related_objects(): +        accessor_name = relation.get_accessor_name() +        reverse_relations[accessor_name] = RelationInfo( +            field=None, +            related=relation.model, +            to_many=relation.field.rel.multiple, +            has_through_model=False +        ) + +    # Deal with reverse many-to-many relationships. +    for relation in opts.get_all_related_many_to_many_objects(): +        accessor_name = relation.get_accessor_name() +        reverse_relations[accessor_name] = RelationInfo( +            field=None, +            related=relation.model, +            to_many=True, +            has_through_model=( +                hasattr(relation.field.rel, 'through') and +                not relation.field.rel.through._meta.auto_created +            ) +        ) + +    return FieldInfo(pk, fields, forward_relations, reverse_relations) diff --git a/rest_framework/utils/representation.py b/rest_framework/utils/representation.py new file mode 100644 index 00000000..1de21597 --- /dev/null +++ b/rest_framework/utils/representation.py @@ -0,0 +1,72 @@ +""" +Helper functions for creating user-friendly representations +of serializer classes and serializer fields. +""" +import re + + +def smart_repr(value): +    value = repr(value) + +    # Representations like u'help text' +    # should simply be presented as 'help text' +    if value.startswith("u'") and value.endswith("'"): +        return value[1:] + +    # Representations like +    # <django.core.validators.RegexValidator object at 0x1047af050> +    # Should be presented as +    # <django.core.validators.RegexValidator object> +    value = re.sub(' at 0x[0-9a-f]{8,10}>', '>', value) + +    return value + + +def field_repr(field, force_many=False): +    kwargs = field._kwargs +    if force_many: +        kwargs = kwargs.copy() +        kwargs['many'] = True +        kwargs.pop('child', None) + +    arg_string = ', '.join([smart_repr(val) for val in field._args]) +    kwarg_string = ', '.join([ +        '%s=%s' % (key, smart_repr(val)) +        for key, val in sorted(kwargs.items()) +    ]) +    if arg_string and kwarg_string: +        arg_string += ', ' + +    if force_many: +        class_name = force_many.__class__.__name__ +    else: +        class_name = field.__class__.__name__ + +    return "%s(%s%s)" % (class_name, arg_string, kwarg_string) + + +def serializer_repr(serializer, indent, force_many=None): +    ret = field_repr(serializer, force_many) + ':' +    indent_str = '    ' * indent + +    if force_many: +        fields = force_many.fields +    else: +        fields = serializer.fields + +    for field_name, field in fields.items(): +        ret += '\n' + indent_str + field_name + ' = ' +        if hasattr(field, 'fields'): +            ret += serializer_repr(field, indent + 1) +        elif hasattr(field, 'child'): +            ret += list_repr(field, indent + 1) +        else: +            ret += field_repr(field) +    return ret + + +def list_repr(serializer, indent): +    child = serializer.child +    if hasattr(child, 'fields'): +        return serializer_repr(serializer, indent, force_many=child) +    return field_repr(serializer) diff --git a/tests/test_model_field_mappings.py b/tests/test_model_field_mappings.py new file mode 100644 index 00000000..dc254da4 --- /dev/null +++ b/tests/test_model_field_mappings.py @@ -0,0 +1,160 @@ +""" +The `ModelSerializer` and `HyperlinkedModelSerializer` classes are essentially +shortcuts for automatically creating serializers based on a given model class. + +These tests deal with ensuring that we correctly map the model fields onto +an appropriate set of serializer fields for each case. +""" +from django.db import models +from django.test import TestCase +from rest_framework import serializers + + +# Models for testing regular field mapping + +class RegularFieldsModel(models.Model): +    auto_field = models.AutoField(primary_key=True) +    big_integer_field = models.BigIntegerField() +    boolean_field = models.BooleanField() +    char_field = models.CharField(max_length=100) +    comma_seperated_integer_field = models.CommaSeparatedIntegerField(max_length=100) +    date_field = models.DateField() +    datetime_field = models.DateTimeField() +    decimal_field = models.DecimalField(max_digits=3, decimal_places=1) +    email_field = models.EmailField(max_length=100) +    float_field = models.FloatField() +    integer_field = models.IntegerField() +    null_boolean_field = models.NullBooleanField() +    positive_integer_field = models.PositiveIntegerField() +    positive_small_integer_field = models.PositiveSmallIntegerField() +    slug_field = models.SlugField(max_length=100) +    small_integer_field = models.SmallIntegerField() +    text_field = models.TextField() +    time_field = models.TimeField() +    url_field = models.URLField(max_length=100) + + +REGULAR_FIELDS_REPR = """ +TestSerializer(): +    auto_field = IntegerField(label='auto field', read_only=True) +    big_integer_field = IntegerField(label='big integer field') +    boolean_field = BooleanField(default=False, label='boolean field') +    char_field = CharField(label='char field', max_length=100) +    comma_seperated_integer_field = CharField(label='comma seperated integer field', max_length=100, validators=[<django.core.validators.RegexValidator object>]) +    date_field = DateField(label='date field') +    datetime_field = DateTimeField(label='datetime field') +    decimal_field = DecimalField(decimal_places=1, label='decimal field', max_digits=3) +    email_field = EmailField(label='email field', max_length=100) +    float_field = FloatField(label='float field') +    integer_field = IntegerField(label='integer field') +    null_boolean_field = BooleanField(label='null boolean field', required=False) +    positive_integer_field = IntegerField(label='positive integer field') +    positive_small_integer_field = IntegerField(label='positive small integer field') +    slug_field = SlugField(label='slug field', max_length=100) +    small_integer_field = IntegerField(label='small integer field') +    text_field = CharField(label='text field') +    time_field = TimeField(label='time field') +    url_field = URLField(label='url field', max_length=100) +""".strip() + + +# Model for testing relational field mapping + +class ForeignKeyTarget(models.Model): +    char_field = models.CharField(max_length=100) + + +class ManyToManyTarget(models.Model): +    char_field = models.CharField(max_length=100) + + +class OneToOneTarget(models.Model): +    char_field = models.CharField(max_length=100) + + +class RelationalModel(models.Model): +    foreign_key = models.ForeignKey(ForeignKeyTarget) +    many_to_many = models.ManyToManyField(ManyToManyTarget) +    one_to_one = models.OneToOneField(OneToOneTarget) + + +RELATIONAL_FLAT_REPR = """ +TestSerializer(): +    id = IntegerField(label='ID', read_only=True) +    foreign_key = PrimaryKeyRelatedField(label='foreign key', queryset=<django.db.models.manager.Manager object>) +    one_to_one = PrimaryKeyRelatedField(label='one to one', queryset=<django.db.models.manager.Manager object>) +    many_to_many = PrimaryKeyRelatedField(label='many to many', many=True, queryset=<django.db.models.manager.Manager object>) +""".strip() + + +RELATIONAL_NESTED_REPR = """ +TestSerializer(): +    id = IntegerField(label='ID', read_only=True) +    foreign_key = NestedModelSerializer(read_only=True): +        id = IntegerField(label='ID', read_only=True) +        name = CharField(label='name', max_length=100) +    one_to_one = NestedModelSerializer(read_only=True): +        id = IntegerField(label='ID', read_only=True) +        name = CharField(label='name', max_length=100) +    many_to_many = NestedModelSerializer(many=True, read_only=True): +        id = IntegerField(label='ID', read_only=True) +        name = CharField(label='name', max_length=100) +""".strip() + + +HYPERLINKED_FLAT_REPR = """ +TestSerializer(): +    url = HyperlinkedIdentityField(view_name='relationalmodel-detail') +    foreign_key = HyperlinkedRelatedField(label='foreign key', queryset=<django.db.models.manager.Manager object>, view_name='foreignkeytarget-detail') +    one_to_one = HyperlinkedRelatedField(label='one to one', queryset=<django.db.models.manager.Manager object>, view_name='onetoonetarget-detail') +    many_to_many = HyperlinkedRelatedField(label='many to many', many=True, queryset=<django.db.models.manager.Manager object>, view_name='manytomanytarget-detail') +""".strip() + + +HYPERLINKED_NESTED_REPR = """ +TestSerializer(): +    url = HyperlinkedIdentityField(view_name='relationalmodel-detail') +    foreign_key = NestedModelSerializer(read_only=True): +        id = IntegerField(label='ID', read_only=True) +        name = CharField(label='name', max_length=100) +    one_to_one = NestedModelSerializer(read_only=True): +        id = IntegerField(label='ID', read_only=True) +        name = CharField(label='name', max_length=100) +    many_to_many = NestedModelSerializer(many=True, read_only=True): +        id = IntegerField(label='ID', read_only=True) +        name = CharField(label='name', max_length=100) +""".strip() + + +class TestSerializerMappings(TestCase): +    def test_regular_fields(self): +        class TestSerializer(serializers.ModelSerializer): +            class Meta: +                model = RegularFieldsModel +        self.assertEqual(repr(TestSerializer()), REGULAR_FIELDS_REPR) + +    def test_flat_relational_fields(self): +        class TestSerializer(serializers.ModelSerializer): +            class Meta: +                model = RelationalModel +        self.assertEqual(repr(TestSerializer()), RELATIONAL_FLAT_REPR) + +    def test_nested_relational_fields(self): +        class TestSerializer(serializers.ModelSerializer): +            class Meta: +                model = RelationalModel +                depth = 1 +        self.assertEqual(repr(TestSerializer()), RELATIONAL_NESTED_REPR) + +    def test_flat_hyperlinked_fields(self): +        class TestSerializer(serializers.HyperlinkedModelSerializer): +            class Meta: +                model = RelationalModel +        self.assertEqual(repr(TestSerializer()), HYPERLINKED_FLAT_REPR) + +    def test_nested_hyperlinked_fields(self): +        class TestSerializer(serializers.HyperlinkedModelSerializer): +            class Meta: +                model = RelationalModel +                depth = 1 +        self.assertEqual(repr(TestSerializer()), HYPERLINKED_NESTED_REPR) diff --git a/tests/test_serializers.py b/tests/test_modelinfo.py index 31c41730..254a33c9 100644 --- a/tests/test_serializers.py +++ b/tests/test_modelinfo.py @@ -1,6 +1,6 @@  from django.test import TestCase  from django.utils import six -from rest_framework.serializers import _resolve_model +from rest_framework.utils.modelinfo import _resolve_model  from tests.models import BasicModel diff --git a/tests/test_relations.py b/tests/test_relations.py index a30b12e6..b1bc66b6 100644 --- a/tests/test_relations.py +++ b/tests/test_relations.py @@ -22,18 +22,18 @@  #         https://github.com/tomchristie/django-rest-framework/issues/446  #         """  #         field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all()) -#         self.assertRaises(serializers.ValidationError, field.from_native, '') -#         self.assertRaises(serializers.ValidationError, field.from_native, []) +#         self.assertRaises(serializers.ValidationError, field.to_primative, '') +#         self.assertRaises(serializers.ValidationError, field.to_primative, [])  #     def test_hyperlinked_related_field_with_empty_string(self):  #         field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='') -#         self.assertRaises(serializers.ValidationError, field.from_native, '') -#         self.assertRaises(serializers.ValidationError, field.from_native, []) +#         self.assertRaises(serializers.ValidationError, field.to_primative, '') +#         self.assertRaises(serializers.ValidationError, field.to_primative, [])  #     def test_slug_related_field_with_empty_string(self):  #         field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk') -#         self.assertRaises(serializers.ValidationError, field.from_native, '') -#         self.assertRaises(serializers.ValidationError, field.from_native, []) +#         self.assertRaises(serializers.ValidationError, field.to_primative, '') +#         self.assertRaises(serializers.ValidationError, field.to_primative, [])  # class TestManyRelatedMixin(TestCase): diff --git a/tests/test_serializer_empty.py b/tests/test_serializer_empty.py index d0006ad3..4e4a7b42 100644 --- a/tests/test_serializer_empty.py +++ b/tests/test_serializer_empty.py @@ -6,7 +6,7 @@  #     def test_empty_serializer(self):  #         class FooBarSerializer(serializers.Serializer):  #             foo = serializers.IntegerField() -#             bar = serializers.SerializerMethodField('get_bar') +#             bar = serializers.MethodField()  #             def get_bar(self, obj):  #                 return 'bar' | 
