diff options
| author | Tom Christie | 2014-09-18 11:20:56 +0100 |
|---|---|---|
| committer | Tom Christie | 2014-09-18 11:20:56 +0100 |
| commit | 5b7e4af0d657a575cb15eea85a63a7100c636085 (patch) | |
| tree | 798e30ea326324151f4e87319156fb2b35147792 | |
| parent | c0155fd9dc654dc5932effd46a00f66495ce700b (diff) | |
| download | django-rest-framework-5b7e4af0d657a575cb15eea85a63a7100c636085.tar.bz2 | |
get_base_field() refactor
| -rw-r--r-- | rest_framework/fields.py | 6 | ||||
| -rw-r--r-- | rest_framework/relations.py | 9 | ||||
| -rw-r--r-- | rest_framework/serializers.py | 464 | ||||
| -rw-r--r-- | rest_framework/utils/field_mapping.py | 215 | ||||
| -rw-r--r-- | rest_framework/utils/model_meta.py | 46 | ||||
| -rw-r--r-- | tests/models.py | 7 | ||||
| -rw-r--r-- | tests/test_model_field_mappings.py | 16 | ||||
| -rw-r--r-- | tests/test_response.py | 9 | ||||
| -rw-r--r-- | tests/test_routers.py | 7 |
9 files changed, 384 insertions, 395 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 1818e705..0c78b3fb 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -80,10 +80,6 @@ def set_value(dictionary, keys, value): dictionary[keys[-1]] = value -def field_name_to_label(field_name): - return field_name.replace('_', ' ').capitalize() - - class SkipField(Exception): pass @@ -162,7 +158,7 @@ class Field(object): # `self.label` should deafult to being based on the field name. if self.label is None: - self.label = field_name_to_label(self.field_name) + self.label = field_name.replace('_', ' ').capitalize() # self.source should default to being the same as the field name. if self.source is None: diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 46fe55ef..9f44ab63 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -73,7 +73,8 @@ class HyperlinkedRelatedField(RelatedField): 'incorrect_type': 'Incorrect type. Expected URL string, received {data_type}.', } - def __init__(self, view_name, **kwargs): + def __init__(self, view_name=None, **kwargs): + assert view_name is not None, 'The `view_name` argument is required.' self.view_name = view_name self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field) @@ -182,7 +183,8 @@ class HyperlinkedIdentityField(HyperlinkedRelatedField): URL of relationships to other objects. """ - def __init__(self, view_name, **kwargs): + def __init__(self, view_name=None, **kwargs): + assert view_name is not None, 'The `view_name` argument is required.' kwargs['read_only'] = True kwargs['source'] = '*' super(HyperlinkedIdentityField, self).__init__(view_name, **kwargs) @@ -199,7 +201,8 @@ class SlugRelatedField(RelatedField): 'invalid': _('Invalid value.'), } - def __init__(self, slug_field, **kwargs): + def __init__(self, slug_field=None, **kwargs): + assert slug_field is not None, 'The `slug_field` argument is required.' self.slug_field = slug_field super(SlugRelatedField, self).__init__(**kwargs) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 1fea1380..99dcc349 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -10,17 +10,19 @@ python primitives. 2. The process of marshalling between python primitives and request and response content is handled by parsers and renderers. """ -from django.core import validators from django.core.exceptions import ValidationError from django.db import models from django.utils import six from django.utils.datastructures import SortedDict -from django.utils.text import capfirst from collections import namedtuple -from rest_framework.compat import clean_manytomany_helptext from rest_framework.fields import empty, set_value, Field, SkipField from rest_framework.settings import api_settings from rest_framework.utils import html, model_meta, representation +from rest_framework.utils.field_mapping import ( + get_url_kwargs, get_field_kwargs, + get_relation_kwargs, get_nested_relation_kwargs, + lookup_class +) import copy # Note: We do the following so that users of the framework can use this style: @@ -126,7 +128,7 @@ class SerializerMetaclass(type): """ @classmethod - def _get_fields(cls, bases, attrs): + def _get_declared_fields(cls, bases, attrs): fields = [(field_name, attrs.pop(field_name)) for field_name, obj in list(attrs.items()) if isinstance(obj, Field)] @@ -136,25 +138,18 @@ class SerializerMetaclass(type): # fields. Note that we loop over the bases in *reverse*. This is necessary # in order to maintain the correct order of fields. for base in bases[::-1]: - if hasattr(base, 'base_fields'): - fields = list(base.base_fields.items()) + fields + if hasattr(base, '_declared_fields'): + fields = list(base._declared_fields.items()) + fields return SortedDict(fields) def __new__(cls, name, bases, attrs): - attrs['base_fields'] = cls._get_fields(bases, attrs) + attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs) return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs) @six.add_metaclass(SerializerMetaclass) class Serializer(BaseSerializer): - - def __new__(cls, *args, **kwargs): - if kwargs.pop('many', False): - kwargs['child'] = cls() - return ListSerializer(*args, **kwargs) - return super(Serializer, cls).__new__(cls, *args, **kwargs) - def __init__(self, *args, **kwargs): self.context = kwargs.pop('context', {}) kwargs.pop('partial', None) @@ -165,14 +160,22 @@ class Serializer(BaseSerializer): # Every new serializer is created with a clone of the field instances. # This allows users to dynamically modify the fields on a serializer # instance without affecting every other serializer class. - self.fields = self.get_fields() + self.fields = self._get_base_fields() # Setup all the child fields, to provide them with the current context. for field_name, field in self.fields.items(): field.bind(field_name, self, self) - def get_fields(self): - return copy.deepcopy(self.base_fields) + def __new__(cls, *args, **kwargs): + # We override this method in order to automagically create + # `ListSerializer` classes instead when `many=True` is set. + if kwargs.pop('many', False): + kwargs['child'] = cls() + return ListSerializer(*args, **kwargs) + return super(Serializer, cls).__new__(cls, *args, **kwargs) + + def _get_base_fields(self): + return copy.deepcopy(self._declared_fields) def bind(self, field_name, parent, root): # If the serializer is used as a field then when it becomes bound @@ -312,39 +315,8 @@ class ListSerializer(BaseSerializer): return representation.list_repr(self, indent=1) -class ModelSerializerOptions(object): - """ - Meta class options for ModelSerializer - """ - def __init__(self, meta): - self.model = getattr(meta, 'model') - self.fields = getattr(meta, 'fields', ()) - self.depth = getattr(meta, 'depth', 0) - - -def lookup_class(mapping, instance): - """ - Takes a dictionary with classes as keys, and an object. - Traverses the object's inheritance hierarchy in method - resolution order, and returns the first matching value - from the dictionary or raises a KeyError if nothing matches. - """ - for cls in inspect.getmro(instance.__class__): - if cls in mapping: - return mapping[cls] - raise KeyError('Class %s not found in lookup.', cls.__name__) - - -def needs_label(model_field, field_name): - """ - Returns `True` if the label based on the model's verbose name - is not equal to the default label it would have based on it's field name. - """ - return capfirst(model_field.verbose_name) != field_name_to_label(field_name) - - class ModelSerializer(Serializer): - field_mapping = { + _field_mapping = { models.AutoField: IntegerField, models.BigIntegerField: IntegerField, models.BooleanField: BooleanField, @@ -368,16 +340,10 @@ class ModelSerializer(Serializer): models.TimeField: TimeField, models.URLField: URLField, } - nested_class = None # We fill this in at the end of this module. - - _options_class = ModelSerializerOptions - - def __init__(self, *args, **kwargs): - self.opts = self._options_class(self.Meta) - super(ModelSerializer, self).__init__(*args, **kwargs) + _related_class = PrimaryKeyRelatedField def create(self, attrs): - ModelClass = self.opts.model + ModelClass = self.Meta.model return ModelClass.objects.create(**attrs) def update(self, obj, attrs): @@ -385,319 +351,97 @@ class ModelSerializer(Serializer): setattr(obj, attr, value) obj.save() - def get_fields(self): - # Get the explicitly declared fields. - fields = copy.deepcopy(self.base_fields) + def _get_base_fields(self): + declared_fields = copy.deepcopy(self._declared_fields) - # Add in the default fields. - for key, val in self.get_default_fields().items(): - if key not in fields: - fields[key] = val - - # If `fields` is set on the `Meta` class, - # then use only those fields, and in that order. - if self.opts.fields: - fields = SortedDict([ - (key, fields[key]) for key in self.opts.fields - ]) - - return fields - - def get_default_fields(self): - """ - Return all the fields that should be serialized for the model. - """ - info = model_meta.get_field_info(self.opts.model) ret = SortedDict() + model = getattr(self.Meta, 'model') + fields = getattr(self.Meta, 'fields', None) + depth = getattr(self.Meta, 'depth', 0) + + # Retrieve metadata about fields & relationships on the model class. + info = model_meta.get_field_info(model) + + # Use the default set of fields if none is supplied explicitly. + if fields is None: + fields = self._get_default_field_names(declared_fields, info) + + for field_name in fields: + if field_name in declared_fields: + # Field is explicitly declared on the class, use that. + ret[field_name] = declared_fields[field_name] + continue + + elif field_name == api_settings.URL_FIELD_NAME: + # Create the URL field. + field_cls = HyperlinkedIdentityField + kwargs = get_url_kwargs(model) + + elif field_name in info.fields_and_pk: + # Create regular model fields. + model_field = info.fields_and_pk[field_name] + field_cls = lookup_class(self._field_mapping, model_field) + kwargs = get_field_kwargs(field_name, model_field) + if 'choices' in kwargs: + # Fields with choices get coerced into `ChoiceField` + # instead of using their regular typed field. + field_cls = ChoiceField + if not issubclass(field_cls, ModelField): + # `model_field` is only valid for the fallback case of + # `ModelField`, which is used when no other typed field + # matched to the model field. + kwargs.pop('model_field', None) + + elif field_name in info.relations: + # Create forward and reverse relationships. + relation_info = info.relations[field_name] + if depth: + field_cls = self._get_nested_class(depth, relation_info) + kwargs = get_nested_relation_kwargs(relation_info) + else: + field_cls = self._related_class + kwargs = get_relation_kwargs(field_name, relation_info) + # `view_name` is only valid for hyperlinked relationships. + if not issubclass(field_cls, HyperlinkedRelatedField): + kwargs.pop('view_name', None) - # URL field - serializer_url_field = self.get_url_field() - if serializer_url_field: - ret[api_settings.URL_FIELD_NAME] = serializer_url_field - - # Primary key field - field_name = info.pk.name - serializer_pk_field = self.get_pk_field(field_name, info.pk) - if serializer_pk_field: - ret[field_name] = serializer_pk_field - - # Regular fields - for field_name, field in info.fields.items(): - ret[field_name] = self.get_field(field_name, field) - - # Forward relations - for field_name, relation_info in info.forward_relations.items(): - if self.opts.depth: - ret[field_name] = self.get_nested_field(field_name, *relation_info) else: - ret[field_name] = self.get_related_field(field_name, *relation_info) + assert False, 'Field name `%s` is not valid.' % field_name - # Reverse relations - for accessor_name, relation_info in info.reverse_relations.items(): - if accessor_name in self.opts.fields: - if self.opts.depth: - ret[accessor_name] = self.get_nested_field(accessor_name, *relation_info) - else: - ret[accessor_name] = self.get_related_field(accessor_name, *relation_info) + ret[field_name] = field_cls(**kwargs) return ret - def get_url_field(self): - return None - - def get_pk_field(self, field_name, model_field): - """ - Returns a default instance of the pk field. - """ - return self.get_field(field_name, model_field) - - def get_nested_field(self, field_name, model_field, related_model, to_many, has_through_model): - """ - Creates a default instance of a nested relational field. + def _get_default_field_names(self, declared_fields, model_info): + return ( + [model_info.pk.name] + + list(declared_fields.keys()) + + list(model_info.fields.keys()) + + list(model_info.forward_relations.keys()) + ) - Note that model_field will be `None` for reverse relationships. - """ - class NestedModelSerializer(self.nested_class): + def _get_nested_class(self, nested_depth, relation_info): + class NestedSerializer(ModelSerializer): class Meta: - model = related_model - depth = self.opts.depth - 1 - - kwargs = {'read_only': True} - if to_many: - kwargs['many'] = True - return NestedModelSerializer(**kwargs) - - def get_related_field(self, field_name, model_field, related_model, to_many, has_through_model): - """ - Creates a default instance of a flat relational field. - - Note that model_field will be `None` for reverse relationships. - """ - kwargs = { - 'queryset': related_model._default_manager, - } - - if to_many: - kwargs['many'] = True - - if has_through_model: - kwargs['read_only'] = True - kwargs.pop('queryset', None) - - if model_field: - if model_field.null or model_field.blank: - kwargs['required'] = False - if model_field.verbose_name and needs_label(model_field, field_name): - kwargs['label'] = capfirst(model_field.verbose_name) - if not model_field.editable: - kwargs['read_only'] = True - kwargs.pop('queryset', None) - help_text = clean_manytomany_helptext(model_field.help_text) - if help_text: - kwargs['help_text'] = help_text - - return PrimaryKeyRelatedField(**kwargs) - - def get_field(self, field_name, model_field): - """ - Creates a default instance of a basic non-relational field. - """ - serializer_cls = lookup_class(self.field_mapping, model_field) - kwargs = {} - validator_kwarg = model_field.validators - - if model_field.null or model_field.blank: - kwargs['required'] = False - - if model_field.verbose_name and needs_label(model_field, field_name): - kwargs['label'] = capfirst(model_field.verbose_name) - - if model_field.help_text: - kwargs['help_text'] = model_field.help_text - - if isinstance(model_field, models.AutoField) or not model_field.editable: - kwargs['read_only'] = True - # Read only implies that the field is not required. - # We have a cleaner repr on the instance if we don't set it. - kwargs.pop('required', None) - - if model_field.has_default(): - kwargs['default'] = model_field.get_default() - # Having a default implies that the field is not required. - # We have a cleaner repr on the instance if we don't set it. - kwargs.pop('required', None) - - if model_field.flatchoices: - # If this model field contains choices, then use a ChoiceField, - # rather than the standard serializer field for this type. - # Note that we return this prior to setting any validation type - # keyword arguments, as those are not valid initializers. - kwargs['choices'] = model_field.flatchoices - return ChoiceField(**kwargs) - - # Ensure that max_length is passed explicitly as a keyword arg, - # rather than as a validator. - max_length = getattr(model_field, 'max_length', None) - if max_length is not None: - kwargs['max_length'] = max_length - validator_kwarg = [ - validator for validator in validator_kwarg - if not isinstance(validator, validators.MaxLengthValidator) - ] - - # Ensure that min_length is passed explicitly as a keyword arg, - # rather than as a validator. - min_length = getattr(model_field, 'min_length', None) - if min_length is not None: - kwargs['min_length'] = min_length - validator_kwarg = [ - validator for validator in validator_kwarg - if not isinstance(validator, validators.MinLengthValidator) - ] - - # Ensure that max_value is passed explicitly as a keyword arg, - # rather than as a validator. - max_value = next(( - validator.limit_value for validator in validator_kwarg - if isinstance(validator, validators.MaxValueValidator) - ), None) - if max_value is not None: - kwargs['max_value'] = max_value - validator_kwarg = [ - validator for validator in validator_kwarg - if not isinstance(validator, validators.MaxValueValidator) - ] - - # Ensure that max_value is passed explicitly as a keyword arg, - # rather than as a validator. - min_value = next(( - validator.limit_value for validator in validator_kwarg - if isinstance(validator, validators.MinValueValidator) - ), None) - if min_value is not None: - kwargs['min_value'] = min_value - validator_kwarg = [ - validator for validator in validator_kwarg - if not isinstance(validator, validators.MinValueValidator) - ] - - # URLField does not need to include the URLValidator argument, - # as it is explicitly added in. - if isinstance(model_field, models.URLField): - validator_kwarg = [ - validator for validator in validator_kwarg - if not isinstance(validator, validators.URLValidator) - ] - - # EmailField does not need to include the validate_email argument, - # as it is explicitly added in. - if isinstance(model_field, models.EmailField): - validator_kwarg = [ - validator for validator in validator_kwarg - if validator is not validators.validate_email - ] - - # SlugField do not need to include the 'validate_slug' argument, - if isinstance(model_field, models.SlugField): - validator_kwarg = [ - validator for validator in validator_kwarg - if validator is not validators.validate_slug - ] - - max_digits = getattr(model_field, 'max_digits', None) - if max_digits is not None: - kwargs['max_digits'] = max_digits - - decimal_places = getattr(model_field, 'decimal_places', None) - if decimal_places is not None: - kwargs['decimal_places'] = decimal_places - - if isinstance(model_field, models.BooleanField): - # models.BooleanField has `blank=True`, but *is* actually - # required *unless* a default is provided. - # Also note that Django<1.6 uses `default=False` for - # models.BooleanField, but Django>=1.6 uses `default=None`. - kwargs.pop('required', None) - - if validator_kwarg: - kwargs['validators'] = validator_kwarg - - if issubclass(serializer_cls, ModelField): - kwargs['model_field'] = model_field - - return serializer_cls(**kwargs) - - -class HyperlinkedModelSerializerOptions(ModelSerializerOptions): - """ - Options for HyperlinkedModelSerializer - """ - def __init__(self, meta): - super(HyperlinkedModelSerializerOptions, self).__init__(meta) - self.view_name = getattr(meta, 'view_name', None) - self.lookup_field = getattr(meta, 'lookup_field', None) + model = relation_info.related + depth = nested_depth + return NestedSerializer class HyperlinkedModelSerializer(ModelSerializer): - _options_class = HyperlinkedModelSerializerOptions - - def get_url_field(self): - if self.opts.view_name is not None: - view_name = self.opts.view_name - else: - view_name = self.get_default_view_name(self.opts.model) - - kwargs = { - 'view_name': view_name - } - if self.opts.lookup_field: - kwargs['lookup_field'] = self.opts.lookup_field - - return HyperlinkedIdentityField(**kwargs) - - def get_pk_field(self, field_name, model_field): - if self.opts.fields and model_field.name in self.opts.fields: - return self.get_field(model_field) - - def get_related_field(self, field_name, model_field, related_model, to_many, has_through_model): - """ - Creates a default instance of a flat relational field. - """ - kwargs = { - 'queryset': related_model._default_manager, - 'view_name': self.get_default_view_name(related_model), - } - - if to_many: - kwargs['many'] = True - - if has_through_model: - kwargs['read_only'] = True - kwargs.pop('queryset', None) - - if model_field: - if model_field.null or model_field.blank: - kwargs['required'] = False - if model_field.verbose_name and needs_label(model_field, field_name): - kwargs['label'] = capfirst(model_field.verbose_name) - if not model_field.editable: - kwargs['read_only'] = True - kwargs.pop('queryset', None) - help_text = clean_manytomany_helptext(model_field.help_text) - if help_text: - kwargs['help_text'] = help_text - - return HyperlinkedRelatedField(**kwargs) - - def get_default_view_name(self, model): - """ - Return the view name to use for related models. - """ - return '%(model_name)s-detail' % { - 'app_label': model._meta.app_label, - 'model_name': model._meta.object_name.lower() - } - - -ModelSerializer.nested_class = ModelSerializer -HyperlinkedModelSerializer.nested_class = HyperlinkedModelSerializer + _related_class = HyperlinkedRelatedField + + def _get_default_field_names(self, declared_fields, model_info): + return ( + [api_settings.URL_FIELD_NAME] + + list(declared_fields.keys()) + + list(model_info.fields.keys()) + + list(model_info.forward_relations.keys()) + ) + + def _get_nested_class(self, nested_depth, relation_info): + class NestedSerializer(HyperlinkedModelSerializer): + class Meta: + model = relation_info.related + depth = nested_depth + return NestedSerializer diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py new file mode 100644 index 00000000..be72e444 --- /dev/null +++ b/rest_framework/utils/field_mapping.py @@ -0,0 +1,215 @@ +""" +Helper functions for mapping model fields to a dictionary of default +keyword arguments that should be used for their equivelent serializer fields. +""" +from django.core import validators +from django.db import models +from django.utils.text import capfirst +from rest_framework.compat import clean_manytomany_helptext +import inspect + + +def lookup_class(mapping, instance): + """ + Takes a dictionary with classes as keys, and an object. + Traverses the object's inheritance hierarchy in method + resolution order, and returns the first matching value + from the dictionary or raises a KeyError if nothing matches. + """ + for cls in inspect.getmro(instance.__class__): + if cls in mapping: + return mapping[cls] + raise KeyError('Class %s not found in lookup.', cls.__name__) + + +def needs_label(model_field, field_name): + """ + Returns `True` if the label based on the model's verbose name + is not equal to the default label it would have based on it's field name. + """ + default_label = field_name.replace('_', ' ').capitalize() + return capfirst(model_field.verbose_name) != default_label + + +def get_detail_view_name(model): + """ + Given a model class, return the view name to use for URL relationships + that refer to instances of the model. + """ + return '%(model_name)s-detail' % { + 'app_label': model._meta.app_label, + 'model_name': model._meta.object_name.lower() + } + + +def get_field_kwargs(field_name, model_field): + """ + Creates a default instance of a basic non-relational field. + """ + kwargs = {} + validator_kwarg = model_field.validators + + if model_field.null or model_field.blank: + kwargs['required'] = False + + if model_field.verbose_name and needs_label(model_field, field_name): + kwargs['label'] = capfirst(model_field.verbose_name) + + if model_field.help_text: + kwargs['help_text'] = model_field.help_text + + if isinstance(model_field, models.AutoField) or not model_field.editable: + kwargs['read_only'] = True + # Read only implies that the field is not required. + # We have a cleaner repr on the instance if we don't set it. + kwargs.pop('required', None) + + if model_field.has_default(): + kwargs['default'] = model_field.get_default() + # Having a default implies that the field is not required. + # We have a cleaner repr on the instance if we don't set it. + kwargs.pop('required', None) + + if model_field.flatchoices: + # If this model field contains choices, then return now, + # any further keyword arguments are not valid. + kwargs['choices'] = model_field.flatchoices + return kwargs + + # Ensure that max_length is passed explicitly as a keyword arg, + # rather than as a validator. + max_length = getattr(model_field, 'max_length', None) + if max_length is not None: + kwargs['max_length'] = max_length + validator_kwarg = [ + validator for validator in validator_kwarg + if not isinstance(validator, validators.MaxLengthValidator) + ] + + # Ensure that min_length is passed explicitly as a keyword arg, + # rather than as a validator. + min_length = getattr(model_field, 'min_length', None) + if min_length is not None: + kwargs['min_length'] = min_length + validator_kwarg = [ + validator for validator in validator_kwarg + if not isinstance(validator, validators.MinLengthValidator) + ] + + # Ensure that max_value is passed explicitly as a keyword arg, + # rather than as a validator. + max_value = next(( + validator.limit_value for validator in validator_kwarg + if isinstance(validator, validators.MaxValueValidator) + ), None) + if max_value is not None: + kwargs['max_value'] = max_value + validator_kwarg = [ + validator for validator in validator_kwarg + if not isinstance(validator, validators.MaxValueValidator) + ] + + # Ensure that max_value is passed explicitly as a keyword arg, + # rather than as a validator. + min_value = next(( + validator.limit_value for validator in validator_kwarg + if isinstance(validator, validators.MinValueValidator) + ), None) + if min_value is not None: + kwargs['min_value'] = min_value + validator_kwarg = [ + validator for validator in validator_kwarg + if not isinstance(validator, validators.MinValueValidator) + ] + + # URLField does not need to include the URLValidator argument, + # as it is explicitly added in. + if isinstance(model_field, models.URLField): + validator_kwarg = [ + validator for validator in validator_kwarg + if not isinstance(validator, validators.URLValidator) + ] + + # EmailField does not need to include the validate_email argument, + # as it is explicitly added in. + if isinstance(model_field, models.EmailField): + validator_kwarg = [ + validator for validator in validator_kwarg + if validator is not validators.validate_email + ] + + # SlugField do not need to include the 'validate_slug' argument, + if isinstance(model_field, models.SlugField): + validator_kwarg = [ + validator for validator in validator_kwarg + if validator is not validators.validate_slug + ] + + max_digits = getattr(model_field, 'max_digits', None) + if max_digits is not None: + kwargs['max_digits'] = max_digits + + decimal_places = getattr(model_field, 'decimal_places', None) + if decimal_places is not None: + kwargs['decimal_places'] = decimal_places + + if isinstance(model_field, models.BooleanField): + # models.BooleanField has `blank=True`, but *is* actually + # required *unless* a default is provided. + # Also note that Django<1.6 uses `default=False` for + # models.BooleanField, but Django>=1.6 uses `default=None`. + kwargs.pop('required', None) + + if validator_kwarg: + kwargs['validators'] = validator_kwarg + + # The following will only be used by ModelField classes. + # Gets removed for everything else. + kwargs['model_field'] = model_field + + return kwargs + + +def get_relation_kwargs(field_name, relation_info): + """ + Creates a default instance of a flat relational field. + """ + model_field, related_model, to_many, has_through_model = relation_info + kwargs = { + 'queryset': related_model._default_manager, + 'view_name': get_detail_view_name(related_model) + } + + if to_many: + kwargs['many'] = True + + if has_through_model: + kwargs['read_only'] = True + kwargs.pop('queryset', None) + + if model_field: + if model_field.null or model_field.blank: + kwargs['required'] = False + if model_field.verbose_name and needs_label(model_field, field_name): + kwargs['label'] = capfirst(model_field.verbose_name) + if not model_field.editable: + kwargs['read_only'] = True + kwargs.pop('queryset', None) + help_text = clean_manytomany_helptext(model_field.help_text) + if help_text: + kwargs['help_text'] = help_text + + return kwargs + + +def get_nested_relation_kwargs(relation_info): + kwargs = {'read_only': True} + if relation_info.to_many: + kwargs['many'] = True + return kwargs + + +def get_url_kwargs(model_field): + return { + 'view_name': get_detail_view_name(model_field) + } diff --git a/rest_framework/utils/model_meta.py b/rest_framework/utils/model_meta.py index 960fa4d0..b6c41174 100644 --- a/rest_framework/utils/model_meta.py +++ b/rest_framework/utils/model_meta.py @@ -1,7 +1,9 @@ """ -Helper functions for returning the field information that is associated +Helper function for returning the field information that is associated with a model class. This includes returning all the forward and reverse relationships and their associated metadata. + +Usage: `get_field_info(model)` returns a `FieldInfo` instance. """ from collections import namedtuple from django.db import models @@ -9,8 +11,22 @@ from django.utils import six from django.utils.datastructures import SortedDict import inspect -FieldInfo = namedtuple('FieldResult', ['pk', 'fields', 'forward_relations', 'reverse_relations']) -RelationInfo = namedtuple('RelationInfo', ['field', 'related', 'to_many', 'has_through_model']) + +FieldInfo = namedtuple('FieldResult', [ + 'pk', # Model field instance + 'fields', # Dict of field name -> model field instance + 'forward_relations', # Dict of field name -> RelationInfo + 'reverse_relations', # Dict of field name -> RelationInfo + 'fields_and_pk', # Shortcut for 'pk' + 'fields' + 'relations' # Shortcut for 'forward_relations' + 'reverse_relations' +]) + +RelationInfo = namedtuple('RelationInfo', [ + 'model_field', + 'related', + 'to_many', + 'has_through_model' +]) def _resolve_model(obj): @@ -55,7 +71,7 @@ def get_field_info(model): forward_relations = SortedDict() for field in [field for field in opts.fields if field.serialize and field.rel]: forward_relations[field.name] = RelationInfo( - field=field, + model_field=field, related=_resolve_model(field.rel.to), to_many=False, has_through_model=False @@ -64,7 +80,7 @@ def get_field_info(model): # Deal with forward many-to-many relationships. for field in [field for field in opts.many_to_many if field.serialize]: forward_relations[field.name] = RelationInfo( - field=field, + model_field=field, related=_resolve_model(field.rel.to), to_many=True, has_through_model=( @@ -77,7 +93,7 @@ def get_field_info(model): for relation in opts.get_all_related_objects(): accessor_name = relation.get_accessor_name() reverse_relations[accessor_name] = RelationInfo( - field=None, + model_field=None, related=relation.model, to_many=relation.field.rel.multiple, has_through_model=False @@ -87,7 +103,7 @@ def get_field_info(model): for relation in opts.get_all_related_many_to_many_objects(): accessor_name = relation.get_accessor_name() reverse_relations[accessor_name] = RelationInfo( - field=None, + model_field=None, related=relation.model, to_many=True, has_through_model=( @@ -96,4 +112,18 @@ def get_field_info(model): ) ) - return FieldInfo(pk, fields, forward_relations, reverse_relations) + # Shortcut that merges both regular fields and the pk, + # for simplifying regular field lookup. + fields_and_pk = SortedDict() + fields_and_pk['pk'] = pk + fields_and_pk[pk.name] = pk + fields_and_pk.update(fields) + + # Shortcut that merges both forward and reverse relationships + + relations = SortedDict( + list(forward_relations.items()) + + list(reverse_relations.items()) + ) + + return FieldInfo(pk, fields, forward_relations, reverse_relations, fields_and_pk, relations) diff --git a/tests/models.py b/tests/models.py index fe064b46..06ec5a22 100644 --- a/tests/models.py +++ b/tests/models.py @@ -1,7 +1,6 @@ from __future__ import unicode_literals from django.db import models from django.utils.translation import ugettext_lazy as _ -from rest_framework import serializers def foobar(): @@ -178,9 +177,3 @@ class NullableOneToOneSource(RESTFrameworkModel): name = models.CharField(max_length=100) target = models.OneToOneField(OneToOneTarget, null=True, blank=True, related_name='nullable_source') - - -# Serializer used to test BasicModel -class BasicModelSerializer(serializers.ModelSerializer): - class Meta: - model = BasicModel diff --git a/tests/test_model_field_mappings.py b/tests/test_model_field_mappings.py index bae63e5a..6daa574e 100644 --- a/tests/test_model_field_mappings.py +++ b/tests/test_model_field_mappings.py @@ -126,16 +126,16 @@ class TestRelationalFieldMappings(TestCase): expected = dedent(""" TestSerializer(): id = IntegerField(label='ID', read_only=True) - foreign_key = NestedModelSerializer(read_only=True): + foreign_key = NestedSerializer(read_only=True): id = IntegerField(label='ID', read_only=True) name = CharField(max_length=100) - one_to_one = NestedModelSerializer(read_only=True): + one_to_one = NestedSerializer(read_only=True): id = IntegerField(label='ID', read_only=True) name = CharField(max_length=100) - many_to_many = NestedModelSerializer(many=True, read_only=True): + many_to_many = NestedSerializer(many=True, read_only=True): id = IntegerField(label='ID', read_only=True) name = CharField(max_length=100) - through = NestedModelSerializer(many=True, read_only=True): + through = NestedSerializer(many=True, read_only=True): id = IntegerField(label='ID', read_only=True) name = CharField(max_length=100) """) @@ -165,16 +165,16 @@ class TestRelationalFieldMappings(TestCase): expected = dedent(""" TestSerializer(): url = HyperlinkedIdentityField(view_name='relationalmodel-detail') - foreign_key = NestedModelSerializer(read_only=True): + foreign_key = NestedSerializer(read_only=True): url = HyperlinkedIdentityField(view_name='foreignkeytargetmodel-detail') name = CharField(max_length=100) - one_to_one = NestedModelSerializer(read_only=True): + one_to_one = NestedSerializer(read_only=True): url = HyperlinkedIdentityField(view_name='onetoonetargetmodel-detail') name = CharField(max_length=100) - many_to_many = NestedModelSerializer(many=True, read_only=True): + many_to_many = NestedSerializer(many=True, read_only=True): url = HyperlinkedIdentityField(view_name='manytomanytargetmodel-detail') name = CharField(max_length=100) - through = NestedModelSerializer(many=True, read_only=True): + through = NestedSerializer(many=True, read_only=True): url = HyperlinkedIdentityField(view_name='throughtargetmodel-detail') name = CharField(max_length=100) """) diff --git a/tests/test_response.py b/tests/test_response.py index 67419a71..84c39c1a 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -2,11 +2,12 @@ from __future__ import unicode_literals from django.conf.urls import patterns, url, include from django.test import TestCase from django.utils import six -from tests.models import BasicModel, BasicModelSerializer +from tests.models import BasicModel from rest_framework.response import Response from rest_framework.views import APIView from rest_framework import generics from rest_framework import routers +from rest_framework import serializers from rest_framework import status from rest_framework.renderers import ( BaseRenderer, @@ -17,6 +18,12 @@ from rest_framework import viewsets from rest_framework.settings import api_settings +# Serializer used to test BasicModel +class BasicModelSerializer(serializers.ModelSerializer): + class Meta: + model = BasicModel + + class MockPickleRenderer(BaseRenderer): media_type = 'application/pickle' diff --git a/tests/test_routers.py b/tests/test_routers.py index b076f134..c2d595f7 100644 --- a/tests/test_routers.py +++ b/tests/test_routers.py @@ -76,9 +76,10 @@ class TestCustomLookupFields(TestCase): def setUp(self): class NoteSerializer(serializers.HyperlinkedModelSerializer): + url = serializers.HyperlinkedIdentityField(view_name='routertestmodel-detail', lookup_field='uuid') + class Meta: model = RouterTestModel - lookup_field = 'uuid' fields = ('url', 'uuid', 'text') class NoteViewSet(viewsets.ModelViewSet): @@ -86,8 +87,6 @@ class TestCustomLookupFields(TestCase): serializer_class = NoteSerializer lookup_field = 'uuid' - RouterTestModel.objects.create(uuid='123', text='foo bar') - self.router = SimpleRouter() self.router.register(r'notes', NoteViewSet) @@ -98,6 +97,8 @@ class TestCustomLookupFields(TestCase): url(r'^', include(self.router.urls)), ) + RouterTestModel.objects.create(uuid='123', text='foo bar') + def test_custom_lookup_field_route(self): detail_route = self.router.urls[-1] detail_url_pattern = detail_route.regex.pattern |
