diff options
| -rw-r--r-- | rest_framework/fields.py | 6 | ||||
| -rw-r--r-- | rest_framework/relations.py | 9 | ||||
| -rw-r--r-- | rest_framework/serializers.py | 464 | ||||
| -rw-r--r-- | rest_framework/utils/field_mapping.py | 215 | ||||
| -rw-r--r-- | rest_framework/utils/model_meta.py | 46 | ||||
| -rw-r--r-- | tests/models.py | 7 | ||||
| -rw-r--r-- | tests/test_model_field_mappings.py | 16 | ||||
| -rw-r--r-- | tests/test_response.py | 9 | ||||
| -rw-r--r-- | tests/test_routers.py | 7 | 
9 files changed, 384 insertions, 395 deletions
| diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 1818e705..0c78b3fb 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -80,10 +80,6 @@ def set_value(dictionary, keys, value):      dictionary[keys[-1]] = value -def field_name_to_label(field_name): -    return field_name.replace('_', ' ').capitalize() - -  class SkipField(Exception):      pass @@ -162,7 +158,7 @@ class Field(object):          # `self.label` should deafult to being based on the field name.          if self.label is None: -            self.label = field_name_to_label(self.field_name) +            self.label = field_name.replace('_', ' ').capitalize()          # self.source should default to being the same as the field name.          if self.source is None: diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 46fe55ef..9f44ab63 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -73,7 +73,8 @@ class HyperlinkedRelatedField(RelatedField):          'incorrect_type': 'Incorrect type. Expected URL string, received {data_type}.',      } -    def __init__(self, view_name, **kwargs): +    def __init__(self, view_name=None, **kwargs): +        assert view_name is not None, 'The `view_name` argument is required.'          self.view_name = view_name          self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)          self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field) @@ -182,7 +183,8 @@ class HyperlinkedIdentityField(HyperlinkedRelatedField):      URL of relationships to other objects.      """ -    def __init__(self, view_name, **kwargs): +    def __init__(self, view_name=None, **kwargs): +        assert view_name is not None, 'The `view_name` argument is required.'          kwargs['read_only'] = True          kwargs['source'] = '*'          super(HyperlinkedIdentityField, self).__init__(view_name, **kwargs) @@ -199,7 +201,8 @@ class SlugRelatedField(RelatedField):          'invalid': _('Invalid value.'),      } -    def __init__(self, slug_field, **kwargs): +    def __init__(self, slug_field=None, **kwargs): +        assert slug_field is not None, 'The `slug_field` argument is required.'          self.slug_field = slug_field          super(SlugRelatedField, self).__init__(**kwargs) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 1fea1380..99dcc349 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -10,17 +10,19 @@ python primitives.  2. The process of marshalling between python primitives and request and  response content is handled by parsers and renderers.  """ -from django.core import validators  from django.core.exceptions import ValidationError  from django.db import models  from django.utils import six  from django.utils.datastructures import SortedDict -from django.utils.text import capfirst  from collections import namedtuple -from rest_framework.compat import clean_manytomany_helptext  from rest_framework.fields import empty, set_value, Field, SkipField  from rest_framework.settings import api_settings  from rest_framework.utils import html, model_meta, representation +from rest_framework.utils.field_mapping import ( +    get_url_kwargs, get_field_kwargs, +    get_relation_kwargs, get_nested_relation_kwargs, +    lookup_class +)  import copy  # Note: We do the following so that users of the framework can use this style: @@ -126,7 +128,7 @@ class SerializerMetaclass(type):      """      @classmethod -    def _get_fields(cls, bases, attrs): +    def _get_declared_fields(cls, bases, attrs):          fields = [(field_name, attrs.pop(field_name))                    for field_name, obj in list(attrs.items())                    if isinstance(obj, Field)] @@ -136,25 +138,18 @@ class SerializerMetaclass(type):          # fields.  Note that we loop over the bases in *reverse*. This is necessary          # in order to maintain the correct order of fields.          for base in bases[::-1]: -            if hasattr(base, 'base_fields'): -                fields = list(base.base_fields.items()) + fields +            if hasattr(base, '_declared_fields'): +                fields = list(base._declared_fields.items()) + fields          return SortedDict(fields)      def __new__(cls, name, bases, attrs): -        attrs['base_fields'] = cls._get_fields(bases, attrs) +        attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs)          return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs)  @six.add_metaclass(SerializerMetaclass)  class Serializer(BaseSerializer): - -    def __new__(cls, *args, **kwargs): -        if kwargs.pop('many', False): -            kwargs['child'] = cls() -            return ListSerializer(*args, **kwargs) -        return super(Serializer, cls).__new__(cls, *args, **kwargs) -      def __init__(self, *args, **kwargs):          self.context = kwargs.pop('context', {})          kwargs.pop('partial', None) @@ -165,14 +160,22 @@ class Serializer(BaseSerializer):          # Every new serializer is created with a clone of the field instances.          # This allows users to dynamically modify the fields on a serializer          # instance without affecting every other serializer class. -        self.fields = self.get_fields() +        self.fields = self._get_base_fields()          # Setup all the child fields, to provide them with the current context.          for field_name, field in self.fields.items():              field.bind(field_name, self, self) -    def get_fields(self): -        return copy.deepcopy(self.base_fields) +    def __new__(cls, *args, **kwargs): +        # We override this method in order to automagically create +        # `ListSerializer` classes instead when `many=True` is set. +        if kwargs.pop('many', False): +            kwargs['child'] = cls() +            return ListSerializer(*args, **kwargs) +        return super(Serializer, cls).__new__(cls, *args, **kwargs) + +    def _get_base_fields(self): +        return copy.deepcopy(self._declared_fields)      def bind(self, field_name, parent, root):          # If the serializer is used as a field then when it becomes bound @@ -312,39 +315,8 @@ class ListSerializer(BaseSerializer):          return representation.list_repr(self, indent=1) -class ModelSerializerOptions(object): -    """ -    Meta class options for ModelSerializer -    """ -    def __init__(self, meta): -        self.model = getattr(meta, 'model') -        self.fields = getattr(meta, 'fields', ()) -        self.depth = getattr(meta, 'depth', 0) - - -def lookup_class(mapping, instance): -    """ -    Takes a dictionary with classes as keys, and an object. -    Traverses the object's inheritance hierarchy in method -    resolution order, and returns the first matching value -    from the dictionary or raises a KeyError if nothing matches. -    """ -    for cls in inspect.getmro(instance.__class__): -        if cls in mapping: -            return mapping[cls] -    raise KeyError('Class %s not found in lookup.', cls.__name__) - - -def needs_label(model_field, field_name): -    """ -    Returns `True` if the label based on the model's verbose name -    is not equal to the default label it would have based on it's field name. -    """ -    return capfirst(model_field.verbose_name) != field_name_to_label(field_name) - -  class ModelSerializer(Serializer): -    field_mapping = { +    _field_mapping = {          models.AutoField: IntegerField,          models.BigIntegerField: IntegerField,          models.BooleanField: BooleanField, @@ -368,16 +340,10 @@ class ModelSerializer(Serializer):          models.TimeField: TimeField,          models.URLField: URLField,      } -    nested_class = None  # We fill this in at the end of this module. - -    _options_class = ModelSerializerOptions - -    def __init__(self, *args, **kwargs): -        self.opts = self._options_class(self.Meta) -        super(ModelSerializer, self).__init__(*args, **kwargs) +    _related_class = PrimaryKeyRelatedField      def create(self, attrs): -        ModelClass = self.opts.model +        ModelClass = self.Meta.model          return ModelClass.objects.create(**attrs)      def update(self, obj, attrs): @@ -385,319 +351,97 @@ class ModelSerializer(Serializer):              setattr(obj, attr, value)          obj.save() -    def get_fields(self): -        # Get the explicitly declared fields. -        fields = copy.deepcopy(self.base_fields) +    def _get_base_fields(self): +        declared_fields = copy.deepcopy(self._declared_fields) -        # Add in the default fields. -        for key, val in self.get_default_fields().items(): -            if key not in fields: -                fields[key] = val - -        # If `fields` is set on the `Meta` class, -        # then use only those fields, and in that order. -        if self.opts.fields: -            fields = SortedDict([ -                (key, fields[key]) for key in self.opts.fields -            ]) - -        return fields - -    def get_default_fields(self): -        """ -        Return all the fields that should be serialized for the model. -        """ -        info = model_meta.get_field_info(self.opts.model)          ret = SortedDict() +        model = getattr(self.Meta, 'model') +        fields = getattr(self.Meta, 'fields', None) +        depth = getattr(self.Meta, 'depth', 0) + +        # Retrieve metadata about fields & relationships on the model class. +        info = model_meta.get_field_info(model) + +        # Use the default set of fields if none is supplied explicitly. +        if fields is None: +            fields = self._get_default_field_names(declared_fields, info) + +        for field_name in fields: +            if field_name in declared_fields: +                # Field is explicitly declared on the class, use that. +                ret[field_name] = declared_fields[field_name] +                continue + +            elif field_name == api_settings.URL_FIELD_NAME: +                # Create the URL field. +                field_cls = HyperlinkedIdentityField +                kwargs = get_url_kwargs(model) + +            elif field_name in info.fields_and_pk: +                # Create regular model fields. +                model_field = info.fields_and_pk[field_name] +                field_cls = lookup_class(self._field_mapping, model_field) +                kwargs = get_field_kwargs(field_name, model_field) +                if 'choices' in kwargs: +                    # Fields with choices get coerced into `ChoiceField` +                    # instead of using their regular typed field. +                    field_cls = ChoiceField +                if not issubclass(field_cls, ModelField): +                    # `model_field` is only valid for the fallback case of +                    # `ModelField`, which is used when no other typed field +                    # matched to the model field. +                    kwargs.pop('model_field', None) + +            elif field_name in info.relations: +                # Create forward and reverse relationships. +                relation_info = info.relations[field_name] +                if depth: +                    field_cls = self._get_nested_class(depth, relation_info) +                    kwargs = get_nested_relation_kwargs(relation_info) +                else: +                    field_cls = self._related_class +                    kwargs = get_relation_kwargs(field_name, relation_info) +                    # `view_name` is only valid for hyperlinked relationships. +                    if not issubclass(field_cls, HyperlinkedRelatedField): +                        kwargs.pop('view_name', None) -        # URL field -        serializer_url_field = self.get_url_field() -        if serializer_url_field: -            ret[api_settings.URL_FIELD_NAME] = serializer_url_field - -        # Primary key field -        field_name = info.pk.name -        serializer_pk_field = self.get_pk_field(field_name, info.pk) -        if serializer_pk_field: -            ret[field_name] = serializer_pk_field - -        # Regular fields -        for field_name, field in info.fields.items(): -            ret[field_name] = self.get_field(field_name, field) - -        # Forward relations -        for field_name, relation_info in info.forward_relations.items(): -            if self.opts.depth: -                ret[field_name] = self.get_nested_field(field_name, *relation_info)              else: -                ret[field_name] = self.get_related_field(field_name, *relation_info) +                assert False, 'Field name `%s` is not valid.' % field_name -        # Reverse relations -        for accessor_name, relation_info in info.reverse_relations.items(): -            if accessor_name in self.opts.fields: -                if self.opts.depth: -                    ret[accessor_name] = self.get_nested_field(accessor_name, *relation_info) -                else: -                    ret[accessor_name] = self.get_related_field(accessor_name, *relation_info) +            ret[field_name] = field_cls(**kwargs)          return ret -    def get_url_field(self): -        return None - -    def get_pk_field(self, field_name, model_field): -        """ -        Returns a default instance of the pk field. -        """ -        return self.get_field(field_name, model_field) - -    def get_nested_field(self, field_name, model_field, related_model, to_many, has_through_model): -        """ -        Creates a default instance of a nested relational field. +    def _get_default_field_names(self, declared_fields, model_info): +        return ( +            [model_info.pk.name] + +            list(declared_fields.keys()) + +            list(model_info.fields.keys()) + +            list(model_info.forward_relations.keys()) +        ) -        Note that model_field will be `None` for reverse relationships. -        """ -        class NestedModelSerializer(self.nested_class): +    def _get_nested_class(self, nested_depth, relation_info): +        class NestedSerializer(ModelSerializer):              class Meta: -                model = related_model -                depth = self.opts.depth - 1 - -        kwargs = {'read_only': True} -        if to_many: -            kwargs['many'] = True -        return NestedModelSerializer(**kwargs) - -    def get_related_field(self, field_name, model_field, related_model, to_many, has_through_model): -        """ -        Creates a default instance of a flat relational field. - -        Note that model_field will be `None` for reverse relationships. -        """ -        kwargs = { -            'queryset': related_model._default_manager, -        } - -        if to_many: -            kwargs['many'] = True - -        if has_through_model: -            kwargs['read_only'] = True -            kwargs.pop('queryset', None) - -        if model_field: -            if model_field.null or model_field.blank: -                kwargs['required'] = False -            if model_field.verbose_name and needs_label(model_field, field_name): -                kwargs['label'] = capfirst(model_field.verbose_name) -            if not model_field.editable: -                kwargs['read_only'] = True -                kwargs.pop('queryset', None) -            help_text = clean_manytomany_helptext(model_field.help_text) -            if help_text: -                kwargs['help_text'] = help_text - -        return PrimaryKeyRelatedField(**kwargs) - -    def get_field(self, field_name, model_field): -        """ -        Creates a default instance of a basic non-relational field. -        """ -        serializer_cls = lookup_class(self.field_mapping, model_field) -        kwargs = {} -        validator_kwarg = model_field.validators - -        if model_field.null or model_field.blank: -            kwargs['required'] = False - -        if model_field.verbose_name and needs_label(model_field, field_name): -            kwargs['label'] = capfirst(model_field.verbose_name) - -        if model_field.help_text: -            kwargs['help_text'] = model_field.help_text - -        if isinstance(model_field, models.AutoField) or not model_field.editable: -            kwargs['read_only'] = True -            # Read only implies that the field is not required. -            # We have a cleaner repr on the instance if we don't set it. -            kwargs.pop('required', None) - -        if model_field.has_default(): -            kwargs['default'] = model_field.get_default() -            # Having a default implies that the field is not required. -            # We have a cleaner repr on the instance if we don't set it. -            kwargs.pop('required', None) - -        if model_field.flatchoices: -            # If this model field contains choices, then use a ChoiceField, -            # rather than the standard serializer field for this type. -            # Note that we return this prior to setting any validation type -            # keyword arguments, as those are not valid initializers. -            kwargs['choices'] = model_field.flatchoices -            return ChoiceField(**kwargs) - -        # Ensure that max_length is passed explicitly as a keyword arg, -        # rather than as a validator. -        max_length = getattr(model_field, 'max_length', None) -        if max_length is not None: -            kwargs['max_length'] = max_length -            validator_kwarg = [ -                validator for validator in validator_kwarg -                if not isinstance(validator, validators.MaxLengthValidator) -            ] - -        # Ensure that min_length is passed explicitly as a keyword arg, -        # rather than as a validator. -        min_length = getattr(model_field, 'min_length', None) -        if min_length is not None: -            kwargs['min_length'] = min_length -            validator_kwarg = [ -                validator for validator in validator_kwarg -                if not isinstance(validator, validators.MinLengthValidator) -            ] - -        # Ensure that max_value is passed explicitly as a keyword arg, -        # rather than as a validator. -        max_value = next(( -            validator.limit_value for validator in validator_kwarg -            if isinstance(validator, validators.MaxValueValidator) -        ), None) -        if max_value is not None: -            kwargs['max_value'] = max_value -            validator_kwarg = [ -                validator for validator in validator_kwarg -                if not isinstance(validator, validators.MaxValueValidator) -            ] - -        # Ensure that max_value is passed explicitly as a keyword arg, -        # rather than as a validator. -        min_value = next(( -            validator.limit_value for validator in validator_kwarg -            if isinstance(validator, validators.MinValueValidator) -        ), None) -        if min_value is not None: -            kwargs['min_value'] = min_value -            validator_kwarg = [ -                validator for validator in validator_kwarg -                if not isinstance(validator, validators.MinValueValidator) -            ] - -        # URLField does not need to include the URLValidator argument, -        # as it is explicitly added in. -        if isinstance(model_field, models.URLField): -            validator_kwarg = [ -                validator for validator in validator_kwarg -                if not isinstance(validator, validators.URLValidator) -            ] - -        # EmailField does not need to include the validate_email argument, -        # as it is explicitly added in. -        if isinstance(model_field, models.EmailField): -            validator_kwarg = [ -                validator for validator in validator_kwarg -                if validator is not validators.validate_email -            ] - -        # SlugField do not need to include the 'validate_slug' argument, -        if isinstance(model_field, models.SlugField): -            validator_kwarg = [ -                validator for validator in validator_kwarg -                if validator is not validators.validate_slug -            ] - -        max_digits = getattr(model_field, 'max_digits', None) -        if max_digits is not None: -            kwargs['max_digits'] = max_digits - -        decimal_places = getattr(model_field, 'decimal_places', None) -        if decimal_places is not None: -            kwargs['decimal_places'] = decimal_places - -        if isinstance(model_field, models.BooleanField): -            # models.BooleanField has `blank=True`, but *is* actually -            # required *unless* a default is provided. -            # Also note that Django<1.6 uses `default=False` for -            # models.BooleanField, but Django>=1.6 uses `default=None`. -            kwargs.pop('required', None) - -        if validator_kwarg: -            kwargs['validators'] = validator_kwarg - -        if issubclass(serializer_cls, ModelField): -            kwargs['model_field'] = model_field - -        return serializer_cls(**kwargs) - - -class HyperlinkedModelSerializerOptions(ModelSerializerOptions): -    """ -    Options for HyperlinkedModelSerializer -    """ -    def __init__(self, meta): -        super(HyperlinkedModelSerializerOptions, self).__init__(meta) -        self.view_name = getattr(meta, 'view_name', None) -        self.lookup_field = getattr(meta, 'lookup_field', None) +                model = relation_info.related +                depth = nested_depth +        return NestedSerializer  class HyperlinkedModelSerializer(ModelSerializer): -    _options_class = HyperlinkedModelSerializerOptions - -    def get_url_field(self): -        if self.opts.view_name is not None: -            view_name = self.opts.view_name -        else: -            view_name = self.get_default_view_name(self.opts.model) - -        kwargs = { -            'view_name': view_name -        } -        if self.opts.lookup_field: -            kwargs['lookup_field'] = self.opts.lookup_field - -        return HyperlinkedIdentityField(**kwargs) - -    def get_pk_field(self, field_name, model_field): -        if self.opts.fields and model_field.name in self.opts.fields: -            return self.get_field(model_field) - -    def get_related_field(self, field_name, model_field, related_model, to_many, has_through_model): -        """ -        Creates a default instance of a flat relational field. -        """ -        kwargs = { -            'queryset': related_model._default_manager, -            'view_name': self.get_default_view_name(related_model), -        } - -        if to_many: -            kwargs['many'] = True - -        if has_through_model: -            kwargs['read_only'] = True -            kwargs.pop('queryset', None) - -        if model_field: -            if model_field.null or model_field.blank: -                kwargs['required'] = False -            if model_field.verbose_name and needs_label(model_field, field_name): -                kwargs['label'] = capfirst(model_field.verbose_name) -            if not model_field.editable: -                kwargs['read_only'] = True -                kwargs.pop('queryset', None) -            help_text = clean_manytomany_helptext(model_field.help_text) -            if help_text: -                kwargs['help_text'] = help_text - -        return HyperlinkedRelatedField(**kwargs) - -    def get_default_view_name(self, model): -        """ -        Return the view name to use for related models. -        """ -        return '%(model_name)s-detail' % { -            'app_label': model._meta.app_label, -            'model_name': model._meta.object_name.lower() -        } - - -ModelSerializer.nested_class = ModelSerializer -HyperlinkedModelSerializer.nested_class = HyperlinkedModelSerializer +    _related_class = HyperlinkedRelatedField + +    def _get_default_field_names(self, declared_fields, model_info): +        return ( +            [api_settings.URL_FIELD_NAME] + +            list(declared_fields.keys()) + +            list(model_info.fields.keys()) + +            list(model_info.forward_relations.keys()) +        ) + +    def _get_nested_class(self, nested_depth, relation_info): +        class NestedSerializer(HyperlinkedModelSerializer): +            class Meta: +                model = relation_info.related +                depth = nested_depth +        return NestedSerializer diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py new file mode 100644 index 00000000..be72e444 --- /dev/null +++ b/rest_framework/utils/field_mapping.py @@ -0,0 +1,215 @@ +""" +Helper functions for mapping model fields to a dictionary of default +keyword arguments that should be used for their equivelent serializer fields. +""" +from django.core import validators +from django.db import models +from django.utils.text import capfirst +from rest_framework.compat import clean_manytomany_helptext +import inspect + + +def lookup_class(mapping, instance): +    """ +    Takes a dictionary with classes as keys, and an object. +    Traverses the object's inheritance hierarchy in method +    resolution order, and returns the first matching value +    from the dictionary or raises a KeyError if nothing matches. +    """ +    for cls in inspect.getmro(instance.__class__): +        if cls in mapping: +            return mapping[cls] +    raise KeyError('Class %s not found in lookup.', cls.__name__) + + +def needs_label(model_field, field_name): +    """ +    Returns `True` if the label based on the model's verbose name +    is not equal to the default label it would have based on it's field name. +    """ +    default_label = field_name.replace('_', ' ').capitalize() +    return capfirst(model_field.verbose_name) != default_label + + +def get_detail_view_name(model): +    """ +    Given a model class, return the view name to use for URL relationships +    that refer to instances of the model. +    """ +    return '%(model_name)s-detail' % { +        'app_label': model._meta.app_label, +        'model_name': model._meta.object_name.lower() +    } + + +def get_field_kwargs(field_name, model_field): +    """ +    Creates a default instance of a basic non-relational field. +    """ +    kwargs = {} +    validator_kwarg = model_field.validators + +    if model_field.null or model_field.blank: +        kwargs['required'] = False + +    if model_field.verbose_name and needs_label(model_field, field_name): +        kwargs['label'] = capfirst(model_field.verbose_name) + +    if model_field.help_text: +        kwargs['help_text'] = model_field.help_text + +    if isinstance(model_field, models.AutoField) or not model_field.editable: +        kwargs['read_only'] = True +        # Read only implies that the field is not required. +        # We have a cleaner repr on the instance if we don't set it. +        kwargs.pop('required', None) + +    if model_field.has_default(): +        kwargs['default'] = model_field.get_default() +        # Having a default implies that the field is not required. +        # We have a cleaner repr on the instance if we don't set it. +        kwargs.pop('required', None) + +    if model_field.flatchoices: +        # If this model field contains choices, then return now, +        # any further keyword arguments are not valid. +        kwargs['choices'] = model_field.flatchoices +        return kwargs + +    # Ensure that max_length is passed explicitly as a keyword arg, +    # rather than as a validator. +    max_length = getattr(model_field, 'max_length', None) +    if max_length is not None: +        kwargs['max_length'] = max_length +        validator_kwarg = [ +            validator for validator in validator_kwarg +            if not isinstance(validator, validators.MaxLengthValidator) +        ] + +    # Ensure that min_length is passed explicitly as a keyword arg, +    # rather than as a validator. +    min_length = getattr(model_field, 'min_length', None) +    if min_length is not None: +        kwargs['min_length'] = min_length +        validator_kwarg = [ +            validator for validator in validator_kwarg +            if not isinstance(validator, validators.MinLengthValidator) +        ] + +    # Ensure that max_value is passed explicitly as a keyword arg, +    # rather than as a validator. +    max_value = next(( +        validator.limit_value for validator in validator_kwarg +        if isinstance(validator, validators.MaxValueValidator) +    ), None) +    if max_value is not None: +        kwargs['max_value'] = max_value +        validator_kwarg = [ +            validator for validator in validator_kwarg +            if not isinstance(validator, validators.MaxValueValidator) +        ] + +    # Ensure that max_value is passed explicitly as a keyword arg, +    # rather than as a validator. +    min_value = next(( +        validator.limit_value for validator in validator_kwarg +        if isinstance(validator, validators.MinValueValidator) +    ), None) +    if min_value is not None: +        kwargs['min_value'] = min_value +        validator_kwarg = [ +            validator for validator in validator_kwarg +            if not isinstance(validator, validators.MinValueValidator) +        ] + +    # URLField does not need to include the URLValidator argument, +    # as it is explicitly added in. +    if isinstance(model_field, models.URLField): +        validator_kwarg = [ +            validator for validator in validator_kwarg +            if not isinstance(validator, validators.URLValidator) +        ] + +    # EmailField does not need to include the validate_email argument, +    # as it is explicitly added in. +    if isinstance(model_field, models.EmailField): +        validator_kwarg = [ +            validator for validator in validator_kwarg +            if validator is not validators.validate_email +        ] + +    # SlugField do not need to include the 'validate_slug' argument, +    if isinstance(model_field, models.SlugField): +        validator_kwarg = [ +            validator for validator in validator_kwarg +            if validator is not validators.validate_slug +        ] + +    max_digits = getattr(model_field, 'max_digits', None) +    if max_digits is not None: +        kwargs['max_digits'] = max_digits + +    decimal_places = getattr(model_field, 'decimal_places', None) +    if decimal_places is not None: +        kwargs['decimal_places'] = decimal_places + +    if isinstance(model_field, models.BooleanField): +        # models.BooleanField has `blank=True`, but *is* actually +        # required *unless* a default is provided. +        # Also note that Django<1.6 uses `default=False` for +        # models.BooleanField, but Django>=1.6 uses `default=None`. +        kwargs.pop('required', None) + +    if validator_kwarg: +        kwargs['validators'] = validator_kwarg + +    # The following will only be used by ModelField classes. +    # Gets removed for everything else. +    kwargs['model_field'] = model_field + +    return kwargs + + +def get_relation_kwargs(field_name, relation_info): +    """ +    Creates a default instance of a flat relational field. +    """ +    model_field, related_model, to_many, has_through_model = relation_info +    kwargs = { +        'queryset': related_model._default_manager, +        'view_name': get_detail_view_name(related_model) +    } + +    if to_many: +        kwargs['many'] = True + +    if has_through_model: +        kwargs['read_only'] = True +        kwargs.pop('queryset', None) + +    if model_field: +        if model_field.null or model_field.blank: +            kwargs['required'] = False +        if model_field.verbose_name and needs_label(model_field, field_name): +            kwargs['label'] = capfirst(model_field.verbose_name) +        if not model_field.editable: +            kwargs['read_only'] = True +            kwargs.pop('queryset', None) +        help_text = clean_manytomany_helptext(model_field.help_text) +        if help_text: +            kwargs['help_text'] = help_text + +    return kwargs + + +def get_nested_relation_kwargs(relation_info): +    kwargs = {'read_only': True} +    if relation_info.to_many: +        kwargs['many'] = True +    return kwargs + + +def get_url_kwargs(model_field): +    return { +        'view_name': get_detail_view_name(model_field) +    } diff --git a/rest_framework/utils/model_meta.py b/rest_framework/utils/model_meta.py index 960fa4d0..b6c41174 100644 --- a/rest_framework/utils/model_meta.py +++ b/rest_framework/utils/model_meta.py @@ -1,7 +1,9 @@  """ -Helper functions for returning the field information that is associated +Helper function for returning the field information that is associated  with a model class. This includes returning all the forward and reverse  relationships and their associated metadata. + +Usage: `get_field_info(model)` returns a `FieldInfo` instance.  """  from collections import namedtuple  from django.db import models @@ -9,8 +11,22 @@ from django.utils import six  from django.utils.datastructures import SortedDict  import inspect -FieldInfo = namedtuple('FieldResult', ['pk', 'fields', 'forward_relations', 'reverse_relations']) -RelationInfo = namedtuple('RelationInfo', ['field', 'related', 'to_many', 'has_through_model']) + +FieldInfo = namedtuple('FieldResult', [ +    'pk',  # Model field instance +    'fields',  # Dict of field name -> model field instance +    'forward_relations',  # Dict of field name -> RelationInfo +    'reverse_relations',  # Dict of field name -> RelationInfo +    'fields_and_pk',  # Shortcut for 'pk' + 'fields' +    'relations'  # Shortcut for 'forward_relations' + 'reverse_relations' +]) + +RelationInfo = namedtuple('RelationInfo', [ +    'model_field', +    'related', +    'to_many', +    'has_through_model' +])  def _resolve_model(obj): @@ -55,7 +71,7 @@ def get_field_info(model):      forward_relations = SortedDict()      for field in [field for field in opts.fields if field.serialize and field.rel]:          forward_relations[field.name] = RelationInfo( -            field=field, +            model_field=field,              related=_resolve_model(field.rel.to),              to_many=False,              has_through_model=False @@ -64,7 +80,7 @@ def get_field_info(model):      # Deal with forward many-to-many relationships.      for field in [field for field in opts.many_to_many if field.serialize]:          forward_relations[field.name] = RelationInfo( -            field=field, +            model_field=field,              related=_resolve_model(field.rel.to),              to_many=True,              has_through_model=( @@ -77,7 +93,7 @@ def get_field_info(model):      for relation in opts.get_all_related_objects():          accessor_name = relation.get_accessor_name()          reverse_relations[accessor_name] = RelationInfo( -            field=None, +            model_field=None,              related=relation.model,              to_many=relation.field.rel.multiple,              has_through_model=False @@ -87,7 +103,7 @@ def get_field_info(model):      for relation in opts.get_all_related_many_to_many_objects():          accessor_name = relation.get_accessor_name()          reverse_relations[accessor_name] = RelationInfo( -            field=None, +            model_field=None,              related=relation.model,              to_many=True,              has_through_model=( @@ -96,4 +112,18 @@ def get_field_info(model):              )          ) -    return FieldInfo(pk, fields, forward_relations, reverse_relations) +    # Shortcut that merges both regular fields and the pk, +    # for simplifying regular field lookup. +    fields_and_pk = SortedDict() +    fields_and_pk['pk'] = pk +    fields_and_pk[pk.name] = pk +    fields_and_pk.update(fields) + +    # Shortcut that merges both forward and reverse relationships + +    relations = SortedDict( +        list(forward_relations.items()) + +        list(reverse_relations.items()) +    ) + +    return FieldInfo(pk, fields, forward_relations, reverse_relations, fields_and_pk, relations) diff --git a/tests/models.py b/tests/models.py index fe064b46..06ec5a22 100644 --- a/tests/models.py +++ b/tests/models.py @@ -1,7 +1,6 @@  from __future__ import unicode_literals  from django.db import models  from django.utils.translation import ugettext_lazy as _ -from rest_framework import serializers  def foobar(): @@ -178,9 +177,3 @@ class NullableOneToOneSource(RESTFrameworkModel):      name = models.CharField(max_length=100)      target = models.OneToOneField(OneToOneTarget, null=True, blank=True,                                    related_name='nullable_source') - - -# Serializer used to test BasicModel -class BasicModelSerializer(serializers.ModelSerializer): -    class Meta: -        model = BasicModel diff --git a/tests/test_model_field_mappings.py b/tests/test_model_field_mappings.py index bae63e5a..6daa574e 100644 --- a/tests/test_model_field_mappings.py +++ b/tests/test_model_field_mappings.py @@ -126,16 +126,16 @@ class TestRelationalFieldMappings(TestCase):          expected = dedent("""              TestSerializer():                  id = IntegerField(label='ID', read_only=True) -                foreign_key = NestedModelSerializer(read_only=True): +                foreign_key = NestedSerializer(read_only=True):                      id = IntegerField(label='ID', read_only=True)                      name = CharField(max_length=100) -                one_to_one = NestedModelSerializer(read_only=True): +                one_to_one = NestedSerializer(read_only=True):                      id = IntegerField(label='ID', read_only=True)                      name = CharField(max_length=100) -                many_to_many = NestedModelSerializer(many=True, read_only=True): +                many_to_many = NestedSerializer(many=True, read_only=True):                      id = IntegerField(label='ID', read_only=True)                      name = CharField(max_length=100) -                through = NestedModelSerializer(many=True, read_only=True): +                through = NestedSerializer(many=True, read_only=True):                      id = IntegerField(label='ID', read_only=True)                      name = CharField(max_length=100)          """) @@ -165,16 +165,16 @@ class TestRelationalFieldMappings(TestCase):          expected = dedent("""              TestSerializer():                  url = HyperlinkedIdentityField(view_name='relationalmodel-detail') -                foreign_key = NestedModelSerializer(read_only=True): +                foreign_key = NestedSerializer(read_only=True):                      url = HyperlinkedIdentityField(view_name='foreignkeytargetmodel-detail')                      name = CharField(max_length=100) -                one_to_one = NestedModelSerializer(read_only=True): +                one_to_one = NestedSerializer(read_only=True):                      url = HyperlinkedIdentityField(view_name='onetoonetargetmodel-detail')                      name = CharField(max_length=100) -                many_to_many = NestedModelSerializer(many=True, read_only=True): +                many_to_many = NestedSerializer(many=True, read_only=True):                      url = HyperlinkedIdentityField(view_name='manytomanytargetmodel-detail')                      name = CharField(max_length=100) -                through = NestedModelSerializer(many=True, read_only=True): +                through = NestedSerializer(many=True, read_only=True):                      url = HyperlinkedIdentityField(view_name='throughtargetmodel-detail')                      name = CharField(max_length=100)          """) diff --git a/tests/test_response.py b/tests/test_response.py index 67419a71..84c39c1a 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -2,11 +2,12 @@ from __future__ import unicode_literals  from django.conf.urls import patterns, url, include  from django.test import TestCase  from django.utils import six -from tests.models import BasicModel, BasicModelSerializer +from tests.models import BasicModel  from rest_framework.response import Response  from rest_framework.views import APIView  from rest_framework import generics  from rest_framework import routers +from rest_framework import serializers  from rest_framework import status  from rest_framework.renderers import (      BaseRenderer, @@ -17,6 +18,12 @@ from rest_framework import viewsets  from rest_framework.settings import api_settings +# Serializer used to test BasicModel +class BasicModelSerializer(serializers.ModelSerializer): +    class Meta: +        model = BasicModel + +  class MockPickleRenderer(BaseRenderer):      media_type = 'application/pickle' diff --git a/tests/test_routers.py b/tests/test_routers.py index b076f134..c2d595f7 100644 --- a/tests/test_routers.py +++ b/tests/test_routers.py @@ -76,9 +76,10 @@ class TestCustomLookupFields(TestCase):      def setUp(self):          class NoteSerializer(serializers.HyperlinkedModelSerializer): +            url = serializers.HyperlinkedIdentityField(view_name='routertestmodel-detail', lookup_field='uuid') +              class Meta:                  model = RouterTestModel -                lookup_field = 'uuid'                  fields = ('url', 'uuid', 'text')          class NoteViewSet(viewsets.ModelViewSet): @@ -86,8 +87,6 @@ class TestCustomLookupFields(TestCase):              serializer_class = NoteSerializer              lookup_field = 'uuid' -        RouterTestModel.objects.create(uuid='123', text='foo bar') -          self.router = SimpleRouter()          self.router.register(r'notes', NoteViewSet) @@ -98,6 +97,8 @@ class TestCustomLookupFields(TestCase):              url(r'^', include(self.router.urls)),          ) +        RouterTestModel.objects.create(uuid='123', text='foo bar') +      def test_custom_lookup_field_route(self):          detail_route = self.router.urls[-1]          detail_url_pattern = detail_route.regex.pattern | 
