diff options
| author | Tom Christie | 2013-01-07 21:04:52 +0000 |
|---|---|---|
| committer | Tom Christie | 2013-01-07 21:04:52 +0000 |
| commit | 36fa722ebb1b438b710b90fe470fbdbf82fd676e (patch) | |
| tree | 9a837478ff46ebeed0b03fe9a430d72695cc2784 /rest_framework/serializers.py | |
| parent | 873a142af2f63084fd10bf35c13e79131837da07 (diff) | |
| parent | e429f702e00ed807d68e90cd6a6af2749eb0b73e (diff) | |
| download | django-rest-framework-36fa722ebb1b438b710b90fe470fbdbf82fd676e.tar.bz2 | |
Merged to latest master
Diffstat (limited to 'rest_framework/serializers.py')
| -rw-r--r-- | rest_framework/serializers.py | 203 |
1 files changed, 142 insertions, 61 deletions
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 95145d58..fa92838b 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -14,7 +14,7 @@ from rest_framework.compat import get_concrete_model # This helps keep the seperation between model fields, form fields, and # serializer fields more explicit. - +from rest_framework.relations import * from rest_framework.fields import * @@ -22,7 +22,16 @@ class DictWithMetadata(dict): """ A dict-like object, that can have additional properties attached. """ - pass + def __getstate__(self): + """ + Used by pickle (e.g., caching). + Overriden to remove metadata from the dict, since it shouldn't be pickled + and may in some instances be unpickleable. + """ + # return an instance of the first dict in MRO that isn't a DictWithMetadata + for base in self.__class__.__mro__: + if not isinstance(base, DictWithMetadata) and isinstance(base, dict): + return base(self) class SortedDictWithMetadata(SortedDict, DictWithMetadata): @@ -60,7 +69,7 @@ def _get_declared_fields(bases, attrs): # If this class is subclassing another Serializer, add that Serializer's # fields. Note that we loop over the bases in *reverse*. This is necessary - # in order to the correct order of fields. + # in order to maintain the correct order of fields. for base in bases[::-1]: if hasattr(base, 'base_fields'): fields = base.base_fields.items() + fields @@ -89,50 +98,53 @@ class BaseSerializer(Field): pass _options_class = SerializerOptions - _dict_class = SortedDictWithMetadata # Set to unsorted dict for backwards compatability with unsorted implementations. + _dict_class = SortedDictWithMetadata # Set to unsorted dict for backwards compatibility with unsorted implementations. - def __init__(self, instance=None, data=None, context=None, **kwargs): + def __init__(self, instance=None, data=None, files=None, + context=None, partial=False, **kwargs): super(BaseSerializer, self).__init__(**kwargs) self.opts = self._options_class(self.Meta) - self.fields = copy.deepcopy(self.base_fields) self.parent = None self.root = None + self.partial = partial self.context = context or {} self.init_data = data + self.init_files = files self.object = instance + self.fields = self.get_fields() self._data = None + self._files = None self._errors = None ##### # Methods to determine which fields to use when (de)serializing objects. - def default_fields(self, nested=False): + def get_default_fields(self): """ Return the complete set of default fields for the object, as a dict. """ return {} - def get_fields(self, nested=False): + def get_fields(self): """ Returns the complete set of fields for the object as a dict. This will be the set of any explicitly declared fields, - plus the set of fields returned by default_fields(). + plus the set of fields returned by get_default_fields(). """ ret = SortedDict() # Get the explicitly declared fields - for key, field in self.fields.items(): + base_fields = copy.deepcopy(self.base_fields) + for key, field in base_fields.items(): ret[key] = field - # Set up the field - field.initialize(parent=self, field_name=key) # Add in the default fields - fields = self.default_fields(nested) - for key, val in fields.items(): + default_fields = self.get_default_fields() + for key, val in default_fields.items(): if key not in ret: ret[key] = val @@ -148,6 +160,9 @@ class BaseSerializer(Field): for key in self.opts.exclude: ret.pop(key, None) + for key, field in ret.items(): + field.initialize(parent=self, field_name=key) + return ret ##### @@ -163,7 +178,7 @@ class BaseSerializer(Field): self.opts.depth = parent.opts.depth - 1 ##### - # Methods to convert or revert from objects <--> primative representations. + # Methods to convert or revert from objects <--> primitive representations. def get_field_key(self, field_name): """ @@ -179,24 +194,29 @@ class BaseSerializer(Field): ret = self._dict_class() ret.fields = {} - fields = self.get_fields(nested=bool(self.opts.depth)) - for field_name, field in fields.items(): + for field_name, field in self.fields.items(): + field.initialize(parent=self, field_name=field_name) key = self.get_field_key(field_name) value = field.field_to_native(obj, field_name) ret[key] = value ret.fields[key] = field return ret - def restore_fields(self, data): + def restore_fields(self, data, files): """ Core of deserialization, together with `restore_object`. Converts a dictionary of data into a dictionary of deserialized fields. """ - fields = self.get_fields(nested=bool(self.opts.depth)) reverted_data = {} - for field_name, field in fields.items(): + + if data is not None and not isinstance(data, dict): + self._errors['non_field_errors'] = [u'Invalid data'] + return None + + for field_name, field in self.fields.items(): + field.initialize(parent=self, field_name=field_name) try: - field.field_from_native(data, field_name, reverted_data) + field.field_from_native(data, files, field_name, reverted_data) except ValidationError as err: self._errors[field_name] = list(err.messages) @@ -206,10 +226,7 @@ class BaseSerializer(Field): """ Run `validate_<fieldname>()` and `validate()` methods on the serializer """ - # TODO: refactor this so we're not determining the fields again - fields = self.get_fields(nested=bool(self.opts.depth)) - - for field_name, field in fields.items(): + for field_name, field in self.fields.items(): try: validate_method = getattr(self, 'validate_%s' % field_name, None) if validate_method: @@ -218,10 +235,18 @@ class BaseSerializer(Field): except ValidationError as err: self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages) - try: - attrs = self.validate(attrs) - except ValidationError as err: - self._errors['non_field_errors'] = err.messages + # If there are already errors, we don't run .validate() because + # field-validation failed and thus `attrs` may not be complete. + # which in turn can cause inconsistent validation errors. + if not self._errors: + try: + attrs = self.validate(attrs) + except ValidationError as err: + if hasattr(err, 'message_dict'): + for field_name, error_messages in err.message_dict.items(): + self._errors[field_name] = self._errors.get(field_name, []) + list(error_messages) + elif hasattr(err, 'messages'): + self._errors['non_field_errors'] = err.messages return attrs @@ -244,23 +269,23 @@ class BaseSerializer(Field): def to_native(self, obj): """ - Serialize objects -> primatives. + Serialize objects -> primitives. """ if hasattr(obj, '__iter__'): return [self.convert_object(item) for item in obj] return self.convert_object(obj) - def from_native(self, data): + def from_native(self, data, files): """ - Deserialize primatives -> objects. + Deserialize primitives -> objects. """ if hasattr(data, '__iter__') and not isinstance(data, dict): # TODO: error data when deserializing lists - return (self.from_native(item) for item in data) + return [self.from_native(item, None) for item in data] self._errors = {} - if data is not None: - attrs = self.restore_fields(data) + if data is not None or files is not None: + attrs = self.restore_fields(data, files) attrs = self.perform_validation(attrs) else: self._errors['non_field_errors'] = ['No input provided'] @@ -273,12 +298,23 @@ class BaseSerializer(Field): Override default so that we can apply ModelSerializer as a nested field to relationships. """ - obj = getattr(obj, self.source or field_name) + if self.source: + for component in self.source.split('.'): + obj = getattr(obj, component) + if is_simple_callable(obj): + obj = obj() + else: + obj = getattr(obj, field_name) + if is_simple_callable(obj): + obj = value() # If the object has an "all" method, assume it's a relationship if is_simple_callable(getattr(obj, 'all', None)): return [self.to_native(item) for item in obj.all()] + if obj is None: + return None + return self.to_native(obj) @property @@ -288,7 +324,7 @@ class BaseSerializer(Field): setting self.object if no errors occurred. """ if self._errors is None: - obj = self.from_native(self.init_data) + obj = self.from_native(self.init_data, self.init_files) if not self._errors: self.object = obj return self._errors @@ -321,6 +357,7 @@ class ModelSerializerOptions(SerializerOptions): def __init__(self, meta): super(ModelSerializerOptions, self).__init__(meta) self.model = getattr(meta, 'model', None) + self.read_only_fields = getattr(meta, 'read_only_fields', ()) class ModelSerializer(Serializer): @@ -329,16 +366,10 @@ class ModelSerializer(Serializer): """ _options_class = ModelSerializerOptions - def default_fields(self, nested=False): + def get_default_fields(self): """ Return all the fields that should be serialized for the model. """ - # TODO: Modfiy this so that it's called on init, and drop - # serialize/obj/data arguments. - # - # We *could* provide a hook for dynamic fields, but - # it'd be nice if the default was to generate fields statically - # at the point of __init__ cls = self.opts.model opts = get_concrete_model(cls)._meta @@ -350,6 +381,7 @@ class ModelSerializer(Serializer): fields += [field for field in opts.many_to_many if field.serialize] ret = SortedDict() + nested = bool(self.opts.depth) is_pk = True # First field in the list is the pk for model_field in fields: @@ -366,9 +398,14 @@ class ModelSerializer(Serializer): field = self.get_field(model_field) if field: - field.initialize(parent=self, field_name=model_field.name) ret[model_field.name] = field + for field_name in self.opts.read_only_fields: + assert field_name in ret, \ + "read_only_fields on '%s' included invalid item '%s'" % \ + (self.__class__.__name__, field_name) + ret[field_name].read_only = True + return ret def get_pk_field(self, model_field): @@ -381,7 +418,10 @@ class ModelSerializer(Serializer): """ Creates a default instance of a nested relational field. """ - return ModelSerializer() + class NestedModelSerializer(ModelSerializer): + class Meta: + model = model_field.rel.to + return NestedModelSerializer() def get_related_field(self, model_field, to_many=False): """ @@ -389,10 +429,14 @@ class ModelSerializer(Serializer): """ # TODO: filter queryset using: # .using(db).complex_filter(self.rel.limit_choices_to) - queryset = model_field.rel.to._default_manager + kwargs = { + 'null': model_field.null, + 'queryset': model_field.rel.to._default_manager + } + if to_many: - return ManyPrimaryKeyRelatedField(queryset=queryset) - return PrimaryKeyRelatedField(queryset=queryset) + return ManyPrimaryKeyRelatedField(**kwargs) + return PrimaryKeyRelatedField(**kwargs) def get_field(self, model_field): """ @@ -402,7 +446,7 @@ class ModelSerializer(Serializer): kwargs['blank'] = model_field.blank - if model_field.null: + if model_field.null or model_field.blank: kwargs['required'] = False if model_field.has_default(): @@ -427,49 +471,86 @@ class ModelSerializer(Serializer): models.DateField: DateField, models.EmailField: EmailField, models.CharField: CharField, + models.URLField: URLField, + models.SlugField: SlugField, models.TextField: CharField, models.CommaSeparatedIntegerField: CharField, models.BooleanField: BooleanField, + models.FileField: FileField, + models.ImageField: ImageField, } try: return field_mapping[model_field.__class__](**kwargs) except KeyError: return ModelField(model_field=model_field, **kwargs) + def get_validation_exclusions(self): + """ + Return a list of field names to exclude from model validation. + """ + cls = self.opts.model + opts = get_concrete_model(cls)._meta + exclusions = [field.name for field in opts.fields + opts.many_to_many] + for field_name, field in self.fields.items(): + if field_name in exclusions and not field.read_only: + exclusions.remove(field_name) + return exclusions + def restore_object(self, attrs, instance=None): """ Restore the model instance. """ self.m2m_data = {} + self.related_data = {} - if instance: - for key, val in attrs.items(): - setattr(instance, key, val) - return instance + # Reverse fk relations + for (obj, model) in self.opts.model._meta.get_all_related_objects_with_model(): + field_name = obj.field.related_query_name() + if field_name in attrs: + self.related_data[field_name] = attrs.pop(field_name) - # Reverse relations + # Reverse m2m relations for (obj, model) in self.opts.model._meta.get_all_related_m2m_objects_with_model(): field_name = obj.field.related_query_name() if field_name in attrs: self.m2m_data[field_name] = attrs.pop(field_name) - # Forward relations + # Forward m2m relations for field in self.opts.model._meta.many_to_many: if field.name in attrs: self.m2m_data[field.name] = attrs.pop(field.name) - return self.opts.model(**attrs) - def save(self, save_m2m=True): + if instance is not None: + for key, val in attrs.items(): + setattr(instance, key, val) + + else: + instance = self.opts.model(**attrs) + + try: + instance.full_clean(exclude=self.get_validation_exclusions()) + except ValidationError, err: + self._errors = err.message_dict + return None + + return instance + + def save(self): """ Save the deserialized object and return it. """ self.object.save() - if getattr(self, 'm2m_data', None) and save_m2m: + if getattr(self, 'm2m_data', None): for accessor_name, object_list in self.m2m_data.items(): setattr(self.object, accessor_name, object_list) self.m2m_data = {} + if getattr(self, 'related_data', None): + for accessor_name, object_list in self.related_data.items(): + setattr(self.object, accessor_name, object_list) + self.related_data = {} + return self.object @@ -516,9 +597,9 @@ class HyperlinkedModelSerializer(ModelSerializer): # TODO: filter queryset using: # .using(db).complex_filter(self.rel.limit_choices_to) rel = model_field.rel.to - queryset = rel._default_manager kwargs = { - 'queryset': queryset, + 'null': model_field.null, + 'queryset': rel._default_manager, 'view_name': self._get_default_view_name(rel) } if to_many: |
