aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTom Christie2014-09-09 17:46:28 +0100
committerTom Christie2014-09-09 17:46:28 +0100
commitb1c07670ca65084c5fef2bbb63d1f4163763014b (patch)
tree4f08654d698990d97fe275d8dbbbcc1164524086
parent21980b800d04a1d82a6003823abfdf4ab80ae979 (diff)
downloaddjango-rest-framework-b1c07670ca65084c5fef2bbb63d1f4163763014b.tar.bz2
Fleshing out serializer fields
-rw-r--r--rest_framework/fields.py591
-rw-r--r--rest_framework/serializers.py380
-rw-r--r--rest_framework/utils/humanize_datetime.py47
-rw-r--r--rest_framework/utils/modelinfo.py97
-rw-r--r--rest_framework/utils/representation.py72
-rw-r--r--tests/test_model_field_mappings.py160
-rw-r--r--tests/test_modelinfo.py (renamed from tests/test_serializers.py)2
-rw-r--r--tests/test_relations.py12
-rw-r--r--tests/test_serializer_empty.py2
9 files changed, 1040 insertions, 323 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 250c0579..043a44ed 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -1,8 +1,18 @@
+from django.conf import settings
from django.core import validators
from django.core.exceptions import ValidationError
+from django.utils import timezone
+from django.utils.dateparse import parse_date, parse_datetime, parse_time
from django.utils.encoding import is_protected_type
-from rest_framework.utils import html
+from django.utils.translation import ugettext_lazy as _
+from rest_framework import ISO_8601
+from rest_framework.compat import smart_text
+from rest_framework.settings import api_settings
+from rest_framework.utils import html, representation, humanize_datetime
+import datetime
+import decimal
import inspect
+import warnings
class empty:
@@ -71,22 +81,22 @@ class SkipField(Exception):
pass
+NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`'
+NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`'
+NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`'
+NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`'
+MISSING_ERROR_MESSAGE = (
+ 'ValidationError raised by `{class_name}`, but error key `{key}` does '
+ 'not exist in the `error_messages` dictionary.'
+)
+
+
class Field(object):
_creation_counter = 0
- MESSAGES = {
- 'required': 'This field is required.'
+ default_error_messages = {
+ 'required': _('This field is required.')
}
-
- _NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`'
- _NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`'
- _NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`'
- _NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`'
- _MISSING_ERROR_MESSAGE = (
- 'ValidationError raised by `{class_name}`, but error key `{key}` does '
- 'not exist in the `MESSAGES` dictionary.'
- )
-
default_validators = []
def __init__(self, read_only=False, write_only=False,
@@ -100,10 +110,10 @@ class Field(object):
required = default is empty and not read_only
# Some combinations of keyword arguments do not make sense.
- assert not (read_only and write_only), self._NOT_READ_ONLY_WRITE_ONLY
- assert not (read_only and required), self._NOT_READ_ONLY_REQUIRED
- assert not (read_only and default is not empty), self._NOT_READ_ONLY_DEFAULT
- assert not (required and default is not empty), self._NOT_REQUIRED_DEFAULT
+ assert not (read_only and write_only), NOT_READ_ONLY_WRITE_ONLY
+ assert not (read_only and required), NOT_READ_ONLY_REQUIRED
+ assert not (read_only and default is not empty), NOT_READ_ONLY_DEFAULT
+ assert not (required and default is not empty), NOT_REQUIRED_DEFAULT
self.read_only = read_only
self.write_only = write_only
@@ -113,7 +123,14 @@ class Field(object):
self.initial = initial
self.label = label
self.style = {} if style is None else style
- self.validators = self.default_validators + validators
+ self.validators = validators or self.default_validators[:]
+
+ # Collect default error message from self and parent classes
+ messages = {}
+ for cls in reversed(self.__class__.__mro__):
+ messages.update(getattr(cls, 'default_error_messages', {}))
+ messages.update(error_messages or {})
+ self.error_messages = messages
def bind(self, field_name, parent, root):
"""
@@ -186,12 +203,14 @@ class Field(object):
self.fail('required')
return self.get_default()
- self.run_validators(data)
- return self.to_native(data)
+ value = self.to_native(data)
+ self.run_validators(value)
+ return value
def run_validators(self, value):
if value in validators.EMPTY_VALUES:
return
+
errors = []
for validator in self.validators:
try:
@@ -218,33 +237,32 @@ class Field(object):
A helper method that simply raises a validation error.
"""
try:
- raise ValidationError(self.MESSAGES[key].format(**kwargs))
+ msg = self.error_messages[key]
except KeyError:
class_name = self.__class__.__name__
- msg = self._MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key)
+ msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key)
raise AssertionError(msg)
+ raise ValidationError(msg.format(**kwargs))
def __new__(cls, *args, **kwargs):
+ """
+ When a field is instantiated, we store the arguments that were used,
+ so that we can present a helpful representation of the object.
+ """
instance = super(Field, cls).__new__(cls)
instance._args = args
instance._kwargs = kwargs
return instance
def __repr__(self):
- arg_string = ', '.join([repr(val) for val in self._args])
- kwarg_string = ', '.join([
- '%s=%s' % (key, repr(val)) for key, val in self._kwargs.items()
- ])
- if arg_string and kwarg_string:
- arg_string += ', '
- class_name = self.__class__.__name__
- return "%s(%s%s)" % (class_name, arg_string, kwarg_string)
+ return representation.field_repr(self)
+# Boolean types...
+
class BooleanField(Field):
- MESSAGES = {
- 'required': 'This field is required.',
- 'invalid_value': '`{input}` is not a valid boolean.'
+ default_error_messages = {
+ 'invalid': _('`{input}` is not a valid boolean.')
}
TRUE_VALUES = {'t', 'T', 'true', 'True', 'TRUE', '1', 1, True}
FALSE_VALUES = {'f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False}
@@ -261,13 +279,23 @@ class BooleanField(Field):
return True
elif data in self.FALSE_VALUES:
return False
- self.fail('invalid_value', input=data)
+ self.fail('invalid', input=data)
+
+ def to_primative(self, value):
+ if value is None:
+ return None
+ if value in self.TRUE_VALUES:
+ return True
+ elif value in self.FALSE_VALUES:
+ return False
+ return bool(value)
+# String types...
+
class CharField(Field):
- MESSAGES = {
- 'required': 'This field is required.',
- 'blank': 'This field may not be blank.'
+ default_error_messages = {
+ 'blank': _('This field may not be blank.')
}
def __init__(self, **kwargs):
@@ -281,19 +309,364 @@ class CharField(Field):
self.fail('blank')
return str(data)
+ def to_primative(self, value):
+ if value is None:
+ return None
+ return str(value)
-class ChoiceField(Field):
- MESSAGES = {
- 'required': 'This field is required.',
- 'invalid_choice': '`{input}` is not a valid choice.'
+
+class EmailField(CharField):
+ default_error_messages = {
+ 'invalid': _('Enter a valid email address.')
+ }
+ default_validators = [validators.validate_email]
+
+ def to_native(self, data):
+ ret = super(EmailField, self).to_native(data)
+ if ret is None:
+ return None
+ return ret.strip()
+
+ def to_primative(self, value):
+ ret = super(EmailField, self).to_primative(value)
+ if ret is None:
+ return None
+ return ret.strip()
+
+
+class RegexField(CharField):
+ def __init__(self, regex, **kwargs):
+ kwargs['validators'] = (
+ [validators.RegexValidator(regex)] +
+ kwargs.get('validators', [])
+ )
+ super(RegexField, self).__init__(**kwargs)
+
+
+class SlugField(CharField):
+ default_error_messages = {
+ 'invalid': _("Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens.")
+ }
+ default_validators = [validators.validate_slug]
+
+
+class URLField(CharField):
+ default_error_messages = {
+ 'invalid': _("Enter a valid URL.")
+ }
+ default_validators = [validators.URLValidator()]
+
+
+# Number types...
+
+class IntegerField(Field):
+ default_error_messages = {
+ 'invalid': _('A valid integer is required.')
+ }
+
+ def __init__(self, **kwargs):
+ max_value = kwargs.pop('max_value', None)
+ min_value = kwargs.pop('min_value', None)
+ super(IntegerField, self).__init__(**kwargs)
+ if max_value is not None:
+ self.validators.append(validators.MaxValueValidator(max_value))
+ if min_value is not None:
+ self.validators.append(validators.MinValueValidator(min_value))
+ print self.__class__.__name__, self.validators
+
+ def to_native(self, data):
+ try:
+ data = int(str(data))
+ except (ValueError, TypeError):
+ self.fail('invalid')
+ return data
+
+ def to_primative(self, value):
+ if value is None:
+ return None
+ return int(value)
+
+
+class FloatField(Field):
+ default_error_messages = {
+ 'invalid': _("'%s' value must be a float."),
}
- coerce_to_type = str
def __init__(self, **kwargs):
- choices = kwargs.pop('choices')
+ max_value = kwargs.pop('max_value', None)
+ min_value = kwargs.pop('min_value', None)
+ super(FloatField, self).__init__(**kwargs)
+ if max_value is not None:
+ self.validators.append(validators.MaxValueValidator(max_value))
+ if min_value is not None:
+ self.validators.append(validators.MinValueValidator(min_value))
+
+ def to_primative(self, value):
+ if value is None:
+ return None
+ try:
+ return float(value)
+ except (TypeError, ValueError):
+ self.fail('invalid', value=value)
+
+ def to_native(self, value):
+ if value is None:
+ return None
+ return float(value)
+
+
+class DecimalField(Field):
+ default_error_messages = {
+ 'invalid': _('Enter a number.'),
+ 'max_value': _('Ensure this value is less than or equal to {max_value}.'),
+ 'min_value': _('Ensure this value is greater than or equal to {min_value}.'),
+ 'max_digits': _('Ensure that there are no more than {max_digits} digits in total.'),
+ 'max_decimal_places': _('Ensure that there are no more than {max_decimal_places} decimal places.'),
+ 'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.')
+ }
+
+ def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, **kwargs):
+ self.max_value, self.min_value = max_value, min_value
+ self.max_digits, self.max_decimal_places = max_digits, decimal_places
+ super(DecimalField, self).__init__(**kwargs)
+ if max_value is not None:
+ self.validators.append(validators.MaxValueValidator(max_value))
+ if min_value is not None:
+ self.validators.append(validators.MinValueValidator(min_value))
+
+ def from_native(self, value):
+ """
+ Validates that the input is a decimal number. Returns a Decimal
+ instance. Returns None for empty values. Ensures that there are no more
+ than max_digits in the number, and no more than decimal_places digits
+ after the decimal point.
+ """
+ if value in validators.EMPTY_VALUES:
+ return None
+
+ value = smart_text(value).strip()
+ try:
+ value = decimal.Decimal(value)
+ except decimal.DecimalException:
+ self.fail('invalid')
+
+ # Check for NaN. It is the only value that isn't equal to itself,
+ # so we can use this to identify NaN values.
+ if value != value:
+ self.fail('invalid')
+
+ # Check for infinity and negative infinity.
+ if value in (decimal.Decimal('Inf'), decimal.Decimal('-Inf')):
+ self.fail('invalid')
+
+ sign, digittuple, exponent = value.as_tuple()
+ decimals = abs(exponent)
+ # digittuple doesn't include any leading zeros.
+ digits = len(digittuple)
+ if decimals > digits:
+ # We have leading zeros up to or past the decimal point. Count
+ # everything past the decimal point as a digit. We do not count
+ # 0 before the decimal point as a digit since that would mean
+ # we would not allow max_digits = decimal_places.
+ digits = decimals
+ whole_digits = digits - decimals
+
+ if self.max_digits is not None and digits > self.max_digits:
+ self.fail('max_digits', max_digits=self.max_digits)
+ if self.decimal_places is not None and decimals > self.decimal_places:
+ self.fail('max_decimal_places', max_decimal_places=self.max_decimal_places)
+ if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places):
+ self.fail('max_whole_digits', max_while_digits=self.max_digits - self.decimal_places)
+
+ return value
+
+
+# Date & time fields...
+
+class DateField(Field):
+ default_error_messages = {
+ 'invalid': _("Date has wrong format. Use one of these formats instead: %s"),
+ }
+ input_formats = api_settings.DATE_INPUT_FORMATS
+ format = api_settings.DATE_FORMAT
+
+ def __init__(self, input_formats=None, format=None, *args, **kwargs):
+ self.input_formats = input_formats if input_formats is not None else self.input_formats
+ self.format = format if format is not None else self.format
+ super(DateField, self).__init__(*args, **kwargs)
+
+ def from_native(self, value):
+ if value in validators.EMPTY_VALUES:
+ return None
+
+ if isinstance(value, datetime.datetime):
+ if timezone and settings.USE_TZ and timezone.is_aware(value):
+ # Convert aware datetimes to the default time zone
+ # before casting them to dates (#17742).
+ default_timezone = timezone.get_default_timezone()
+ value = timezone.make_naive(value, default_timezone)
+ return value.date()
+ if isinstance(value, datetime.date):
+ return value
+
+ for format in self.input_formats:
+ if format.lower() == ISO_8601:
+ try:
+ parsed = parse_date(value)
+ except (ValueError, TypeError):
+ pass
+ else:
+ if parsed is not None:
+ return parsed
+ else:
+ try:
+ parsed = datetime.datetime.strptime(value, format)
+ except (ValueError, TypeError):
+ pass
+ else:
+ return parsed.date()
+
+ humanized_format = humanize_datetime.date_formats(self.input_formats)
+ msg = self.error_messages['invalid'] % humanized_format
+ raise ValidationError(msg)
+
+ def to_primative(self, value):
+ if value is None or self.format is None:
+ return value
+
+ if isinstance(value, datetime.datetime):
+ value = value.date()
+
+ if self.format.lower() == ISO_8601:
+ return value.isoformat()
+ return value.strftime(self.format)
+
+
+class DateTimeField(Field):
+ default_error_messages = {
+ 'invalid': _("Datetime has wrong format. Use one of these formats instead: %s"),
+ }
+ input_formats = api_settings.DATETIME_INPUT_FORMATS
+ format = api_settings.DATETIME_FORMAT
+
+ def __init__(self, input_formats=None, format=None, *args, **kwargs):
+ self.input_formats = input_formats if input_formats is not None else self.input_formats
+ self.format = format if format is not None else self.format
+ super(DateTimeField, self).__init__(*args, **kwargs)
+
+ def from_native(self, value):
+ if value in validators.EMPTY_VALUES:
+ return None
+
+ if isinstance(value, datetime.datetime):
+ return value
+ if isinstance(value, datetime.date):
+ value = datetime.datetime(value.year, value.month, value.day)
+ if settings.USE_TZ:
+ # For backwards compatibility, interpret naive datetimes in
+ # local time. This won't work during DST change, but we can't
+ # do much about it, so we let the exceptions percolate up the
+ # call stack.
+ warnings.warn("DateTimeField received a naive datetime (%s)"
+ " while time zone support is active." % value,
+ RuntimeWarning)
+ default_timezone = timezone.get_default_timezone()
+ value = timezone.make_aware(value, default_timezone)
+ return value
+
+ for format in self.input_formats:
+ if format.lower() == ISO_8601:
+ try:
+ parsed = parse_datetime(value)
+ except (ValueError, TypeError):
+ pass
+ else:
+ if parsed is not None:
+ return parsed
+ else:
+ try:
+ parsed = datetime.datetime.strptime(value, format)
+ except (ValueError, TypeError):
+ pass
+ else:
+ return parsed
+
+ humanized_format = humanize_datetime.datetime_formats(self.input_formats)
+ msg = self.error_messages['invalid'] % humanized_format
+ raise ValidationError(msg)
+
+ def to_primative(self, value):
+ if value is None or self.format is None:
+ return value
+
+ if self.format.lower() == ISO_8601:
+ ret = value.isoformat()
+ if ret.endswith('+00:00'):
+ ret = ret[:-6] + 'Z'
+ return ret
+ return value.strftime(self.format)
- assert choices, '`choices` argument is required and may not be empty'
+class TimeField(Field):
+ default_error_messages = {
+ 'invalid': _("Time has wrong format. Use one of these formats instead: %s"),
+ }
+ input_formats = api_settings.TIME_INPUT_FORMATS
+ format = api_settings.TIME_FORMAT
+
+ def __init__(self, input_formats=None, format=None, *args, **kwargs):
+ self.input_formats = input_formats if input_formats is not None else self.input_formats
+ self.format = format if format is not None else self.format
+ super(TimeField, self).__init__(*args, **kwargs)
+
+ def from_native(self, value):
+ if value in validators.EMPTY_VALUES:
+ return None
+
+ if isinstance(value, datetime.time):
+ return value
+
+ for format in self.input_formats:
+ if format.lower() == ISO_8601:
+ try:
+ parsed = parse_time(value)
+ except (ValueError, TypeError):
+ pass
+ else:
+ if parsed is not None:
+ return parsed
+ else:
+ try:
+ parsed = datetime.datetime.strptime(value, format)
+ except (ValueError, TypeError):
+ pass
+ else:
+ return parsed.time()
+
+ humanized_format = humanize_datetime.time_formats(self.input_formats)
+ msg = self.error_messages['invalid'] % humanized_format
+ raise ValidationError(msg)
+
+ def to_primative(self, value):
+ if value is None or self.format is None:
+ return value
+
+ if isinstance(value, datetime.datetime):
+ value = value.time()
+
+ if self.format.lower() == ISO_8601:
+ return value.isoformat()
+ return value.strftime(self.format)
+
+
+# Choice types...
+
+class ChoiceField(Field):
+ default_error_messages = {
+ 'invalid_choice': _('`{input}` is not a valid choice.')
+ }
+
+ def __init__(self, choices, **kwargs):
# Allow either single or paired choices style:
# choices = [1, 2, 3]
# choices = [(1, 'First'), (2, 'Second'), (3, 'Third')]
@@ -321,12 +694,14 @@ class ChoiceField(Field):
except KeyError:
self.fail('invalid_choice', input=data)
+ def to_primative(self, value):
+ return value
+
class MultipleChoiceField(ChoiceField):
- MESSAGES = {
- 'required': 'This field is required.',
- 'invalid_choice': '`{input}` is not a valid choice.',
- 'not_a_list': 'Expected a list of items but got type `{input_type}`'
+ default_error_messages = {
+ 'invalid_choice': _('`{input}` is not a valid choice.'),
+ 'not_a_list': _('Expected a list of items but got type `{input_type}`')
}
def to_native(self, data):
@@ -337,72 +712,42 @@ class MultipleChoiceField(ChoiceField):
for item in data
])
-
-class IntegerField(Field):
- MESSAGES = {
- 'required': 'This field is required.',
- 'invalid_integer': 'A valid integer is required.'
- }
-
- def __init__(self, **kwargs):
- max_value = kwargs.pop('max_value', None)
- min_value = kwargs.pop('min_value', None)
- super(IntegerField, self).__init__(**kwargs)
- if max_value is not None:
- self.validators.append(validators.MaxValueValidator(max_value))
- if min_value is not None:
- self.validators.append(validators.MinValueValidator(min_value))
-
- def to_native(self, data):
- try:
- data = int(str(data))
- except (ValueError, TypeError):
- self.fail('invalid_integer')
- return data
-
def to_primative(self, value):
- if value is None:
- return None
- return int(value)
+ return value
-class EmailField(CharField):
+# File types...
+
+class FileField(Field):
pass # TODO
-class URLField(CharField):
+class ImageField(Field):
pass # TODO
-class RegexField(CharField):
- def __init__(self, **kwargs):
- self.regex = kwargs.pop('regex')
- super(CharField, self).__init__(**kwargs)
-
+# Advanced field types...
-class DateField(CharField):
- def __init__(self, **kwargs):
- self.input_formats = kwargs.pop('input_formats', None)
- super(DateField, self).__init__(**kwargs)
+class ReadOnlyField(Field):
+ """
+ A read-only field that simply returns the field value.
+ If the field is a method with no parameters, the method will be called
+ and it's return value used as the representation.
-class TimeField(CharField):
- def __init__(self, **kwargs):
- self.input_formats = kwargs.pop('input_formats', None)
- super(TimeField, self).__init__(**kwargs)
+ For example, the following would call `get_expiry_date()` on the object:
+ class ExampleSerializer(self):
+ expiry_date = ReadOnlyField(source='get_expiry_date')
+ """
-class DateTimeField(CharField):
def __init__(self, **kwargs):
- self.input_formats = kwargs.pop('input_formats', None)
- super(DateTimeField, self).__init__(**kwargs)
-
-
-class FileField(Field):
- pass # TODO
+ kwargs['read_only'] = True
+ super(ReadOnlyField, self).__init__(**kwargs)
+ def to_native(self, data):
+ raise NotImplemented('.to_native() not supported.')
-class ReadOnlyField(Field):
def to_primative(self, value):
if is_simple_callable(value):
return value()
@@ -410,11 +755,28 @@ class ReadOnlyField(Field):
class MethodField(Field):
+ """
+ A read-only field that get its representation from calling a method on the
+ parent serializer class. The method called will be of the form
+ "get_{field_name}", and should take a single argument, which is the
+ object being serialized.
+
+ For example:
+
+ class ExampleSerializer(self):
+ extra_info = MethodField()
+
+ def get_extra_info(self, obj):
+ return ... # Calculate some data to return.
+ """
def __init__(self, **kwargs):
kwargs['source'] = '*'
kwargs['read_only'] = True
super(MethodField, self).__init__(**kwargs)
+ def to_native(self, data):
+ raise NotImplemented('.to_native() not supported.')
+
def to_primative(self, value):
attr = 'get_{field_name}'.format(field_name=self.field_name)
method = getattr(self.parent, attr)
@@ -424,35 +786,14 @@ class MethodField(Field):
class ModelField(Field):
"""
A generic field that can be used against an arbitrary model field.
- """
- def __init__(self, *args, **kwargs):
- try:
- self.model_field = kwargs.pop('model_field')
- except KeyError:
- raise ValueError("ModelField requires 'model_field' kwarg")
-
- self.min_length = kwargs.pop('min_length',
- getattr(self.model_field, 'min_length', None))
- self.max_length = kwargs.pop('max_length',
- getattr(self.model_field, 'max_length', None))
- self.min_value = kwargs.pop('min_value',
- getattr(self.model_field, 'min_value', None))
- self.max_value = kwargs.pop('max_value',
- getattr(self.model_field, 'max_value', None))
-
- super(ModelField, self).__init__(*args, **kwargs)
-
- if self.min_length is not None:
- self.validators.append(validators.MinLengthValidator(self.min_length))
- if self.max_length is not None:
- self.validators.append(validators.MaxLengthValidator(self.max_length))
- if self.min_value is not None:
- self.validators.append(validators.MinValueValidator(self.min_value))
- if self.max_value is not None:
- self.validators.append(validators.MaxValueValidator(self.max_value))
- def get_attribute(self, instance):
- return get_attribute(instance, self.source_attrs[:-1])
+ This is used by `ModelSerializer` when dealing with custom model fields,
+ that do not have a serializer field to be mapped to.
+ """
+ def __init__(self, model_field, **kwargs):
+ self.model_field = model_field
+ kwargs['source'] = '*'
+ super(ModelField, self).__init__(**kwargs)
def to_native(self, data):
rel = getattr(self.model_field, 'rel', None)
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 93226d32..8ca28387 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -10,15 +10,15 @@ python primitives.
2. The process of marshalling between python primitives and request and
response content is handled by parsers and renderers.
"""
+from django.core import validators
from django.core.exceptions import ValidationError
from django.db import models
from django.utils import six
from collections import namedtuple, OrderedDict
from rest_framework.fields import empty, set_value, Field, SkipField
from rest_framework.settings import api_settings
-from rest_framework.utils import html
+from rest_framework.utils import html, modelinfo, representation
import copy
-import inspect
# Note: We do the following so that users of the framework can use this style:
#
@@ -146,12 +146,10 @@ class SerializerMetaclass(type):
class Serializer(BaseSerializer):
def __new__(cls, *args, **kwargs):
- many = kwargs.pop('many', False)
- if many:
- class DynamicListSerializer(ListSerializer):
- child = cls()
- return DynamicListSerializer(*args, **kwargs)
- return super(Serializer, cls).__new__(cls)
+ if kwargs.pop('many', False):
+ kwargs['child'] = cls()
+ return ListSerializer(*args, **kwargs)
+ return super(Serializer, cls).__new__(cls, *args, **kwargs)
def __init__(self, *args, **kwargs):
self.context = kwargs.pop('context', {})
@@ -248,6 +246,9 @@ class Serializer(BaseSerializer):
error = errors.get(field.field_name)
yield FieldResult(field, value, error)
+ def __repr__(self):
+ return representation.serializer_repr(self, indent=1)
+
class ListSerializer(BaseSerializer):
child = None
@@ -299,26 +300,8 @@ class ListSerializer(BaseSerializer):
self.instance = self.create(self.validated_data)
return self.instance
-
-def _resolve_model(obj):
- """
- Resolve supplied `obj` to a Django model class.
-
- `obj` must be a Django model class itself, or a string
- representation of one. Useful in situtations like GH #1225 where
- Django may not have resolved a string-based reference to a model in
- another model's foreign key definition.
-
- String representations should have the format:
- 'appname.ModelName'
- """
- if isinstance(obj, six.string_types) and len(obj.split('.')) == 2:
- app_name, model_name = obj.split('.')
- return models.get_model(app_name, model_name)
- elif inspect.isclass(obj) and issubclass(obj, models.Model):
- return obj
- else:
- raise ValueError("{0} is not a Django model".format(obj))
+ def __repr__(self):
+ return representation.list_repr(self, indent=1)
class ModelSerializerOptions(object):
@@ -334,24 +317,25 @@ class ModelSerializerOptions(object):
class ModelSerializer(Serializer):
field_mapping = {
models.AutoField: IntegerField,
- # models.FloatField: FloatField,
+ models.BigIntegerField: IntegerField,
+ models.BooleanField: BooleanField,
+ models.CharField: CharField,
+ models.CommaSeparatedIntegerField: CharField,
+ models.DateField: DateField,
+ models.DateTimeField: DateTimeField,
+ models.DecimalField: DecimalField,
+ models.EmailField: EmailField,
+ models.FileField: FileField,
+ models.FloatField: FloatField,
models.IntegerField: IntegerField,
+ models.NullBooleanField: BooleanField,
models.PositiveIntegerField: IntegerField,
- models.SmallIntegerField: IntegerField,
models.PositiveSmallIntegerField: IntegerField,
- models.DateTimeField: DateTimeField,
- models.DateField: DateField,
+ models.SlugField: SlugField,
+ models.SmallIntegerField: IntegerField,
+ models.TextField: CharField,
models.TimeField: TimeField,
- # models.DecimalField: DecimalField,
- models.EmailField: EmailField,
- models.CharField: CharField,
models.URLField: URLField,
- # models.SlugField: SlugField,
- models.TextField: CharField,
- models.CommaSeparatedIntegerField: CharField,
- models.BooleanField: BooleanField,
- models.NullBooleanField: BooleanField,
- models.FileField: FileField,
# models.ImageField: ImageField,
}
@@ -392,85 +376,31 @@ class ModelSerializer(Serializer):
"""
Return all the fields that should be serialized for the model.
"""
- cls = self.opts.model
- opts = cls._meta.concrete_model._meta
+ info = modelinfo.get_field_info(self.opts.model)
ret = OrderedDict()
- nested = bool(self.opts.depth)
- # Deal with adding the primary key field
- pk_field = opts.pk
- while pk_field.rel and pk_field.rel.parent_link:
- # If model is a child via multitable inheritance, use parent's pk
- pk_field = pk_field.rel.to._meta.pk
-
- serializer_pk_field = self.get_pk_field(pk_field)
+ serializer_pk_field = self.get_pk_field(info.pk)
if serializer_pk_field:
- ret[pk_field.name] = serializer_pk_field
-
- # Deal with forward relationships
- forward_rels = [field for field in opts.fields if field.serialize]
- forward_rels += [field for field in opts.many_to_many if field.serialize]
-
- for model_field in forward_rels:
- has_through_model = False
-
- if model_field.rel:
- to_many = isinstance(model_field,
- models.fields.related.ManyToManyField)
- related_model = _resolve_model(model_field.rel.to)
-
- if to_many and not model_field.rel.through._meta.auto_created:
- has_through_model = True
+ ret[info.pk.name] = serializer_pk_field
- if model_field.rel and nested:
- field = self.get_nested_field(model_field, related_model, to_many)
- elif model_field.rel:
- field = self.get_related_field(model_field, related_model, to_many)
- else:
- field = self.get_field(model_field)
-
- if field:
- if has_through_model:
- field.read_only = True
+ # Regular fields
+ for field_name, field in info.fields.items():
+ ret[field_name] = self.get_field(field)
- ret[model_field.name] = field
-
- # Deal with reverse relationships
- if not self.opts.fields:
- reverse_rels = []
- else:
- # Reverse relationships are only included if they are explicitly
- # present in the `fields` option on the serializer
- reverse_rels = opts.get_all_related_objects()
- reverse_rels += opts.get_all_related_many_to_many_objects()
-
- for relation in reverse_rels:
- accessor_name = relation.get_accessor_name()
- if not self.opts.fields or accessor_name not in self.opts.fields:
- continue
- related_model = relation.model
- to_many = relation.field.rel.multiple
- has_through_model = False
- is_m2m = isinstance(relation.field,
- models.fields.related.ManyToManyField)
-
- if (
- is_m2m and
- hasattr(relation.field.rel, 'through') and
- not relation.field.rel.through._meta.auto_created
- ):
- has_through_model = True
-
- if nested:
- field = self.get_nested_field(None, related_model, to_many)
+ # Forward relations
+ for field_name, relation_info in info.forward_relations.items():
+ if self.opts.depth:
+ ret[field_name] = self.get_nested_field(*relation_info)
else:
- field = self.get_related_field(None, related_model, to_many)
+ ret[field_name] = self.get_related_field(*relation_info)
- if field:
- if has_through_model:
- field.read_only = True
-
- ret[accessor_name] = field
+ # Reverse relations
+ for accessor_name, relation_info in info.reverse_relations.items():
+ if accessor_name in self.opts.fields:
+ if self.opts.depth:
+ ret[field_name] = self.get_nested_field(*relation_info)
+ else:
+ ret[field_name] = self.get_related_field(*relation_info)
return ret
@@ -480,7 +410,7 @@ class ModelSerializer(Serializer):
"""
return self.get_field(model_field)
- def get_nested_field(self, model_field, related_model, to_many):
+ def get_nested_field(self, model_field, related_model, to_many, has_through_model):
"""
Creates a default instance of a nested relational field.
@@ -491,59 +421,148 @@ class ModelSerializer(Serializer):
model = related_model
depth = self.opts.depth - 1
- return NestedModelSerializer(many=to_many)
+ kwargs = {'read_only': True}
+ if to_many:
+ kwargs['many'] = True
+ return NestedModelSerializer(**kwargs)
- def get_related_field(self, model_field, related_model, to_many):
+ def get_related_field(self, model_field, related_model, to_many, has_through_model):
"""
Creates a default instance of a flat relational field.
Note that model_field will be `None` for reverse relationships.
"""
- # TODO: filter queryset using:
- # .using(db).complex_filter(self.rel.limit_choices_to)
+ kwargs = {
+ 'queryset': related_model._default_manager,
+ }
- kwargs = {}
- # 'queryset': related_model._default_manager,
- # 'many': to_many
- # }
+ if to_many:
+ kwargs['many'] = True
+
+ if has_through_model:
+ kwargs['read_only'] = True
+ kwargs.pop('queryset', None)
if model_field:
- kwargs['required'] = not(model_field.null or model_field.blank)
+ if model_field.null or model_field.blank:
+ kwargs['required'] = False
# if model_field.help_text is not None:
# kwargs['help_text'] = model_field.help_text
if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name
if not model_field.editable:
kwargs['read_only'] = True
- if model_field.verbose_name is not None:
- kwargs['label'] = model_field.verbose_name
+ kwargs.pop('queryset', None)
- return IntegerField(**kwargs)
- # TODO: return PrimaryKeyRelatedField(**kwargs)
+ return PrimaryKeyRelatedField(**kwargs)
def get_field(self, model_field):
"""
Creates a default instance of a basic non-relational field.
"""
kwargs = {}
+ validator_kwarg = model_field.validators
if model_field.null or model_field.blank:
kwargs['required'] = False
+ if model_field.verbose_name is not None:
+ kwargs['label'] = model_field.verbose_name
+
if isinstance(model_field, models.AutoField) or not model_field.editable:
kwargs['read_only'] = True
+ # Read only implies that the field is not required.
+ # We have a cleaner repr on the instance if we don't set it.
+ kwargs.pop('required', None)
if model_field.has_default():
kwargs['default'] = model_field.get_default()
-
- if issubclass(model_field.__class__, models.TextField):
- kwargs['widget'] = widgets.Textarea
-
- if model_field.verbose_name is not None:
- kwargs['label'] = model_field.verbose_name
-
- if model_field.validators is not None:
- kwargs['validators'] = model_field.validators
+ # Having a default implies that the field is not required.
+ # We have a cleaner repr on the instance if we don't set it.
+ kwargs.pop('required', None)
+
+ # Ensure that max_length is passed explicitly as a keyword arg,
+ # rather than as a validator.
+ max_length = getattr(model_field, 'max_length', None)
+ if max_length is not None:
+ kwargs['max_length'] = max_length
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if not isinstance(validator, validators.MaxLengthValidator)
+ ]
+
+ # Ensure that min_length is passed explicitly as a keyword arg,
+ # rather than as a validator.
+ min_length = getattr(model_field, 'min_length', None)
+ if min_length is not None:
+ kwargs['min_length'] = min_length
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if not isinstance(validator, validators.MinLengthValidator)
+ ]
+
+ # Ensure that max_value is passed explicitly as a keyword arg,
+ # rather than as a validator.
+ max_value = next((
+ validator.limit_value for validator in validator_kwarg
+ if isinstance(validator, validators.MaxValueValidator)
+ ), None)
+ if max_value is not None:
+ kwargs['max_value'] = max_value
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if not isinstance(validator, validators.MaxValueValidator)
+ ]
+
+ # Ensure that max_value is passed explicitly as a keyword arg,
+ # rather than as a validator.
+ min_value = next((
+ validator.limit_value for validator in validator_kwarg
+ if isinstance(validator, validators.MinValueValidator)
+ ), None)
+ if min_value is not None:
+ kwargs['min_value'] = min_value
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if not isinstance(validator, validators.MinValueValidator)
+ ]
+
+ # URLField does not need to include the URLValidator argument,
+ # as it is explicitly added in.
+ if isinstance(model_field, models.URLField):
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if not isinstance(validator, validators.URLValidator)
+ ]
+
+ # EmailField does not need to include the validate_email argument,
+ # as it is explicitly added in.
+ if isinstance(model_field, models.EmailField):
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if validator is not validators.validate_email
+ ]
+
+ # SlugField do not need to include the 'validate_slug' argument,
+ if isinstance(model_field, models.SlugField):
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if validator is not validators.validate_slug
+ ]
+
+ max_digits = getattr(model_field, 'max_digits', None)
+ if max_digits is not None:
+ kwargs['max_digits'] = max_digits
+
+ decimal_places = getattr(model_field, 'decimal_places', None)
+ if decimal_places is not None:
+ kwargs['decimal_places'] = decimal_places
+
+ if validator_kwarg:
+ kwargs['validators'] = validator_kwarg
+
+ # if issubclass(model_field.__class__, models.TextField):
+ # kwargs['widget'] = widgets.Textarea
# if model_field.help_text is not None:
# kwargs['help_text'] = model_field.help_text
@@ -555,31 +574,10 @@ class ModelSerializer(Serializer):
kwargs['empty'] = None
return ChoiceField(**kwargs)
- # put this below the ChoiceField because min_value isn't a valid initializer
- if issubclass(model_field.__class__, models.PositiveIntegerField) or \
- issubclass(model_field.__class__, models.PositiveSmallIntegerField):
- kwargs['min_value'] = 0
-
if model_field.null and \
issubclass(model_field.__class__, (models.CharField, models.TextField)):
kwargs['allow_none'] = True
- # attribute_dict = {
- # models.CharField: ['max_length'],
- # models.CommaSeparatedIntegerField: ['max_length'],
- # models.DecimalField: ['max_digits', 'decimal_places'],
- # models.EmailField: ['max_length'],
- # models.FileField: ['max_length'],
- # models.ImageField: ['max_length'],
- # models.SlugField: ['max_length'],
- # models.URLField: ['max_length'],
- # }
-
- # if model_field.__class__ in attribute_dict:
- # attributes = attribute_dict[model_field.__class__]
- # for attribute in attributes:
- # kwargs.update({attribute: getattr(model_field, attribute)})
-
try:
return self.field_mapping[model_field.__class__](**kwargs)
except KeyError:
@@ -594,28 +592,21 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
super(HyperlinkedModelSerializerOptions, self).__init__(meta)
self.view_name = getattr(meta, 'view_name', None)
self.lookup_field = getattr(meta, 'lookup_field', None)
- self.url_field_name = getattr(meta, 'url_field_name', api_settings.URL_FIELD_NAME)
class HyperlinkedModelSerializer(ModelSerializer):
_options_class = HyperlinkedModelSerializerOptions
- _default_view_name = '%(model_name)s-detail'
- _hyperlink_field_class = HyperlinkedRelatedField
- _hyperlink_identify_field_class = HyperlinkedIdentityField
def get_default_fields(self):
fields = super(HyperlinkedModelSerializer, self).get_default_fields()
if self.opts.view_name is None:
- self.opts.view_name = self._get_default_view_name(self.opts.model)
+ self.opts.view_name = self.get_default_view_name(self.opts.model)
- if self.opts.url_field_name not in fields:
- url_field = self._hyperlink_identify_field_class(
- view_name=self.opts.view_name,
- lookup_field=self.opts.lookup_field
- )
+ url_field_name = api_settings.URL_FIELD_NAME
+ if url_field_name not in fields:
ret = fields.__class__()
- ret[self.opts.url_field_name] = url_field
+ ret[url_field_name] = self.get_url_field()
ret.update(fields)
fields = ret
@@ -625,39 +616,48 @@ class HyperlinkedModelSerializer(ModelSerializer):
if self.opts.fields and model_field.name in self.opts.fields:
return self.get_field(model_field)
- def get_related_field(self, model_field, related_model, to_many):
+ def get_url_field(self):
+ kwargs = {
+ 'view_name': self.get_default_view_name(self.opts.model)
+ }
+ if self.opts.lookup_field:
+ kwargs['lookup_field'] = self.opts.lookup_field
+ return HyperlinkedIdentityField(**kwargs)
+
+ def get_related_field(self, model_field, related_model, to_many, has_through_model):
"""
Creates a default instance of a flat relational field.
"""
- # TODO: filter queryset using:
- # .using(db).complex_filter(self.rel.limit_choices_to)
- # kwargs = {
- # 'queryset': related_model._default_manager,
- # 'view_name': self._get_default_view_name(related_model),
- # 'many': to_many
- # }
- kwargs = {}
+ kwargs = {
+ 'queryset': related_model._default_manager,
+ 'view_name': self.get_default_view_name(related_model),
+ }
+
+ if to_many:
+ kwargs['many'] = True
+
+ if has_through_model:
+ kwargs['read_only'] = True
+ kwargs.pop('queryset', None)
if model_field:
- kwargs['required'] = not(model_field.null or model_field.blank)
+ if model_field.null or model_field.blank:
+ kwargs['required'] = False
# if model_field.help_text is not None:
# kwargs['help_text'] = model_field.help_text
if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name
+ if not model_field.editable:
+ kwargs['read_only'] = True
+ kwargs.pop('queryset', None)
- return IntegerField(**kwargs)
- # if self.opts.lookup_field:
- # kwargs['lookup_field'] = self.opts.lookup_field
-
- # return self._hyperlink_field_class(**kwargs)
+ return HyperlinkedRelatedField(**kwargs)
- def _get_default_view_name(self, model):
+ def get_default_view_name(self, model):
"""
- Return the view name to use if 'view_name' is not specified in 'Meta'
+ Return the view name to use for related models.
"""
- model_meta = model._meta
- format_kwargs = {
- 'app_label': model_meta.app_label,
- 'model_name': model_meta.object_name.lower()
+ return '%(model_name)s-detail' % {
+ 'app_label': model._meta.app_label,
+ 'model_name': model._meta.object_name.lower()
}
- return self._default_view_name % format_kwargs
diff --git a/rest_framework/utils/humanize_datetime.py b/rest_framework/utils/humanize_datetime.py
new file mode 100644
index 00000000..649f2abc
--- /dev/null
+++ b/rest_framework/utils/humanize_datetime.py
@@ -0,0 +1,47 @@
+"""
+Helper functions that convert strftime formats into more readable representations.
+"""
+from rest_framework import ISO_8601
+
+
+def datetime_formats(formats):
+ format = ', '.join(formats).replace(
+ ISO_8601,
+ 'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]'
+ )
+ return humanize_strptime(format)
+
+
+def date_formats(formats):
+ format = ', '.join(formats).replace(ISO_8601, 'YYYY[-MM[-DD]]')
+ return humanize_strptime(format)
+
+
+def time_formats(formats):
+ format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]')
+ return humanize_strptime(format)
+
+
+def humanize_strptime(format_string):
+ # Note that we're missing some of the locale specific mappings that
+ # don't really make sense.
+ mapping = {
+ "%Y": "YYYY",
+ "%y": "YY",
+ "%m": "MM",
+ "%b": "[Jan-Dec]",
+ "%B": "[January-December]",
+ "%d": "DD",
+ "%H": "hh",
+ "%I": "hh", # Requires '%p' to differentiate from '%H'.
+ "%M": "mm",
+ "%S": "ss",
+ "%f": "uuuuuu",
+ "%a": "[Mon-Sun]",
+ "%A": "[Monday-Sunday]",
+ "%p": "[AM|PM]",
+ "%z": "[+HHMM|-HHMM]"
+ }
+ for key, val in mapping.items():
+ format_string = format_string.replace(key, val)
+ return format_string
diff --git a/rest_framework/utils/modelinfo.py b/rest_framework/utils/modelinfo.py
new file mode 100644
index 00000000..c0513886
--- /dev/null
+++ b/rest_framework/utils/modelinfo.py
@@ -0,0 +1,97 @@
+"""
+Helper functions for returning the field information that is associated
+with a model class.
+"""
+from collections import namedtuple, OrderedDict
+from django.db import models
+from django.utils import six
+import inspect
+
+FieldInfo = namedtuple('FieldResult', ['pk', 'fields', 'forward_relations', 'reverse_relations'])
+RelationInfo = namedtuple('RelationInfo', ['field', 'related', 'to_many', 'has_through_model'])
+
+
+def _resolve_model(obj):
+ """
+ Resolve supplied `obj` to a Django model class.
+
+ `obj` must be a Django model class itself, or a string
+ representation of one. Useful in situtations like GH #1225 where
+ Django may not have resolved a string-based reference to a model in
+ another model's foreign key definition.
+
+ String representations should have the format:
+ 'appname.ModelName'
+ """
+ if isinstance(obj, six.string_types) and len(obj.split('.')) == 2:
+ app_name, model_name = obj.split('.')
+ return models.get_model(app_name, model_name)
+ elif inspect.isclass(obj) and issubclass(obj, models.Model):
+ return obj
+ raise ValueError("{0} is not a Django model".format(obj))
+
+
+def get_field_info(model):
+ """
+ Given a model class, returns a `FieldInfo` instance containing metadata
+ about the various field types on the model.
+ """
+ opts = model._meta.concrete_model._meta
+
+ # Deal with the primary key.
+ pk = opts.pk
+ while pk.rel and pk.rel.parent_link:
+ # If model is a child via multitable inheritance, use parent's pk.
+ pk = pk.rel.to._meta.pk
+
+ # Deal with regular fields.
+ fields = OrderedDict()
+ for field in [field for field in opts.fields if field.serialize and not field.rel]:
+ fields[field.name] = field
+
+ # Deal with forward relationships.
+ forward_relations = OrderedDict()
+ for field in [field for field in opts.fields if field.serialize and field.rel]:
+ forward_relations[field.name] = RelationInfo(
+ field=field,
+ related=_resolve_model(field.rel.to),
+ to_many=False,
+ has_through_model=False
+ )
+
+ # Deal with forward many-to-many relationships.
+ for field in [field for field in opts.many_to_many if field.serialize]:
+ forward_relations[field.name] = RelationInfo(
+ field=field,
+ related=_resolve_model(field.rel.to),
+ to_many=True,
+ has_through_model=(
+ not field.rel.through._meta.auto_created
+ )
+ )
+
+ # Deal with reverse relationships.
+ reverse_relations = OrderedDict()
+ for relation in opts.get_all_related_objects():
+ accessor_name = relation.get_accessor_name()
+ reverse_relations[accessor_name] = RelationInfo(
+ field=None,
+ related=relation.model,
+ to_many=relation.field.rel.multiple,
+ has_through_model=False
+ )
+
+ # Deal with reverse many-to-many relationships.
+ for relation in opts.get_all_related_many_to_many_objects():
+ accessor_name = relation.get_accessor_name()
+ reverse_relations[accessor_name] = RelationInfo(
+ field=None,
+ related=relation.model,
+ to_many=True,
+ has_through_model=(
+ hasattr(relation.field.rel, 'through') and
+ not relation.field.rel.through._meta.auto_created
+ )
+ )
+
+ return FieldInfo(pk, fields, forward_relations, reverse_relations)
diff --git a/rest_framework/utils/representation.py b/rest_framework/utils/representation.py
new file mode 100644
index 00000000..1de21597
--- /dev/null
+++ b/rest_framework/utils/representation.py
@@ -0,0 +1,72 @@
+"""
+Helper functions for creating user-friendly representations
+of serializer classes and serializer fields.
+"""
+import re
+
+
+def smart_repr(value):
+ value = repr(value)
+
+ # Representations like u'help text'
+ # should simply be presented as 'help text'
+ if value.startswith("u'") and value.endswith("'"):
+ return value[1:]
+
+ # Representations like
+ # <django.core.validators.RegexValidator object at 0x1047af050>
+ # Should be presented as
+ # <django.core.validators.RegexValidator object>
+ value = re.sub(' at 0x[0-9a-f]{8,10}>', '>', value)
+
+ return value
+
+
+def field_repr(field, force_many=False):
+ kwargs = field._kwargs
+ if force_many:
+ kwargs = kwargs.copy()
+ kwargs['many'] = True
+ kwargs.pop('child', None)
+
+ arg_string = ', '.join([smart_repr(val) for val in field._args])
+ kwarg_string = ', '.join([
+ '%s=%s' % (key, smart_repr(val))
+ for key, val in sorted(kwargs.items())
+ ])
+ if arg_string and kwarg_string:
+ arg_string += ', '
+
+ if force_many:
+ class_name = force_many.__class__.__name__
+ else:
+ class_name = field.__class__.__name__
+
+ return "%s(%s%s)" % (class_name, arg_string, kwarg_string)
+
+
+def serializer_repr(serializer, indent, force_many=None):
+ ret = field_repr(serializer, force_many) + ':'
+ indent_str = ' ' * indent
+
+ if force_many:
+ fields = force_many.fields
+ else:
+ fields = serializer.fields
+
+ for field_name, field in fields.items():
+ ret += '\n' + indent_str + field_name + ' = '
+ if hasattr(field, 'fields'):
+ ret += serializer_repr(field, indent + 1)
+ elif hasattr(field, 'child'):
+ ret += list_repr(field, indent + 1)
+ else:
+ ret += field_repr(field)
+ return ret
+
+
+def list_repr(serializer, indent):
+ child = serializer.child
+ if hasattr(child, 'fields'):
+ return serializer_repr(serializer, indent, force_many=child)
+ return field_repr(serializer)
diff --git a/tests/test_model_field_mappings.py b/tests/test_model_field_mappings.py
new file mode 100644
index 00000000..dc254da4
--- /dev/null
+++ b/tests/test_model_field_mappings.py
@@ -0,0 +1,160 @@
+"""
+The `ModelSerializer` and `HyperlinkedModelSerializer` classes are essentially
+shortcuts for automatically creating serializers based on a given model class.
+
+These tests deal with ensuring that we correctly map the model fields onto
+an appropriate set of serializer fields for each case.
+"""
+from django.db import models
+from django.test import TestCase
+from rest_framework import serializers
+
+
+# Models for testing regular field mapping
+
+class RegularFieldsModel(models.Model):
+ auto_field = models.AutoField(primary_key=True)
+ big_integer_field = models.BigIntegerField()
+ boolean_field = models.BooleanField()
+ char_field = models.CharField(max_length=100)
+ comma_seperated_integer_field = models.CommaSeparatedIntegerField(max_length=100)
+ date_field = models.DateField()
+ datetime_field = models.DateTimeField()
+ decimal_field = models.DecimalField(max_digits=3, decimal_places=1)
+ email_field = models.EmailField(max_length=100)
+ float_field = models.FloatField()
+ integer_field = models.IntegerField()
+ null_boolean_field = models.NullBooleanField()
+ positive_integer_field = models.PositiveIntegerField()
+ positive_small_integer_field = models.PositiveSmallIntegerField()
+ slug_field = models.SlugField(max_length=100)
+ small_integer_field = models.SmallIntegerField()
+ text_field = models.TextField()
+ time_field = models.TimeField()
+ url_field = models.URLField(max_length=100)
+
+
+REGULAR_FIELDS_REPR = """
+TestSerializer():
+ auto_field = IntegerField(label='auto field', read_only=True)
+ big_integer_field = IntegerField(label='big integer field')
+ boolean_field = BooleanField(default=False, label='boolean field')
+ char_field = CharField(label='char field', max_length=100)
+ comma_seperated_integer_field = CharField(label='comma seperated integer field', max_length=100, validators=[<django.core.validators.RegexValidator object>])
+ date_field = DateField(label='date field')
+ datetime_field = DateTimeField(label='datetime field')
+ decimal_field = DecimalField(decimal_places=1, label='decimal field', max_digits=3)
+ email_field = EmailField(label='email field', max_length=100)
+ float_field = FloatField(label='float field')
+ integer_field = IntegerField(label='integer field')
+ null_boolean_field = BooleanField(label='null boolean field', required=False)
+ positive_integer_field = IntegerField(label='positive integer field')
+ positive_small_integer_field = IntegerField(label='positive small integer field')
+ slug_field = SlugField(label='slug field', max_length=100)
+ small_integer_field = IntegerField(label='small integer field')
+ text_field = CharField(label='text field')
+ time_field = TimeField(label='time field')
+ url_field = URLField(label='url field', max_length=100)
+""".strip()
+
+
+# Model for testing relational field mapping
+
+class ForeignKeyTarget(models.Model):
+ char_field = models.CharField(max_length=100)
+
+
+class ManyToManyTarget(models.Model):
+ char_field = models.CharField(max_length=100)
+
+
+class OneToOneTarget(models.Model):
+ char_field = models.CharField(max_length=100)
+
+
+class RelationalModel(models.Model):
+ foreign_key = models.ForeignKey(ForeignKeyTarget)
+ many_to_many = models.ManyToManyField(ManyToManyTarget)
+ one_to_one = models.OneToOneField(OneToOneTarget)
+
+
+RELATIONAL_FLAT_REPR = """
+TestSerializer():
+ id = IntegerField(label='ID', read_only=True)
+ foreign_key = PrimaryKeyRelatedField(label='foreign key', queryset=<django.db.models.manager.Manager object>)
+ one_to_one = PrimaryKeyRelatedField(label='one to one', queryset=<django.db.models.manager.Manager object>)
+ many_to_many = PrimaryKeyRelatedField(label='many to many', many=True, queryset=<django.db.models.manager.Manager object>)
+""".strip()
+
+
+RELATIONAL_NESTED_REPR = """
+TestSerializer():
+ id = IntegerField(label='ID', read_only=True)
+ foreign_key = NestedModelSerializer(read_only=True):
+ id = IntegerField(label='ID', read_only=True)
+ name = CharField(label='name', max_length=100)
+ one_to_one = NestedModelSerializer(read_only=True):
+ id = IntegerField(label='ID', read_only=True)
+ name = CharField(label='name', max_length=100)
+ many_to_many = NestedModelSerializer(many=True, read_only=True):
+ id = IntegerField(label='ID', read_only=True)
+ name = CharField(label='name', max_length=100)
+""".strip()
+
+
+HYPERLINKED_FLAT_REPR = """
+TestSerializer():
+ url = HyperlinkedIdentityField(view_name='relationalmodel-detail')
+ foreign_key = HyperlinkedRelatedField(label='foreign key', queryset=<django.db.models.manager.Manager object>, view_name='foreignkeytarget-detail')
+ one_to_one = HyperlinkedRelatedField(label='one to one', queryset=<django.db.models.manager.Manager object>, view_name='onetoonetarget-detail')
+ many_to_many = HyperlinkedRelatedField(label='many to many', many=True, queryset=<django.db.models.manager.Manager object>, view_name='manytomanytarget-detail')
+""".strip()
+
+
+HYPERLINKED_NESTED_REPR = """
+TestSerializer():
+ url = HyperlinkedIdentityField(view_name='relationalmodel-detail')
+ foreign_key = NestedModelSerializer(read_only=True):
+ id = IntegerField(label='ID', read_only=True)
+ name = CharField(label='name', max_length=100)
+ one_to_one = NestedModelSerializer(read_only=True):
+ id = IntegerField(label='ID', read_only=True)
+ name = CharField(label='name', max_length=100)
+ many_to_many = NestedModelSerializer(many=True, read_only=True):
+ id = IntegerField(label='ID', read_only=True)
+ name = CharField(label='name', max_length=100)
+""".strip()
+
+
+class TestSerializerMappings(TestCase):
+ def test_regular_fields(self):
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = RegularFieldsModel
+ self.assertEqual(repr(TestSerializer()), REGULAR_FIELDS_REPR)
+
+ def test_flat_relational_fields(self):
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = RelationalModel
+ self.assertEqual(repr(TestSerializer()), RELATIONAL_FLAT_REPR)
+
+ def test_nested_relational_fields(self):
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = RelationalModel
+ depth = 1
+ self.assertEqual(repr(TestSerializer()), RELATIONAL_NESTED_REPR)
+
+ def test_flat_hyperlinked_fields(self):
+ class TestSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = RelationalModel
+ self.assertEqual(repr(TestSerializer()), HYPERLINKED_FLAT_REPR)
+
+ def test_nested_hyperlinked_fields(self):
+ class TestSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = RelationalModel
+ depth = 1
+ self.assertEqual(repr(TestSerializer()), HYPERLINKED_NESTED_REPR)
diff --git a/tests/test_serializers.py b/tests/test_modelinfo.py
index 31c41730..254a33c9 100644
--- a/tests/test_serializers.py
+++ b/tests/test_modelinfo.py
@@ -1,6 +1,6 @@
from django.test import TestCase
from django.utils import six
-from rest_framework.serializers import _resolve_model
+from rest_framework.utils.modelinfo import _resolve_model
from tests.models import BasicModel
diff --git a/tests/test_relations.py b/tests/test_relations.py
index a30b12e6..b1bc66b6 100644
--- a/tests/test_relations.py
+++ b/tests/test_relations.py
@@ -22,18 +22,18 @@
# https://github.com/tomchristie/django-rest-framework/issues/446
# """
# field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all())
-# self.assertRaises(serializers.ValidationError, field.from_native, '')
-# self.assertRaises(serializers.ValidationError, field.from_native, [])
+# self.assertRaises(serializers.ValidationError, field.to_primative, '')
+# self.assertRaises(serializers.ValidationError, field.to_primative, [])
# def test_hyperlinked_related_field_with_empty_string(self):
# field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='')
-# self.assertRaises(serializers.ValidationError, field.from_native, '')
-# self.assertRaises(serializers.ValidationError, field.from_native, [])
+# self.assertRaises(serializers.ValidationError, field.to_primative, '')
+# self.assertRaises(serializers.ValidationError, field.to_primative, [])
# def test_slug_related_field_with_empty_string(self):
# field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk')
-# self.assertRaises(serializers.ValidationError, field.from_native, '')
-# self.assertRaises(serializers.ValidationError, field.from_native, [])
+# self.assertRaises(serializers.ValidationError, field.to_primative, '')
+# self.assertRaises(serializers.ValidationError, field.to_primative, [])
# class TestManyRelatedMixin(TestCase):
diff --git a/tests/test_serializer_empty.py b/tests/test_serializer_empty.py
index d0006ad3..4e4a7b42 100644
--- a/tests/test_serializer_empty.py
+++ b/tests/test_serializer_empty.py
@@ -6,7 +6,7 @@
# def test_empty_serializer(self):
# class FooBarSerializer(serializers.Serializer):
# foo = serializers.IntegerField()
-# bar = serializers.SerializerMethodField('get_bar')
+# bar = serializers.MethodField()
# def get_bar(self, obj):
# return 'bar'