diff options
Diffstat (limited to 'rest_framework/serializers.py')
| -rw-r--r-- | rest_framework/serializers.py | 1096 |
1 files changed, 291 insertions, 805 deletions
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index be8ad3f2..d121812d 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -10,21 +10,14 @@ python primitives. 2. The process of marshalling between python primitives and request and response content is handled by parsers and renderers. """ -from __future__ import unicode_literals -import copy -import datetime -import inspect -import types -from decimal import Decimal -from django.contrib.contenttypes.generic import GenericForeignKey -from django.core.paginator import Page from django.db import models -from django.forms import widgets from django.utils import six -from django.utils.datastructures import SortedDict -from django.core.exceptions import ObjectDoesNotExist +from collections import namedtuple, OrderedDict +from rest_framework.fields import empty, set_value, Field, SkipField, ValidationError from rest_framework.settings import api_settings - +from rest_framework.utils import html +import copy +import inspect # Note: We do the following so that users of the framework can use this style: # @@ -37,635 +30,339 @@ from rest_framework.relations import * # NOQA from rest_framework.fields import * # NOQA -def _resolve_model(obj): - """ - Resolve supplied `obj` to a Django model class. +FieldResult = namedtuple('FieldResult', ['field', 'value', 'error']) - `obj` must be a Django model class itself, or a string - representation of one. Useful in situtations like GH #1225 where - Django may not have resolved a string-based reference to a model in - another model's foreign key definition. - - String representations should have the format: - 'appname.ModelName' - """ - if isinstance(obj, six.string_types) and len(obj.split('.')) == 2: - app_name, model_name = obj.split('.') - return models.get_model(app_name, model_name) - elif inspect.isclass(obj) and issubclass(obj, models.Model): - return obj - else: - raise ValueError("{0} is not a Django model".format(obj)) - - -def pretty_name(name): - """Converts 'first_name' to 'First name'""" - if not name: - return '' - return name.replace('_', ' ').capitalize() +class BaseSerializer(Field): + def __init__(self, instance=None, data=None, **kwargs): + super(BaseSerializer, self).__init__(**kwargs) + self.instance = instance + self._initial_data = data -class RelationsList(list): - _deleted = [] + def to_native(self, data): + raise NotImplementedError() + def to_primative(self, instance): + raise NotImplementedError() -class NestedValidationError(ValidationError): - """ - The default ValidationError behavior is to stringify each item in the list - if the messages are a list of error messages. + def update(self, instance): + raise NotImplementedError() - In the case of nested serializers, where the parent has many children, - then the child's `serializer.errors` will be a list of dicts. In the case - of a single child, the `serializer.errors` will be a dict. + def create(self): + raise NotImplementedError() - We need to override the default behavior to get properly nested error dicts. - """ + def save(self, extras=None): + if extras is not None: + self._validated_data.update(extras) - def __init__(self, message): - if isinstance(message, dict): - self._messages = [message] + if self.instance is not None: + self.update(self.instance) else: - self._messages = message - - @property - def messages(self): - return self._messages + self.instance = self.create() + return self.instance -class DictWithMetadata(dict): - """ - A dict-like object, that can have additional properties attached. - """ - def __getstate__(self): - """ - Used by pickle (e.g., caching). - Overridden to remove the metadata from the dict, since it shouldn't be - pickled and may in some instances be unpickleable. - """ - return dict(self) - + def is_valid(self): + try: + self._validated_data = self.to_native(self._initial_data) + except ValidationError as exc: + self._validated_data = {} + self._errors = exc.args[0] + return False + self._errors = {} + return True -class SortedDictWithMetadata(SortedDict): - """ - A sorted dict-like object, that can have additional properties attached. - """ - def __getstate__(self): - """ - Used by pickle (e.g., caching). - Overriden to remove the metadata from the dict, since it shouldn't be - pickle and may in some instances be unpickleable. - """ - return SortedDict(self).__dict__ + @property + def data(self): + if not hasattr(self, '_data'): + if self.instance is not None: + self._data = self.to_primative(self.instance) + elif self._initial_data is not None: + self._data = { + field_name: field.get_value(self._initial_data) + for field_name, field in self.fields.items() + } + else: + self._data = self.get_initial() + return self._data + @property + def errors(self): + if not hasattr(self, '_errors'): + msg = 'You must call `.is_valid()` before accessing `.errors`.' + raise AssertionError(msg) + return self._errors -def _is_protected_type(obj): - """ - True if the object is a native datatype that does not need to - be serialized further. - """ - return isinstance(obj, ( - types.NoneType, - int, long, - datetime.datetime, datetime.date, datetime.time, - float, Decimal, - basestring) - ) + @property + def validated_data(self): + if not hasattr(self, '_validated_data'): + msg = 'You must call `.is_valid()` before accessing `.validated_data`.' + raise AssertionError(msg) + return self._validated_data -def _get_declared_fields(bases, attrs): +class SerializerMetaclass(type): """ - Create a list of serializer field instances from the passed in 'attrs', - plus any fields on the base classes (in 'bases'). + This metaclass sets a dictionary named `base_fields` on the class. - Note that all fields from the base classes are used. + Any fields included as attributes on either the class or it's superclasses + will be include in the `base_fields` dictionary. """ - fields = [(field_name, attrs.pop(field_name)) - for field_name, obj in list(six.iteritems(attrs)) - if isinstance(obj, Field)] - fields.sort(key=lambda x: x[1].creation_counter) - # If this class is subclassing another Serializer, add that Serializer's - # fields. Note that we loop over the bases in *reverse*. This is necessary - # in order to maintain the correct order of fields. - for base in bases[::-1]: - if hasattr(base, 'base_fields'): - fields = list(base.base_fields.items()) + fields + @classmethod + def _get_fields(cls, bases, attrs): + fields = [(field_name, attrs.pop(field_name)) + for field_name, obj in list(attrs.items()) + if isinstance(obj, Field)] + fields.sort(key=lambda x: x[1]._creation_counter) - return SortedDict(fields) + # If this class is subclassing another Serializer, add that Serializer's + # fields. Note that we loop over the bases in *reverse*. This is necessary + # in order to maintain the correct order of fields. + for base in bases[::-1]: + if hasattr(base, 'base_fields'): + fields = list(base.base_fields.items()) + fields + return OrderedDict(fields) -class SerializerMetaclass(type): def __new__(cls, name, bases, attrs): - attrs['base_fields'] = _get_declared_fields(bases, attrs) + attrs['base_fields'] = cls._get_fields(bases, attrs) return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs) -class SerializerOptions(object): - """ - Meta class options for Serializer - """ - def __init__(self, meta): - self.depth = getattr(meta, 'depth', 0) - self.fields = getattr(meta, 'fields', ()) - self.exclude = getattr(meta, 'exclude', ()) +@six.add_metaclass(SerializerMetaclass) +class Serializer(BaseSerializer): + def __new__(cls, *args, **kwargs): + many = kwargs.pop('many', False) + if many: + class DynamicListSerializer(ListSerializer): + child = cls() + return DynamicListSerializer(*args, **kwargs) + return super(Serializer, cls).__new__(cls) -class BaseSerializer(WritableField): - """ - This is the Serializer implementation. - We need to implement it as `BaseSerializer` due to metaclass magicks. - """ - class Meta(object): - pass - - _options_class = SerializerOptions - _dict_class = SortedDictWithMetadata - - def __init__(self, instance=None, data=None, files=None, - context=None, partial=False, many=False, - allow_add_remove=False, **kwargs): - super(BaseSerializer, self).__init__(**kwargs) - self.opts = self._options_class(self.Meta) - self.parent = None - self.root = None - self.partial = partial - self.many = many - self.allow_add_remove = allow_add_remove + def __init__(self, *args, **kwargs): + kwargs.pop('context', None) + kwargs.pop('partial', None) + kwargs.pop('many', False) - self.context = context or {} + super(Serializer, self).__init__(*args, **kwargs) - self.init_data = data - self.init_files = files - self.object = instance + # Every new serializer is created with a clone of the field instances. + # This allows users to dynamically modify the fields on a serializer + # instance without affecting every other serializer class. self.fields = self.get_fields() - self._data = None - self._files = None - self._errors = None - - if many and instance is not None and not hasattr(instance, '__iter__'): - raise ValueError('instance should be a queryset or other iterable with many=True') - - if allow_add_remove and not many: - raise ValueError('allow_add_remove should only be used for bulk updates, but you have not set many=True') - - ##### - # Methods to determine which fields to use when (de)serializing objects. - - def get_default_fields(self): - """ - Return the complete set of default fields for the object, as a dict. - """ - return {} - - def get_fields(self): - """ - Returns the complete set of fields for the object as a dict. - - This will be the set of any explicitly declared fields, - plus the set of fields returned by get_default_fields(). - """ - ret = SortedDict() - - # Get the explicitly declared fields - base_fields = copy.deepcopy(self.base_fields) - for key, field in base_fields.items(): - ret[key] = field - - # Add in the default fields - default_fields = self.get_default_fields() - for key, val in default_fields.items(): - if key not in ret: - ret[key] = val - - # If 'fields' is specified, use those fields, in that order. - if self.opts.fields: - assert isinstance(self.opts.fields, (list, tuple)), '`fields` must be a list or tuple' - new = SortedDict() - for key in self.opts.fields: - new[key] = ret[key] - ret = new - - # Remove anything in 'exclude' - if self.opts.exclude: - assert isinstance(self.opts.exclude, (list, tuple)), '`exclude` must be a list or tuple' - for key in self.opts.exclude: - ret.pop(key, None) - - for key, field in ret.items(): - field.initialize(parent=self, field_name=key) - - return ret - - ##### - # Methods to convert or revert from objects <--> primitive representations. - - def get_field_key(self, field_name): - """ - Return the key that should be used for a given field. - """ - return field_name - - def restore_fields(self, data, files): - """ - Core of deserialization, together with `restore_object`. - Converts a dictionary of data into a dictionary of deserialized fields. - """ - reverted_data = {} - - if data is not None and not isinstance(data, dict): - self._errors['non_field_errors'] = ['Invalid data'] - return None - + # Setup all the child fields, to provide them with the current context. for field_name, field in self.fields.items(): - field.initialize(parent=self, field_name=field_name) - try: - field.field_from_native(data, files, field_name, reverted_data) - except ValidationError as err: - self._errors[field_name] = list(err.messages) + field.bind(field_name, self, self) - return reverted_data + def get_fields(self): + return copy.deepcopy(self.base_fields) - def perform_validation(self, attrs): - """ - Run `validate_<fieldname>()` and `validate()` methods on the serializer - """ + def bind(self, field_name, parent, root): + # If the serializer is used as a field then when it becomes bound + # it also needs to bind all its child fields. + super(Serializer, self).bind(field_name, parent, root) for field_name, field in self.fields.items(): - if field_name in self._errors: - continue + field.bind(field_name, self, root) - source = field.source or field_name - if self.partial and source not in attrs: - continue - try: - validate_method = getattr(self, 'validate_%s' % field_name, None) - if validate_method: - attrs = validate_method(attrs, source) - except ValidationError as err: - self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages) - - # If there are already errors, we don't run .validate() because - # field-validation failed and thus `attrs` may not be complete. - # which in turn can cause inconsistent validation errors. - if not self._errors: - try: - attrs = self.validate(attrs) - except ValidationError as err: - if hasattr(err, 'message_dict'): - for field_name, error_messages in err.message_dict.items(): - self._errors[field_name] = self._errors.get(field_name, []) + list(error_messages) - elif hasattr(err, 'messages'): - self._errors['non_field_errors'] = err.messages - - return attrs + def get_initial(self): + return { + field.field_name: field.get_initial() + for field in self.fields.values() + } - def validate(self, attrs): - """ - Stub method, to be overridden in Serializer subclasses - """ - return attrs + def get_value(self, dictionary): + # We override the default field access in order to support + # nested HTML forms. + if html.is_html_input(dictionary): + return html.parse_html_dict(dictionary, prefix=self.field_name) + return dictionary.get(self.field_name, empty) - def restore_object(self, attrs, instance=None): + def to_native(self, data): """ - Deserialize a dictionary of attributes into an object instance. - You should override this method to control how deserialized objects - are instantiated. + Dict of native values <- Dict of primitive datatypes. """ - if instance is not None: - instance.update(attrs) - return instance - return attrs + ret = {} + errors = {} + fields = [field for field in self.fields.values() if not field.read_only] - def to_native(self, obj): - """ - Serialize objects -> primitives. - """ - ret = self._dict_class() - ret.fields = self._dict_class() + for field in fields: + primitive_value = field.get_value(data) + try: + validated_value = field.validate(primitive_value) + except ValidationError as exc: + errors[field.field_name] = str(exc) + except SkipField: + pass + else: + set_value(ret, field.source_attrs, validated_value) - for field_name, field in self.fields.items(): - if field.read_only and obj is None: - continue - field.initialize(parent=self, field_name=field_name) - key = self.get_field_key(field_name) - value = field.field_to_native(obj, field_name) - method = getattr(self, 'transform_%s' % field_name, None) - if callable(method): - value = method(obj, value) - if not getattr(field, 'write_only', False): - ret[key] = value - ret.fields[key] = self.augment_field(field, field_name, key, value) + if errors: + raise ValidationError(errors) return ret - def from_native(self, data, files=None): - """ - Deserialize primitives -> objects. - """ - self._errors = {} - - if data is not None or files is not None: - attrs = self.restore_fields(data, files) - if attrs is not None: - attrs = self.perform_validation(attrs) - else: - self._errors['non_field_errors'] = ['No input provided'] - - if not self._errors: - return self.restore_object(attrs, instance=getattr(self, 'object', None)) - - def augment_field(self, field, field_name, key, value): - # This horrible stuff is to manage serializers rendering to HTML - field._errors = self._errors.get(key) if self._errors else None - field._name = field_name - field._value = self.init_data.get(key) if self._errors and self.init_data else value - if not field.label: - field.label = pretty_name(key) - return field - - def field_to_native(self, obj, field_name): + def to_primative(self, instance): """ - Override default so that the serializer can be used as a nested field - across relationships. + Object instance -> Dict of primitive datatypes. """ - if self.write_only: - return None + ret = OrderedDict() + fields = [field for field in self.fields.values() if not field.write_only] - if self.source == '*': - return self.to_native(obj) + for field in fields: + native_value = field.get_attribute(instance) + ret[field.field_name] = field.to_primative(native_value) - # Get the raw field value - try: - source = self.source or field_name - value = obj - - for component in source.split('.'): - if value is None: - break - value = get_component(value, component) - except ObjectDoesNotExist: - return None + return ret - if is_simple_callable(getattr(value, 'all', None)): - return [self.to_native(item) for item in value.all()] + def __iter__(self): + errors = self.errors if hasattr(self, '_errors') else {} + for field in self.fields.values(): + value = self.data.get(field.field_name) if self.data else None + error = errors.get(field.field_name) + yield FieldResult(field, value, error) - if value is None: - return None - if self.many: - return [self.to_native(item) for item in value] - return self.to_native(value) +class ListSerializer(BaseSerializer): + child = None + initial = [] - def field_from_native(self, data, files, field_name, into): - """ - Override default so that the serializer can be used as a writable - nested field across relationships. - """ - if self.read_only: - return + def __init__(self, *args, **kwargs): + self.child = kwargs.pop('child', copy.deepcopy(self.child)) + assert self.child is not None, '`child` is a required argument.' - try: - value = data[field_name] - except KeyError: - if self.default is not None and not self.partial: - # Note: partial updates shouldn't set defaults - value = copy.deepcopy(self.default) - else: - if self.required: - raise ValidationError(self.error_messages['required']) - return - - if self.source == '*': - if value: - reverted_data = self.restore_fields(value, {}) - if not self._errors: - into.update(reverted_data) - else: - if value in (None, ''): - into[(self.source or field_name)] = None - else: - # Set the serializer object if it exists - obj = get_component(self.parent.object, self.source or field_name) if self.parent.object else None - - # If we have a model manager or similar object then we need - # to iterate through each instance. - if ( - self.many and - not hasattr(obj, '__iter__') and - is_simple_callable(getattr(obj, 'all', None)) - ): - obj = obj.all() - - kwargs = { - 'instance': obj, - 'data': value, - 'context': self.context, - 'partial': self.partial, - 'many': self.many, - 'allow_add_remove': self.allow_add_remove - } - serializer = self.__class__(**kwargs) + kwargs.pop('context', None) + kwargs.pop('partial', None) - if serializer.is_valid(): - into[self.source or field_name] = serializer.object - else: - # Propagate errors up to our parent - raise NestedValidationError(serializer.errors) + super(ListSerializer, self).__init__(*args, **kwargs) + self.child.bind('', self, self) - def get_identity(self, data): - """ - This hook is required for bulk update. - It is used to determine the canonical identity of a given object. + def bind(self, field_name, parent, root): + # If the list is used as a field then it needs to provide + # the current context to the child serializer. + super(ListSerializer, self).bind(field_name, parent, root) + self.child.bind(field_name, self, root) - Note that the data has not been validated at this point, so we need - to make sure that we catch any cases of incorrect datatypes being - passed to this method. - """ - try: - return data.get('id', None) - except AttributeError: - return None + def get_value(self, dictionary): + # We override the default field access in order to support + # lists in HTML forms. + if is_html_input(dictionary): + return html.parse_html_list(dictionary, prefix=self.field_name) + return dictionary.get(self.field_name, empty) - @property - def errors(self): + def to_native(self, data): """ - Run deserialization and return error data, - setting self.object if no errors occurred. + List of dicts of native values <- List of dicts of primitive datatypes. """ - if self._errors is None: - data, files = self.init_data, self.init_files + if html.is_html_input(data): + data = html.parse_html_list(data) - if self.many is not None: - many = self.many - else: - many = hasattr(data, '__iter__') and not isinstance(data, (Page, dict, six.text_type)) - if many: - warnings.warn('Implicit list/queryset serialization is deprecated. ' - 'Use the `many=True` flag when instantiating the serializer.', - DeprecationWarning, stacklevel=3) - - if many: - ret = RelationsList() - errors = [] - update = self.object is not None - - if update: - # If this is a bulk update we need to map all the objects - # to a canonical identity so we can determine which - # individual object is being updated for each item in the - # incoming data - objects = self.object - identities = [self.get_identity(self.to_native(obj)) for obj in objects] - identity_to_objects = dict(zip(identities, objects)) - - if hasattr(data, '__iter__') and not isinstance(data, (dict, six.text_type)): - for item in data: - if update: - # Determine which object we're updating - identity = self.get_identity(item) - self.object = identity_to_objects.pop(identity, None) - if self.object is None and not self.allow_add_remove: - ret.append(None) - errors.append({'non_field_errors': ['Cannot create a new item, only existing items may be updated.']}) - continue - - ret.append(self.from_native(item, None)) - errors.append(self._errors) - - if update and self.allow_add_remove: - ret._deleted = identity_to_objects.values() - - self._errors = any(errors) and errors or [] - else: - self._errors = {'non_field_errors': ['Expected a list of items.']} - else: - ret = self.from_native(data, files) - - if not self._errors: - self.object = ret - - return self._errors - - def is_valid(self): - return not self.errors + return [self.child.validate(item) for item in data] - @property - def data(self): + def to_primative(self, data): """ - Returns the serialized data on the serializer. + List of object instances -> List of dicts of primitive datatypes. """ - if self._data is None: - obj = self.object + return [self.child.to_primative(item) for item in data] - if self.many is not None: - many = self.many - else: - many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict)) - if many: - warnings.warn('Implicit list/queryset serialization is deprecated. ' - 'Use the `many=True` flag when instantiating the serializer.', - DeprecationWarning, stacklevel=2) - - if many: - self._data = [self.to_native(item) for item in obj] - else: - self._data = self.to_native(obj) + def create(self, attrs_list): + return [self.child.create(attrs) for attrs in attrs_list] - return self._data + def save(self): + if self.instance is not None: + self.update(self.instance, self.validated_data) + self.instance = self.create(self.validated_data) + return self.instance - def save_object(self, obj, **kwargs): - obj.save(**kwargs) - def delete_object(self, obj): - obj.delete() - - def save(self, **kwargs): - """ - Save the deserialized object and return it. - """ - # Clear cached _data, which may be invalidated by `save()` - self._data = None - - if isinstance(self.object, list): - [self.save_object(item, **kwargs) for item in self.object] - - if self.object._deleted: - [self.delete_object(item) for item in self.object._deleted] - else: - self.save_object(self.object, **kwargs) - - return self.object - - def metadata(self): - """ - Return a dictionary of metadata about the fields on the serializer. - Useful for things like responding to OPTIONS requests, or generating - API schemas for auto-documentation. - """ - return SortedDict( - [ - (field_name, field.metadata()) - for field_name, field in six.iteritems(self.fields) - ] - ) +def _resolve_model(obj): + """ + Resolve supplied `obj` to a Django model class. + `obj` must be a Django model class itself, or a string + representation of one. Useful in situtations like GH #1225 where + Django may not have resolved a string-based reference to a model in + another model's foreign key definition. -class Serializer(six.with_metaclass(SerializerMetaclass, BaseSerializer)): - pass + String representations should have the format: + 'appname.ModelName' + """ + if isinstance(obj, six.string_types) and len(obj.split('.')) == 2: + app_name, model_name = obj.split('.') + return models.get_model(app_name, model_name) + elif inspect.isclass(obj) and issubclass(obj, models.Model): + return obj + else: + raise ValueError("{0} is not a Django model".format(obj)) -class ModelSerializerOptions(SerializerOptions): +class ModelSerializerOptions(object): """ Meta class options for ModelSerializer """ def __init__(self, meta): - super(ModelSerializerOptions, self).__init__(meta) - self.model = getattr(meta, 'model', None) - self.read_only_fields = getattr(meta, 'read_only_fields', ()) - self.write_only_fields = getattr(meta, 'write_only_fields', ()) + self.model = getattr(meta, 'model') + self.fields = getattr(meta, 'fields', ()) + self.depth = getattr(meta, 'depth', 0) class ModelSerializer(Serializer): - """ - A serializer that deals with model instances and querysets. - """ - _options_class = ModelSerializerOptions - field_mapping = { models.AutoField: IntegerField, - models.FloatField: FloatField, + # models.FloatField: FloatField, models.IntegerField: IntegerField, models.PositiveIntegerField: IntegerField, models.SmallIntegerField: IntegerField, models.PositiveSmallIntegerField: IntegerField, - models.DateTimeField: DateTimeField, - models.DateField: DateField, - models.TimeField: TimeField, - models.DecimalField: DecimalField, - models.EmailField: EmailField, + # models.DateTimeField: DateTimeField, + # models.DateField: DateField, + # models.TimeField: TimeField, + # models.DecimalField: DecimalField, + # models.EmailField: EmailField, models.CharField: CharField, - models.URLField: URLField, - models.SlugField: SlugField, + # models.URLField: URLField, + # models.SlugField: SlugField, models.TextField: CharField, models.CommaSeparatedIntegerField: CharField, models.BooleanField: BooleanField, models.NullBooleanField: BooleanField, - models.FileField: FileField, - models.ImageField: ImageField, + # models.FileField: FileField, + # models.ImageField: ImageField, } + _options_class = ModelSerializerOptions + + def __init__(self, *args, **kwargs): + self.opts = self._options_class(self.Meta) + super(ModelSerializer, self).__init__(*args, **kwargs) + + def get_fields(self): + # Get the explicitly declared fields. + fields = copy.deepcopy(self.base_fields) + + # Add in the default fields. + for key, val in self.get_default_fields().items(): + if key not in fields: + fields[key] = val + + # If `fields` is set on the `Meta` class, + # then use only those fields, and in that order. + if self.opts.fields: + fields = OrderedDict([ + (key, fields[key]) for key in self.opts.fields + ]) + + return fields + def get_default_fields(self): """ Return all the fields that should be serialized for the model. """ - cls = self.opts.model - assert cls is not None, ( - "Serializer class '%s' is missing 'model' Meta option" % - self.__class__.__name__ - ) opts = cls._meta.concrete_model._meta - ret = SortedDict() + ret = OrderedDict() nested = bool(self.opts.depth) # Deal with adding the primary key field @@ -694,29 +391,9 @@ class ModelSerializer(Serializer): has_through_model = True if model_field.rel and nested: - if len(inspect.getargspec(self.get_nested_field).args) == 2: - warnings.warn( - 'The `get_nested_field(model_field)` call signature ' - 'is deprecated. ' - 'Use `get_nested_field(model_field, related_model, ' - 'to_many) instead', - DeprecationWarning - ) - field = self.get_nested_field(model_field) - else: - field = self.get_nested_field(model_field, related_model, to_many) + field = self.get_nested_field(model_field, related_model, to_many) elif model_field.rel: - if len(inspect.getargspec(self.get_nested_field).args) == 3: - warnings.warn( - 'The `get_related_field(model_field, to_many)` call ' - 'signature is deprecated. ' - 'Use `get_related_field(model_field, related_model, ' - 'to_many) instead', - DeprecationWarning - ) - field = self.get_related_field(model_field, to_many=to_many) - else: - field = self.get_related_field(model_field, related_model, to_many) + field = self.get_related_field(model_field, related_model, to_many) else: field = self.get_field(model_field) @@ -763,38 +440,6 @@ class ModelSerializer(Serializer): ret[accessor_name] = field - # Ensure that 'read_only_fields' is an iterable - assert isinstance(self.opts.read_only_fields, (list, tuple)), '`read_only_fields` must be a list or tuple' - - # Add the `read_only` flag to any fields that have been specified - # in the `read_only_fields` option - for field_name in self.opts.read_only_fields: - assert field_name not in self.base_fields.keys(), ( - "field '%s' on serializer '%s' specified in " - "`read_only_fields`, but also added " - "as an explicit field. Remove it from `read_only_fields`." % - (field_name, self.__class__.__name__)) - assert field_name in ret, ( - "Non-existant field '%s' specified in `read_only_fields` " - "on serializer '%s'." % - (field_name, self.__class__.__name__)) - ret[field_name].read_only = True - - # Ensure that 'write_only_fields' is an iterable - assert isinstance(self.opts.write_only_fields, (list, tuple)), '`write_only_fields` must be a list or tuple' - - for field_name in self.opts.write_only_fields: - assert field_name not in self.base_fields.keys(), ( - "field '%s' on serializer '%s' specified in " - "`write_only_fields`, but also added " - "as an explicit field. Remove it from `write_only_fields`." % - (field_name, self.__class__.__name__)) - assert field_name in ret, ( - "Non-existant field '%s' specified in `write_only_fields` " - "on serializer '%s'." % - (field_name, self.__class__.__name__)) - ret[field_name].write_only = True - return ret def get_pk_field(self, model_field): @@ -825,28 +470,24 @@ class ModelSerializer(Serializer): # TODO: filter queryset using: # .using(db).complex_filter(self.rel.limit_choices_to) - kwargs = { - 'queryset': related_model._default_manager, - 'many': to_many - } + kwargs = {} + # 'queryset': related_model._default_manager, + # 'many': to_many + # } if model_field: kwargs['required'] = not(model_field.null or model_field.blank) - if model_field.help_text is not None: - kwargs['help_text'] = model_field.help_text + # if model_field.help_text is not None: + # kwargs['help_text'] = model_field.help_text if model_field.verbose_name is not None: kwargs['label'] = model_field.verbose_name - if not model_field.editable: kwargs['read_only'] = True - if model_field.verbose_name is not None: kwargs['label'] = model_field.verbose_name - if model_field.help_text is not None: - kwargs['help_text'] = model_field.help_text - - return PrimaryKeyRelatedField(**kwargs) + return IntegerField(**kwargs) + # TODO: return PrimaryKeyRelatedField(**kwargs) def get_field(self, model_field): """ @@ -869,8 +510,8 @@ class ModelSerializer(Serializer): if model_field.verbose_name is not None: kwargs['label'] = model_field.verbose_name - if model_field.help_text is not None: - kwargs['help_text'] = model_field.help_text + # if model_field.help_text is not None: + # kwargs['help_text'] = model_field.help_text # TODO: TypedChoiceField? if model_field.flatchoices: # This ModelField contains choices @@ -880,7 +521,7 @@ class ModelSerializer(Serializer): return ChoiceField(**kwargs) # put this below the ChoiceField because min_value isn't a valid initializer - if issubclass(model_field.__class__, models.PositiveIntegerField) or\ + if issubclass(model_field.__class__, models.PositiveIntegerField) or \ issubclass(model_field.__class__, models.PositiveSmallIntegerField): kwargs['min_value'] = 0 @@ -888,170 +529,27 @@ class ModelSerializer(Serializer): issubclass(model_field.__class__, (models.CharField, models.TextField)): kwargs['allow_none'] = True - attribute_dict = { - models.CharField: ['max_length'], - models.CommaSeparatedIntegerField: ['max_length'], - models.DecimalField: ['max_digits', 'decimal_places'], - models.EmailField: ['max_length'], - models.FileField: ['max_length'], - models.ImageField: ['max_length'], - models.SlugField: ['max_length'], - models.URLField: ['max_length'], - } - - if model_field.__class__ in attribute_dict: - attributes = attribute_dict[model_field.__class__] - for attribute in attributes: - kwargs.update({attribute: getattr(model_field, attribute)}) + # attribute_dict = { + # models.CharField: ['max_length'], + # models.CommaSeparatedIntegerField: ['max_length'], + # models.DecimalField: ['max_digits', 'decimal_places'], + # models.EmailField: ['max_length'], + # models.FileField: ['max_length'], + # models.ImageField: ['max_length'], + # models.SlugField: ['max_length'], + # models.URLField: ['max_length'], + # } + + # if model_field.__class__ in attribute_dict: + # attributes = attribute_dict[model_field.__class__] + # for attribute in attributes: + # kwargs.update({attribute: getattr(model_field, attribute)}) try: return self.field_mapping[model_field.__class__](**kwargs) except KeyError: - return ModelField(model_field=model_field, **kwargs) - - def get_validation_exclusions(self, instance=None): - """ - Return a list of field names to exclude from model validation. - """ - cls = self.opts.model - opts = cls._meta.concrete_model._meta - exclusions = [field.name for field in opts.fields + opts.many_to_many] - - for field_name, field in self.fields.items(): - field_name = field.source or field_name - if ( - field_name in exclusions - and not field.read_only - and (field.required or hasattr(instance, field_name)) - and not isinstance(field, Serializer) - ): - exclusions.remove(field_name) - return exclusions - - def full_clean(self, instance): - """ - Perform Django's full_clean, and populate the `errors` dictionary - if any validation errors occur. - - Note that we don't perform this inside the `.restore_object()` method, - so that subclasses can override `.restore_object()`, and still get - the full_clean validation checking. - """ - try: - instance.full_clean(exclude=self.get_validation_exclusions(instance)) - except ValidationError as err: - self._errors = err.message_dict - return None - return instance - - def restore_object(self, attrs, instance=None): - """ - Restore the model instance. - """ - m2m_data = {} - related_data = {} - nested_forward_relations = {} - meta = self.opts.model._meta - - # Reverse fk or one-to-one relations - for (obj, model) in meta.get_all_related_objects_with_model(): - field_name = obj.get_accessor_name() - if field_name in attrs: - related_data[field_name] = attrs.pop(field_name) - - # Reverse m2m relations - for (obj, model) in meta.get_all_related_m2m_objects_with_model(): - field_name = obj.get_accessor_name() - if field_name in attrs: - m2m_data[field_name] = attrs.pop(field_name) - - # Forward m2m relations - for field in meta.many_to_many + meta.virtual_fields: - if isinstance(field, GenericForeignKey): - continue - if field.name in attrs: - m2m_data[field.name] = attrs.pop(field.name) - - # Nested forward relations - These need to be marked so we can save - # them before saving the parent model instance. - for field_name in attrs.keys(): - if isinstance(self.fields.get(field_name, None), Serializer): - nested_forward_relations[field_name] = attrs[field_name] - - # Create an empty instance of the model - if instance is None: - instance = self.opts.model() - - for key, val in attrs.items(): - try: - setattr(instance, key, val) - except ValueError: - self._errors[key] = [self.error_messages['required']] - - # Any relations that cannot be set until we've - # saved the model get hidden away on these - # private attributes, so we can deal with them - # at the point of save. - instance._related_data = related_data - instance._m2m_data = m2m_data - instance._nested_forward_relations = nested_forward_relations - - return instance - - def from_native(self, data, files): - """ - Override the default method to also include model field validation. - """ - instance = super(ModelSerializer, self).from_native(data, files) - if not self._errors: - return self.full_clean(instance) - - def save_object(self, obj, **kwargs): - """ - Save the deserialized object. - """ - if getattr(obj, '_nested_forward_relations', None): - # Nested relationships need to be saved before we can save the - # parent instance. - for field_name, sub_object in obj._nested_forward_relations.items(): - if sub_object: - self.save_object(sub_object) - setattr(obj, field_name, sub_object) - - obj.save(**kwargs) - - if getattr(obj, '_m2m_data', None): - for accessor_name, object_list in obj._m2m_data.items(): - setattr(obj, accessor_name, object_list) - del(obj._m2m_data) - - if getattr(obj, '_related_data', None): - related_fields = dict([ - (field.get_accessor_name(), field) - for field, model - in obj._meta.get_all_related_objects_with_model() - ]) - for accessor_name, related in obj._related_data.items(): - if isinstance(related, RelationsList): - # Nested reverse fk relationship - for related_item in related: - fk_field = related_fields[accessor_name].field.name - setattr(related_item, fk_field, obj) - self.save_object(related_item) - - # Delete any removed objects - if related._deleted: - [self.delete_object(item) for item in related._deleted] - - elif isinstance(related, models.Model): - # Nested reverse one-one relationship - fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name - setattr(related, fk_field, obj) - self.save_object(related) - else: - # Reverse FK or reverse one-one - setattr(obj, accessor_name, related) - del(obj._related_data) + # TODO: Change this to `return ModelField(model_field=model_field, **kwargs)` + return CharField(**kwargs) class HyperlinkedModelSerializerOptions(ModelSerializerOptions): @@ -1066,14 +564,10 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions): class HyperlinkedModelSerializer(ModelSerializer): - """ - A subclass of ModelSerializer that uses hyperlinked relationships, - instead of primary key relationships. - """ _options_class = HyperlinkedModelSerializerOptions _default_view_name = '%(model_name)s-detail' - _hyperlink_field_class = HyperlinkedRelatedField - _hyperlink_identify_field_class = HyperlinkedIdentityField + #_hyperlink_field_class = HyperlinkedRelatedField + #_hyperlink_identify_field_class = HyperlinkedIdentityField def get_default_fields(self): fields = super(HyperlinkedModelSerializer, self).get_default_fields() @@ -1081,15 +575,15 @@ class HyperlinkedModelSerializer(ModelSerializer): if self.opts.view_name is None: self.opts.view_name = self._get_default_view_name(self.opts.model) - if self.opts.url_field_name not in fields: - url_field = self._hyperlink_identify_field_class( - view_name=self.opts.view_name, - lookup_field=self.opts.lookup_field - ) - ret = self._dict_class() - ret[self.opts.url_field_name] = url_field - ret.update(fields) - fields = ret + # if self.opts.url_field_name not in fields: + # url_field = self._hyperlink_identify_field_class( + # view_name=self.opts.view_name, + # lookup_field=self.opts.lookup_field + # ) + # ret = self._dict_class() + # ret[self.opts.url_field_name] = url_field + # ret.update(fields) + # fields = ret return fields @@ -1103,33 +597,25 @@ class HyperlinkedModelSerializer(ModelSerializer): """ # TODO: filter queryset using: # .using(db).complex_filter(self.rel.limit_choices_to) - kwargs = { - 'queryset': related_model._default_manager, - 'view_name': self._get_default_view_name(related_model), - 'many': to_many - } + # kwargs = { + # 'queryset': related_model._default_manager, + # 'view_name': self._get_default_view_name(related_model), + # 'many': to_many + # } + kwargs = {} if model_field: kwargs['required'] = not(model_field.null or model_field.blank) - if model_field.help_text is not None: - kwargs['help_text'] = model_field.help_text + # if model_field.help_text is not None: + # kwargs['help_text'] = model_field.help_text if model_field.verbose_name is not None: kwargs['label'] = model_field.verbose_name - if self.opts.lookup_field: - kwargs['lookup_field'] = self.opts.lookup_field - - return self._hyperlink_field_class(**kwargs) + return IntegerField(**kwargs) + # if self.opts.lookup_field: + # kwargs['lookup_field'] = self.opts.lookup_field - def get_identity(self, data): - """ - This hook is required for bulk update. - We need to override the default, to use the url as the identity. - """ - try: - return data.get(self.opts.url_field_name, None) - except AttributeError: - return None + # return self._hyperlink_field_class(**kwargs) def _get_default_view_name(self, model): """ |
