diff options
Diffstat (limited to 'rest_framework/serializers.py')
| -rw-r--r-- | rest_framework/serializers.py | 1917 |
1 files changed, 949 insertions, 968 deletions
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 7d85894f..6f89df0d 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -11,20 +11,24 @@ python primitives. response content is handled by parsers and renderers. """ from __future__ import unicode_literals -import copy -import datetime -import inspect -import types -from decimal import Decimal -from django.contrib.contenttypes.generic import GenericForeignKey -from django.core.paginator import Page from django.db import models -from django.forms import widgets -from django.utils import six -from django.utils.datastructures import SortedDict -from django.utils.functional import cached_property -from django.core.exceptions import ObjectDoesNotExist -from rest_framework.settings import api_settings +from django.db.models.fields import FieldDoesNotExist +from django.utils.translation import ugettext_lazy as _ +from rest_framework.compat import unicode_to_repr +from rest_framework.utils import model_meta +from rest_framework.utils.field_mapping import ( + get_url_kwargs, get_field_kwargs, + get_relation_kwargs, get_nested_relation_kwargs, + ClassLookupDict +) +from rest_framework.utils.serializer_helpers import ( + ReturnDict, ReturnList, BoundField, NestedBoundField, BindingDict +) +from rest_framework.validators import ( + UniqueForDateValidator, UniqueForMonthValidator, UniqueForYearValidator, + UniqueTogetherValidator +) +import warnings # Note: We do the following so that users of the framework can use this style: @@ -38,1126 +42,1103 @@ from rest_framework.relations import * # NOQA from rest_framework.fields import * # NOQA -def _resolve_model(obj): - """ - Resolve supplied `obj` to a Django model class. +# We assume that 'validators' are intended for the child serializer, +# rather than the parent serializer. +LIST_SERIALIZER_KWARGS = ( + 'read_only', 'write_only', 'required', 'default', 'initial', 'source', + 'label', 'help_text', 'style', 'error_messages', + 'instance', 'data', 'partial', 'context' +) - `obj` must be a Django model class itself, or a string - representation of one. Useful in situtations like GH #1225 where - Django may not have resolved a string-based reference to a model in - another model's foreign key definition. - String representations should have the format: - 'appname.ModelName' - """ - if isinstance(obj, six.string_types) and len(obj.split('.')) == 2: - app_name, model_name = obj.split('.') - return models.get_model(app_name, model_name) - elif inspect.isclass(obj) and issubclass(obj, models.Model): - return obj - else: - raise ValueError("{0} is not a Django model".format(obj)) +# BaseSerializer +# -------------- +class BaseSerializer(Field): + """ + The BaseSerializer class provides a minimal class which may be used + for writing custom serializer implementations. -def pretty_name(name): - """Converts 'first_name' to 'First name'""" - if not name: - return '' - return name.replace('_', ' ').capitalize() + Note that we strongly restrict the ordering of operations/properties + that may be used on the serializer in order to enforce correct usage. + In particular, if a `data=` argument is passed then: -class RelationsList(list): - _deleted = [] + .is_valid() - Available. + .initial_data - Available. + .validated_data - Only available after calling `is_valid()` + .errors - Only available after calling `is_valid()` + .data - Only available after calling `is_valid()` + If a `data=` argument is not passed then: -class NestedValidationError(ValidationError): + .is_valid() - Not available. + .initial_data - Not available. + .validated_data - Not available. + .errors - Not available. + .data - Available. """ - The default ValidationError behavior is to stringify each item in the list - if the messages are a list of error messages. - In the case of nested serializers, where the parent has many children, - then the child's `serializer.errors` will be a list of dicts. In the case - of a single child, the `serializer.errors` will be a dict. + def __init__(self, instance=None, data=empty, **kwargs): + self.instance = instance + if data is not empty: + self.initial_data = data + self.partial = kwargs.pop('partial', False) + self._context = kwargs.pop('context', {}) + kwargs.pop('many', None) + super(BaseSerializer, self).__init__(**kwargs) - We need to override the default behavior to get properly nested error dicts. - """ + def __new__(cls, *args, **kwargs): + # We override this method in order to automagically create + # `ListSerializer` classes instead when `many=True` is set. + if kwargs.pop('many', False): + return cls.many_init(*args, **kwargs) + return super(BaseSerializer, cls).__new__(cls, *args, **kwargs) - def __init__(self, message): - if isinstance(message, dict): - self._messages = [message] - else: - self._messages = message + @classmethod + def many_init(cls, *args, **kwargs): + """ + This method implements the creation of a `ListSerializer` parent + class when `many=True` is used. You can customize it if you need to + control which keyword arguments are passed to the parent, and + which are passed to the child. + + Note that we're over-cautious in passing most arguments to both parent + and child classes in order to try to cover the general case. If you're + overriding this method you'll probably want something much simpler, eg: + + @classmethod + def many_init(cls, *args, **kwargs): + kwargs['child'] = cls() + return CustomListSerializer(*args, **kwargs) + """ + child_serializer = cls(*args, **kwargs) + list_kwargs = {'child': child_serializer} + list_kwargs.update(dict([ + (key, value) for key, value in kwargs.items() + if key in LIST_SERIALIZER_KWARGS + ])) + meta = getattr(cls, 'Meta', None) + list_serializer_class = getattr(meta, 'list_serializer_class', ListSerializer) + return list_serializer_class(*args, **list_kwargs) - @property - def messages(self): - return self._messages + def to_internal_value(self, data): + raise NotImplementedError('`to_internal_value()` must be implemented.') + def to_representation(self, instance): + raise NotImplementedError('`to_representation()` must be implemented.') -class DictWithMetadata(dict): - """ - A dict-like object, that can have additional properties attached. - """ - def __getstate__(self): - """ - Used by pickle (e.g., caching). - Overridden to remove the metadata from the dict, since it shouldn't be - pickled and may in some instances be unpickleable. - """ - return dict(self) + def update(self, instance, validated_data): + raise NotImplementedError('`update()` must be implemented.') + def create(self, validated_data): + raise NotImplementedError('`create()` must be implemented.') -class SortedDictWithMetadata(SortedDict): - """ - A sorted dict-like object, that can have additional properties attached. - """ - def __getstate__(self): - """ - Used by pickle (e.g., caching). - Overriden to remove the metadata from the dict, since it shouldn't be - pickle and may in some instances be unpickleable. - """ - return SortedDict(self).__dict__ + def save(self, **kwargs): + assert not hasattr(self, 'save_object'), ( + 'Serializer `%s.%s` has old-style version 2 `.save_object()` ' + 'that is no longer compatible with REST framework 3. ' + 'Use the new-style `.create()` and `.update()` methods instead.' % + (self.__class__.__module__, self.__class__.__name__) + ) + assert hasattr(self, '_errors'), ( + 'You must call `.is_valid()` before calling `.save()`.' + ) -def _is_protected_type(obj): - """ - True if the object is a native datatype that does not need to - be serialized further. - """ - return isinstance(obj, ( - types.NoneType, - int, long, - datetime.datetime, datetime.date, datetime.time, - float, Decimal, - basestring) - ) + assert not self.errors, ( + 'You cannot call `.save()` on a serializer with invalid data.' + ) + validated_data = dict( + list(self.validated_data.items()) + + list(kwargs.items()) + ) -def _get_declared_fields(bases, attrs): - """ - Create a list of serializer field instances from the passed in 'attrs', - plus any fields on the base classes (in 'bases'). + if self.instance is not None: + self.instance = self.update(self.instance, validated_data) + assert self.instance is not None, ( + '`update()` did not return an object instance.' + ) + else: + self.instance = self.create(validated_data) + assert self.instance is not None, ( + '`create()` did not return an object instance.' + ) - Note that all fields from the base classes are used. - """ - fields = [(field_name, attrs.pop(field_name)) - for field_name, obj in list(six.iteritems(attrs)) - if isinstance(obj, Field)] - fields.sort(key=lambda x: x[1].creation_counter) + return self.instance - # If this class is subclassing another Serializer, add that Serializer's - # fields. Note that we loop over the bases in *reverse*. This is necessary - # in order to maintain the correct order of fields. - for base in bases[::-1]: - if hasattr(base, 'base_fields'): - fields = list(base.base_fields.items()) + fields + def is_valid(self, raise_exception=False): + assert not hasattr(self, 'restore_object'), ( + 'Serializer `%s.%s` has old-style version 2 `.restore_object()` ' + 'that is no longer compatible with REST framework 3. ' + 'Use the new-style `.create()` and `.update()` methods instead.' % + (self.__class__.__module__, self.__class__.__name__) + ) - return SortedDict(fields) + assert hasattr(self, 'initial_data'), ( + 'Cannot call `.is_valid()` as no `data=` keyword argument was' + 'passed when instantiating the serializer instance.' + ) + if not hasattr(self, '_validated_data'): + try: + self._validated_data = self.run_validation(self.initial_data) + except ValidationError as exc: + self._validated_data = {} + self._errors = exc.detail + else: + self._errors = {} -class SerializerMetaclass(type): - def __new__(cls, name, bases, attrs): - attrs['base_fields'] = _get_declared_fields(bases, attrs) - return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs) + if self._errors and raise_exception: + raise ValidationError(self._errors) + return not bool(self._errors) -class SerializerOptions(object): - """ - Meta class options for Serializer - """ - def __init__(self, meta): - self.depth = getattr(meta, 'depth', 0) - self.fields = getattr(meta, 'fields', ()) - self.exclude = getattr(meta, 'exclude', ()) + @property + def data(self): + if hasattr(self, 'initial_data') and not hasattr(self, '_validated_data'): + msg = ( + 'When a serializer is passed a `data` keyword argument you ' + 'must call `.is_valid()` before attempting to access the ' + 'serialized `.data` representation.\n' + 'You should either call `.is_valid()` first, ' + 'or access `.initial_data` instead.' + ) + raise AssertionError(msg) + + if not hasattr(self, '_data'): + if self.instance is not None and not getattr(self, '_errors', None): + self._data = self.to_representation(self.instance) + elif hasattr(self, '_validated_data') and not getattr(self, '_errors', None): + self._data = self.to_representation(self.validated_data) + else: + self._data = self.get_initial() + return self._data + @property + def errors(self): + if not hasattr(self, '_errors'): + msg = 'You must call `.is_valid()` before accessing `.errors`.' + raise AssertionError(msg) + return self._errors -class BaseSerializer(WritableField): + @property + def validated_data(self): + if not hasattr(self, '_validated_data'): + msg = 'You must call `.is_valid()` before accessing `.validated_data`.' + raise AssertionError(msg) + return self._validated_data + + +# Serializer & ListSerializer classes +# ----------------------------------- + +class SerializerMetaclass(type): """ - This is the Serializer implementation. - We need to implement it as `BaseSerializer` due to metaclass magicks. + This metaclass sets a dictionary named `base_fields` on the class. + + Any instances of `Field` included as attributes on either the class + or on any of its superclasses will be include in the + `base_fields` dictionary. """ - class Meta(object): - pass - _options_class = SerializerOptions - _dict_class = SortedDictWithMetadata + @classmethod + def _get_declared_fields(cls, bases, attrs): + fields = [(field_name, attrs.pop(field_name)) + for field_name, obj in list(attrs.items()) + if isinstance(obj, Field)] + fields.sort(key=lambda x: x[1]._creation_counter) - def __init__(self, instance=None, data=None, files=None, - context=None, partial=False, many=False, - allow_add_remove=False, **kwargs): - super(BaseSerializer, self).__init__(**kwargs) - self.opts = self._options_class(self.Meta) - self.parent = None - self.root = None - self.partial = partial - self.many = many - self.allow_add_remove = allow_add_remove + # If this class is subclassing another Serializer, add that Serializer's + # fields. Note that we loop over the bases in *reverse*. This is necessary + # in order to maintain the correct order of fields. + for base in bases[::-1]: + if hasattr(base, '_declared_fields'): + fields = list(base._declared_fields.items()) + fields - self.context = context or {} + return OrderedDict(fields) - self.init_data = data - self.init_files = files - self.object = instance + def __new__(cls, name, bases, attrs): + attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs) + return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs) - self._data = None - self._files = None - self._errors = None - if many and instance is not None and not hasattr(instance, '__iter__'): - raise ValueError('instance should be a queryset or other iterable with many=True') +def get_validation_error_detail(exc): + assert isinstance(exc, (ValidationError, DjangoValidationError)) - if allow_add_remove and not many: - raise ValueError('allow_add_remove should only be used for bulk updates, but you have not set many=True') + if isinstance(exc, DjangoValidationError): + # Normally you should raise `serializers.ValidationError` + # inside your codebase, but we handle Django's validation + # exception class as well for simpler compat. + # Eg. Calling Model.clean() explicitly inside Serializer.validate() + return { + api_settings.NON_FIELD_ERRORS_KEY: list(exc.messages) + } + elif isinstance(exc.detail, dict): + # If errors may be a dict we use the standard {key: list of values}. + # Here we ensure that all the values are *lists* of errors. + return dict([ + (key, value if isinstance(value, list) else [value]) + for key, value in exc.detail.items() + ]) + elif isinstance(exc.detail, list): + # Errors raised as a list are non-field errors. + return { + api_settings.NON_FIELD_ERRORS_KEY: exc.detail + } + # Errors raised as a string are non-field errors. + return { + api_settings.NON_FIELD_ERRORS_KEY: [exc.detail] + } - ##### - # Methods to determine which fields to use when (de)serializing objects. - @cached_property - def fields(self): - return self.get_fields() +@six.add_metaclass(SerializerMetaclass) +class Serializer(BaseSerializer): + default_error_messages = { + 'invalid': _('Invalid data. Expected a dictionary, but got {datatype}.') + } - def get_default_fields(self): + @property + def fields(self): """ - Return the complete set of default fields for the object, as a dict. + A dictionary of {field_name: field_instance}. """ - return {} + # `fields` is evaluated lazily. We do this to ensure that we don't + # have issues importing modules that use ModelSerializers as fields, + # even if Django's app-loading stage has not yet run. + if not hasattr(self, '_fields'): + self._fields = BindingDict(self) + for key, value in self.get_fields().items(): + self._fields[key] = value + return self._fields def get_fields(self): """ - Returns the complete set of fields for the object as a dict. + Returns a dictionary of {field_name: field_instance}. + """ + # Every new serializer is created with a clone of the field instances. + # This allows users to dynamically modify the fields on a serializer + # instance without affecting every other serializer class. + return copy.deepcopy(self._declared_fields) - This will be the set of any explicitly declared fields, - plus the set of fields returned by get_default_fields(). + def get_validators(self): + """ + Returns a list of validator callables. """ - ret = SortedDict() - - # Get the explicitly declared fields - base_fields = copy.deepcopy(self.base_fields) - for key, field in base_fields.items(): - ret[key] = field - - # Add in the default fields - default_fields = self.get_default_fields() - for key, val in default_fields.items(): - if key not in ret: - ret[key] = val - - # If 'fields' is specified, use those fields, in that order. - if self.opts.fields: - assert isinstance(self.opts.fields, (list, tuple)), '`fields` must be a list or tuple' - new = SortedDict() - for key in self.opts.fields: - new[key] = ret[key] - ret = new - - # Remove anything in 'exclude' - if self.opts.exclude: - assert isinstance(self.opts.exclude, (list, tuple)), '`exclude` must be a list or tuple' - for key in self.opts.exclude: - ret.pop(key, None) - - for key, field in ret.items(): - field.initialize(parent=self, field_name=key) + # Used by the lazily-evaluated `validators` property. + return getattr(getattr(self, 'Meta', None), 'validators', []) + + def get_initial(self): + if hasattr(self, 'initial_data'): + return OrderedDict([ + (field_name, field.get_value(self.initial_data)) + for field_name, field in self.fields.items() + if field.get_value(self.initial_data) is not empty + and not field.read_only + ]) - return ret + return OrderedDict([ + (field.field_name, field.get_initial()) + for field in self.fields.values() + if not field.read_only + ]) - ##### - # Methods to convert or revert from objects <--> primitive representations. + def get_value(self, dictionary): + # We override the default field access in order to support + # nested HTML forms. + if html.is_html_input(dictionary): + return html.parse_html_dict(dictionary, prefix=self.field_name) + return dictionary.get(self.field_name, empty) - def get_field_key(self, field_name): + def run_validation(self, data=empty): """ - Return the key that should be used for a given field. + We override the default `run_validation`, because the validation + performed by validators and the `.validate()` method should + be coerced into an error dictionary with a 'non_fields_error' key. """ - return field_name + (is_empty_value, data) = self.validate_empty_values(data) + if is_empty_value: + return data - def restore_fields(self, data, files): - """ - Core of deserialization, together with `restore_object`. - Converts a dictionary of data into a dictionary of deserialized fields. - """ - reverted_data = {} + value = self.to_internal_value(data) + try: + self.run_validators(value) + value = self.validate(value) + assert value is not None, '.validate() should return the validated data' + except (ValidationError, DjangoValidationError) as exc: + raise ValidationError(detail=get_validation_error_detail(exc)) - if data is not None and not isinstance(data, dict): - self._errors['non_field_errors'] = ['Invalid data'] - return None + return value - for field_name, field in self.fields.items(): - field.initialize(parent=self, field_name=field_name) + def to_internal_value(self, data): + """ + Dict of native values <- Dict of primitive datatypes. + """ + if not isinstance(data, dict): + message = self.error_messages['invalid'].format( + datatype=type(data).__name__ + ) + raise ValidationError({ + api_settings.NON_FIELD_ERRORS_KEY: [message] + }) + + ret = OrderedDict() + errors = OrderedDict() + fields = [ + field for field in self.fields.values() + if (not field.read_only) or (field.default is not empty) + ] + + for field in fields: + validate_method = getattr(self, 'validate_' + field.field_name, None) + primitive_value = field.get_value(data) try: - field.field_from_native(data, files, field_name, reverted_data) - except ValidationError as err: - self._errors[field_name] = list(err.messages) + validated_value = field.run_validation(primitive_value) + if validate_method is not None: + validated_value = validate_method(validated_value) + except ValidationError as exc: + errors[field.field_name] = exc.detail + except DjangoValidationError as exc: + errors[field.field_name] = list(exc.messages) + except SkipField: + pass + else: + set_value(ret, field.source_attrs, validated_value) + + if errors: + raise ValidationError(errors) - return reverted_data + return ret - def perform_validation(self, attrs): + def to_representation(self, instance): """ - Run `validate_<fieldname>()` and `validate()` methods on the serializer + Object instance -> Dict of primitive datatypes. """ - for field_name, field in self.fields.items(): - if field_name in self._errors: - continue + ret = OrderedDict() + fields = [field for field in self.fields.values() if not field.write_only] - source = field.source or field_name - if self.partial and source not in attrs: - continue - try: - validate_method = getattr(self, 'validate_%s' % field_name, None) - if validate_method: - attrs = validate_method(attrs, source) - except ValidationError as err: - self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages) - - # If there are already errors, we don't run .validate() because - # field-validation failed and thus `attrs` may not be complete. - # which in turn can cause inconsistent validation errors. - if not self._errors: - try: - attrs = self.validate(attrs) - except ValidationError as err: - if hasattr(err, 'message_dict'): - for field_name, error_messages in err.message_dict.items(): - self._errors[field_name] = self._errors.get(field_name, []) + list(error_messages) - elif hasattr(err, 'messages'): - self._errors['non_field_errors'] = err.messages + for field in fields: + attribute = field.get_attribute(instance) + if attribute is None: + ret[field.field_name] = None + else: + ret[field.field_name] = field.to_representation(attribute) - return attrs + return ret def validate(self, attrs): - """ - Stub method, to be overridden in Serializer subclasses - """ return attrs - def restore_object(self, attrs, instance=None): - """ - Deserialize a dictionary of attributes into an object instance. - You should override this method to control how deserialized objects - are instantiated. - """ - if instance is not None: - instance.update(attrs) - return instance - return attrs + def __repr__(self): + return unicode_to_repr(representation.serializer_repr(self, indent=1)) - def to_native(self, obj): - """ - Serialize objects -> primitives. - """ - ret = self._dict_class() - ret.fields = self._dict_class() + # The following are used for accessing `BoundField` instances on the + # serializer, for the purposes of presenting a form-like API onto the + # field values and field errors. - for field_name, field in self.fields.items(): - if field.read_only and obj is None: - continue - field.initialize(parent=self, field_name=field_name) - key = self.get_field_key(field_name) - value = field.field_to_native(obj, field_name) - method = getattr(self, 'transform_%s' % field_name, None) - if callable(method): - value = method(obj, value) - if not getattr(field, 'write_only', False): - ret[key] = value - ret.fields[key] = self.augment_field(field, field_name, key, value) + def __iter__(self): + for field in self.fields.values(): + yield self[field.field_name] - return ret + def __getitem__(self, key): + field = self.fields[key] + value = self.data.get(key) + error = self.errors.get(key) if hasattr(self, '_errors') else None + if isinstance(field, Serializer): + return NestedBoundField(field, value, error) + return BoundField(field, value, error) - def from_native(self, data, files=None): - """ - Deserialize primitives -> objects. - """ - self._errors = {} + # Include a backlink to the serializer class on return objects. + # Allows renderers such as HTMLFormRenderer to get the full field info. - if data is not None or files is not None: - attrs = self.restore_fields(data, files) - if attrs is not None: - attrs = self.perform_validation(attrs) - else: - self._errors['non_field_errors'] = ['No input provided'] - - if not self._errors: - return self.restore_object(attrs, instance=getattr(self, 'object', None)) - - def augment_field(self, field, field_name, key, value): - # This horrible stuff is to manage serializers rendering to HTML - field._errors = self._errors.get(key) if self._errors else None - field._name = field_name - field._value = self.init_data.get(key) if self._errors and self.init_data else value - if not field.label: - field.label = pretty_name(key) - return field + @property + def data(self): + ret = super(Serializer, self).data + return ReturnDict(ret, serializer=self) - def field_to_native(self, obj, field_name): - """ - Override default so that the serializer can be used as a nested field - across relationships. - """ - if self.write_only: - return None + @property + def errors(self): + ret = super(Serializer, self).errors + return ReturnDict(ret, serializer=self) - if self.source == '*': - return self.to_native(obj) - # Get the raw field value - try: - source = self.source or field_name - value = obj +# There's some replication of `ListField` here, +# but that's probably better than obfuscating the call hierarchy. - for component in source.split('.'): - if value is None: - break - value = get_component(value, component) - except ObjectDoesNotExist: - return None +class ListSerializer(BaseSerializer): + child = None + many = True - if is_simple_callable(getattr(value, 'all', None)): - return [self.to_native(item) for item in value.all()] + default_error_messages = { + 'not_a_list': _('Expected a list of items but got type `{input_type}`.') + } - if value is None: - return None + def __init__(self, *args, **kwargs): + self.child = kwargs.pop('child', copy.deepcopy(self.child)) + assert self.child is not None, '`child` is a required argument.' + assert not inspect.isclass(self.child), '`child` has not been instantiated.' + super(ListSerializer, self).__init__(*args, **kwargs) + self.child.bind(field_name='', parent=self) - if self.many: - return [self.to_native(item) for item in value] - return self.to_native(value) + def get_initial(self): + if hasattr(self, 'initial_data'): + return self.to_representation(self.initial_data) + return [] - def field_from_native(self, data, files, field_name, into): + def get_value(self, dictionary): """ - Override default so that the serializer can be used as a writable - nested field across relationships. + Given the input dictionary, return the field value. """ - if self.read_only: - return - - try: - value = data[field_name] - except KeyError: - if self.default is not None and not self.partial: - # Note: partial updates shouldn't set defaults - value = copy.deepcopy(self.default) - else: - if self.required: - raise ValidationError(self.error_messages['required']) - return - - if self.source == '*': - if value: - reverted_data = self.restore_fields(value, {}) - if not self._errors: - into.update(reverted_data) - else: - if value in (None, ''): - into[(self.source or field_name)] = None - else: - # Set the serializer object if it exists - obj = get_component(self.parent.object, self.source or field_name) if self.parent.object else None - - # If we have a model manager or similar object then we need - # to iterate through each instance. - if ( - self.many and - not hasattr(obj, '__iter__') and - is_simple_callable(getattr(obj, 'all', None)) - ): - obj = obj.all() - - kwargs = { - 'instance': obj, - 'data': value, - 'context': self.context, - 'partial': self.partial, - 'many': self.many, - 'allow_add_remove': self.allow_add_remove - } - serializer = self.__class__(**kwargs) - - if serializer.is_valid(): - into[self.source or field_name] = serializer.object - else: - # Propagate errors up to our parent - raise NestedValidationError(serializer.errors) + # We override the default field access in order to support + # lists in HTML forms. + if html.is_html_input(dictionary): + return html.parse_html_list(dictionary, prefix=self.field_name) + return dictionary.get(self.field_name, empty) - def get_identity(self, data): + def run_validation(self, data=empty): """ - This hook is required for bulk update. - It is used to determine the canonical identity of a given object. - - Note that the data has not been validated at this point, so we need - to make sure that we catch any cases of incorrect datatypes being - passed to this method. + We override the default `run_validation`, because the validation + performed by validators and the `.validate()` method should + be coerced into an error dictionary with a 'non_fields_error' key. """ + (is_empty_value, data) = self.validate_empty_values(data) + if is_empty_value: + return data + + value = self.to_internal_value(data) try: - return data.get('id', None) - except AttributeError: - return None + self.run_validators(value) + value = self.validate(value) + assert value is not None, '.validate() should return the validated data' + except (ValidationError, DjangoValidationError) as exc: + raise ValidationError(detail=get_validation_error_detail(exc)) - @property - def errors(self): + return value + + def to_internal_value(self, data): """ - Run deserialization and return error data, - setting self.object if no errors occurred. + List of dicts of native values <- List of dicts of primitive datatypes. """ - if self._errors is None: - data, files = self.init_data, self.init_files + if html.is_html_input(data): + data = html.parse_html_list(data) - if self.many is not None: - many = self.many - else: - many = hasattr(data, '__iter__') and not isinstance(data, (Page, dict, six.text_type)) - if many: - warnings.warn('Implicit list/queryset serialization is deprecated. ' - 'Use the `many=True` flag when instantiating the serializer.', - DeprecationWarning, stacklevel=3) - - if many: - ret = RelationsList() - errors = [] - update = self.object is not None - - if update: - # If this is a bulk update we need to map all the objects - # to a canonical identity so we can determine which - # individual object is being updated for each item in the - # incoming data - objects = self.object - identities = [self.get_identity(self.to_native(obj)) for obj in objects] - identity_to_objects = dict(zip(identities, objects)) - - if hasattr(data, '__iter__') and not isinstance(data, (dict, six.text_type)): - for item in data: - if update: - # Determine which object we're updating - identity = self.get_identity(item) - self.object = identity_to_objects.pop(identity, None) - if self.object is None and not self.allow_add_remove: - ret.append(None) - errors.append({'non_field_errors': ['Cannot create a new item, only existing items may be updated.']}) - continue - - ret.append(self.from_native(item, None)) - errors.append(self._errors) - - if update and self.allow_add_remove: - ret._deleted = identity_to_objects.values() - - self._errors = any(errors) and errors or [] - else: - self._errors = {'non_field_errors': ['Expected a list of items.']} - else: - ret = self.from_native(data, files) + if not isinstance(data, list): + message = self.error_messages['not_a_list'].format( + input_type=type(data).__name__ + ) + raise ValidationError({ + api_settings.NON_FIELD_ERRORS_KEY: [message] + }) - if not self._errors: - self.object = ret + ret = [] + errors = [] - return self._errors + for item in data: + try: + validated = self.child.run_validation(item) + except ValidationError as exc: + errors.append(exc.detail) + else: + ret.append(validated) + errors.append({}) - def is_valid(self): - return not self.errors + if any(errors): + raise ValidationError(errors) - @property - def data(self): + return ret + + def to_representation(self, data): """ - Returns the serialized data on the serializer. + List of object instances -> List of dicts of primitive datatypes. """ - if self._data is None: - obj = self.object + # Dealing with nested relationships, data can be a Manager, + # so, first get a queryset from the Manager if needed + iterable = data.all() if isinstance(data, models.Manager) else data + return [ + self.child.to_representation(item) for item in iterable + ] - if self.many is not None: - many = self.many - else: - many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict)) - if many: - warnings.warn('Implicit list/queryset serialization is deprecated. ' - 'Use the `many=True` flag when instantiating the serializer.', - DeprecationWarning, stacklevel=2) - - if many: - self._data = [self.to_native(item) for item in obj] - else: - self._data = self.to_native(obj) - - return self._data + def validate(self, attrs): + return attrs - def save_object(self, obj, **kwargs): - obj.save(**kwargs) + def update(self, instance, validated_data): + raise NotImplementedError( + "Serializers with many=True do not support multiple update by " + "default, only multiple create. For updates it is unclear how to " + "deal with insertions and deletions. If you need to support " + "multiple update, use a `ListSerializer` class and override " + "`.update()` so you can specify the behavior exactly." + ) - def delete_object(self, obj): - obj.delete() + def create(self, validated_data): + return [ + self.child.create(attrs) for attrs in validated_data + ] def save(self, **kwargs): """ - Save the deserialized object and return it. + Save and return a list of object instances. """ - # Clear cached _data, which may be invalidated by `save()` - self._data = None + validated_data = [ + dict(list(attrs.items()) + list(kwargs.items())) + for attrs in self.validated_data + ] + + if self.instance is not None: + self.instance = self.update(self.instance, validated_data) + assert self.instance is not None, ( + '`update()` did not return an object instance.' + ) + else: + self.instance = self.create(validated_data) + assert self.instance is not None, ( + '`create()` did not return an object instance.' + ) - if isinstance(self.object, list): - [self.save_object(item, **kwargs) for item in self.object] + return self.instance - if self.object._deleted: - [self.delete_object(item) for item in self.object._deleted] - else: - self.save_object(self.object, **kwargs) + def __repr__(self): + return unicode_to_repr(representation.list_repr(self, indent=1)) - return self.object + # Include a backlink to the serializer class on return objects. + # Allows renderers such as HTMLFormRenderer to get the full field info. - def metadata(self): - """ - Return a dictionary of metadata about the fields on the serializer. - Useful for things like responding to OPTIONS requests, or generating - API schemas for auto-documentation. - """ - return SortedDict( - [ - (field_name, field.metadata()) - for field_name, field in six.iteritems(self.fields) - ] - ) + @property + def data(self): + ret = super(ListSerializer, self).data + return ReturnList(ret, serializer=self) + @property + def errors(self): + ret = super(ListSerializer, self).errors + if isinstance(ret, dict): + return ReturnDict(ret, serializer=self) + return ReturnList(ret, serializer=self) -class Serializer(six.with_metaclass(SerializerMetaclass, BaseSerializer)): - pass +# ModelSerializer & HyperlinkedModelSerializer +# -------------------------------------------- -class ModelSerializerOptions(SerializerOptions): - """ - Meta class options for ModelSerializer +def raise_errors_on_nested_writes(method_name, serializer, validated_data): """ - def __init__(self, meta): - super(ModelSerializerOptions, self).__init__(meta) - self.model = getattr(meta, 'model', None) - self.read_only_fields = getattr(meta, 'read_only_fields', ()) - self.write_only_fields = getattr(meta, 'write_only_fields', ()) + Give explicit errors when users attempt to pass writable nested data. + If we don't do this explicitly they'd get a less helpful error when + calling `.save()` on the serializer. -def _get_class_mapping(mapping, obj): - """ - Takes a dictionary with classes as keys, and an object. - Traverses the object's inheritance hierarchy in method - resolution order, and returns the first matching value - from the dictionary or None. + We don't *automatically* support these sorts of nested writes brecause + there are too many ambiguities to define a default behavior. + + Eg. Suppose we have a `UserSerializer` with a nested profile. How should + we handle the case of an update, where the `profile` realtionship does + not exist? Any of the following might be valid: + * Raise an application error. + * Silently ignore the nested part of the update. + * Automatically create a profile instance. """ - return next( - (mapping[cls] for cls in inspect.getmro(obj.__class__) if cls in mapping), - None + + # Ensure we don't have a writable nested field. For example: + # + # class UserSerializer(ModelSerializer): + # ... + # profile = ProfileSerializer() + assert not any( + isinstance(field, BaseSerializer) and (key in validated_data) + and isinstance(validated_data[key], (list, dict)) + for key, field in serializer.fields.items() + ), ( + 'The `.{method_name}()` method does not support writable nested' + 'fields by default.\nWrite an explicit `.{method_name}()` method for ' + 'serializer `{module}.{class_name}`, or set `read_only=True` on ' + 'nested serializer fields.'.format( + method_name=method_name, + module=serializer.__class__.__module__, + class_name=serializer.__class__.__name__ + ) + ) + + # Ensure we don't have a writable dotted-source field. For example: + # + # class UserSerializer(ModelSerializer): + # ... + # address = serializer.CharField('profile.address') + assert not any( + '.' in field.source and (key in validated_data) + and isinstance(validated_data[key], (list, dict)) + for key, field in serializer.fields.items() + ), ( + 'The `.{method_name}()` method does not support writable dotted-source ' + 'fields by default.\nWrite an explicit `.{method_name}()` method for ' + 'serializer `{module}.{class_name}`, or set `read_only=True` on ' + 'dotted-source serializer fields.'.format( + method_name=method_name, + module=serializer.__class__.__module__, + class_name=serializer.__class__.__name__ + ) ) class ModelSerializer(Serializer): """ - A serializer that deals with model instances and querysets. - """ - _options_class = ModelSerializerOptions + A `ModelSerializer` is just a regular `Serializer`, except that: + + * A set of default fields are automatically populated. + * A set of default validators are automatically populated. + * Default `.create()` and `.update()` implementations are provided. - field_mapping = { + The process of automatically determining a set of serializer fields + based on the model fields is reasonably complex, but you almost certainly + don't need to dig into the implementation. + + If the `ModelSerializer` class *doesn't* generate the set of fields that + you need you should either declare the extra/differing fields explicitly on + the serializer class, or simply use a `Serializer` class. + """ + _field_mapping = ClassLookupDict({ models.AutoField: IntegerField, + models.BigIntegerField: IntegerField, + models.BooleanField: BooleanField, + models.CharField: CharField, + models.CommaSeparatedIntegerField: CharField, + models.DateField: DateField, + models.DateTimeField: DateTimeField, + models.DecimalField: DecimalField, + models.EmailField: EmailField, + models.Field: ModelField, + models.FileField: FileField, models.FloatField: FloatField, + models.ImageField: ImageField, models.IntegerField: IntegerField, + models.NullBooleanField: NullBooleanField, models.PositiveIntegerField: IntegerField, - models.SmallIntegerField: IntegerField, models.PositiveSmallIntegerField: IntegerField, - models.DateTimeField: DateTimeField, - models.DateField: DateField, - models.TimeField: TimeField, - models.DecimalField: DecimalField, - models.EmailField: EmailField, - models.CharField: CharField, - models.URLField: URLField, models.SlugField: SlugField, + models.SmallIntegerField: IntegerField, models.TextField: CharField, - models.CommaSeparatedIntegerField: CharField, - models.BooleanField: BooleanField, - models.NullBooleanField: BooleanField, - models.FileField: FileField, - models.ImageField: ImageField, - } - - def get_default_fields(self): - """ - Return all the fields that should be serialized for the model. - """ - - cls = self.opts.model - assert cls is not None, ( - "Serializer class '%s' is missing 'model' Meta option" % - self.__class__.__name__ - ) - opts = cls._meta.concrete_model._meta - ret = SortedDict() - nested = bool(self.opts.depth) - - # Deal with adding the primary key field - pk_field = opts.pk - while pk_field.rel and pk_field.rel.parent_link: - # If model is a child via multitable inheritance, use parent's pk - pk_field = pk_field.rel.to._meta.pk - - serializer_pk_field = self.get_pk_field(pk_field) - if serializer_pk_field: - ret[pk_field.name] = serializer_pk_field - - # Deal with forward relationships - forward_rels = [field for field in opts.fields if field.serialize] - forward_rels += [field for field in opts.many_to_many if field.serialize] - - for model_field in forward_rels: - has_through_model = False - - if model_field.rel: - to_many = isinstance(model_field, - models.fields.related.ManyToManyField) - related_model = _resolve_model(model_field.rel.to) - - if to_many and not model_field.rel.through._meta.auto_created: - has_through_model = True - - if model_field.rel and nested: - if len(inspect.getargspec(self.get_nested_field).args) == 2: - warnings.warn( - 'The `get_nested_field(model_field)` call signature ' - 'is deprecated. ' - 'Use `get_nested_field(model_field, related_model, ' - 'to_many) instead', - DeprecationWarning - ) - field = self.get_nested_field(model_field) - else: - field = self.get_nested_field(model_field, related_model, to_many) - elif model_field.rel: - if len(inspect.getargspec(self.get_nested_field).args) == 3: - warnings.warn( - 'The `get_related_field(model_field, to_many)` call ' - 'signature is deprecated. ' - 'Use `get_related_field(model_field, related_model, ' - 'to_many) instead', - DeprecationWarning - ) - field = self.get_related_field(model_field, to_many=to_many) - else: - field = self.get_related_field(model_field, related_model, to_many) - else: - field = self.get_field(model_field) - - if field: - if has_through_model: - field.read_only = True - - ret[model_field.name] = field - - # Deal with reverse relationships - if not self.opts.fields: - reverse_rels = [] - else: - # Reverse relationships are only included if they are explicitly - # present in the `fields` option on the serializer - reverse_rels = opts.get_all_related_objects() - reverse_rels += opts.get_all_related_many_to_many_objects() - - for relation in reverse_rels: - accessor_name = relation.get_accessor_name() - if not self.opts.fields or accessor_name not in self.opts.fields: - continue - related_model = relation.model - to_many = relation.field.rel.multiple - has_through_model = False - is_m2m = isinstance(relation.field, - models.fields.related.ManyToManyField) - - if ( - is_m2m and - hasattr(relation.field.rel, 'through') and - not relation.field.rel.through._meta.auto_created - ): - has_through_model = True - - if nested: - field = self.get_nested_field(None, related_model, to_many) - else: - field = self.get_related_field(None, related_model, to_many) - - if field: - if has_through_model: - field.read_only = True - - ret[accessor_name] = field - - # Ensure that 'read_only_fields' is an iterable - assert isinstance(self.opts.read_only_fields, (list, tuple)), '`read_only_fields` must be a list or tuple' - - # Add the `read_only` flag to any fields that have been specified - # in the `read_only_fields` option - for field_name in self.opts.read_only_fields: - assert field_name not in self.base_fields.keys(), ( - "field '%s' on serializer '%s' specified in " - "`read_only_fields`, but also added " - "as an explicit field. Remove it from `read_only_fields`." % - (field_name, self.__class__.__name__)) - assert field_name in ret, ( - "Non-existant field '%s' specified in `read_only_fields` " - "on serializer '%s'." % - (field_name, self.__class__.__name__)) - ret[field_name].read_only = True - - # Ensure that 'write_only_fields' is an iterable - assert isinstance(self.opts.write_only_fields, (list, tuple)), '`write_only_fields` must be a list or tuple' - - for field_name in self.opts.write_only_fields: - assert field_name not in self.base_fields.keys(), ( - "field '%s' on serializer '%s' specified in " - "`write_only_fields`, but also added " - "as an explicit field. Remove it from `write_only_fields`." % - (field_name, self.__class__.__name__)) - assert field_name in ret, ( - "Non-existant field '%s' specified in `write_only_fields` " - "on serializer '%s'." % - (field_name, self.__class__.__name__)) - ret[field_name].write_only = True - - return ret + models.TimeField: TimeField, + models.URLField: URLField, + }) + _related_class = PrimaryKeyRelatedField - def get_pk_field(self, model_field): - """ - Returns a default instance of the pk field. + def create(self, validated_data): """ - return self.get_field(model_field) + We have a bit of extra checking around this in order to provide + descriptive messages when something goes wrong, but this method is + essentially just: - def get_nested_field(self, model_field, related_model, to_many): - """ - Creates a default instance of a nested relational field. + return ExampleModel.objects.create(**validated_data) - Note that model_field will be `None` for reverse relationships. - """ - class NestedModelSerializer(ModelSerializer): - class Meta: - model = related_model - depth = self.opts.depth - 1 + If there are many to many fields present on the instance then they + cannot be set until the model is instantiated, in which case the + implementation is like so: - return NestedModelSerializer(many=to_many) + example_relationship = validated_data.pop('example_relationship') + instance = ExampleModel.objects.create(**validated_data) + instance.example_relationship = example_relationship + return instance - def get_related_field(self, model_field, related_model, to_many): + The default implementation also does not handle nested relationships. + If you want to support writable nested relationships you'll need + to write an explicit `.create()` method. """ - Creates a default instance of a flat relational field. + raise_errors_on_nested_writes('create', self, validated_data) - Note that model_field will be `None` for reverse relationships. - """ - # TODO: filter queryset using: - # .using(db).complex_filter(self.rel.limit_choices_to) + ModelClass = self.Meta.model - kwargs = { - 'queryset': related_model._default_manager, - 'many': to_many - } + # Remove many-to-many relationships from validated_data. + # They are not valid arguments to the default `.create()` method, + # as they require that the instance has already been saved. + info = model_meta.get_field_info(ModelClass) + many_to_many = {} + for field_name, relation_info in info.relations.items(): + if relation_info.to_many and (field_name in validated_data): + many_to_many[field_name] = validated_data.pop(field_name) - if model_field: - kwargs['required'] = not(model_field.null or model_field.blank) and model_field.editable - if model_field.help_text is not None: - kwargs['help_text'] = model_field.help_text - if model_field.verbose_name is not None: - kwargs['label'] = model_field.verbose_name - - if not model_field.editable: - kwargs['read_only'] = True - - if model_field.verbose_name is not None: - kwargs['label'] = model_field.verbose_name - - if model_field.help_text is not None: - kwargs['help_text'] = model_field.help_text + try: + instance = ModelClass.objects.create(**validated_data) + except TypeError as exc: + msg = ( + 'Got a `TypeError` when calling `%s.objects.create()`. ' + 'This may be because you have a writable field on the ' + 'serializer class that is not a valid argument to ' + '`%s.objects.create()`. You may need to make the field ' + 'read-only, or override the %s.create() method to handle ' + 'this correctly.\nOriginal exception text was: %s.' % + ( + ModelClass.__name__, + ModelClass.__name__, + self.__class__.__name__, + exc + ) + ) + raise TypeError(msg) - return PrimaryKeyRelatedField(**kwargs) + # Save many-to-many relationships after the instance is created. + if many_to_many: + for field_name, value in many_to_many.items(): + setattr(instance, field_name, value) - def get_field(self, model_field): - """ - Creates a default instance of a basic non-relational field. - """ - kwargs = {} - - if model_field.null or model_field.blank and model_field.editable: - kwargs['required'] = False - - if isinstance(model_field, models.AutoField) or not model_field.editable: - kwargs['read_only'] = True - - if model_field.has_default(): - kwargs['default'] = model_field.get_default() - - if issubclass(model_field.__class__, models.TextField): - kwargs['widget'] = widgets.Textarea - - if model_field.verbose_name is not None: - kwargs['label'] = model_field.verbose_name - - if model_field.help_text is not None: - kwargs['help_text'] = model_field.help_text - - # TODO: TypedChoiceField? - if model_field.flatchoices: # This ModelField contains choices - kwargs['choices'] = model_field.flatchoices - if model_field.null: - kwargs['empty'] = None - return ChoiceField(**kwargs) - - # put this below the ChoiceField because min_value isn't a valid initializer - if issubclass(model_field.__class__, models.PositiveIntegerField) or\ - issubclass(model_field.__class__, models.PositiveSmallIntegerField): - kwargs['min_value'] = 0 - - if model_field.null and \ - issubclass(model_field.__class__, (models.CharField, models.TextField)): - kwargs['allow_none'] = True - - attribute_dict = { - models.CharField: ['max_length'], - models.CommaSeparatedIntegerField: ['max_length'], - models.DecimalField: ['max_digits', 'decimal_places'], - models.EmailField: ['max_length'], - models.FileField: ['max_length'], - models.ImageField: ['max_length'], - models.SlugField: ['max_length'], - models.URLField: ['max_length'], - } + return instance - attributes = _get_class_mapping(attribute_dict, model_field) - if attributes: - for attribute in attributes: - kwargs.update({attribute: getattr(model_field, attribute)}) + def update(self, instance, validated_data): + raise_errors_on_nested_writes('update', self, validated_data) - serializer_field_class = _get_class_mapping( - self.field_mapping, model_field) + for attr, value in validated_data.items(): + setattr(instance, attr, value) + instance.save() - if serializer_field_class: - return serializer_field_class(**kwargs) - return ModelField(model_field=model_field, **kwargs) + return instance - def get_validation_exclusions(self, instance=None): - """ - Return a list of field names to exclude from model validation. - """ - cls = self.opts.model - opts = cls._meta.concrete_model._meta - exclusions = [field.name for field in opts.fields + opts.many_to_many] - - for field_name, field in self.fields.items(): - field_name = field.source or field_name - if ( - field_name in exclusions - and not field.read_only - and (field.required or hasattr(instance, field_name)) - and not isinstance(field, Serializer) - ): - exclusions.remove(field_name) - return exclusions + def get_validators(self): + # If the validators have been declared explicitly then use that. + validators = getattr(getattr(self, 'Meta', None), 'validators', None) + if validators is not None: + return validators + + # Determine the default set of validators. + validators = [] + model_class = self.Meta.model + field_names = set([ + field.source for field in self.fields.values() + if (field.source != '*') and ('.' not in field.source) + ]) + + # Note that we make sure to check `unique_together` both on the + # base model class, but also on any parent classes. + for parent_class in [model_class] + list(model_class._meta.parents.keys()): + for unique_together in parent_class._meta.unique_together: + if field_names.issuperset(set(unique_together)): + validator = UniqueTogetherValidator( + queryset=parent_class._default_manager, + fields=unique_together + ) + validators.append(validator) + + # Add any unique_for_date/unique_for_month/unique_for_year constraints. + info = model_meta.get_field_info(model_class) + for field_name, field in info.fields_and_pk.items(): + if field.unique_for_date and field_name in field_names: + validator = UniqueForDateValidator( + queryset=model_class._default_manager, + field=field_name, + date_field=field.unique_for_date + ) + validators.append(validator) + + if field.unique_for_month and field_name in field_names: + validator = UniqueForMonthValidator( + queryset=model_class._default_manager, + field=field_name, + date_field=field.unique_for_month + ) + validators.append(validator) + + if field.unique_for_year and field_name in field_names: + validator = UniqueForYearValidator( + queryset=model_class._default_manager, + field=field_name, + date_field=field.unique_for_year + ) + validators.append(validator) + + return validators - def full_clean(self, instance): - """ - Perform Django's full_clean, and populate the `errors` dictionary - if any validation errors occur. + def get_fields(self): + declared_fields = copy.deepcopy(self._declared_fields) + + ret = OrderedDict() + model = getattr(self.Meta, 'model') + fields = getattr(self.Meta, 'fields', None) + exclude = getattr(self.Meta, 'exclude', None) + depth = getattr(self.Meta, 'depth', 0) + extra_kwargs = getattr(self.Meta, 'extra_kwargs', {}) + + if fields and not isinstance(fields, (list, tuple)): + raise TypeError( + 'The `fields` option must be a list or tuple. Got %s.' % + type(fields).__name__ + ) - Note that we don't perform this inside the `.restore_object()` method, - so that subclasses can override `.restore_object()`, and still get - the full_clean validation checking. - """ - try: - instance.full_clean(exclude=self.get_validation_exclusions(instance)) - except ValidationError as err: - self._errors = err.message_dict - return None - return instance + if exclude and not isinstance(exclude, (list, tuple)): + raise TypeError( + 'The `exclude` option must be a list or tuple. Got %s.' % + type(exclude).__name__ + ) - def restore_object(self, attrs, instance=None): - """ - Restore the model instance. - """ - m2m_data = {} - related_data = {} - nested_forward_relations = {} - meta = self.opts.model._meta - - # Reverse fk or one-to-one relations - for (obj, model) in meta.get_all_related_objects_with_model(): - field_name = obj.get_accessor_name() - if field_name in attrs: - related_data[field_name] = attrs.pop(field_name) - - # Reverse m2m relations - for (obj, model) in meta.get_all_related_m2m_objects_with_model(): - field_name = obj.get_accessor_name() - if field_name in attrs: - m2m_data[field_name] = attrs.pop(field_name) - - # Forward m2m relations - for field in meta.many_to_many + meta.virtual_fields: - if isinstance(field, GenericForeignKey): - continue - if field.name in attrs: - m2m_data[field.name] = attrs.pop(field.name) + assert not (fields and exclude), "Cannot set both 'fields' and 'exclude'." - # Nested forward relations - These need to be marked so we can save - # them before saving the parent model instance. - for field_name in attrs.keys(): - if isinstance(self.fields.get(field_name, None), Serializer): - nested_forward_relations[field_name] = attrs[field_name] + extra_kwargs = self._include_additional_options(extra_kwargs) - # Create an empty instance of the model - if instance is None: - instance = self.opts.model() + # Retrieve metadata about fields & relationships on the model class. + info = model_meta.get_field_info(model) - for key, val in attrs.items(): + # Use the default set of field names if none is supplied explicitly. + if fields is None: + fields = self._get_default_field_names(declared_fields, info) + exclude = getattr(self.Meta, 'exclude', None) + if exclude is not None: + for field_name in exclude: + assert field_name in fields, ( + 'The field in the `exclude` option must be a model field. Got %s.' % + field_name + ) + fields.remove(field_name) + + # Determine the set of model fields, and the fields that they map to. + # We actually only need this to deal with the slightly awkward case + # of supporting `unique_for_date`/`unique_for_month`/`unique_for_year`. + model_field_mapping = {} + for field_name in fields: + if field_name in declared_fields: + field = declared_fields[field_name] + source = field.source or field_name + else: + try: + source = extra_kwargs[field_name]['source'] + except KeyError: + source = field_name + # Model fields will always have a simple source mapping, + # they can't be nested attribute lookups. + if '.' not in source and source != '*': + model_field_mapping[source] = field_name + + # Determine if we need any additional `HiddenField` or extra keyword + # arguments to deal with `unique_for` dates that are required to + # be in the input data in order to validate it. + hidden_fields = {} + unique_constraint_names = set() + + for model_field_name, field_name in model_field_mapping.items(): try: - setattr(instance, key, val) - except ValueError: - self._errors[key] = [self.error_messages['required']] - - # Any relations that cannot be set until we've - # saved the model get hidden away on these - # private attributes, so we can deal with them - # at the point of save. - instance._related_data = related_data - instance._m2m_data = m2m_data - instance._nested_forward_relations = nested_forward_relations - - return instance - - def from_native(self, data, files): - """ - Override the default method to also include model field validation. - """ - instance = super(ModelSerializer, self).from_native(data, files) - if not self._errors: - return self.full_clean(instance) + model_field = model._meta.get_field(model_field_name) + except FieldDoesNotExist: + continue - def save_object(self, obj, **kwargs): - """ - Save the deserialized object. - """ - if getattr(obj, '_nested_forward_relations', None): - # Nested relationships need to be saved before we can save the - # parent instance. - for field_name, sub_object in obj._nested_forward_relations.items(): - if sub_object: - self.save_object(sub_object) - setattr(obj, field_name, sub_object) - - obj.save(**kwargs) - - if getattr(obj, '_m2m_data', None): - for accessor_name, object_list in obj._m2m_data.items(): - setattr(obj, accessor_name, object_list) - del(obj._m2m_data) - - if getattr(obj, '_related_data', None): - related_fields = dict([ - (field.get_accessor_name(), field) - for field, model - in obj._meta.get_all_related_objects_with_model() + # Include each of the `unique_for_*` field names. + unique_constraint_names |= set([ + model_field.unique_for_date, + model_field.unique_for_month, + model_field.unique_for_year ]) - for accessor_name, related in obj._related_data.items(): - if isinstance(related, RelationsList): - # Nested reverse fk relationship - for related_item in related: - fk_field = related_fields[accessor_name].field.name - setattr(related_item, fk_field, obj) - self.save_object(related_item) - - # Delete any removed objects - if related._deleted: - [self.delete_object(item) for item in related._deleted] - - elif isinstance(related, models.Model): - # Nested reverse one-one relationship - fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name - setattr(related, fk_field, obj) - self.save_object(related) - else: - # Reverse FK or reverse one-one - setattr(obj, accessor_name, related) - del(obj._related_data) - - -class HyperlinkedModelSerializerOptions(ModelSerializerOptions): - """ - Options for HyperlinkedModelSerializer - """ - def __init__(self, meta): - super(HyperlinkedModelSerializerOptions, self).__init__(meta) - self.view_name = getattr(meta, 'view_name', None) - self.lookup_field = getattr(meta, 'lookup_field', None) - self.url_field_name = getattr(meta, 'url_field_name', api_settings.URL_FIELD_NAME) + unique_constraint_names -= set([None]) + + # Include each of the `unique_together` field names, + # so long as all the field names are included on the serializer. + for parent_class in [model] + list(model._meta.parents.keys()): + for unique_together_list in parent_class._meta.unique_together: + if set(fields).issuperset(set(unique_together_list)): + unique_constraint_names |= set(unique_together_list) + + # Now we have all the field names that have uniqueness constraints + # applied, we can add the extra 'required=...' or 'default=...' + # arguments that are appropriate to these fields, or add a `HiddenField` for it. + for unique_constraint_name in unique_constraint_names: + # Get the model field that is referred too. + unique_constraint_field = model._meta.get_field(unique_constraint_name) + + if getattr(unique_constraint_field, 'auto_now_add', None): + default = CreateOnlyDefault(timezone.now) + elif getattr(unique_constraint_field, 'auto_now', None): + default = timezone.now + elif unique_constraint_field.has_default(): + default = unique_constraint_field.default + else: + default = empty + + if unique_constraint_name in model_field_mapping: + # The corresponding field is present in the serializer + if unique_constraint_name not in extra_kwargs: + extra_kwargs[unique_constraint_name] = {} + if default is empty: + if 'required' not in extra_kwargs[unique_constraint_name]: + extra_kwargs[unique_constraint_name]['required'] = True + else: + if 'default' not in extra_kwargs[unique_constraint_name]: + extra_kwargs[unique_constraint_name]['default'] = default + elif default is not empty: + # The corresponding field is not present in the, + # serializer. We have a default to use for it, so + # add in a hidden field that populates it. + hidden_fields[unique_constraint_name] = HiddenField(default=default) + + # Now determine the fields that should be included on the serializer. + for field_name in fields: + if field_name in declared_fields: + # Field is explicitly declared on the class, use that. + ret[field_name] = declared_fields[field_name] + continue -class HyperlinkedModelSerializer(ModelSerializer): - """ - A subclass of ModelSerializer that uses hyperlinked relationships, - instead of primary key relationships. - """ - _options_class = HyperlinkedModelSerializerOptions - _default_view_name = '%(model_name)s-detail' - _hyperlink_field_class = HyperlinkedRelatedField - _hyperlink_identify_field_class = HyperlinkedIdentityField + elif field_name in info.fields_and_pk: + # Create regular model fields. + model_field = info.fields_and_pk[field_name] + field_cls = self._field_mapping[model_field] + kwargs = get_field_kwargs(field_name, model_field) + if 'choices' in kwargs: + # Fields with choices get coerced into `ChoiceField` + # instead of using their regular typed field. + field_cls = ChoiceField + if not issubclass(field_cls, ModelField): + # `model_field` is only valid for the fallback case of + # `ModelField`, which is used when no other typed field + # matched to the model field. + kwargs.pop('model_field', None) + if not issubclass(field_cls, CharField) and not issubclass(field_cls, ChoiceField): + # `allow_blank` is only valid for textual fields. + kwargs.pop('allow_blank', None) + + elif field_name in info.relations: + # Create forward and reverse relationships. + relation_info = info.relations[field_name] + if depth: + field_cls = self._get_nested_class(depth, relation_info) + kwargs = get_nested_relation_kwargs(relation_info) + else: + field_cls = self._related_class + kwargs = get_relation_kwargs(field_name, relation_info) + # `view_name` is only valid for hyperlinked relationships. + if not issubclass(field_cls, HyperlinkedRelatedField): + kwargs.pop('view_name', None) + + elif hasattr(model, field_name): + # Create a read only field for model methods and properties. + field_cls = ReadOnlyField + kwargs = {} + + elif field_name == api_settings.URL_FIELD_NAME: + # Create the URL field. + field_cls = HyperlinkedIdentityField + kwargs = get_url_kwargs(model) - def get_default_fields(self): - fields = super(HyperlinkedModelSerializer, self).get_default_fields() + else: + raise ImproperlyConfigured( + 'Field name `%s` is not valid for model `%s`.' % + (field_name, model.__class__.__name__) + ) + + # Check that any fields declared on the class are + # also explicitly included in `Meta.fields`. + missing_fields = set(declared_fields.keys()) - set(fields) + if missing_fields: + missing_field = list(missing_fields)[0] + raise ImproperlyConfigured( + 'Field `%s` has been declared on serializer `%s`, but ' + 'is missing from `Meta.fields`.' % + (missing_field, self.__class__.__name__) + ) + + # Populate any kwargs defined in `Meta.extra_kwargs` + extras = extra_kwargs.get(field_name, {}) + if extras.get('read_only', False): + for attr in [ + 'required', 'default', 'allow_blank', 'allow_null', + 'min_length', 'max_length', 'min_value', 'max_value', + 'validators', 'queryset' + ]: + kwargs.pop(attr, None) + + if extras.get('default') and kwargs.get('required') is False: + kwargs.pop('required') + + kwargs.update(extras) + + # Create the serializer field. + ret[field_name] = field_cls(**kwargs) + + for field_name, field in hidden_fields.items(): + ret[field_name] = field - if self.opts.view_name is None: - self.opts.view_name = self._get_default_view_name(self.opts.model) + return ret - if self.opts.url_field_name not in fields: - url_field = self._hyperlink_identify_field_class( - view_name=self.opts.view_name, - lookup_field=self.opts.lookup_field + def _include_additional_options(self, extra_kwargs): + read_only_fields = getattr(self.Meta, 'read_only_fields', None) + if read_only_fields is not None: + for field_name in read_only_fields: + kwargs = extra_kwargs.get(field_name, {}) + kwargs['read_only'] = True + extra_kwargs[field_name] = kwargs + + # These are all pending deprecation. + write_only_fields = getattr(self.Meta, 'write_only_fields', None) + if write_only_fields is not None: + warnings.warn( + "The `Meta.write_only_fields` option is pending deprecation. " + "Use `Meta.extra_kwargs={<field_name>: {'write_only': True}}` instead.", + PendingDeprecationWarning, + stacklevel=3 ) - ret = self._dict_class() - ret[self.opts.url_field_name] = url_field - ret.update(fields) - fields = ret - - return fields + for field_name in write_only_fields: + kwargs = extra_kwargs.get(field_name, {}) + kwargs['write_only'] = True + extra_kwargs[field_name] = kwargs + + view_name = getattr(self.Meta, 'view_name', None) + if view_name is not None: + warnings.warn( + "The `Meta.view_name` option is pending deprecation. " + "Use `Meta.extra_kwargs={'url': {'view_name': ...}}` instead.", + PendingDeprecationWarning, + stacklevel=3 + ) + kwargs = extra_kwargs.get(api_settings.URL_FIELD_NAME, {}) + kwargs['view_name'] = view_name + extra_kwargs[api_settings.URL_FIELD_NAME] = kwargs + + lookup_field = getattr(self.Meta, 'lookup_field', None) + if lookup_field is not None: + warnings.warn( + "The `Meta.lookup_field` option is pending deprecation. " + "Use `Meta.extra_kwargs={'url': {'lookup_field': ...}}` instead.", + PendingDeprecationWarning, + stacklevel=3 + ) + kwargs = extra_kwargs.get(api_settings.URL_FIELD_NAME, {}) + kwargs['lookup_field'] = lookup_field + extra_kwargs[api_settings.URL_FIELD_NAME] = kwargs + + return extra_kwargs + + def _get_default_field_names(self, declared_fields, model_info): + return ( + [model_info.pk.name] + + list(declared_fields.keys()) + + list(model_info.fields.keys()) + + list(model_info.forward_relations.keys()) + ) - def get_pk_field(self, model_field): - if self.opts.fields and model_field.name in self.opts.fields: - return self.get_field(model_field) + def _get_nested_class(self, nested_depth, relation_info): + class NestedSerializer(ModelSerializer): + class Meta: + model = relation_info.related + depth = nested_depth - 1 - def get_related_field(self, model_field, related_model, to_many): - """ - Creates a default instance of a flat relational field. - """ - # TODO: filter queryset using: - # .using(db).complex_filter(self.rel.limit_choices_to) - kwargs = { - 'queryset': related_model._default_manager, - 'view_name': self._get_default_view_name(related_model), - 'many': to_many - } + return NestedSerializer - if model_field: - kwargs['required'] = not(model_field.null or model_field.blank) and model_field.editable - if model_field.help_text is not None: - kwargs['help_text'] = model_field.help_text - if model_field.verbose_name is not None: - kwargs['label'] = model_field.verbose_name - if self.opts.lookup_field: - kwargs['lookup_field'] = self.opts.lookup_field +class HyperlinkedModelSerializer(ModelSerializer): + """ + A type of `ModelSerializer` that uses hyperlinked relationships instead + of primary key relationships. Specifically: - return self._hyperlink_field_class(**kwargs) + * A 'url' field is included instead of the 'id' field. + * Relationships to other instances are hyperlinks, instead of primary keys. + """ + _related_class = HyperlinkedRelatedField + + def _get_default_field_names(self, declared_fields, model_info): + return ( + [api_settings.URL_FIELD_NAME] + + list(declared_fields.keys()) + + list(model_info.fields.keys()) + + list(model_info.forward_relations.keys()) + ) - def get_identity(self, data): - """ - This hook is required for bulk update. - We need to override the default, to use the url as the identity. - """ - try: - return data.get(self.opts.url_field_name, None) - except AttributeError: - return None + def _get_nested_class(self, nested_depth, relation_info): + class NestedSerializer(HyperlinkedModelSerializer): + class Meta: + model = relation_info.related + depth = nested_depth - 1 - def _get_default_view_name(self, model): - """ - Return the view name to use if 'view_name' is not specified in 'Meta' - """ - model_meta = model._meta - format_kwargs = { - 'app_label': model_meta.app_label, - 'model_name': model_meta.object_name.lower() - } - return self._default_view_name % format_kwargs + return NestedSerializer |
