aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--rest_framework/fields.py6
-rw-r--r--rest_framework/relations.py9
-rw-r--r--rest_framework/serializers.py464
-rw-r--r--rest_framework/utils/field_mapping.py215
-rw-r--r--rest_framework/utils/model_meta.py46
-rw-r--r--tests/models.py7
-rw-r--r--tests/test_model_field_mappings.py16
-rw-r--r--tests/test_response.py9
-rw-r--r--tests/test_routers.py7
9 files changed, 384 insertions, 395 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 1818e705..0c78b3fb 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -80,10 +80,6 @@ def set_value(dictionary, keys, value):
dictionary[keys[-1]] = value
-def field_name_to_label(field_name):
- return field_name.replace('_', ' ').capitalize()
-
-
class SkipField(Exception):
pass
@@ -162,7 +158,7 @@ class Field(object):
# `self.label` should deafult to being based on the field name.
if self.label is None:
- self.label = field_name_to_label(self.field_name)
+ self.label = field_name.replace('_', ' ').capitalize()
# self.source should default to being the same as the field name.
if self.source is None:
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index 46fe55ef..9f44ab63 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -73,7 +73,8 @@ class HyperlinkedRelatedField(RelatedField):
'incorrect_type': 'Incorrect type. Expected URL string, received {data_type}.',
}
- def __init__(self, view_name, **kwargs):
+ def __init__(self, view_name=None, **kwargs):
+ assert view_name is not None, 'The `view_name` argument is required.'
self.view_name = view_name
self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field)
@@ -182,7 +183,8 @@ class HyperlinkedIdentityField(HyperlinkedRelatedField):
URL of relationships to other objects.
"""
- def __init__(self, view_name, **kwargs):
+ def __init__(self, view_name=None, **kwargs):
+ assert view_name is not None, 'The `view_name` argument is required.'
kwargs['read_only'] = True
kwargs['source'] = '*'
super(HyperlinkedIdentityField, self).__init__(view_name, **kwargs)
@@ -199,7 +201,8 @@ class SlugRelatedField(RelatedField):
'invalid': _('Invalid value.'),
}
- def __init__(self, slug_field, **kwargs):
+ def __init__(self, slug_field=None, **kwargs):
+ assert slug_field is not None, 'The `slug_field` argument is required.'
self.slug_field = slug_field
super(SlugRelatedField, self).__init__(**kwargs)
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 1fea1380..99dcc349 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -10,17 +10,19 @@ python primitives.
2. The process of marshalling between python primitives and request and
response content is handled by parsers and renderers.
"""
-from django.core import validators
from django.core.exceptions import ValidationError
from django.db import models
from django.utils import six
from django.utils.datastructures import SortedDict
-from django.utils.text import capfirst
from collections import namedtuple
-from rest_framework.compat import clean_manytomany_helptext
from rest_framework.fields import empty, set_value, Field, SkipField
from rest_framework.settings import api_settings
from rest_framework.utils import html, model_meta, representation
+from rest_framework.utils.field_mapping import (
+ get_url_kwargs, get_field_kwargs,
+ get_relation_kwargs, get_nested_relation_kwargs,
+ lookup_class
+)
import copy
# Note: We do the following so that users of the framework can use this style:
@@ -126,7 +128,7 @@ class SerializerMetaclass(type):
"""
@classmethod
- def _get_fields(cls, bases, attrs):
+ def _get_declared_fields(cls, bases, attrs):
fields = [(field_name, attrs.pop(field_name))
for field_name, obj in list(attrs.items())
if isinstance(obj, Field)]
@@ -136,25 +138,18 @@ class SerializerMetaclass(type):
# fields. Note that we loop over the bases in *reverse*. This is necessary
# in order to maintain the correct order of fields.
for base in bases[::-1]:
- if hasattr(base, 'base_fields'):
- fields = list(base.base_fields.items()) + fields
+ if hasattr(base, '_declared_fields'):
+ fields = list(base._declared_fields.items()) + fields
return SortedDict(fields)
def __new__(cls, name, bases, attrs):
- attrs['base_fields'] = cls._get_fields(bases, attrs)
+ attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs)
return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs)
@six.add_metaclass(SerializerMetaclass)
class Serializer(BaseSerializer):
-
- def __new__(cls, *args, **kwargs):
- if kwargs.pop('many', False):
- kwargs['child'] = cls()
- return ListSerializer(*args, **kwargs)
- return super(Serializer, cls).__new__(cls, *args, **kwargs)
-
def __init__(self, *args, **kwargs):
self.context = kwargs.pop('context', {})
kwargs.pop('partial', None)
@@ -165,14 +160,22 @@ class Serializer(BaseSerializer):
# Every new serializer is created with a clone of the field instances.
# This allows users to dynamically modify the fields on a serializer
# instance without affecting every other serializer class.
- self.fields = self.get_fields()
+ self.fields = self._get_base_fields()
# Setup all the child fields, to provide them with the current context.
for field_name, field in self.fields.items():
field.bind(field_name, self, self)
- def get_fields(self):
- return copy.deepcopy(self.base_fields)
+ def __new__(cls, *args, **kwargs):
+ # We override this method in order to automagically create
+ # `ListSerializer` classes instead when `many=True` is set.
+ if kwargs.pop('many', False):
+ kwargs['child'] = cls()
+ return ListSerializer(*args, **kwargs)
+ return super(Serializer, cls).__new__(cls, *args, **kwargs)
+
+ def _get_base_fields(self):
+ return copy.deepcopy(self._declared_fields)
def bind(self, field_name, parent, root):
# If the serializer is used as a field then when it becomes bound
@@ -312,39 +315,8 @@ class ListSerializer(BaseSerializer):
return representation.list_repr(self, indent=1)
-class ModelSerializerOptions(object):
- """
- Meta class options for ModelSerializer
- """
- def __init__(self, meta):
- self.model = getattr(meta, 'model')
- self.fields = getattr(meta, 'fields', ())
- self.depth = getattr(meta, 'depth', 0)
-
-
-def lookup_class(mapping, instance):
- """
- Takes a dictionary with classes as keys, and an object.
- Traverses the object's inheritance hierarchy in method
- resolution order, and returns the first matching value
- from the dictionary or raises a KeyError if nothing matches.
- """
- for cls in inspect.getmro(instance.__class__):
- if cls in mapping:
- return mapping[cls]
- raise KeyError('Class %s not found in lookup.', cls.__name__)
-
-
-def needs_label(model_field, field_name):
- """
- Returns `True` if the label based on the model's verbose name
- is not equal to the default label it would have based on it's field name.
- """
- return capfirst(model_field.verbose_name) != field_name_to_label(field_name)
-
-
class ModelSerializer(Serializer):
- field_mapping = {
+ _field_mapping = {
models.AutoField: IntegerField,
models.BigIntegerField: IntegerField,
models.BooleanField: BooleanField,
@@ -368,16 +340,10 @@ class ModelSerializer(Serializer):
models.TimeField: TimeField,
models.URLField: URLField,
}
- nested_class = None # We fill this in at the end of this module.
-
- _options_class = ModelSerializerOptions
-
- def __init__(self, *args, **kwargs):
- self.opts = self._options_class(self.Meta)
- super(ModelSerializer, self).__init__(*args, **kwargs)
+ _related_class = PrimaryKeyRelatedField
def create(self, attrs):
- ModelClass = self.opts.model
+ ModelClass = self.Meta.model
return ModelClass.objects.create(**attrs)
def update(self, obj, attrs):
@@ -385,319 +351,97 @@ class ModelSerializer(Serializer):
setattr(obj, attr, value)
obj.save()
- def get_fields(self):
- # Get the explicitly declared fields.
- fields = copy.deepcopy(self.base_fields)
+ def _get_base_fields(self):
+ declared_fields = copy.deepcopy(self._declared_fields)
- # Add in the default fields.
- for key, val in self.get_default_fields().items():
- if key not in fields:
- fields[key] = val
-
- # If `fields` is set on the `Meta` class,
- # then use only those fields, and in that order.
- if self.opts.fields:
- fields = SortedDict([
- (key, fields[key]) for key in self.opts.fields
- ])
-
- return fields
-
- def get_default_fields(self):
- """
- Return all the fields that should be serialized for the model.
- """
- info = model_meta.get_field_info(self.opts.model)
ret = SortedDict()
+ model = getattr(self.Meta, 'model')
+ fields = getattr(self.Meta, 'fields', None)
+ depth = getattr(self.Meta, 'depth', 0)
+
+ # Retrieve metadata about fields & relationships on the model class.
+ info = model_meta.get_field_info(model)
+
+ # Use the default set of fields if none is supplied explicitly.
+ if fields is None:
+ fields = self._get_default_field_names(declared_fields, info)
+
+ for field_name in fields:
+ if field_name in declared_fields:
+ # Field is explicitly declared on the class, use that.
+ ret[field_name] = declared_fields[field_name]
+ continue
+
+ elif field_name == api_settings.URL_FIELD_NAME:
+ # Create the URL field.
+ field_cls = HyperlinkedIdentityField
+ kwargs = get_url_kwargs(model)
+
+ elif field_name in info.fields_and_pk:
+ # Create regular model fields.
+ model_field = info.fields_and_pk[field_name]
+ field_cls = lookup_class(self._field_mapping, model_field)
+ kwargs = get_field_kwargs(field_name, model_field)
+ if 'choices' in kwargs:
+ # Fields with choices get coerced into `ChoiceField`
+ # instead of using their regular typed field.
+ field_cls = ChoiceField
+ if not issubclass(field_cls, ModelField):
+ # `model_field` is only valid for the fallback case of
+ # `ModelField`, which is used when no other typed field
+ # matched to the model field.
+ kwargs.pop('model_field', None)
+
+ elif field_name in info.relations:
+ # Create forward and reverse relationships.
+ relation_info = info.relations[field_name]
+ if depth:
+ field_cls = self._get_nested_class(depth, relation_info)
+ kwargs = get_nested_relation_kwargs(relation_info)
+ else:
+ field_cls = self._related_class
+ kwargs = get_relation_kwargs(field_name, relation_info)
+ # `view_name` is only valid for hyperlinked relationships.
+ if not issubclass(field_cls, HyperlinkedRelatedField):
+ kwargs.pop('view_name', None)
- # URL field
- serializer_url_field = self.get_url_field()
- if serializer_url_field:
- ret[api_settings.URL_FIELD_NAME] = serializer_url_field
-
- # Primary key field
- field_name = info.pk.name
- serializer_pk_field = self.get_pk_field(field_name, info.pk)
- if serializer_pk_field:
- ret[field_name] = serializer_pk_field
-
- # Regular fields
- for field_name, field in info.fields.items():
- ret[field_name] = self.get_field(field_name, field)
-
- # Forward relations
- for field_name, relation_info in info.forward_relations.items():
- if self.opts.depth:
- ret[field_name] = self.get_nested_field(field_name, *relation_info)
else:
- ret[field_name] = self.get_related_field(field_name, *relation_info)
+ assert False, 'Field name `%s` is not valid.' % field_name
- # Reverse relations
- for accessor_name, relation_info in info.reverse_relations.items():
- if accessor_name in self.opts.fields:
- if self.opts.depth:
- ret[accessor_name] = self.get_nested_field(accessor_name, *relation_info)
- else:
- ret[accessor_name] = self.get_related_field(accessor_name, *relation_info)
+ ret[field_name] = field_cls(**kwargs)
return ret
- def get_url_field(self):
- return None
-
- def get_pk_field(self, field_name, model_field):
- """
- Returns a default instance of the pk field.
- """
- return self.get_field(field_name, model_field)
-
- def get_nested_field(self, field_name, model_field, related_model, to_many, has_through_model):
- """
- Creates a default instance of a nested relational field.
+ def _get_default_field_names(self, declared_fields, model_info):
+ return (
+ [model_info.pk.name] +
+ list(declared_fields.keys()) +
+ list(model_info.fields.keys()) +
+ list(model_info.forward_relations.keys())
+ )
- Note that model_field will be `None` for reverse relationships.
- """
- class NestedModelSerializer(self.nested_class):
+ def _get_nested_class(self, nested_depth, relation_info):
+ class NestedSerializer(ModelSerializer):
class Meta:
- model = related_model
- depth = self.opts.depth - 1
-
- kwargs = {'read_only': True}
- if to_many:
- kwargs['many'] = True
- return NestedModelSerializer(**kwargs)
-
- def get_related_field(self, field_name, model_field, related_model, to_many, has_through_model):
- """
- Creates a default instance of a flat relational field.
-
- Note that model_field will be `None` for reverse relationships.
- """
- kwargs = {
- 'queryset': related_model._default_manager,
- }
-
- if to_many:
- kwargs['many'] = True
-
- if has_through_model:
- kwargs['read_only'] = True
- kwargs.pop('queryset', None)
-
- if model_field:
- if model_field.null or model_field.blank:
- kwargs['required'] = False
- if model_field.verbose_name and needs_label(model_field, field_name):
- kwargs['label'] = capfirst(model_field.verbose_name)
- if not model_field.editable:
- kwargs['read_only'] = True
- kwargs.pop('queryset', None)
- help_text = clean_manytomany_helptext(model_field.help_text)
- if help_text:
- kwargs['help_text'] = help_text
-
- return PrimaryKeyRelatedField(**kwargs)
-
- def get_field(self, field_name, model_field):
- """
- Creates a default instance of a basic non-relational field.
- """
- serializer_cls = lookup_class(self.field_mapping, model_field)
- kwargs = {}
- validator_kwarg = model_field.validators
-
- if model_field.null or model_field.blank:
- kwargs['required'] = False
-
- if model_field.verbose_name and needs_label(model_field, field_name):
- kwargs['label'] = capfirst(model_field.verbose_name)
-
- if model_field.help_text:
- kwargs['help_text'] = model_field.help_text
-
- if isinstance(model_field, models.AutoField) or not model_field.editable:
- kwargs['read_only'] = True
- # Read only implies that the field is not required.
- # We have a cleaner repr on the instance if we don't set it.
- kwargs.pop('required', None)
-
- if model_field.has_default():
- kwargs['default'] = model_field.get_default()
- # Having a default implies that the field is not required.
- # We have a cleaner repr on the instance if we don't set it.
- kwargs.pop('required', None)
-
- if model_field.flatchoices:
- # If this model field contains choices, then use a ChoiceField,
- # rather than the standard serializer field for this type.
- # Note that we return this prior to setting any validation type
- # keyword arguments, as those are not valid initializers.
- kwargs['choices'] = model_field.flatchoices
- return ChoiceField(**kwargs)
-
- # Ensure that max_length is passed explicitly as a keyword arg,
- # rather than as a validator.
- max_length = getattr(model_field, 'max_length', None)
- if max_length is not None:
- kwargs['max_length'] = max_length
- validator_kwarg = [
- validator for validator in validator_kwarg
- if not isinstance(validator, validators.MaxLengthValidator)
- ]
-
- # Ensure that min_length is passed explicitly as a keyword arg,
- # rather than as a validator.
- min_length = getattr(model_field, 'min_length', None)
- if min_length is not None:
- kwargs['min_length'] = min_length
- validator_kwarg = [
- validator for validator in validator_kwarg
- if not isinstance(validator, validators.MinLengthValidator)
- ]
-
- # Ensure that max_value is passed explicitly as a keyword arg,
- # rather than as a validator.
- max_value = next((
- validator.limit_value for validator in validator_kwarg
- if isinstance(validator, validators.MaxValueValidator)
- ), None)
- if max_value is not None:
- kwargs['max_value'] = max_value
- validator_kwarg = [
- validator for validator in validator_kwarg
- if not isinstance(validator, validators.MaxValueValidator)
- ]
-
- # Ensure that max_value is passed explicitly as a keyword arg,
- # rather than as a validator.
- min_value = next((
- validator.limit_value for validator in validator_kwarg
- if isinstance(validator, validators.MinValueValidator)
- ), None)
- if min_value is not None:
- kwargs['min_value'] = min_value
- validator_kwarg = [
- validator for validator in validator_kwarg
- if not isinstance(validator, validators.MinValueValidator)
- ]
-
- # URLField does not need to include the URLValidator argument,
- # as it is explicitly added in.
- if isinstance(model_field, models.URLField):
- validator_kwarg = [
- validator for validator in validator_kwarg
- if not isinstance(validator, validators.URLValidator)
- ]
-
- # EmailField does not need to include the validate_email argument,
- # as it is explicitly added in.
- if isinstance(model_field, models.EmailField):
- validator_kwarg = [
- validator for validator in validator_kwarg
- if validator is not validators.validate_email
- ]
-
- # SlugField do not need to include the 'validate_slug' argument,
- if isinstance(model_field, models.SlugField):
- validator_kwarg = [
- validator for validator in validator_kwarg
- if validator is not validators.validate_slug
- ]
-
- max_digits = getattr(model_field, 'max_digits', None)
- if max_digits is not None:
- kwargs['max_digits'] = max_digits
-
- decimal_places = getattr(model_field, 'decimal_places', None)
- if decimal_places is not None:
- kwargs['decimal_places'] = decimal_places
-
- if isinstance(model_field, models.BooleanField):
- # models.BooleanField has `blank=True`, but *is* actually
- # required *unless* a default is provided.
- # Also note that Django<1.6 uses `default=False` for
- # models.BooleanField, but Django>=1.6 uses `default=None`.
- kwargs.pop('required', None)
-
- if validator_kwarg:
- kwargs['validators'] = validator_kwarg
-
- if issubclass(serializer_cls, ModelField):
- kwargs['model_field'] = model_field
-
- return serializer_cls(**kwargs)
-
-
-class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
- """
- Options for HyperlinkedModelSerializer
- """
- def __init__(self, meta):
- super(HyperlinkedModelSerializerOptions, self).__init__(meta)
- self.view_name = getattr(meta, 'view_name', None)
- self.lookup_field = getattr(meta, 'lookup_field', None)
+ model = relation_info.related
+ depth = nested_depth
+ return NestedSerializer
class HyperlinkedModelSerializer(ModelSerializer):
- _options_class = HyperlinkedModelSerializerOptions
-
- def get_url_field(self):
- if self.opts.view_name is not None:
- view_name = self.opts.view_name
- else:
- view_name = self.get_default_view_name(self.opts.model)
-
- kwargs = {
- 'view_name': view_name
- }
- if self.opts.lookup_field:
- kwargs['lookup_field'] = self.opts.lookup_field
-
- return HyperlinkedIdentityField(**kwargs)
-
- def get_pk_field(self, field_name, model_field):
- if self.opts.fields and model_field.name in self.opts.fields:
- return self.get_field(model_field)
-
- def get_related_field(self, field_name, model_field, related_model, to_many, has_through_model):
- """
- Creates a default instance of a flat relational field.
- """
- kwargs = {
- 'queryset': related_model._default_manager,
- 'view_name': self.get_default_view_name(related_model),
- }
-
- if to_many:
- kwargs['many'] = True
-
- if has_through_model:
- kwargs['read_only'] = True
- kwargs.pop('queryset', None)
-
- if model_field:
- if model_field.null or model_field.blank:
- kwargs['required'] = False
- if model_field.verbose_name and needs_label(model_field, field_name):
- kwargs['label'] = capfirst(model_field.verbose_name)
- if not model_field.editable:
- kwargs['read_only'] = True
- kwargs.pop('queryset', None)
- help_text = clean_manytomany_helptext(model_field.help_text)
- if help_text:
- kwargs['help_text'] = help_text
-
- return HyperlinkedRelatedField(**kwargs)
-
- def get_default_view_name(self, model):
- """
- Return the view name to use for related models.
- """
- return '%(model_name)s-detail' % {
- 'app_label': model._meta.app_label,
- 'model_name': model._meta.object_name.lower()
- }
-
-
-ModelSerializer.nested_class = ModelSerializer
-HyperlinkedModelSerializer.nested_class = HyperlinkedModelSerializer
+ _related_class = HyperlinkedRelatedField
+
+ def _get_default_field_names(self, declared_fields, model_info):
+ return (
+ [api_settings.URL_FIELD_NAME] +
+ list(declared_fields.keys()) +
+ list(model_info.fields.keys()) +
+ list(model_info.forward_relations.keys())
+ )
+
+ def _get_nested_class(self, nested_depth, relation_info):
+ class NestedSerializer(HyperlinkedModelSerializer):
+ class Meta:
+ model = relation_info.related
+ depth = nested_depth
+ return NestedSerializer
diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py
new file mode 100644
index 00000000..be72e444
--- /dev/null
+++ b/rest_framework/utils/field_mapping.py
@@ -0,0 +1,215 @@
+"""
+Helper functions for mapping model fields to a dictionary of default
+keyword arguments that should be used for their equivelent serializer fields.
+"""
+from django.core import validators
+from django.db import models
+from django.utils.text import capfirst
+from rest_framework.compat import clean_manytomany_helptext
+import inspect
+
+
+def lookup_class(mapping, instance):
+ """
+ Takes a dictionary with classes as keys, and an object.
+ Traverses the object's inheritance hierarchy in method
+ resolution order, and returns the first matching value
+ from the dictionary or raises a KeyError if nothing matches.
+ """
+ for cls in inspect.getmro(instance.__class__):
+ if cls in mapping:
+ return mapping[cls]
+ raise KeyError('Class %s not found in lookup.', cls.__name__)
+
+
+def needs_label(model_field, field_name):
+ """
+ Returns `True` if the label based on the model's verbose name
+ is not equal to the default label it would have based on it's field name.
+ """
+ default_label = field_name.replace('_', ' ').capitalize()
+ return capfirst(model_field.verbose_name) != default_label
+
+
+def get_detail_view_name(model):
+ """
+ Given a model class, return the view name to use for URL relationships
+ that refer to instances of the model.
+ """
+ return '%(model_name)s-detail' % {
+ 'app_label': model._meta.app_label,
+ 'model_name': model._meta.object_name.lower()
+ }
+
+
+def get_field_kwargs(field_name, model_field):
+ """
+ Creates a default instance of a basic non-relational field.
+ """
+ kwargs = {}
+ validator_kwarg = model_field.validators
+
+ if model_field.null or model_field.blank:
+ kwargs['required'] = False
+
+ if model_field.verbose_name and needs_label(model_field, field_name):
+ kwargs['label'] = capfirst(model_field.verbose_name)
+
+ if model_field.help_text:
+ kwargs['help_text'] = model_field.help_text
+
+ if isinstance(model_field, models.AutoField) or not model_field.editable:
+ kwargs['read_only'] = True
+ # Read only implies that the field is not required.
+ # We have a cleaner repr on the instance if we don't set it.
+ kwargs.pop('required', None)
+
+ if model_field.has_default():
+ kwargs['default'] = model_field.get_default()
+ # Having a default implies that the field is not required.
+ # We have a cleaner repr on the instance if we don't set it.
+ kwargs.pop('required', None)
+
+ if model_field.flatchoices:
+ # If this model field contains choices, then return now,
+ # any further keyword arguments are not valid.
+ kwargs['choices'] = model_field.flatchoices
+ return kwargs
+
+ # Ensure that max_length is passed explicitly as a keyword arg,
+ # rather than as a validator.
+ max_length = getattr(model_field, 'max_length', None)
+ if max_length is not None:
+ kwargs['max_length'] = max_length
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if not isinstance(validator, validators.MaxLengthValidator)
+ ]
+
+ # Ensure that min_length is passed explicitly as a keyword arg,
+ # rather than as a validator.
+ min_length = getattr(model_field, 'min_length', None)
+ if min_length is not None:
+ kwargs['min_length'] = min_length
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if not isinstance(validator, validators.MinLengthValidator)
+ ]
+
+ # Ensure that max_value is passed explicitly as a keyword arg,
+ # rather than as a validator.
+ max_value = next((
+ validator.limit_value for validator in validator_kwarg
+ if isinstance(validator, validators.MaxValueValidator)
+ ), None)
+ if max_value is not None:
+ kwargs['max_value'] = max_value
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if not isinstance(validator, validators.MaxValueValidator)
+ ]
+
+ # Ensure that max_value is passed explicitly as a keyword arg,
+ # rather than as a validator.
+ min_value = next((
+ validator.limit_value for validator in validator_kwarg
+ if isinstance(validator, validators.MinValueValidator)
+ ), None)
+ if min_value is not None:
+ kwargs['min_value'] = min_value
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if not isinstance(validator, validators.MinValueValidator)
+ ]
+
+ # URLField does not need to include the URLValidator argument,
+ # as it is explicitly added in.
+ if isinstance(model_field, models.URLField):
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if not isinstance(validator, validators.URLValidator)
+ ]
+
+ # EmailField does not need to include the validate_email argument,
+ # as it is explicitly added in.
+ if isinstance(model_field, models.EmailField):
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if validator is not validators.validate_email
+ ]
+
+ # SlugField do not need to include the 'validate_slug' argument,
+ if isinstance(model_field, models.SlugField):
+ validator_kwarg = [
+ validator for validator in validator_kwarg
+ if validator is not validators.validate_slug
+ ]
+
+ max_digits = getattr(model_field, 'max_digits', None)
+ if max_digits is not None:
+ kwargs['max_digits'] = max_digits
+
+ decimal_places = getattr(model_field, 'decimal_places', None)
+ if decimal_places is not None:
+ kwargs['decimal_places'] = decimal_places
+
+ if isinstance(model_field, models.BooleanField):
+ # models.BooleanField has `blank=True`, but *is* actually
+ # required *unless* a default is provided.
+ # Also note that Django<1.6 uses `default=False` for
+ # models.BooleanField, but Django>=1.6 uses `default=None`.
+ kwargs.pop('required', None)
+
+ if validator_kwarg:
+ kwargs['validators'] = validator_kwarg
+
+ # The following will only be used by ModelField classes.
+ # Gets removed for everything else.
+ kwargs['model_field'] = model_field
+
+ return kwargs
+
+
+def get_relation_kwargs(field_name, relation_info):
+ """
+ Creates a default instance of a flat relational field.
+ """
+ model_field, related_model, to_many, has_through_model = relation_info
+ kwargs = {
+ 'queryset': related_model._default_manager,
+ 'view_name': get_detail_view_name(related_model)
+ }
+
+ if to_many:
+ kwargs['many'] = True
+
+ if has_through_model:
+ kwargs['read_only'] = True
+ kwargs.pop('queryset', None)
+
+ if model_field:
+ if model_field.null or model_field.blank:
+ kwargs['required'] = False
+ if model_field.verbose_name and needs_label(model_field, field_name):
+ kwargs['label'] = capfirst(model_field.verbose_name)
+ if not model_field.editable:
+ kwargs['read_only'] = True
+ kwargs.pop('queryset', None)
+ help_text = clean_manytomany_helptext(model_field.help_text)
+ if help_text:
+ kwargs['help_text'] = help_text
+
+ return kwargs
+
+
+def get_nested_relation_kwargs(relation_info):
+ kwargs = {'read_only': True}
+ if relation_info.to_many:
+ kwargs['many'] = True
+ return kwargs
+
+
+def get_url_kwargs(model_field):
+ return {
+ 'view_name': get_detail_view_name(model_field)
+ }
diff --git a/rest_framework/utils/model_meta.py b/rest_framework/utils/model_meta.py
index 960fa4d0..b6c41174 100644
--- a/rest_framework/utils/model_meta.py
+++ b/rest_framework/utils/model_meta.py
@@ -1,7 +1,9 @@
"""
-Helper functions for returning the field information that is associated
+Helper function for returning the field information that is associated
with a model class. This includes returning all the forward and reverse
relationships and their associated metadata.
+
+Usage: `get_field_info(model)` returns a `FieldInfo` instance.
"""
from collections import namedtuple
from django.db import models
@@ -9,8 +11,22 @@ from django.utils import six
from django.utils.datastructures import SortedDict
import inspect
-FieldInfo = namedtuple('FieldResult', ['pk', 'fields', 'forward_relations', 'reverse_relations'])
-RelationInfo = namedtuple('RelationInfo', ['field', 'related', 'to_many', 'has_through_model'])
+
+FieldInfo = namedtuple('FieldResult', [
+ 'pk', # Model field instance
+ 'fields', # Dict of field name -> model field instance
+ 'forward_relations', # Dict of field name -> RelationInfo
+ 'reverse_relations', # Dict of field name -> RelationInfo
+ 'fields_and_pk', # Shortcut for 'pk' + 'fields'
+ 'relations' # Shortcut for 'forward_relations' + 'reverse_relations'
+])
+
+RelationInfo = namedtuple('RelationInfo', [
+ 'model_field',
+ 'related',
+ 'to_many',
+ 'has_through_model'
+])
def _resolve_model(obj):
@@ -55,7 +71,7 @@ def get_field_info(model):
forward_relations = SortedDict()
for field in [field for field in opts.fields if field.serialize and field.rel]:
forward_relations[field.name] = RelationInfo(
- field=field,
+ model_field=field,
related=_resolve_model(field.rel.to),
to_many=False,
has_through_model=False
@@ -64,7 +80,7 @@ def get_field_info(model):
# Deal with forward many-to-many relationships.
for field in [field for field in opts.many_to_many if field.serialize]:
forward_relations[field.name] = RelationInfo(
- field=field,
+ model_field=field,
related=_resolve_model(field.rel.to),
to_many=True,
has_through_model=(
@@ -77,7 +93,7 @@ def get_field_info(model):
for relation in opts.get_all_related_objects():
accessor_name = relation.get_accessor_name()
reverse_relations[accessor_name] = RelationInfo(
- field=None,
+ model_field=None,
related=relation.model,
to_many=relation.field.rel.multiple,
has_through_model=False
@@ -87,7 +103,7 @@ def get_field_info(model):
for relation in opts.get_all_related_many_to_many_objects():
accessor_name = relation.get_accessor_name()
reverse_relations[accessor_name] = RelationInfo(
- field=None,
+ model_field=None,
related=relation.model,
to_many=True,
has_through_model=(
@@ -96,4 +112,18 @@ def get_field_info(model):
)
)
- return FieldInfo(pk, fields, forward_relations, reverse_relations)
+ # Shortcut that merges both regular fields and the pk,
+ # for simplifying regular field lookup.
+ fields_and_pk = SortedDict()
+ fields_and_pk['pk'] = pk
+ fields_and_pk[pk.name] = pk
+ fields_and_pk.update(fields)
+
+ # Shortcut that merges both forward and reverse relationships
+
+ relations = SortedDict(
+ list(forward_relations.items()) +
+ list(reverse_relations.items())
+ )
+
+ return FieldInfo(pk, fields, forward_relations, reverse_relations, fields_and_pk, relations)
diff --git a/tests/models.py b/tests/models.py
index fe064b46..06ec5a22 100644
--- a/tests/models.py
+++ b/tests/models.py
@@ -1,7 +1,6 @@
from __future__ import unicode_literals
from django.db import models
from django.utils.translation import ugettext_lazy as _
-from rest_framework import serializers
def foobar():
@@ -178,9 +177,3 @@ class NullableOneToOneSource(RESTFrameworkModel):
name = models.CharField(max_length=100)
target = models.OneToOneField(OneToOneTarget, null=True, blank=True,
related_name='nullable_source')
-
-
-# Serializer used to test BasicModel
-class BasicModelSerializer(serializers.ModelSerializer):
- class Meta:
- model = BasicModel
diff --git a/tests/test_model_field_mappings.py b/tests/test_model_field_mappings.py
index bae63e5a..6daa574e 100644
--- a/tests/test_model_field_mappings.py
+++ b/tests/test_model_field_mappings.py
@@ -126,16 +126,16 @@ class TestRelationalFieldMappings(TestCase):
expected = dedent("""
TestSerializer():
id = IntegerField(label='ID', read_only=True)
- foreign_key = NestedModelSerializer(read_only=True):
+ foreign_key = NestedSerializer(read_only=True):
id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100)
- one_to_one = NestedModelSerializer(read_only=True):
+ one_to_one = NestedSerializer(read_only=True):
id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100)
- many_to_many = NestedModelSerializer(many=True, read_only=True):
+ many_to_many = NestedSerializer(many=True, read_only=True):
id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100)
- through = NestedModelSerializer(many=True, read_only=True):
+ through = NestedSerializer(many=True, read_only=True):
id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100)
""")
@@ -165,16 +165,16 @@ class TestRelationalFieldMappings(TestCase):
expected = dedent("""
TestSerializer():
url = HyperlinkedIdentityField(view_name='relationalmodel-detail')
- foreign_key = NestedModelSerializer(read_only=True):
+ foreign_key = NestedSerializer(read_only=True):
url = HyperlinkedIdentityField(view_name='foreignkeytargetmodel-detail')
name = CharField(max_length=100)
- one_to_one = NestedModelSerializer(read_only=True):
+ one_to_one = NestedSerializer(read_only=True):
url = HyperlinkedIdentityField(view_name='onetoonetargetmodel-detail')
name = CharField(max_length=100)
- many_to_many = NestedModelSerializer(many=True, read_only=True):
+ many_to_many = NestedSerializer(many=True, read_only=True):
url = HyperlinkedIdentityField(view_name='manytomanytargetmodel-detail')
name = CharField(max_length=100)
- through = NestedModelSerializer(many=True, read_only=True):
+ through = NestedSerializer(many=True, read_only=True):
url = HyperlinkedIdentityField(view_name='throughtargetmodel-detail')
name = CharField(max_length=100)
""")
diff --git a/tests/test_response.py b/tests/test_response.py
index 67419a71..84c39c1a 100644
--- a/tests/test_response.py
+++ b/tests/test_response.py
@@ -2,11 +2,12 @@ from __future__ import unicode_literals
from django.conf.urls import patterns, url, include
from django.test import TestCase
from django.utils import six
-from tests.models import BasicModel, BasicModelSerializer
+from tests.models import BasicModel
from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework import generics
from rest_framework import routers
+from rest_framework import serializers
from rest_framework import status
from rest_framework.renderers import (
BaseRenderer,
@@ -17,6 +18,12 @@ from rest_framework import viewsets
from rest_framework.settings import api_settings
+# Serializer used to test BasicModel
+class BasicModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BasicModel
+
+
class MockPickleRenderer(BaseRenderer):
media_type = 'application/pickle'
diff --git a/tests/test_routers.py b/tests/test_routers.py
index b076f134..c2d595f7 100644
--- a/tests/test_routers.py
+++ b/tests/test_routers.py
@@ -76,9 +76,10 @@ class TestCustomLookupFields(TestCase):
def setUp(self):
class NoteSerializer(serializers.HyperlinkedModelSerializer):
+ url = serializers.HyperlinkedIdentityField(view_name='routertestmodel-detail', lookup_field='uuid')
+
class Meta:
model = RouterTestModel
- lookup_field = 'uuid'
fields = ('url', 'uuid', 'text')
class NoteViewSet(viewsets.ModelViewSet):
@@ -86,8 +87,6 @@ class TestCustomLookupFields(TestCase):
serializer_class = NoteSerializer
lookup_field = 'uuid'
- RouterTestModel.objects.create(uuid='123', text='foo bar')
-
self.router = SimpleRouter()
self.router.register(r'notes', NoteViewSet)
@@ -98,6 +97,8 @@ class TestCustomLookupFields(TestCase):
url(r'^', include(self.router.urls)),
)
+ RouterTestModel.objects.create(uuid='123', text='foo bar')
+
def test_custom_lookup_field_route(self):
detail_route = self.router.urls[-1]
detail_url_pattern = detail_route.regex.pattern