aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/serializers.py
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework/serializers.py')
-rw-r--r--rest_framework/serializers.py258
1 files changed, 168 insertions, 90 deletions
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 27458f96..4fe857a6 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -1,11 +1,13 @@
+from __future__ import unicode_literals
import copy
import datetime
import types
from decimal import Decimal
+from django.core.paginator import Page
from django.db import models
from django.forms import widgets
from django.utils.datastructures import SortedDict
-from rest_framework.compat import get_concrete_model
+from rest_framework.compat import get_concrete_model, six
# Note: We do the following so that users of the framework can use this style:
#
@@ -25,20 +27,23 @@ class DictWithMetadata(dict):
def __getstate__(self):
"""
Used by pickle (e.g., caching).
- Overriden to remove metadata from the dict, since it shouldn't be pickled
- and may in some instances be unpickleable.
+ Overriden to remove the metadata from the dict, since it shouldn't be
+ pickled and may in some instances be unpickleable.
"""
- # return an instance of the first dict in MRO that isn't a DictWithMetadata
- for base in self.__class__.__mro__:
- if not isinstance(base, DictWithMetadata) and isinstance(base, dict):
- return base(self)
+ return dict(self)
-class SortedDictWithMetadata(SortedDict, DictWithMetadata):
+class SortedDictWithMetadata(SortedDict):
"""
A sorted dict-like object, that can have additional properties attached.
"""
- pass
+ def __getstate__(self):
+ """
+ Used by pickle (e.g., caching).
+ Overriden to remove the metadata from the dict, since it shouldn't be
+ pickle and may in some instances be unpickleable.
+ """
+ return SortedDict(self).__dict__
def _is_protected_type(obj):
@@ -63,7 +68,7 @@ def _get_declared_fields(bases, attrs):
Note that all fields from the base classes are used.
"""
fields = [(field_name, attrs.pop(field_name))
- for field_name, obj in attrs.items()
+ for field_name, obj in list(six.iteritems(attrs))
if isinstance(obj, Field)]
fields.sort(key=lambda x: x[1].creation_counter)
@@ -72,7 +77,7 @@ def _get_declared_fields(bases, attrs):
# in order to maintain the correct order of fields.
for base in bases[::-1]:
if hasattr(base, 'base_fields'):
- fields = base.base_fields.items() + fields
+ fields = list(base.base_fields.items()) + fields
return SortedDict(fields)
@@ -94,19 +99,24 @@ class SerializerOptions(object):
class BaseSerializer(Field):
+ """
+ This is the Serializer implementation.
+ We need to implement it as `BaseSerializer` due to metaclass magicks.
+ """
class Meta(object):
pass
_options_class = SerializerOptions
- _dict_class = SortedDictWithMetadata # Set to unsorted dict for backwards compatibility with unsorted implementations.
+ _dict_class = SortedDictWithMetadata
def __init__(self, instance=None, data=None, files=None,
- context=None, partial=False, **kwargs):
- super(BaseSerializer, self).__init__(**kwargs)
+ context=None, partial=False, many=None, source=None):
+ super(BaseSerializer, self).__init__(source=source)
self.opts = self._options_class(self.Meta)
self.parent = None
self.root = None
self.partial = partial
+ self.many = many
self.context = context or {}
@@ -150,6 +160,7 @@ class BaseSerializer(Field):
# If 'fields' is specified, use those fields, in that order.
if self.opts.fields:
+ assert isinstance(self.opts.fields, (list, tuple)), '`include` must be a list or tuple'
new = SortedDict()
for key in self.opts.fields:
new[key] = ret[key]
@@ -157,6 +168,7 @@ class BaseSerializer(Field):
# Remove anything in 'exclude'
if self.opts.exclude:
+ assert isinstance(self.opts.fields, (list, tuple)), '`exclude` must be a list or tuple'
for key in self.opts.exclude:
ret.pop(key, None)
@@ -186,22 +198,6 @@ class BaseSerializer(Field):
"""
return field_name
- def convert_object(self, obj):
- """
- Core of serialization.
- Convert an object into a dictionary of serialized field values.
- """
- ret = self._dict_class()
- ret.fields = {}
-
- for field_name, field in self.fields.items():
- field.initialize(parent=self, field_name=field_name)
- key = self.get_field_key(field_name)
- value = field.field_to_native(obj, field_name)
- ret[key] = value
- ret.fields[key] = field
- return ret
-
def restore_fields(self, data, files):
"""
Core of deserialization, together with `restore_object`.
@@ -210,7 +206,7 @@ class BaseSerializer(Field):
reverted_data = {}
if data is not None and not isinstance(data, dict):
- self._errors['non_field_errors'] = [u'Invalid data']
+ self._errors['non_field_errors'] = ['Invalid data']
return None
for field_name, field in self.fields.items():
@@ -227,6 +223,8 @@ class BaseSerializer(Field):
Run `validate_<fieldname>()` and `validate()` methods on the serializer
"""
for field_name, field in self.fields.items():
+ if field_name in self._errors:
+ continue
try:
validate_method = getattr(self, 'validate_%s' % field_name, None)
if validate_method:
@@ -271,18 +269,21 @@ class BaseSerializer(Field):
"""
Serialize objects -> primitives.
"""
- if hasattr(obj, '__iter__'):
- return [self.convert_object(item) for item in obj]
- return self.convert_object(obj)
+ ret = self._dict_class()
+ ret.fields = {}
+
+ for field_name, field in self.fields.items():
+ field.initialize(parent=self, field_name=field_name)
+ key = self.get_field_key(field_name)
+ value = field.field_to_native(obj, field_name)
+ ret[key] = value
+ ret.fields[key] = field
+ return ret
def from_native(self, data, files):
"""
Deserialize primitives -> objects.
"""
- if hasattr(data, '__iter__') and not isinstance(data, dict):
- # TODO: error data when deserializing lists
- return [self.from_native(item, None) for item in data]
-
self._errors = {}
if data is not None or files is not None:
attrs = self.restore_fields(data, files)
@@ -298,6 +299,9 @@ class BaseSerializer(Field):
Override default so that we can apply ModelSerializer as a nested
field to relationships.
"""
+ if self.source == '*':
+ return self.to_native(obj)
+
try:
if self.source:
for component in self.source.split('.'):
@@ -318,6 +322,13 @@ class BaseSerializer(Field):
if obj is None:
return None
+ if self.many is not None:
+ many = self.many
+ else:
+ many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict, six.text_type))
+
+ if many:
+ return [self.to_native(item) for item in obj]
return self.to_native(obj)
@property
@@ -327,9 +338,30 @@ class BaseSerializer(Field):
setting self.object if no errors occurred.
"""
if self._errors is None:
- obj = self.from_native(self.init_data, self.init_files)
+ data, files = self.init_data, self.init_files
+
+ if self.many is not None:
+ many = self.many
+ else:
+ many = hasattr(data, '__iter__') and not isinstance(data, (Page, dict, six.text_type))
+ if many:
+ warnings.warn('Implict list/queryset serialization is due to be deprecated. '
+ 'Use the `many=True` flag when instantiating the serializer.',
+ PendingDeprecationWarning, stacklevel=3)
+
+ if many:
+ ret = []
+ errors = []
+ for item in data:
+ ret.append(self.from_native(item, None))
+ errors.append(self._errors)
+ self._errors = any(errors) and errors or []
+ else:
+ ret = self.from_native(data, files)
+
if not self._errors:
- self.object = obj
+ self.object = ret
+
return self._errors
def is_valid(self):
@@ -337,20 +369,44 @@ class BaseSerializer(Field):
@property
def data(self):
+ """
+ Returns the serialized data on the serializer.
+ """
if self._data is None:
- self._data = self.to_native(self.object)
+ obj = self.object
+
+ if self.many is not None:
+ many = self.many
+ else:
+ many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict))
+ if many:
+ warnings.warn('Implict list/queryset serialization is due to be deprecated. '
+ 'Use the `many=True` flag when instantiating the serializer.',
+ PendingDeprecationWarning, stacklevel=2)
+
+ if many:
+ self._data = [self.to_native(item) for item in obj]
+ else:
+ self._data = self.to_native(obj)
+
return self._data
- def save(self):
+ def save_object(self, obj, **kwargs):
+ obj.save(**kwargs)
+
+ def save(self, **kwargs):
"""
Save the deserialized object and return it.
"""
- self.object.save()
+ if isinstance(self.object, list):
+ [self.save_object(item, **kwargs) for item in self.object]
+ else:
+ self.save_object(self.object, **kwargs)
return self.object
-class Serializer(BaseSerializer):
- __metaclass__ = SerializerMetaclass
+class Serializer(six.with_metaclass(SerializerMetaclass, BaseSerializer)):
+ pass
class ModelSerializerOptions(SerializerOptions):
@@ -369,16 +425,42 @@ class ModelSerializer(Serializer):
"""
_options_class = ModelSerializerOptions
+ field_mapping = {
+ models.AutoField: IntegerField,
+ models.FloatField: FloatField,
+ models.IntegerField: IntegerField,
+ models.PositiveIntegerField: IntegerField,
+ models.SmallIntegerField: IntegerField,
+ models.PositiveSmallIntegerField: IntegerField,
+ models.DateTimeField: DateTimeField,
+ models.DateField: DateField,
+ models.TimeField: TimeField,
+ models.EmailField: EmailField,
+ models.CharField: CharField,
+ models.URLField: URLField,
+ models.SlugField: SlugField,
+ models.TextField: CharField,
+ models.CommaSeparatedIntegerField: CharField,
+ models.BooleanField: BooleanField,
+ models.FileField: FileField,
+ models.ImageField: ImageField,
+ }
+
def get_default_fields(self):
"""
Return all the fields that should be serialized for the model.
"""
cls = self.opts.model
+ assert cls is not None, \
+ "Serializer class '%s' is missing 'model' Meta option" % self.__class__.__name__
opts = get_concrete_model(cls)._meta
pk_field = opts.pk
- while pk_field.rel:
+
+ # If model is a child via multitable inheritance, use parent's pk
+ while pk_field.rel and pk_field.rel.parent_link:
pk_field = pk_field.rel.to._meta.pk
+
fields = [pk_field]
fields += [field for field in opts.fields if field.serialize]
fields += [field for field in opts.many_to_many if field.serialize]
@@ -433,12 +515,11 @@ class ModelSerializer(Serializer):
# TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to)
kwargs = {
- 'null': model_field.null or model_field.blank,
- 'queryset': model_field.rel.to._default_manager
+ 'required': not(model_field.null or model_field.blank),
+ 'queryset': model_field.rel.to._default_manager,
+ 'many': to_many
}
- if to_many:
- return ManyPrimaryKeyRelatedField(**kwargs)
return PrimaryKeyRelatedField(**kwargs)
def get_field(self, model_field):
@@ -446,20 +527,18 @@ class ModelSerializer(Serializer):
Creates a default instance of a basic non-relational field.
"""
kwargs = {}
+ has_default = model_field.has_default()
- kwargs['blank'] = model_field.blank
-
- if model_field.null or model_field.blank:
+ if model_field.null or model_field.blank or has_default:
kwargs['required'] = False
if isinstance(model_field, models.AutoField) or not model_field.editable:
kwargs['read_only'] = True
- if model_field.has_default():
- kwargs['required'] = False
+ if has_default:
kwargs['default'] = model_field.get_default()
- if model_field.__class__ == models.TextField:
+ if issubclass(model_field.__class__, models.TextField):
kwargs['widget'] = widgets.Textarea
# TODO: TypedChoiceField?
@@ -467,27 +546,8 @@ class ModelSerializer(Serializer):
kwargs['choices'] = model_field.flatchoices
return ChoiceField(**kwargs)
- field_mapping = {
- models.AutoField: IntegerField,
- models.FloatField: FloatField,
- models.IntegerField: IntegerField,
- models.PositiveIntegerField: IntegerField,
- models.SmallIntegerField: IntegerField,
- models.PositiveSmallIntegerField: IntegerField,
- models.DateTimeField: DateTimeField,
- models.DateField: DateField,
- models.EmailField: EmailField,
- models.CharField: CharField,
- models.URLField: URLField,
- models.SlugField: SlugField,
- models.TextField: CharField,
- models.CommaSeparatedIntegerField: CharField,
- models.BooleanField: BooleanField,
- models.FileField: FileField,
- models.ImageField: ImageField,
- }
try:
- return field_mapping[model_field.__class__](**kwargs)
+ return self.field_mapping[model_field.__class__](**kwargs)
except KeyError:
return ModelField(model_field=model_field, **kwargs)
@@ -499,10 +559,27 @@ class ModelSerializer(Serializer):
opts = get_concrete_model(cls)._meta
exclusions = [field.name for field in opts.fields + opts.many_to_many]
for field_name, field in self.fields.items():
+ field_name = field.source or field_name
if field_name in exclusions and not field.read_only:
exclusions.remove(field_name)
return exclusions
+ def full_clean(self, instance):
+ """
+ Perform Django's full_clean, and populate the `errors` dictionary
+ if any validation errors occur.
+
+ Note that we don't perform this inside the `.restore_object()` method,
+ so that subclasses can override `.restore_object()`, and still get
+ the full_clean validation checking.
+ """
+ try:
+ instance.full_clean(exclude=self.get_validation_exclusions())
+ except ValidationError as err:
+ self._errors = err.message_dict
+ return None
+ return instance
+
def restore_object(self, attrs, instance=None):
"""
Restore the model instance.
@@ -534,19 +611,21 @@ class ModelSerializer(Serializer):
else:
instance = self.opts.model(**attrs)
- try:
- instance.full_clean(exclude=self.get_validation_exclusions())
- except ValidationError, err:
- self._errors = err.message_dict
- return None
-
return instance
- def save(self):
+ def from_native(self, data, files):
+ """
+ Override the default method to also include model field validation.
+ """
+ instance = super(ModelSerializer, self).from_native(data, files)
+ if instance:
+ return self.full_clean(instance)
+
+ def save_object(self, obj, **kwargs):
"""
Save the deserialized object and return it.
"""
- self.object.save()
+ obj.save(**kwargs)
if getattr(self, 'm2m_data', None):
for accessor_name, object_list in self.m2m_data.items():
@@ -558,8 +637,6 @@ class ModelSerializer(Serializer):
setattr(self.object, accessor_name, object_list)
self.related_data = {}
- return self.object
-
class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
"""
@@ -572,6 +649,8 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
class HyperlinkedModelSerializer(ModelSerializer):
"""
+ A subclass of ModelSerializer that uses hyperlinked relationships,
+ instead of primary key relationships.
"""
_options_class = HyperlinkedModelSerializerOptions
_default_view_name = '%(model_name)s-detail'
@@ -605,10 +684,9 @@ class HyperlinkedModelSerializer(ModelSerializer):
# .using(db).complex_filter(self.rel.limit_choices_to)
rel = model_field.rel.to
kwargs = {
- 'null': model_field.null,
+ 'required': not(model_field.null or model_field.blank),
'queryset': rel._default_manager,
- 'view_name': self._get_default_view_name(rel)
+ 'view_name': self._get_default_view_name(rel),
+ 'many': to_many
}
- if to_many:
- return ManyHyperlinkedRelatedField(**kwargs)
return HyperlinkedRelatedField(**kwargs)