diff options
Diffstat (limited to 'rest_framework/serializers.py')
| -rw-r--r-- | rest_framework/serializers.py | 1737 |
1 files changed, 730 insertions, 1007 deletions
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 7d85894f..f00b685f 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -10,22 +10,28 @@ python primitives. 2. The process of marshalling between python primitives and request and response content is handled by parsers and renderers. """ -from __future__ import unicode_literals -import copy -import datetime -import inspect -import types -from decimal import Decimal -from django.contrib.contenttypes.generic import GenericForeignKey -from django.core.paginator import Page +from django.core.exceptions import ImproperlyConfigured from django.db import models -from django.forms import widgets +from django.db.models.fields import FieldDoesNotExist from django.utils import six from django.utils.datastructures import SortedDict -from django.utils.functional import cached_property -from django.core.exceptions import ObjectDoesNotExist +from django.utils.translation import ugettext_lazy as _ +from rest_framework.exceptions import ValidationError +from rest_framework.fields import empty, set_value, Field, SkipField from rest_framework.settings import api_settings - +from rest_framework.utils import html, model_meta, representation +from rest_framework.utils.field_mapping import ( + get_url_kwargs, get_field_kwargs, + get_relation_kwargs, get_nested_relation_kwargs, + ClassLookupDict +) +from rest_framework.validators import ( + UniqueForDateValidator, UniqueForMonthValidator, UniqueForYearValidator, + UniqueTogetherValidator +) +import copy +import inspect +import warnings # Note: We do the following so that users of the framework can use this style: # @@ -38,1126 +44,843 @@ from rest_framework.relations import * # NOQA from rest_framework.fields import * # NOQA -def _resolve_model(obj): +# BaseSerializer +# -------------- + +class BaseSerializer(Field): + """ + The BaseSerializer class provides a minimal class which may be used + for writing custom serializer implementations. """ - Resolve supplied `obj` to a Django model class. - `obj` must be a Django model class itself, or a string - representation of one. Useful in situtations like GH #1225 where - Django may not have resolved a string-based reference to a model in - another model's foreign key definition. + def __init__(self, instance=None, data=None, **kwargs): + self.instance = instance + self._initial_data = data + self.partial = kwargs.pop('partial', False) + self._context = kwargs.pop('context', {}) + kwargs.pop('many', None) + super(BaseSerializer, self).__init__(**kwargs) - String representations should have the format: - 'appname.ModelName' - """ - if isinstance(obj, six.string_types) and len(obj.split('.')) == 2: - app_name, model_name = obj.split('.') - return models.get_model(app_name, model_name) - elif inspect.isclass(obj) and issubclass(obj, models.Model): - return obj - else: - raise ValueError("{0} is not a Django model".format(obj)) + def __new__(cls, *args, **kwargs): + # We override this method in order to automagically create + # `ListSerializer` classes instead when `many=True` is set. + if kwargs.pop('many', False): + kwargs['child'] = cls() + return ListSerializer(*args, **kwargs) + return super(BaseSerializer, cls).__new__(cls, *args, **kwargs) + def to_internal_value(self, data): + raise NotImplementedError('`to_internal_value()` must be implemented.') -def pretty_name(name): - """Converts 'first_name' to 'First name'""" - if not name: - return '' - return name.replace('_', ' ').capitalize() + def to_representation(self, instance): + raise NotImplementedError('`to_representation()` must be implemented.') + def update(self, instance, validated_data): + raise NotImplementedError('`update()` must be implemented.') -class RelationsList(list): - _deleted = [] + def create(self, validated_data): + raise NotImplementedError('`create()` must be implemented.') + def save(self, **kwargs): + validated_data = self.validated_data + if kwargs: + validated_data = dict( + list(validated_data.items()) + + list(kwargs.items()) + ) -class NestedValidationError(ValidationError): - """ - The default ValidationError behavior is to stringify each item in the list - if the messages are a list of error messages. + if self.instance is not None: + self.instance = self.update(self.instance, validated_data) + assert self.instance is not None, ( + '`update()` did not return an object instance.' + ) + else: + self.instance = self.create(validated_data) + assert self.instance is not None, ( + '`create()` did not return an object instance.' + ) - In the case of nested serializers, where the parent has many children, - then the child's `serializer.errors` will be a list of dicts. In the case - of a single child, the `serializer.errors` will be a dict. + return self.instance - We need to override the default behavior to get properly nested error dicts. - """ + def is_valid(self, raise_exception=False): + assert not hasattr(self, 'restore_object'), ( + 'Serializer `%s.%s` has old-style version 2 `.restore_object()` ' + 'that is no longer compatible with REST framework 3. ' + 'Use the new-style `.create()` and `.update()` methods instead.' % + (self.__class__.__module__, self.__class__.__name__) + ) - def __init__(self, message): - if isinstance(message, dict): - self._messages = [message] - else: - self._messages = message + if not hasattr(self, '_validated_data'): + try: + self._validated_data = self.run_validation(self._initial_data) + except ValidationError as exc: + self._validated_data = {} + self._errors = exc.detail + else: + self._errors = {} + + if self._errors and raise_exception: + raise ValidationError(self._errors) + + return not bool(self._errors) @property - def messages(self): - return self._messages + def data(self): + if not hasattr(self, '_data'): + if self.instance is not None and not getattr(self, '_errors', None): + self._data = self.to_representation(self.instance) + elif hasattr(self, '_validated_data') and not getattr(self, '_errors', None): + self._data = self.to_representation(self.validated_data) + else: + self._data = self.get_initial() + return self._data + @property + def errors(self): + if not hasattr(self, '_errors'): + msg = 'You must call `.is_valid()` before accessing `.errors`.' + raise AssertionError(msg) + return self._errors + + @property + def validated_data(self): + if not hasattr(self, '_validated_data'): + msg = 'You must call `.is_valid()` before accessing `.validated_data`.' + raise AssertionError(msg) + return self._validated_data -class DictWithMetadata(dict): - """ - A dict-like object, that can have additional properties attached. - """ - def __getstate__(self): - """ - Used by pickle (e.g., caching). - Overridden to remove the metadata from the dict, since it shouldn't be - pickled and may in some instances be unpickleable. - """ - return dict(self) +# Serializer & ListSerializer classes +# ----------------------------------- -class SortedDictWithMetadata(SortedDict): +class ReturnDict(SortedDict): """ - A sorted dict-like object, that can have additional properties attached. + Return object from `serialier.data` for the `Serializer` class. + Includes a backlink to the serializer instance for renderers + to use if they need richer field information. """ - def __getstate__(self): - """ - Used by pickle (e.g., caching). - Overriden to remove the metadata from the dict, since it shouldn't be - pickle and may in some instances be unpickleable. - """ - return SortedDict(self).__dict__ + def __init__(self, *args, **kwargs): + self.serializer = kwargs.pop('serializer') + super(ReturnDict, self).__init__(*args, **kwargs) -def _is_protected_type(obj): +class ReturnList(list): """ - True if the object is a native datatype that does not need to - be serialized further. + Return object from `serialier.data` for the `SerializerList` class. + Includes a backlink to the serializer instance for renderers + to use if they need richer field information. """ - return isinstance(obj, ( - types.NoneType, - int, long, - datetime.datetime, datetime.date, datetime.time, - float, Decimal, - basestring) - ) + def __init__(self, *args, **kwargs): + self.serializer = kwargs.pop('serializer') + super(ReturnList, self).__init__(*args, **kwargs) -def _get_declared_fields(bases, attrs): +class BoundField(object): """ - Create a list of serializer field instances from the passed in 'attrs', - plus any fields on the base classes (in 'bases'). - - Note that all fields from the base classes are used. + A field object that also includes `.value` and `.error` properties. + Returned when iterating over a serializer instance, + providing an API similar to Django forms and form fields. """ - fields = [(field_name, attrs.pop(field_name)) - for field_name, obj in list(six.iteritems(attrs)) - if isinstance(obj, Field)] - fields.sort(key=lambda x: x[1].creation_counter) + def __init__(self, field, value, errors, prefix=''): + self._field = field + self.value = value + self.errors = errors + self.name = prefix + self.field_name - # If this class is subclassing another Serializer, add that Serializer's - # fields. Note that we loop over the bases in *reverse*. This is necessary - # in order to maintain the correct order of fields. - for base in bases[::-1]: - if hasattr(base, 'base_fields'): - fields = list(base.base_fields.items()) + fields - - return SortedDict(fields) + def __getattr__(self, attr_name): + return getattr(self._field, attr_name) + @property + def _proxy_class(self): + return self._field.__class__ -class SerializerMetaclass(type): - def __new__(cls, name, bases, attrs): - attrs['base_fields'] = _get_declared_fields(bases, attrs) - return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs) + def __repr__(self): + return '<%s value=%s errors=%s>' % ( + self.__class__.__name__, self.value, self.errors + ) -class SerializerOptions(object): +class NestedBoundField(BoundField): """ - Meta class options for Serializer + This BoundField additionally implements __iter__ and __getitem__ + in order to support nested bound fields. This class is the type of + BoundField that is used for serializer fields. """ - def __init__(self, meta): - self.depth = getattr(meta, 'depth', 0) - self.fields = getattr(meta, 'fields', ()) - self.exclude = getattr(meta, 'exclude', ()) + def __iter__(self): + for field in self.fields.values(): + yield self[field.field_name] + def __getitem__(self, key): + field = self.fields[key] + value = self.value.get(key) if self.value else None + error = self.errors.get(key) if self.errors else None + if isinstance(field, Serializer): + return NestedBoundField(field, value, error, prefix=self.name + '.') + return BoundField(field, value, error, prefix=self.name + '.') -class BaseSerializer(WritableField): - """ - This is the Serializer implementation. - We need to implement it as `BaseSerializer` due to metaclass magicks. + +class BindingDict(object): """ - class Meta(object): - pass + This dict-like object is used to store fields on a serializer. - _options_class = SerializerOptions - _dict_class = SortedDictWithMetadata + This ensures that whenever fields are added to the serializer we call + `field.bind()` so that the `field_name` and `parent` attributes + can be set correctly. + """ + def __init__(self, serializer): + self.serializer = serializer + self.fields = SortedDict() - def __init__(self, instance=None, data=None, files=None, - context=None, partial=False, many=False, - allow_add_remove=False, **kwargs): - super(BaseSerializer, self).__init__(**kwargs) - self.opts = self._options_class(self.Meta) - self.parent = None - self.root = None - self.partial = partial - self.many = many - self.allow_add_remove = allow_add_remove + def __setitem__(self, key, field): + self.fields[key] = field + field.bind(field_name=key, parent=self.serializer) - self.context = context or {} + def __getitem__(self, key): + return self.fields[key] - self.init_data = data - self.init_files = files - self.object = instance + def __delitem__(self, key): + del self.fields[key] - self._data = None - self._files = None - self._errors = None + def items(self): + return self.fields.items() - if many and instance is not None and not hasattr(instance, '__iter__'): - raise ValueError('instance should be a queryset or other iterable with many=True') + def keys(self): + return self.fields.keys() - if allow_add_remove and not many: - raise ValueError('allow_add_remove should only be used for bulk updates, but you have not set many=True') + def values(self): + return self.fields.values() - ##### - # Methods to determine which fields to use when (de)serializing objects. - @cached_property - def fields(self): - return self.get_fields() +class SerializerMetaclass(type): + """ + This metaclass sets a dictionary named `base_fields` on the class. - def get_default_fields(self): - """ - Return the complete set of default fields for the object, as a dict. - """ - return {} + Any instances of `Field` included as attributes on either the class + or on any of its superclasses will be include in the + `base_fields` dictionary. + """ - def get_fields(self): - """ - Returns the complete set of fields for the object as a dict. + @classmethod + def _get_declared_fields(cls, bases, attrs): + fields = [(field_name, attrs.pop(field_name)) + for field_name, obj in list(attrs.items()) + if isinstance(obj, Field)] + fields.sort(key=lambda x: x[1]._creation_counter) - This will be the set of any explicitly declared fields, - plus the set of fields returned by get_default_fields(). - """ - ret = SortedDict() + # If this class is subclassing another Serializer, add that Serializer's + # fields. Note that we loop over the bases in *reverse*. This is necessary + # in order to maintain the correct order of fields. + for base in bases[::-1]: + if hasattr(base, '_declared_fields'): + fields = list(base._declared_fields.items()) + fields - # Get the explicitly declared fields - base_fields = copy.deepcopy(self.base_fields) - for key, field in base_fields.items(): - ret[key] = field - - # Add in the default fields - default_fields = self.get_default_fields() - for key, val in default_fields.items(): - if key not in ret: - ret[key] = val - - # If 'fields' is specified, use those fields, in that order. - if self.opts.fields: - assert isinstance(self.opts.fields, (list, tuple)), '`fields` must be a list or tuple' - new = SortedDict() - for key in self.opts.fields: - new[key] = ret[key] - ret = new - - # Remove anything in 'exclude' - if self.opts.exclude: - assert isinstance(self.opts.exclude, (list, tuple)), '`exclude` must be a list or tuple' - for key in self.opts.exclude: - ret.pop(key, None) - - for key, field in ret.items(): - field.initialize(parent=self, field_name=key) + return SortedDict(fields) - return ret + def __new__(cls, name, bases, attrs): + attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs) + return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs) - ##### - # Methods to convert or revert from objects <--> primitive representations. - def get_field_key(self, field_name): - """ - Return the key that should be used for a given field. - """ - return field_name +@six.add_metaclass(SerializerMetaclass) +class Serializer(BaseSerializer): + default_error_messages = { + 'invalid': _('Invalid data. Expected a dictionary, but got {datatype}.') + } - def restore_fields(self, data, files): - """ - Core of deserialization, together with `restore_object`. - Converts a dictionary of data into a dictionary of deserialized fields. - """ - reverted_data = {} + @property + def fields(self): + if not hasattr(self, '_fields'): + self._fields = BindingDict(self) + for key, value in self.get_fields().items(): + self._fields[key] = value + return self._fields - if data is not None and not isinstance(data, dict): - self._errors['non_field_errors'] = ['Invalid data'] + def get_fields(self): + # Every new serializer is created with a clone of the field instances. + # This allows users to dynamically modify the fields on a serializer + # instance without affecting every other serializer class. + return copy.deepcopy(self._declared_fields) + + def get_validators(self): + return getattr(getattr(self, 'Meta', None), 'validators', []) + + def get_initial(self): + if self._initial_data is not None: + return ReturnDict([ + (field_name, field.get_value(self._initial_data)) + for field_name, field in self.fields.items() + if field.get_value(self._initial_data) is not empty + ], serializer=self) + + return ReturnDict([ + (field.field_name, field.get_initial()) + for field in self.fields.values() + if not field.write_only + ], serializer=self) + + def get_value(self, dictionary): + # We override the default field access in order to support + # nested HTML forms. + if html.is_html_input(dictionary): + return html.parse_html_dict(dictionary, prefix=self.field_name) + return dictionary.get(self.field_name, empty) + + def run_validation(self, data=empty): + """ + We override the default `run_validation`, because the validation + performed by validators and the `.validate()` method should + be coerced into an error dictionary with a 'non_fields_error' key. + """ + if data is empty: + if getattr(self.root, 'partial', False): + raise SkipField() + if self.required: + self.fail('required') + return self.get_default() + + if data is None: + if not self.allow_null: + self.fail('null') return None - for field_name, field in self.fields.items(): - field.initialize(parent=self, field_name=field_name) - try: - field.field_from_native(data, files, field_name, reverted_data) - except ValidationError as err: - self._errors[field_name] = list(err.messages) - - return reverted_data - - def perform_validation(self, attrs): - """ - Run `validate_<fieldname>()` and `validate()` methods on the serializer - """ - for field_name, field in self.fields.items(): - if field_name in self._errors: - continue - - source = field.source or field_name - if self.partial and source not in attrs: - continue - try: - validate_method = getattr(self, 'validate_%s' % field_name, None) - if validate_method: - attrs = validate_method(attrs, source) - except ValidationError as err: - self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages) - - # If there are already errors, we don't run .validate() because - # field-validation failed and thus `attrs` may not be complete. - # which in turn can cause inconsistent validation errors. - if not self._errors: - try: - attrs = self.validate(attrs) - except ValidationError as err: - if hasattr(err, 'message_dict'): - for field_name, error_messages in err.message_dict.items(): - self._errors[field_name] = self._errors.get(field_name, []) + list(error_messages) - elif hasattr(err, 'messages'): - self._errors['non_field_errors'] = err.messages + if not isinstance(data, dict): + message = self.error_messages['invalid'].format( + datatype=type(data).__name__ + ) + raise ValidationError({ + api_settings.NON_FIELD_ERRORS_KEY: [message] + }) - return attrs + value = self.to_internal_value(data) + try: + self.run_validators(value) + value = self.validate(value) + assert value is not None, '.validate() should return the validated data' + except ValidationError as exc: + if isinstance(exc.detail, dict): + # .validate() errors may be a dict, in which case, use + # standard {key: list of values} style. + raise ValidationError(dict([ + (key, value if isinstance(value, list) else [value]) + for key, value in exc.detail.items() + ])) + elif isinstance(exc.detail, list): + raise ValidationError({ + api_settings.NON_FIELD_ERRORS_KEY: exc.detail + }) + else: + raise ValidationError({ + api_settings.NON_FIELD_ERRORS_KEY: [exc.detail] + }) - def validate(self, attrs): - """ - Stub method, to be overridden in Serializer subclasses - """ - return attrs + return value - def restore_object(self, attrs, instance=None): + def to_internal_value(self, data): """ - Deserialize a dictionary of attributes into an object instance. - You should override this method to control how deserialized objects - are instantiated. + Dict of native values <- Dict of primitive datatypes. """ - if instance is not None: - instance.update(attrs) - return instance - return attrs + ret = {} + errors = ReturnDict(serializer=self) + fields = [ + field for field in self.fields.values() + if (not field.read_only) or (field.default is not empty) + ] - def to_native(self, obj): - """ - Serialize objects -> primitives. - """ - ret = self._dict_class() - ret.fields = self._dict_class() + for field in fields: + validate_method = getattr(self, 'validate_' + field.field_name, None) + primitive_value = field.get_value(data) + try: + validated_value = field.run_validation(primitive_value) + if validate_method is not None: + validated_value = validate_method(validated_value) + except ValidationError as exc: + errors[field.field_name] = exc.detail + except SkipField: + pass + else: + set_value(ret, field.source_attrs, validated_value) - for field_name, field in self.fields.items(): - if field.read_only and obj is None: - continue - field.initialize(parent=self, field_name=field_name) - key = self.get_field_key(field_name) - value = field.field_to_native(obj, field_name) - method = getattr(self, 'transform_%s' % field_name, None) - if callable(method): - value = method(obj, value) - if not getattr(field, 'write_only', False): - ret[key] = value - ret.fields[key] = self.augment_field(field, field_name, key, value) + if errors: + raise ValidationError(errors) return ret - def from_native(self, data, files=None): + def to_representation(self, instance): """ - Deserialize primitives -> objects. + Object instance -> Dict of primitive datatypes. """ - self._errors = {} - - if data is not None or files is not None: - attrs = self.restore_fields(data, files) - if attrs is not None: - attrs = self.perform_validation(attrs) - else: - self._errors['non_field_errors'] = ['No input provided'] + ret = ReturnDict(serializer=self) + fields = [field for field in self.fields.values() if not field.write_only] - if not self._errors: - return self.restore_object(attrs, instance=getattr(self, 'object', None)) - - def augment_field(self, field, field_name, key, value): - # This horrible stuff is to manage serializers rendering to HTML - field._errors = self._errors.get(key) if self._errors else None - field._name = field_name - field._value = self.init_data.get(key) if self._errors and self.init_data else value - if not field.label: - field.label = pretty_name(key) - return field - - def field_to_native(self, obj, field_name): - """ - Override default so that the serializer can be used as a nested field - across relationships. - """ - if self.write_only: - return None - - if self.source == '*': - return self.to_native(obj) - - # Get the raw field value - try: - source = self.source or field_name - value = obj - - for component in source.split('.'): - if value is None: - break - value = get_component(value, component) - except ObjectDoesNotExist: - return None - - if is_simple_callable(getattr(value, 'all', None)): - return [self.to_native(item) for item in value.all()] - - if value is None: - return None - - if self.many: - return [self.to_native(item) for item in value] - return self.to_native(value) - - def field_from_native(self, data, files, field_name, into): - """ - Override default so that the serializer can be used as a writable - nested field across relationships. - """ - if self.read_only: - return - - try: - value = data[field_name] - except KeyError: - if self.default is not None and not self.partial: - # Note: partial updates shouldn't set defaults - value = copy.deepcopy(self.default) - else: - if self.required: - raise ValidationError(self.error_messages['required']) - return - - if self.source == '*': - if value: - reverted_data = self.restore_fields(value, {}) - if not self._errors: - into.update(reverted_data) - else: - if value in (None, ''): - into[(self.source or field_name)] = None + for field in fields: + attribute = field.get_attribute(instance) + if attribute is None: + value = None else: - # Set the serializer object if it exists - obj = get_component(self.parent.object, self.source or field_name) if self.parent.object else None - - # If we have a model manager or similar object then we need - # to iterate through each instance. - if ( - self.many and - not hasattr(obj, '__iter__') and - is_simple_callable(getattr(obj, 'all', None)) - ): - obj = obj.all() - - kwargs = { - 'instance': obj, - 'data': value, - 'context': self.context, - 'partial': self.partial, - 'many': self.many, - 'allow_add_remove': self.allow_add_remove - } - serializer = self.__class__(**kwargs) - - if serializer.is_valid(): - into[self.source or field_name] = serializer.object - else: - # Propagate errors up to our parent - raise NestedValidationError(serializer.errors) + value = field.to_representation(attribute) + transform_method = getattr(self, 'transform_' + field.field_name, None) + if transform_method is not None: + value = transform_method(value) - def get_identity(self, data): - """ - This hook is required for bulk update. - It is used to determine the canonical identity of a given object. + ret[field.field_name] = value - Note that the data has not been validated at this point, so we need - to make sure that we catch any cases of incorrect datatypes being - passed to this method. - """ - try: - return data.get('id', None) - except AttributeError: - return None + return ret - @property - def errors(self): - """ - Run deserialization and return error data, - setting self.object if no errors occurred. - """ - if self._errors is None: - data, files = self.init_data, self.init_files + def validate(self, attrs): + return attrs - if self.many is not None: - many = self.many - else: - many = hasattr(data, '__iter__') and not isinstance(data, (Page, dict, six.text_type)) - if many: - warnings.warn('Implicit list/queryset serialization is deprecated. ' - 'Use the `many=True` flag when instantiating the serializer.', - DeprecationWarning, stacklevel=3) - - if many: - ret = RelationsList() - errors = [] - update = self.object is not None - - if update: - # If this is a bulk update we need to map all the objects - # to a canonical identity so we can determine which - # individual object is being updated for each item in the - # incoming data - objects = self.object - identities = [self.get_identity(self.to_native(obj)) for obj in objects] - identity_to_objects = dict(zip(identities, objects)) - - if hasattr(data, '__iter__') and not isinstance(data, (dict, six.text_type)): - for item in data: - if update: - # Determine which object we're updating - identity = self.get_identity(item) - self.object = identity_to_objects.pop(identity, None) - if self.object is None and not self.allow_add_remove: - ret.append(None) - errors.append({'non_field_errors': ['Cannot create a new item, only existing items may be updated.']}) - continue - - ret.append(self.from_native(item, None)) - errors.append(self._errors) - - if update and self.allow_add_remove: - ret._deleted = identity_to_objects.values() - - self._errors = any(errors) and errors or [] - else: - self._errors = {'non_field_errors': ['Expected a list of items.']} - else: - ret = self.from_native(data, files) + def __repr__(self): + return representation.serializer_repr(self, indent=1) - if not self._errors: - self.object = ret + # The following are used for accessing `BoundField` instances on the + # serializer, for the purposes of presenting a form-like API onto the + # field values and field errors. - return self._errors + def __iter__(self): + for field in self.fields.values(): + yield self[field.field_name] - def is_valid(self): - return not self.errors + def __getitem__(self, key): + field = self.fields[key] + value = self.data.get(key) + error = self.errors.get(key) if hasattr(self, '_errors') else None + if isinstance(field, Serializer): + return NestedBoundField(field, value, error) + return BoundField(field, value, error) - @property - def data(self): - """ - Returns the serialized data on the serializer. - """ - if self._data is None: - obj = self.object - if self.many is not None: - many = self.many - else: - many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict)) - if many: - warnings.warn('Implicit list/queryset serialization is deprecated. ' - 'Use the `many=True` flag when instantiating the serializer.', - DeprecationWarning, stacklevel=2) - - if many: - self._data = [self.to_native(item) for item in obj] - else: - self._data = self.to_native(obj) +# There's some replication of `ListField` here, +# but that's probably better than obfuscating the call hierarchy. - return self._data +class ListSerializer(BaseSerializer): + child = None + many = True - def save_object(self, obj, **kwargs): - obj.save(**kwargs) + def __init__(self, *args, **kwargs): + self.child = kwargs.pop('child', copy.deepcopy(self.child)) + assert self.child is not None, '`child` is a required argument.' + assert not inspect.isclass(self.child), '`child` has not been instantiated.' + super(ListSerializer, self).__init__(*args, **kwargs) + self.child.bind(field_name='', parent=self) - def delete_object(self, obj): - obj.delete() + def get_initial(self): + if self._initial_data is not None: + return self.to_representation(self._initial_data) + return ReturnList(serializer=self) - def save(self, **kwargs): + def get_value(self, dictionary): + # We override the default field access in order to support + # lists in HTML forms. + if html.is_html_input(dictionary): + return html.parse_html_list(dictionary, prefix=self.field_name) + return dictionary.get(self.field_name, empty) + + def to_internal_value(self, data): """ - Save the deserialized object and return it. + List of dicts of native values <- List of dicts of primitive datatypes. """ - # Clear cached _data, which may be invalidated by `save()` - self._data = None - - if isinstance(self.object, list): - [self.save_object(item, **kwargs) for item in self.object] + if html.is_html_input(data): + data = html.parse_html_list(data) + return [self.child.run_validation(item) for item in data] - if self.object._deleted: - [self.delete_object(item) for item in self.object._deleted] - else: - self.save_object(self.object, **kwargs) - - return self.object - - def metadata(self): + def to_representation(self, data): """ - Return a dictionary of metadata about the fields on the serializer. - Useful for things like responding to OPTIONS requests, or generating - API schemas for auto-documentation. + List of object instances -> List of dicts of primitive datatypes. """ - return SortedDict( - [ - (field_name, field.metadata()) - for field_name, field in six.iteritems(self.fields) - ] + iterable = data.all() if (hasattr(data, 'all')) else data + return ReturnList( + [self.child.to_representation(item) for item in iterable], + serializer=self ) + def create(self, attrs_list): + return [self.child.create(attrs) for attrs in attrs_list] -class Serializer(six.with_metaclass(SerializerMetaclass, BaseSerializer)): - pass + def __repr__(self): + return representation.list_repr(self, indent=1) -class ModelSerializerOptions(SerializerOptions): - """ - Meta class options for ModelSerializer - """ - def __init__(self, meta): - super(ModelSerializerOptions, self).__init__(meta) - self.model = getattr(meta, 'model', None) - self.read_only_fields = getattr(meta, 'read_only_fields', ()) - self.write_only_fields = getattr(meta, 'write_only_fields', ()) - - -def _get_class_mapping(mapping, obj): - """ - Takes a dictionary with classes as keys, and an object. - Traverses the object's inheritance hierarchy in method - resolution order, and returns the first matching value - from the dictionary or None. - - """ - return next( - (mapping[cls] for cls in inspect.getmro(obj.__class__) if cls in mapping), - None - ) - +# ModelSerializer & HyperlinkedModelSerializer +# -------------------------------------------- class ModelSerializer(Serializer): - """ - A serializer that deals with model instances and querysets. - """ - _options_class = ModelSerializerOptions - - field_mapping = { + _field_mapping = ClassLookupDict({ models.AutoField: IntegerField, + models.BigIntegerField: IntegerField, + models.BooleanField: BooleanField, + models.CharField: CharField, + models.CommaSeparatedIntegerField: CharField, + models.DateField: DateField, + models.DateTimeField: DateTimeField, + models.DecimalField: DecimalField, + models.EmailField: EmailField, + models.Field: ModelField, + models.FileField: FileField, models.FloatField: FloatField, + models.ImageField: ImageField, models.IntegerField: IntegerField, + models.NullBooleanField: NullBooleanField, models.PositiveIntegerField: IntegerField, - models.SmallIntegerField: IntegerField, models.PositiveSmallIntegerField: IntegerField, - models.DateTimeField: DateTimeField, - models.DateField: DateField, - models.TimeField: TimeField, - models.DecimalField: DecimalField, - models.EmailField: EmailField, - models.CharField: CharField, - models.URLField: URLField, models.SlugField: SlugField, + models.SmallIntegerField: IntegerField, models.TextField: CharField, - models.CommaSeparatedIntegerField: CharField, - models.BooleanField: BooleanField, - models.NullBooleanField: BooleanField, - models.FileField: FileField, - models.ImageField: ImageField, - } - - def get_default_fields(self): - """ - Return all the fields that should be serialized for the model. - """ - - cls = self.opts.model - assert cls is not None, ( - "Serializer class '%s' is missing 'model' Meta option" % - self.__class__.__name__ + models.TimeField: TimeField, + models.URLField: URLField, + }) + _related_class = PrimaryKeyRelatedField + + def create(self, validated_attrs): + # Check that the user isn't trying to handle a writable nested field. + # If we don't do this explicitly they'd likely get a confusing + # error at the point of calling `Model.objects.create()`. + assert not any( + isinstance(field, BaseSerializer) and not field.read_only + for field in self.fields.values() + ), ( + 'The `.create()` method does not suport nested writable fields ' + 'by default. Write an explicit `.create()` method for serializer ' + '`%s.%s`, or set `read_only=True` on nested serializer fields.' % + (self.__class__.__module__, self.__class__.__name__) ) - opts = cls._meta.concrete_model._meta - ret = SortedDict() - nested = bool(self.opts.depth) - - # Deal with adding the primary key field - pk_field = opts.pk - while pk_field.rel and pk_field.rel.parent_link: - # If model is a child via multitable inheritance, use parent's pk - pk_field = pk_field.rel.to._meta.pk - - serializer_pk_field = self.get_pk_field(pk_field) - if serializer_pk_field: - ret[pk_field.name] = serializer_pk_field - - # Deal with forward relationships - forward_rels = [field for field in opts.fields if field.serialize] - forward_rels += [field for field in opts.many_to_many if field.serialize] - - for model_field in forward_rels: - has_through_model = False - - if model_field.rel: - to_many = isinstance(model_field, - models.fields.related.ManyToManyField) - related_model = _resolve_model(model_field.rel.to) - - if to_many and not model_field.rel.through._meta.auto_created: - has_through_model = True - - if model_field.rel and nested: - if len(inspect.getargspec(self.get_nested_field).args) == 2: - warnings.warn( - 'The `get_nested_field(model_field)` call signature ' - 'is deprecated. ' - 'Use `get_nested_field(model_field, related_model, ' - 'to_many) instead', - DeprecationWarning - ) - field = self.get_nested_field(model_field) - else: - field = self.get_nested_field(model_field, related_model, to_many) - elif model_field.rel: - if len(inspect.getargspec(self.get_nested_field).args) == 3: - warnings.warn( - 'The `get_related_field(model_field, to_many)` call ' - 'signature is deprecated. ' - 'Use `get_related_field(model_field, related_model, ' - 'to_many) instead', - DeprecationWarning - ) - field = self.get_related_field(model_field, to_many=to_many) - else: - field = self.get_related_field(model_field, related_model, to_many) - else: - field = self.get_field(model_field) - - if field: - if has_through_model: - field.read_only = True - - ret[model_field.name] = field - - # Deal with reverse relationships - if not self.opts.fields: - reverse_rels = [] - else: - # Reverse relationships are only included if they are explicitly - # present in the `fields` option on the serializer - reverse_rels = opts.get_all_related_objects() - reverse_rels += opts.get_all_related_many_to_many_objects() - - for relation in reverse_rels: - accessor_name = relation.get_accessor_name() - if not self.opts.fields or accessor_name not in self.opts.fields: - continue - related_model = relation.model - to_many = relation.field.rel.multiple - has_through_model = False - is_m2m = isinstance(relation.field, - models.fields.related.ManyToManyField) - - if ( - is_m2m and - hasattr(relation.field.rel, 'through') and - not relation.field.rel.through._meta.auto_created - ): - has_through_model = True - - if nested: - field = self.get_nested_field(None, related_model, to_many) - else: - field = self.get_related_field(None, related_model, to_many) - - if field: - if has_through_model: - field.read_only = True - - ret[accessor_name] = field - - # Ensure that 'read_only_fields' is an iterable - assert isinstance(self.opts.read_only_fields, (list, tuple)), '`read_only_fields` must be a list or tuple' - - # Add the `read_only` flag to any fields that have been specified - # in the `read_only_fields` option - for field_name in self.opts.read_only_fields: - assert field_name not in self.base_fields.keys(), ( - "field '%s' on serializer '%s' specified in " - "`read_only_fields`, but also added " - "as an explicit field. Remove it from `read_only_fields`." % - (field_name, self.__class__.__name__)) - assert field_name in ret, ( - "Non-existant field '%s' specified in `read_only_fields` " - "on serializer '%s'." % - (field_name, self.__class__.__name__)) - ret[field_name].read_only = True - - # Ensure that 'write_only_fields' is an iterable - assert isinstance(self.opts.write_only_fields, (list, tuple)), '`write_only_fields` must be a list or tuple' - - for field_name in self.opts.write_only_fields: - assert field_name not in self.base_fields.keys(), ( - "field '%s' on serializer '%s' specified in " - "`write_only_fields`, but also added " - "as an explicit field. Remove it from `write_only_fields`." % - (field_name, self.__class__.__name__)) - assert field_name in ret, ( - "Non-existant field '%s' specified in `write_only_fields` " - "on serializer '%s'." % - (field_name, self.__class__.__name__)) - ret[field_name].write_only = True - - return ret - - def get_pk_field(self, model_field): - """ - Returns a default instance of the pk field. - """ - return self.get_field(model_field) - def get_nested_field(self, model_field, related_model, to_many): - """ - Creates a default instance of a nested relational field. - - Note that model_field will be `None` for reverse relationships. - """ - class NestedModelSerializer(ModelSerializer): - class Meta: - model = related_model - depth = self.opts.depth - 1 - - return NestedModelSerializer(many=to_many) + ModelClass = self.Meta.model - def get_related_field(self, model_field, related_model, to_many): - """ - Creates a default instance of a flat relational field. - - Note that model_field will be `None` for reverse relationships. - """ - # TODO: filter queryset using: - # .using(db).complex_filter(self.rel.limit_choices_to) - - kwargs = { - 'queryset': related_model._default_manager, - 'many': to_many - } - - if model_field: - kwargs['required'] = not(model_field.null or model_field.blank) and model_field.editable - if model_field.help_text is not None: - kwargs['help_text'] = model_field.help_text - if model_field.verbose_name is not None: - kwargs['label'] = model_field.verbose_name - - if not model_field.editable: - kwargs['read_only'] = True + # Remove many-to-many relationships from validated_attrs. + # They are not valid arguments to the default `.create()` method, + # as they require that the instance has already been saved. + info = model_meta.get_field_info(ModelClass) + many_to_many = {} + for field_name, relation_info in info.relations.items(): + if relation_info.to_many and (field_name in validated_attrs): + many_to_many[field_name] = validated_attrs.pop(field_name) - if model_field.verbose_name is not None: - kwargs['label'] = model_field.verbose_name + instance = ModelClass.objects.create(**validated_attrs) - if model_field.help_text is not None: - kwargs['help_text'] = model_field.help_text + # Save many-to-many relationships after the instance is created. + if many_to_many: + for field_name, value in many_to_many.items(): + setattr(instance, field_name, value) - return PrimaryKeyRelatedField(**kwargs) - - def get_field(self, model_field): - """ - Creates a default instance of a basic non-relational field. - """ - kwargs = {} - - if model_field.null or model_field.blank and model_field.editable: - kwargs['required'] = False - - if isinstance(model_field, models.AutoField) or not model_field.editable: - kwargs['read_only'] = True - - if model_field.has_default(): - kwargs['default'] = model_field.get_default() - - if issubclass(model_field.__class__, models.TextField): - kwargs['widget'] = widgets.Textarea - - if model_field.verbose_name is not None: - kwargs['label'] = model_field.verbose_name - - if model_field.help_text is not None: - kwargs['help_text'] = model_field.help_text - - # TODO: TypedChoiceField? - if model_field.flatchoices: # This ModelField contains choices - kwargs['choices'] = model_field.flatchoices - if model_field.null: - kwargs['empty'] = None - return ChoiceField(**kwargs) - - # put this below the ChoiceField because min_value isn't a valid initializer - if issubclass(model_field.__class__, models.PositiveIntegerField) or\ - issubclass(model_field.__class__, models.PositiveSmallIntegerField): - kwargs['min_value'] = 0 - - if model_field.null and \ - issubclass(model_field.__class__, (models.CharField, models.TextField)): - kwargs['allow_none'] = True - - attribute_dict = { - models.CharField: ['max_length'], - models.CommaSeparatedIntegerField: ['max_length'], - models.DecimalField: ['max_digits', 'decimal_places'], - models.EmailField: ['max_length'], - models.FileField: ['max_length'], - models.ImageField: ['max_length'], - models.SlugField: ['max_length'], - models.URLField: ['max_length'], - } - - attributes = _get_class_mapping(attribute_dict, model_field) - if attributes: - for attribute in attributes: - kwargs.update({attribute: getattr(model_field, attribute)}) - - serializer_field_class = _get_class_mapping( - self.field_mapping, model_field) - - if serializer_field_class: - return serializer_field_class(**kwargs) - return ModelField(model_field=model_field, **kwargs) - - def get_validation_exclusions(self, instance=None): - """ - Return a list of field names to exclude from model validation. - """ - cls = self.opts.model - opts = cls._meta.concrete_model._meta - exclusions = [field.name for field in opts.fields + opts.many_to_many] - - for field_name, field in self.fields.items(): - field_name = field.source or field_name - if ( - field_name in exclusions - and not field.read_only - and (field.required or hasattr(instance, field_name)) - and not isinstance(field, Serializer) - ): - exclusions.remove(field_name) - return exclusions + return instance - def full_clean(self, instance): - """ - Perform Django's full_clean, and populate the `errors` dictionary - if any validation errors occur. + def update(self, instance, validated_attrs): + assert not any( + isinstance(field, BaseSerializer) and not field.read_only + for field in self.fields.values() + ), ( + 'The `.update()` method does not suport nested writable fields ' + 'by default. Write an explicit `.update()` method for serializer ' + '`%s.%s`, or set `read_only=True` on nested serializer fields.' % + (self.__class__.__module__, self.__class__.__name__) + ) - Note that we don't perform this inside the `.restore_object()` method, - so that subclasses can override `.restore_object()`, and still get - the full_clean validation checking. - """ - try: - instance.full_clean(exclude=self.get_validation_exclusions(instance)) - except ValidationError as err: - self._errors = err.message_dict - return None + for attr, value in validated_attrs.items(): + setattr(instance, attr, value) + instance.save() return instance - def restore_object(self, attrs, instance=None): - """ - Restore the model instance. - """ - m2m_data = {} - related_data = {} - nested_forward_relations = {} - meta = self.opts.model._meta - - # Reverse fk or one-to-one relations - for (obj, model) in meta.get_all_related_objects_with_model(): - field_name = obj.get_accessor_name() - if field_name in attrs: - related_data[field_name] = attrs.pop(field_name) - - # Reverse m2m relations - for (obj, model) in meta.get_all_related_m2m_objects_with_model(): - field_name = obj.get_accessor_name() - if field_name in attrs: - m2m_data[field_name] = attrs.pop(field_name) - - # Forward m2m relations - for field in meta.many_to_many + meta.virtual_fields: - if isinstance(field, GenericForeignKey): - continue - if field.name in attrs: - m2m_data[field.name] = attrs.pop(field.name) - - # Nested forward relations - These need to be marked so we can save - # them before saving the parent model instance. - for field_name in attrs.keys(): - if isinstance(self.fields.get(field_name, None), Serializer): - nested_forward_relations[field_name] = attrs[field_name] + def get_validators(self): + field_names = set([ + field.source for field in self.fields.values() + if (field.source != '*') and ('.' not in field.source) + ]) + + validators = getattr(getattr(self, 'Meta', None), 'validators', []) + model_class = self.Meta.model + + # Note that we make sure to check `unique_together` both on the + # base model class, but also on any parent classes. + for parent_class in [model_class] + list(model_class._meta.parents.keys()): + for unique_together in parent_class._meta.unique_together: + if field_names.issuperset(set(unique_together)): + validator = UniqueTogetherValidator( + queryset=parent_class._default_manager, + fields=unique_together + ) + validators.append(validator) + + # Add any unique_for_date/unique_for_month/unique_for_year constraints. + info = model_meta.get_field_info(model_class) + for field_name, field in info.fields_and_pk.items(): + if field.unique_for_date and field_name in field_names: + validator = UniqueForDateValidator( + queryset=model_class._default_manager, + field=field_name, + date_field=field.unique_for_date + ) + validators.append(validator) + + if field.unique_for_month and field_name in field_names: + validator = UniqueForMonthValidator( + queryset=model_class._default_manager, + field=field_name, + date_field=field.unique_for_month + ) + validators.append(validator) + + if field.unique_for_year and field_name in field_names: + validator = UniqueForYearValidator( + queryset=model_class._default_manager, + field=field_name, + date_field=field.unique_for_year + ) + validators.append(validator) + + return validators - # Create an empty instance of the model - if instance is None: - instance = self.opts.model() + def get_fields(self): + declared_fields = copy.deepcopy(self._declared_fields) - for key, val in attrs.items(): + ret = SortedDict() + model = getattr(self.Meta, 'model') + fields = getattr(self.Meta, 'fields', None) + exclude = getattr(self.Meta, 'exclude', None) + depth = getattr(self.Meta, 'depth', 0) + extra_kwargs = getattr(self.Meta, 'extra_kwargs', {}) + + assert not (fields and exclude), "Cannot set both 'fields' and 'exclude'." + + extra_kwargs = self._include_additional_options(extra_kwargs) + + # Retrieve metadata about fields & relationships on the model class. + info = model_meta.get_field_info(model) + + # Use the default set of field names if none is supplied explicitly. + if fields is None: + fields = self._get_default_field_names(declared_fields, info) + exclude = getattr(self.Meta, 'exclude', None) + if exclude is not None: + for field_name in exclude: + fields.remove(field_name) + + # Determine the set of model fields, and the fields that they map to. + # We actually only need this to deal with the slightly awkward case + # of supporting `unique_for_date`/`unique_for_month`/`unique_for_year`. + model_field_mapping = {} + for field_name in fields: + if field_name in declared_fields: + field = declared_fields[field_name] + source = field.source or field_name + else: + try: + source = extra_kwargs[field_name]['source'] + except KeyError: + source = field_name + # Model fields will always have a simple source mapping, + # they can't be nested attribute lookups. + if '.' not in source and source != '*': + model_field_mapping[source] = field_name + + # Determine if we need any additional `HiddenField` or extra keyword + # arguments to deal with `unique_for` dates that are required to + # be in the input data in order to validate it. + unique_fields = {} + for model_field_name, field_name in model_field_mapping.items(): try: - setattr(instance, key, val) - except ValueError: - self._errors[key] = [self.error_messages['required']] - - # Any relations that cannot be set until we've - # saved the model get hidden away on these - # private attributes, so we can deal with them - # at the point of save. - instance._related_data = related_data - instance._m2m_data = m2m_data - instance._nested_forward_relations = nested_forward_relations - - return instance - - def from_native(self, data, files): - """ - Override the default method to also include model field validation. - """ - instance = super(ModelSerializer, self).from_native(data, files) - if not self._errors: - return self.full_clean(instance) + model_field = model._meta.get_field(model_field_name) + except FieldDoesNotExist: + continue - def save_object(self, obj, **kwargs): - """ - Save the deserialized object. - """ - if getattr(obj, '_nested_forward_relations', None): - # Nested relationships need to be saved before we can save the - # parent instance. - for field_name, sub_object in obj._nested_forward_relations.items(): - if sub_object: - self.save_object(sub_object) - setattr(obj, field_name, sub_object) - - obj.save(**kwargs) - - if getattr(obj, '_m2m_data', None): - for accessor_name, object_list in obj._m2m_data.items(): - setattr(obj, accessor_name, object_list) - del(obj._m2m_data) - - if getattr(obj, '_related_data', None): - related_fields = dict([ - (field.get_accessor_name(), field) - for field, model - in obj._meta.get_all_related_objects_with_model() - ]) - for accessor_name, related in obj._related_data.items(): - if isinstance(related, RelationsList): - # Nested reverse fk relationship - for related_item in related: - fk_field = related_fields[accessor_name].field.name - setattr(related_item, fk_field, obj) - self.save_object(related_item) - - # Delete any removed objects - if related._deleted: - [self.delete_object(item) for item in related._deleted] - - elif isinstance(related, models.Model): - # Nested reverse one-one relationship - fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name - setattr(related, fk_field, obj) - self.save_object(related) + # Deal with each of the `unique_for_*` cases. + for date_field_name in ( + model_field.unique_for_date, + model_field.unique_for_month, + model_field.unique_for_year + ): + if date_field_name is None: + continue + + # Get the model field that is refered too. + date_field = model._meta.get_field(date_field_name) + + if date_field.auto_now_add: + default = CreateOnlyDefault(timezone.now) + elif date_field.auto_now: + default = timezone.now + elif date_field.has_default(): + default = model_field.default else: - # Reverse FK or reverse one-one - setattr(obj, accessor_name, related) - del(obj._related_data) - - -class HyperlinkedModelSerializerOptions(ModelSerializerOptions): - """ - Options for HyperlinkedModelSerializer - """ - def __init__(self, meta): - super(HyperlinkedModelSerializerOptions, self).__init__(meta) - self.view_name = getattr(meta, 'view_name', None) - self.lookup_field = getattr(meta, 'lookup_field', None) - self.url_field_name = getattr(meta, 'url_field_name', api_settings.URL_FIELD_NAME) - + default = empty + + if date_field_name in model_field_mapping: + # The corresponding date field is present in the serializer + if date_field_name not in extra_kwargs: + extra_kwargs[date_field_name] = {} + if default is empty: + if 'required' not in extra_kwargs[date_field_name]: + extra_kwargs[date_field_name]['required'] = True + else: + if 'default' not in extra_kwargs[date_field_name]: + extra_kwargs[date_field_name]['default'] = default + else: + # The corresponding date field is not present in the, + # serializer. We have a default to use for the date, so + # add in a hidden field that populates it. + unique_fields[date_field_name] = HiddenField(default=default) + + # Now determine the fields that should be included on the serializer. + for field_name in fields: + if field_name in declared_fields: + # Field is explicitly declared on the class, use that. + ret[field_name] = declared_fields[field_name] + continue -class HyperlinkedModelSerializer(ModelSerializer): - """ - A subclass of ModelSerializer that uses hyperlinked relationships, - instead of primary key relationships. - """ - _options_class = HyperlinkedModelSerializerOptions - _default_view_name = '%(model_name)s-detail' - _hyperlink_field_class = HyperlinkedRelatedField - _hyperlink_identify_field_class = HyperlinkedIdentityField + elif field_name in info.fields_and_pk: + # Create regular model fields. + model_field = info.fields_and_pk[field_name] + field_cls = self._field_mapping[model_field] + kwargs = get_field_kwargs(field_name, model_field) + if 'choices' in kwargs: + # Fields with choices get coerced into `ChoiceField` + # instead of using their regular typed field. + field_cls = ChoiceField + if not issubclass(field_cls, ModelField): + # `model_field` is only valid for the fallback case of + # `ModelField`, which is used when no other typed field + # matched to the model field. + kwargs.pop('model_field', None) + if not issubclass(field_cls, CharField): + # `allow_blank` is only valid for textual fields. + kwargs.pop('allow_blank', None) + + elif field_name in info.relations: + # Create forward and reverse relationships. + relation_info = info.relations[field_name] + if depth: + field_cls = self._get_nested_class(depth, relation_info) + kwargs = get_nested_relation_kwargs(relation_info) + else: + field_cls = self._related_class + kwargs = get_relation_kwargs(field_name, relation_info) + # `view_name` is only valid for hyperlinked relationships. + if not issubclass(field_cls, HyperlinkedRelatedField): + kwargs.pop('view_name', None) + + elif hasattr(model, field_name): + # Create a read only field for model methods and properties. + field_cls = ReadOnlyField + kwargs = {} + + elif field_name == api_settings.URL_FIELD_NAME: + # Create the URL field. + field_cls = HyperlinkedIdentityField + kwargs = get_url_kwargs(model) - def get_default_fields(self): - fields = super(HyperlinkedModelSerializer, self).get_default_fields() + else: + raise ImproperlyConfigured( + 'Field name `%s` is not valid for model `%s`.' % + (field_name, model.__class__.__name__) + ) + + # Check that any fields declared on the class are + # also explicity included in `Meta.fields`. + missing_fields = set(declared_fields.keys()) - set(fields) + if missing_fields: + missing_field = list(missing_fields)[0] + raise ImproperlyConfigured( + 'Field `%s` has been declared on serializer `%s`, but ' + 'is missing from `Meta.fields`.' % + (missing_field, self.__class__.__name__) + ) + + # Populate any kwargs defined in `Meta.extra_kwargs` + extras = extra_kwargs.get(field_name, {}) + if extras.get('read_only', False): + for attr in [ + 'required', 'default', 'allow_blank', 'allow_null', + 'min_length', 'max_length', 'min_value', 'max_value', + 'validators', 'queryset' + ]: + kwargs.pop(attr, None) + kwargs.update(extras) + + # Create the serializer field. + ret[field_name] = field_cls(**kwargs) + + for field_name, field in unique_fields.items(): + ret[field_name] = field - if self.opts.view_name is None: - self.opts.view_name = self._get_default_view_name(self.opts.model) + return ret - if self.opts.url_field_name not in fields: - url_field = self._hyperlink_identify_field_class( - view_name=self.opts.view_name, - lookup_field=self.opts.lookup_field + def _include_additional_options(self, extra_kwargs): + read_only_fields = getattr(self.Meta, 'read_only_fields', None) + if read_only_fields is not None: + for field_name in read_only_fields: + kwargs = extra_kwargs.get(field_name, {}) + kwargs['read_only'] = True + extra_kwargs[field_name] = kwargs + + # These are all pending deprecation. + write_only_fields = getattr(self.Meta, 'write_only_fields', None) + if write_only_fields is not None: + warnings.warn( + "The `Meta.write_only_fields` option is pending deprecation. " + "Use `Meta.extra_kwargs={<field_name>: {'write_only': True}}` instead.", + PendingDeprecationWarning, + stacklevel=3 + ) + for field_name in write_only_fields: + kwargs = extra_kwargs.get(field_name, {}) + kwargs['write_only'] = True + extra_kwargs[field_name] = kwargs + + view_name = getattr(self.Meta, 'view_name', None) + if view_name is not None: + warnings.warn( + "The `Meta.view_name` option is pending deprecation. " + "Use `Meta.extra_kwargs={'url': {'view_name': ...}}` instead.", + PendingDeprecationWarning, + stacklevel=3 ) - ret = self._dict_class() - ret[self.opts.url_field_name] = url_field - ret.update(fields) - fields = ret + kwargs = extra_kwargs.get(api_settings.URL_FIELD_NAME, {}) + kwargs['view_name'] = view_name + extra_kwargs[api_settings.URL_FIELD_NAME] = kwargs + + lookup_field = getattr(self.Meta, 'lookup_field', None) + if lookup_field is not None: + warnings.warn( + "The `Meta.lookup_field` option is pending deprecation. " + "Use `Meta.extra_kwargs={'url': {'lookup_field': ...}}` instead.", + PendingDeprecationWarning, + stacklevel=3 + ) + kwargs = extra_kwargs.get(api_settings.URL_FIELD_NAME, {}) + kwargs['lookup_field'] = lookup_field + extra_kwargs[api_settings.URL_FIELD_NAME] = kwargs + + return extra_kwargs + + def _get_default_field_names(self, declared_fields, model_info): + return ( + [model_info.pk.name] + + list(declared_fields.keys()) + + list(model_info.fields.keys()) + + list(model_info.forward_relations.keys()) + ) - return fields + def _get_nested_class(self, nested_depth, relation_info): + class NestedSerializer(ModelSerializer): + class Meta: + model = relation_info.related + depth = nested_depth + return NestedSerializer - def get_pk_field(self, model_field): - if self.opts.fields and model_field.name in self.opts.fields: - return self.get_field(model_field) - def get_related_field(self, model_field, related_model, to_many): - """ - Creates a default instance of a flat relational field. - """ - # TODO: filter queryset using: - # .using(db).complex_filter(self.rel.limit_choices_to) - kwargs = { - 'queryset': related_model._default_manager, - 'view_name': self._get_default_view_name(related_model), - 'many': to_many - } - - if model_field: - kwargs['required'] = not(model_field.null or model_field.blank) and model_field.editable - if model_field.help_text is not None: - kwargs['help_text'] = model_field.help_text - if model_field.verbose_name is not None: - kwargs['label'] = model_field.verbose_name - - if self.opts.lookup_field: - kwargs['lookup_field'] = self.opts.lookup_field - - return self._hyperlink_field_class(**kwargs) - - def get_identity(self, data): - """ - This hook is required for bulk update. - We need to override the default, to use the url as the identity. - """ - try: - return data.get(self.opts.url_field_name, None) - except AttributeError: - return None +class HyperlinkedModelSerializer(ModelSerializer): + _related_class = HyperlinkedRelatedField + + def _get_default_field_names(self, declared_fields, model_info): + return ( + [api_settings.URL_FIELD_NAME] + + list(declared_fields.keys()) + + list(model_info.fields.keys()) + + list(model_info.forward_relations.keys()) + ) - def _get_default_view_name(self, model): - """ - Return the view name to use if 'view_name' is not specified in 'Meta' - """ - model_meta = model._meta - format_kwargs = { - 'app_label': model_meta.app_label, - 'model_name': model_meta.object_name.lower() - } - return self._default_view_name % format_kwargs + def _get_nested_class(self, nested_depth, relation_info): + class NestedSerializer(HyperlinkedModelSerializer): + class Meta: + model = relation_info.related + depth = nested_depth + return NestedSerializer |
