diff options
Diffstat (limited to 'rest_framework/serializers.py')
-rw-r--r-- | rest_framework/serializers.py | 282 |
1 files changed, 192 insertions, 90 deletions
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 8ee9a0ec..27458f96 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -3,8 +3,18 @@ import datetime import types from decimal import Decimal from django.db import models +from django.forms import widgets from django.utils.datastructures import SortedDict from rest_framework.compat import get_concrete_model + +# Note: We do the following so that users of the framework can use this style: +# +# example_field = serializers.CharField(...) +# +# This helps keep the seperation between model fields, form fields, and +# serializer fields more explicit. + +from rest_framework.relations import * from rest_framework.fields import * @@ -12,7 +22,16 @@ class DictWithMetadata(dict): """ A dict-like object, that can have additional properties attached. """ - pass + def __getstate__(self): + """ + Used by pickle (e.g., caching). + Overriden to remove metadata from the dict, since it shouldn't be pickled + and may in some instances be unpickleable. + """ + # return an instance of the first dict in MRO that isn't a DictWithMetadata + for base in self.__class__.__mro__: + if not isinstance(base, DictWithMetadata) and isinstance(base, dict): + return base(self) class SortedDictWithMetadata(SortedDict, DictWithMetadata): @@ -22,10 +41,6 @@ class SortedDictWithMetadata(SortedDict, DictWithMetadata): pass -class RecursionOccured(BaseException): - pass - - def _is_protected_type(obj): """ True if the object is a native datatype that does not need to @@ -33,10 +48,10 @@ def _is_protected_type(obj): """ return isinstance(obj, ( types.NoneType, - int, long, - datetime.datetime, datetime.date, datetime.time, - float, Decimal, - basestring) + int, long, + datetime.datetime, datetime.date, datetime.time, + float, Decimal, + basestring) ) @@ -54,7 +69,7 @@ def _get_declared_fields(bases, attrs): # If this class is subclassing another Serializer, add that Serializer's # fields. Note that we loop over the bases in *reverse*. This is necessary - # in order to the correct order of fields. + # in order to maintain the correct order of fields. for base in bases[::-1]: if hasattr(base, 'base_fields'): fields = base.base_fields.items() + fields @@ -73,7 +88,7 @@ class SerializerOptions(object): Meta class options for Serializer """ def __init__(self, meta): - self.nested = getattr(meta, 'nested', False) + self.depth = getattr(meta, 'depth', 0) self.fields = getattr(meta, 'fields', ()) self.exclude = getattr(meta, 'exclude', ()) @@ -83,51 +98,53 @@ class BaseSerializer(Field): pass _options_class = SerializerOptions - _dict_class = SortedDictWithMetadata # Set to unsorted dict for backwards compatability with unsorted implementations. + _dict_class = SortedDictWithMetadata # Set to unsorted dict for backwards compatibility with unsorted implementations. - def __init__(self, data=None, instance=None, context=None, **kwargs): + def __init__(self, instance=None, data=None, files=None, + context=None, partial=False, **kwargs): super(BaseSerializer, self).__init__(**kwargs) - self.fields = copy.deepcopy(self.base_fields) self.opts = self._options_class(self.Meta) self.parent = None self.root = None + self.partial = partial - self.stack = [] self.context = context or {} self.init_data = data + self.init_files = files self.object = instance + self.fields = self.get_fields() self._data = None + self._files = None self._errors = None ##### # Methods to determine which fields to use when (de)serializing objects. - def default_fields(self, serialize, obj=None, data=None, nested=False): + def get_default_fields(self): """ Return the complete set of default fields for the object, as a dict. """ return {} - def get_fields(self, serialize, obj=None, data=None, nested=False): + def get_fields(self): """ Returns the complete set of fields for the object as a dict. This will be the set of any explicitly declared fields, - plus the set of fields returned by default_fields(). + plus the set of fields returned by get_default_fields(). """ ret = SortedDict() # Get the explicitly declared fields - for key, field in self.fields.items(): + base_fields = copy.deepcopy(self.base_fields) + for key, field in base_fields.items(): ret[key] = field - # Set up the field - field.initialize(parent=self) # Add in the default fields - fields = self.default_fields(serialize, obj, data, nested) - for key, val in fields.items(): + default_fields = self.get_default_fields() + for key, val in default_fields.items(): if key not in ret: ret[key] = val @@ -143,25 +160,25 @@ class BaseSerializer(Field): for key in self.opts.exclude: ret.pop(key, None) + for key, field in ret.items(): + field.initialize(parent=self, field_name=key) + return ret ##### # Field methods - used when the serializer class is itself used as a field. - def initialize(self, parent): + def initialize(self, parent, field_name): """ Same behaviour as usual Field, except that we need to keep track - of state so that we can deal with handling maximum depth and recursion. + of state so that we can deal with handling maximum depth. """ - super(BaseSerializer, self).initialize(parent) - self.stack = parent.stack[:] - if parent.opts.nested and not isinstance(parent.opts.nested, bool): - self.opts.nested = parent.opts.nested - 1 - else: - self.opts.nested = parent.opts.nested + super(BaseSerializer, self).initialize(parent, field_name) + if parent.opts.depth: + self.opts.depth = parent.opts.depth - 1 ##### - # Methods to convert or revert from objects <--> primative representations. + # Methods to convert or revert from objects <--> primitive representations. def get_field_key(self, field_name): """ @@ -174,35 +191,32 @@ class BaseSerializer(Field): Core of serialization. Convert an object into a dictionary of serialized field values. """ - if obj in self.stack and not self.source == '*': - raise RecursionOccured() - self.stack.append(obj) - ret = self._dict_class() ret.fields = {} - fields = self.get_fields(serialize=True, obj=obj, nested=self.opts.nested) - for field_name, field in fields.items(): + for field_name, field in self.fields.items(): + field.initialize(parent=self, field_name=field_name) key = self.get_field_key(field_name) - try: - value = field.field_to_native(obj, field_name) - except RecursionOccured: - field = self.get_fields(serialize=True, obj=obj, nested=False)[field_name] - value = field.field_to_native(obj, field_name) + value = field.field_to_native(obj, field_name) ret[key] = value ret.fields[key] = field return ret - def restore_fields(self, data): + def restore_fields(self, data, files): """ Core of deserialization, together with `restore_object`. Converts a dictionary of data into a dictionary of deserialized fields. """ - fields = self.get_fields(serialize=False, data=data, nested=self.opts.nested) reverted_data = {} - for field_name, field in fields.items(): + + if data is not None and not isinstance(data, dict): + self._errors['non_field_errors'] = [u'Invalid data'] + return None + + for field_name, field in self.fields.items(): + field.initialize(parent=self, field_name=field_name) try: - field.field_from_native(data, field_name, reverted_data) + field.field_from_native(data, files, field_name, reverted_data) except ValidationError as err: self._errors[field_name] = list(err.messages) @@ -212,9 +226,7 @@ class BaseSerializer(Field): """ Run `validate_<fieldname>()` and `validate()` methods on the serializer """ - fields = self.get_fields(serialize=False, data=attrs, nested=self.opts.nested) - - for field_name, field in fields.items(): + for field_name, field in self.fields.items(): try: validate_method = getattr(self, 'validate_%s' % field_name, None) if validate_method: @@ -223,10 +235,18 @@ class BaseSerializer(Field): except ValidationError as err: self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages) - try: - attrs = self.validate(attrs) - except ValidationError as err: - self._errors['non_field_errors'] = err.messages + # If there are already errors, we don't run .validate() because + # field-validation failed and thus `attrs` may not be complete. + # which in turn can cause inconsistent validation errors. + if not self._errors: + try: + attrs = self.validate(attrs) + except ValidationError as err: + if hasattr(err, 'message_dict'): + for field_name, error_messages in err.message_dict.items(): + self._errors[field_name] = self._errors.get(field_name, []) + list(error_messages) + elif hasattr(err, 'messages'): + self._errors['non_field_errors'] = err.messages return attrs @@ -249,26 +269,23 @@ class BaseSerializer(Field): def to_native(self, obj): """ - Serialize objects -> primatives. + Serialize objects -> primitives. """ - if isinstance(obj, dict): - return dict([(key, self.to_native(val)) - for (key, val) in obj.items()]) - elif hasattr(obj, '__iter__'): - return [self.to_native(item) for item in obj] + if hasattr(obj, '__iter__'): + return [self.convert_object(item) for item in obj] return self.convert_object(obj) - def from_native(self, data): + def from_native(self, data, files): """ - Deserialize primatives -> objects. + Deserialize primitives -> objects. """ if hasattr(data, '__iter__') and not isinstance(data, dict): # TODO: error data when deserializing lists - return (self.from_native(item) for item in data) + return [self.from_native(item, None) for item in data] self._errors = {} - if data is not None: - attrs = self.restore_fields(data) + if data is not None or files is not None: + attrs = self.restore_fields(data, files) attrs = self.perform_validation(attrs) else: self._errors['non_field_errors'] = ['No input provided'] @@ -281,22 +298,36 @@ class BaseSerializer(Field): Override default so that we can apply ModelSerializer as a nested field to relationships. """ - obj = getattr(obj, self.source or field_name) + try: + if self.source: + for component in self.source.split('.'): + obj = getattr(obj, component) + if is_simple_callable(obj): + obj = obj() + else: + obj = getattr(obj, field_name) + if is_simple_callable(obj): + obj = obj() + except ObjectDoesNotExist: + return None # If the object has an "all" method, assume it's a relationship if is_simple_callable(getattr(obj, 'all', None)): return [self.to_native(item) for item in obj.all()] + if obj is None: + return None + return self.to_native(obj) @property def errors(self): """ Run deserialization and return error data, - setting self.object if no errors occured. + setting self.object if no errors occurred. """ if self._errors is None: - obj = self.from_native(self.init_data) + obj = self.from_native(self.init_data, self.init_files) if not self._errors: self.object = obj return self._errors @@ -329,6 +360,7 @@ class ModelSerializerOptions(SerializerOptions): def __init__(self, meta): super(ModelSerializerOptions, self).__init__(meta) self.model = getattr(meta, 'model', None) + self.read_only_fields = getattr(meta, 'read_only_fields', ()) class ModelSerializer(Serializer): @@ -337,16 +369,10 @@ class ModelSerializer(Serializer): """ _options_class = ModelSerializerOptions - def default_fields(self, serialize, obj=None, data=None, nested=False): + def get_default_fields(self): """ Return all the fields that should be serialized for the model. """ - # TODO: Modfiy this so that it's called on init, and drop - # serialize/obj/data arguments. - # - # We *could* provide a hook for dynamic fields, but - # it'd be nice if the default was to generate fields statically - # at the point of __init__ cls = self.opts.model opts = get_concrete_model(cls)._meta @@ -358,6 +384,7 @@ class ModelSerializer(Serializer): fields += [field for field in opts.many_to_many if field.serialize] ret = SortedDict() + nested = bool(self.opts.depth) is_pk = True # First field in the list is the pk for model_field in fields: @@ -374,22 +401,30 @@ class ModelSerializer(Serializer): field = self.get_field(model_field) if field: - field.initialize(parent=self) ret[model_field.name] = field + for field_name in self.opts.read_only_fields: + assert field_name in ret, \ + "read_only_fields on '%s' included invalid item '%s'" % \ + (self.__class__.__name__, field_name) + ret[field_name].read_only = True + return ret def get_pk_field(self, model_field): """ Returns a default instance of the pk field. """ - return Field() + return self.get_field(model_field) def get_nested_field(self, model_field): """ Creates a default instance of a nested relational field. """ - return ModelSerializer() + class NestedModelSerializer(ModelSerializer): + class Meta: + model = model_field.rel.to + return NestedModelSerializer() def get_related_field(self, model_field, to_many=False): """ @@ -397,20 +432,43 @@ class ModelSerializer(Serializer): """ # TODO: filter queryset using: # .using(db).complex_filter(self.rel.limit_choices_to) - queryset = model_field.rel.to._default_manager + kwargs = { + 'null': model_field.null or model_field.blank, + 'queryset': model_field.rel.to._default_manager + } + if to_many: - return ManyPrimaryKeyRelatedField(queryset=queryset) - return PrimaryKeyRelatedField(queryset=queryset) + return ManyPrimaryKeyRelatedField(**kwargs) + return PrimaryKeyRelatedField(**kwargs) def get_field(self, model_field): """ Creates a default instance of a basic non-relational field. """ kwargs = {} + + kwargs['blank'] = model_field.blank + + if model_field.null or model_field.blank: + kwargs['required'] = False + + if isinstance(model_field, models.AutoField) or not model_field.editable: + kwargs['read_only'] = True + if model_field.has_default(): kwargs['required'] = False + kwargs['default'] = model_field.get_default() + + if model_field.__class__ == models.TextField: + kwargs['widget'] = widgets.Textarea + + # TODO: TypedChoiceField? + if model_field.flatchoices: # This ModelField contains choices + kwargs['choices'] = model_field.flatchoices + return ChoiceField(**kwargs) field_mapping = { + models.AutoField: IntegerField, models.FloatField: FloatField, models.IntegerField: IntegerField, models.PositiveIntegerField: IntegerField, @@ -420,42 +478,86 @@ class ModelSerializer(Serializer): models.DateField: DateField, models.EmailField: EmailField, models.CharField: CharField, + models.URLField: URLField, + models.SlugField: SlugField, models.TextField: CharField, models.CommaSeparatedIntegerField: CharField, models.BooleanField: BooleanField, + models.FileField: FileField, + models.ImageField: ImageField, } try: return field_mapping[model_field.__class__](**kwargs) except KeyError: return ModelField(model_field=model_field, **kwargs) + def get_validation_exclusions(self): + """ + Return a list of field names to exclude from model validation. + """ + cls = self.opts.model + opts = get_concrete_model(cls)._meta + exclusions = [field.name for field in opts.fields + opts.many_to_many] + for field_name, field in self.fields.items(): + if field_name in exclusions and not field.read_only: + exclusions.remove(field_name) + return exclusions + def restore_object(self, attrs, instance=None): """ Restore the model instance. """ self.m2m_data = {} + self.related_data = {} - if instance: - for key, val in attrs.items(): - setattr(instance, key, val) - return instance + # Reverse fk relations + for (obj, model) in self.opts.model._meta.get_all_related_objects_with_model(): + field_name = obj.field.related_query_name() + if field_name in attrs: + self.related_data[field_name] = attrs.pop(field_name) + # Reverse m2m relations + for (obj, model) in self.opts.model._meta.get_all_related_m2m_objects_with_model(): + field_name = obj.field.related_query_name() + if field_name in attrs: + self.m2m_data[field_name] = attrs.pop(field_name) + + # Forward m2m relations for field in self.opts.model._meta.many_to_many: if field.name in attrs: self.m2m_data[field.name] = attrs.pop(field.name) - return self.opts.model(**attrs) - def save(self, save_m2m=True): + if instance is not None: + for key, val in attrs.items(): + setattr(instance, key, val) + + else: + instance = self.opts.model(**attrs) + + try: + instance.full_clean(exclude=self.get_validation_exclusions()) + except ValidationError, err: + self._errors = err.message_dict + return None + + return instance + + def save(self): """ Save the deserialized object and return it. """ self.object.save() - if self.m2m_data and save_m2m: + if getattr(self, 'm2m_data', None): for accessor_name, object_list in self.m2m_data.items(): setattr(self.object, accessor_name, object_list) self.m2m_data = {} + if getattr(self, 'related_data', None): + for accessor_name, object_list in self.related_data.items(): + setattr(self.object, accessor_name, object_list) + self.related_data = {} + return self.object @@ -502,9 +604,9 @@ class HyperlinkedModelSerializer(ModelSerializer): # TODO: filter queryset using: # .using(db).complex_filter(self.rel.limit_choices_to) rel = model_field.rel.to - queryset = rel._default_manager kwargs = { - 'queryset': queryset, + 'null': model_field.null, + 'queryset': rel._default_manager, 'view_name': self._get_default_view_name(rel) } if to_many: |