aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/serializers.py
diff options
context:
space:
mode:
authorTom Christie2012-11-09 13:49:52 +0000
committerTom Christie2012-11-09 13:49:52 +0000
commit8953a60196cb55ec75902882314da5a42636349c (patch)
tree43bf6ea1f69955aeecd83fb9f866d92ea9a5f3df /rest_framework/serializers.py
parentb78872b7dbb55f1aa2d21f15fbb952f0c7156326 (diff)
parent9aaeeacdfebc244850e82469e4af45af252cca4d (diff)
downloaddjango-rest-framework-8953a60196cb55ec75902882314da5a42636349c.tar.bz2
Merge with master
Diffstat (limited to 'rest_framework/serializers.py')
-rw-r--r--rest_framework/serializers.py169
1 files changed, 107 insertions, 62 deletions
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 13f8cde2..95145d58 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -3,8 +3,18 @@ import datetime
import types
from decimal import Decimal
from django.db import models
+from django.forms import widgets
from django.utils.datastructures import SortedDict
from rest_framework.compat import get_concrete_model
+
+# Note: We do the following so that users of the framework can use this style:
+#
+# example_field = serializers.CharField(...)
+#
+# This helps keep the seperation between model fields, form fields, and
+# serializer fields more explicit.
+
+
from rest_framework.fields import *
@@ -22,10 +32,6 @@ class SortedDictWithMetadata(SortedDict, DictWithMetadata):
pass
-class RecursionOccured(BaseException):
- pass
-
-
def _is_protected_type(obj):
"""
True if the object is a native datatype that does not need to
@@ -33,10 +39,10 @@ def _is_protected_type(obj):
"""
return isinstance(obj, (
types.NoneType,
- int, long,
- datetime.datetime, datetime.date, datetime.time,
- float, Decimal,
- basestring)
+ int, long,
+ datetime.datetime, datetime.date, datetime.time,
+ float, Decimal,
+ basestring)
)
@@ -73,7 +79,7 @@ class SerializerOptions(object):
Meta class options for Serializer
"""
def __init__(self, meta):
- self.nested = getattr(meta, 'nested', False)
+ self.depth = getattr(meta, 'depth', 0)
self.fields = getattr(meta, 'fields', ())
self.exclude = getattr(meta, 'exclude', ())
@@ -85,14 +91,13 @@ class BaseSerializer(Field):
_options_class = SerializerOptions
_dict_class = SortedDictWithMetadata # Set to unsorted dict for backwards compatability with unsorted implementations.
- def __init__(self, data=None, instance=None, context=None, **kwargs):
+ def __init__(self, instance=None, data=None, context=None, **kwargs):
super(BaseSerializer, self).__init__(**kwargs)
- self.fields = copy.deepcopy(self.base_fields)
self.opts = self._options_class(self.Meta)
+ self.fields = copy.deepcopy(self.base_fields)
self.parent = None
self.root = None
- self.stack = []
self.context = context or {}
self.init_data = data
@@ -104,13 +109,13 @@ class BaseSerializer(Field):
#####
# Methods to determine which fields to use when (de)serializing objects.
- def default_fields(self, serialize, obj=None, data=None, nested=False):
+ def default_fields(self, nested=False):
"""
Return the complete set of default fields for the object, as a dict.
"""
return {}
- def get_fields(self, serialize, obj=None, data=None, nested=False):
+ def get_fields(self, nested=False):
"""
Returns the complete set of fields for the object as a dict.
@@ -123,10 +128,10 @@ class BaseSerializer(Field):
for key, field in self.fields.items():
ret[key] = field
# Set up the field
- field.initialize(parent=self)
+ field.initialize(parent=self, field_name=key)
# Add in the default fields
- fields = self.default_fields(serialize, obj, data, nested)
+ fields = self.default_fields(nested)
for key, val in fields.items():
if key not in ret:
ret[key] = val
@@ -148,17 +153,14 @@ class BaseSerializer(Field):
#####
# Field methods - used when the serializer class is itself used as a field.
- def initialize(self, parent):
+ def initialize(self, parent, field_name):
"""
Same behaviour as usual Field, except that we need to keep track
- of state so that we can deal with handling maximum depth and recursion.
+ of state so that we can deal with handling maximum depth.
"""
- super(BaseSerializer, self).initialize(parent)
- self.stack = parent.stack[:]
- if parent.opts.nested and not isinstance(parent.opts.nested, bool):
- self.opts.nested = parent.opts.nested - 1
- else:
- self.opts.nested = parent.opts.nested
+ super(BaseSerializer, self).initialize(parent, field_name)
+ if parent.opts.depth:
+ self.opts.depth = parent.opts.depth - 1
#####
# Methods to convert or revert from objects <--> primative representations.
@@ -174,21 +176,13 @@ class BaseSerializer(Field):
Core of serialization.
Convert an object into a dictionary of serialized field values.
"""
- if obj in self.stack and not self.source == '*':
- raise RecursionOccured()
- self.stack.append(obj)
-
ret = self._dict_class()
ret.fields = {}
- fields = self.get_fields(serialize=True, obj=obj, nested=self.opts.nested)
+ fields = self.get_fields(nested=bool(self.opts.depth))
for field_name, field in fields.items():
key = self.get_field_key(field_name)
- try:
- value = field.field_to_native(obj, field_name)
- except RecursionOccured:
- field = self.get_fields(serialize=True, obj=obj, nested=False)[field_name]
- value = field.field_to_native(obj, field_name)
+ value = field.field_to_native(obj, field_name)
ret[key] = value
ret.fields[key] = field
return ret
@@ -198,7 +192,7 @@ class BaseSerializer(Field):
Core of deserialization, together with `restore_object`.
Converts a dictionary of data into a dictionary of deserialized fields.
"""
- fields = self.get_fields(serialize=False, data=data, nested=self.opts.nested)
+ fields = self.get_fields(nested=bool(self.opts.depth))
reverted_data = {}
for field_name, field in fields.items():
try:
@@ -208,6 +202,35 @@ class BaseSerializer(Field):
return reverted_data
+ def perform_validation(self, attrs):
+ """
+ Run `validate_<fieldname>()` and `validate()` methods on the serializer
+ """
+ # TODO: refactor this so we're not determining the fields again
+ fields = self.get_fields(nested=bool(self.opts.depth))
+
+ for field_name, field in fields.items():
+ try:
+ validate_method = getattr(self, 'validate_%s' % field_name, None)
+ if validate_method:
+ source = field.source or field_name
+ attrs = validate_method(attrs, source)
+ except ValidationError as err:
+ self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages)
+
+ try:
+ attrs = self.validate(attrs)
+ except ValidationError as err:
+ self._errors['non_field_errors'] = err.messages
+
+ return attrs
+
+ def validate(self, attrs):
+ """
+ Stub method, to be overridden in Serializer subclasses
+ """
+ return attrs
+
def restore_object(self, attrs, instance=None):
"""
Deserialize a dictionary of attributes into an object instance.
@@ -223,11 +246,8 @@ class BaseSerializer(Field):
"""
Serialize objects -> primatives.
"""
- if isinstance(obj, dict):
- return dict([(key, self.to_native(val))
- for (key, val) in obj.items()])
- elif hasattr(obj, '__iter__'):
- return [self.to_native(item) for item in obj]
+ if hasattr(obj, '__iter__'):
+ return [self.convert_object(item) for item in obj]
return self.convert_object(obj)
def from_native(self, data):
@@ -241,17 +261,31 @@ class BaseSerializer(Field):
self._errors = {}
if data is not None:
attrs = self.restore_fields(data)
+ attrs = self.perform_validation(attrs)
else:
- self._errors['non_field_errors'] = 'No input provided'
+ self._errors['non_field_errors'] = ['No input provided']
if not self._errors:
return self.restore_object(attrs, instance=getattr(self, 'object', None))
+ def field_to_native(self, obj, field_name):
+ """
+ Override default so that we can apply ModelSerializer as a nested
+ field to relationships.
+ """
+ obj = getattr(obj, self.source or field_name)
+
+ # If the object has an "all" method, assume it's a relationship
+ if is_simple_callable(getattr(obj, 'all', None)):
+ return [self.to_native(item) for item in obj.all()]
+
+ return self.to_native(obj)
+
@property
def errors(self):
"""
Run deserialization and return error data,
- setting self.object if no errors occured.
+ setting self.object if no errors occurred.
"""
if self._errors is None:
obj = self.from_native(self.init_data)
@@ -295,17 +329,7 @@ class ModelSerializer(Serializer):
"""
_options_class = ModelSerializerOptions
- def field_to_native(self, obj, field_name):
- """
- Override default so that we can apply ModelSerializer as a nested
- field to relationships.
- """
- obj = getattr(obj, self.source or field_name)
- if obj.__class__.__name__ in ('RelatedManager', 'ManyRelatedManager'):
- return [self.to_native(item) for item in obj.all()]
- return self.to_native(obj)
-
- def default_fields(self, serialize, obj=None, data=None, nested=False):
+ def default_fields(self, nested=False):
"""
Return all the fields that should be serialized for the model.
"""
@@ -342,7 +366,7 @@ class ModelSerializer(Serializer):
field = self.get_field(model_field)
if field:
- field.initialize(parent=self)
+ field.initialize(parent=self, field_name=model_field.name)
ret[model_field.name] = field
return ret
@@ -374,6 +398,25 @@ class ModelSerializer(Serializer):
"""
Creates a default instance of a basic non-relational field.
"""
+ kwargs = {}
+
+ kwargs['blank'] = model_field.blank
+
+ if model_field.null:
+ kwargs['required'] = False
+
+ if model_field.has_default():
+ kwargs['required'] = False
+ kwargs['default'] = model_field.get_default()
+
+ if model_field.__class__ == models.TextField:
+ kwargs['widget'] = widgets.Textarea
+
+ # TODO: TypedChoiceField?
+ if model_field.flatchoices: # This ModelField contains choices
+ kwargs['choices'] = model_field.flatchoices
+ return ChoiceField(**kwargs)
+
field_mapping = {
models.FloatField: FloatField,
models.IntegerField: IntegerField,
@@ -389,14 +432,9 @@ class ModelSerializer(Serializer):
models.BooleanField: BooleanField,
}
try:
- ret = field_mapping[model_field.__class__]()
+ return field_mapping[model_field.__class__](**kwargs)
except KeyError:
- ret = ModelField(model_field=model_field)
-
- if model_field.default:
- ret.required = False
-
- return ret
+ return ModelField(model_field=model_field, **kwargs)
def restore_object(self, attrs, instance=None):
"""
@@ -409,6 +447,13 @@ class ModelSerializer(Serializer):
setattr(instance, key, val)
return instance
+ # Reverse relations
+ for (obj, model) in self.opts.model._meta.get_all_related_m2m_objects_with_model():
+ field_name = obj.field.related_query_name()
+ if field_name in attrs:
+ self.m2m_data[field_name] = attrs.pop(field_name)
+
+ # Forward relations
for field in self.opts.model._meta.many_to_many:
if field.name in attrs:
self.m2m_data[field.name] = attrs.pop(field.name)
@@ -420,7 +465,7 @@ class ModelSerializer(Serializer):
"""
self.object.save()
- if self.m2m_data and save_m2m:
+ if getattr(self, 'm2m_data', None) and save_m2m:
for accessor_name, object_list in self.m2m_data.items():
setattr(self.object, accessor_name, object_list)
self.m2m_data = {}