aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--djangorestframework/fields.py446
-rw-r--r--djangorestframework/parsers.py16
-rw-r--r--djangorestframework/request.py3
-rw-r--r--djangorestframework/serializers.py348
-rw-r--r--djangorestframework/tests/parsers.py6
-rw-r--r--djangorestframework/tests/renderers.py2
6 files changed, 809 insertions, 12 deletions
diff --git a/djangorestframework/fields.py b/djangorestframework/fields.py
new file mode 100644
index 00000000..a44eb417
--- /dev/null
+++ b/djangorestframework/fields.py
@@ -0,0 +1,446 @@
+import copy
+import datetime
+import inspect
+import warnings
+
+from django.core import validators
+from django.core.exceptions import ValidationError
+from django.conf import settings
+from django.db import DEFAULT_DB_ALIAS
+from django.db.models.related import RelatedObject
+from django.utils import timezone
+from django.utils.dateparse import parse_date, parse_datetime
+from django.utils.encoding import is_protected_type, smart_unicode
+from django.utils.translation import ugettext_lazy as _
+
+
+def is_simple_callable(obj):
+ """
+ True if the object is a callable that takes no arguments.
+ """
+ return (
+ (inspect.isfunction(obj) and not inspect.getargspec(obj)[0]) or
+ (inspect.ismethod(obj) and len(inspect.getargspec(obj)[0]) <= 1)
+ )
+
+
+class Field(object):
+ creation_counter = 0
+ default_validators = []
+ default_error_messages = {
+ 'required': _('This field is required.'),
+ 'invalid': _('Invalid value.'),
+ }
+ empty = ''
+
+ def __init__(self, source=None, readonly=False, required=None,
+ validators=[], error_messages=None):
+ self.parent = None
+
+ self.creation_counter = Field.creation_counter
+ Field.creation_counter += 1
+
+ self.source = source
+ self.readonly = readonly
+ self.required = not(readonly)
+
+ messages = {}
+ for c in reversed(self.__class__.__mro__):
+ messages.update(getattr(c, 'default_error_messages', {}))
+ messages.update(error_messages or {})
+ self.error_messages = messages
+
+ self.validators = self.default_validators + validators
+
+ def initialize(self, parent, model_field=None):
+ """
+ Called to set up a field prior to field_to_native or field_from_native.
+
+ parent - The parent serializer.
+ model_field - The model field this field corrosponds to, if one exists.
+ """
+ self.parent = parent
+ self.root = parent.root or parent
+ self.context = self.root.context
+ if model_field:
+ self.model_field = model_field
+
+ def validate(self, value):
+ pass
+ # if value in validators.EMPTY_VALUES and self.required:
+ # raise ValidationError(self.error_messages['required'])
+
+ def run_validators(self, value):
+ if value in validators.EMPTY_VALUES:
+ return
+ errors = []
+ for v in self.validators:
+ try:
+ v(value)
+ except ValidationError as e:
+ if hasattr(e, 'code') and e.code in self.error_messages:
+ message = self.error_messages[e.code]
+ if e.params:
+ message = message % e.params
+ errors.append(message)
+ else:
+ errors.extend(e.messages)
+ if errors:
+ raise ValidationError(errors)
+
+ def field_from_native(self, data, field_name, into):
+ """
+ Given a dictionary and a field name, updates the dictionary `into`,
+ with the field and it's deserialized value.
+ """
+ if self.readonly:
+ return
+
+ try:
+ native = data[field_name]
+ except KeyError:
+ return # TODO Consider validation behaviour, 'required' opt etc...
+
+ value = self.from_native(native)
+ if self.source == '*':
+ if value:
+ into.update(value)
+ else:
+ self.validate(value)
+ self.run_validators(value)
+ into[self.source or field_name] = value
+
+ def from_native(self, value):
+ """
+ Reverts a simple representation back to the field's value.
+ """
+ if hasattr(self, 'model_field'):
+ try:
+ return self.model_field.rel.to._meta.get_field(self.model_field.rel.field_name).to_python(value)
+ except:
+ return self.model_field.to_python(value)
+ return value
+
+ def field_to_native(self, obj, field_name):
+ """
+ Given and object and a field name, returns the value that should be
+ serialized for that field.
+ """
+ if obj is None:
+ return self.empty
+
+ if self.source == '*':
+ return self.to_native(obj)
+
+ self.obj = obj # Need to hang onto this in the case of model fields
+ if hasattr(self, 'model_field'):
+ return self.to_native(self.model_field._get_val_from_obj(obj))
+
+ return self.to_native(getattr(obj, self.source or field_name))
+
+ def to_native(self, value):
+ """
+ Converts the field's value into it's simple representation.
+ """
+ if is_simple_callable(value):
+ value = value()
+
+ if is_protected_type(value):
+ return value
+ elif hasattr(self, 'model_field'):
+ return self.model_field.value_to_string(self.obj)
+ return smart_unicode(value)
+
+ def attributes(self):
+ """
+ Returns a dictionary of attributes to be used when serializing to xml.
+ """
+ try:
+ return {
+ "type": self.model_field.get_internal_type()
+ }
+ except AttributeError:
+ return {}
+
+
+class RelatedField(Field):
+ """
+ A base class for model related fields or related managers.
+
+ Subclass this and override `convert` to define custom behaviour when
+ serializing related objects.
+ """
+
+ def field_to_native(self, obj, field_name):
+ obj = getattr(obj, field_name)
+ if obj.__class__.__name__ in ('RelatedManager', 'ManyRelatedManager'):
+ return [self.to_native(item) for item in obj.all()]
+ return self.to_native(obj)
+
+ def attributes(self):
+ try:
+ return {
+ "rel": self.model_field.rel.__class__.__name__,
+ "to": smart_unicode(self.model_field.rel.to._meta)
+ }
+ except AttributeError:
+ return {}
+
+
+class PrimaryKeyRelatedField(RelatedField):
+ """
+ Serializes a model related field or related manager to a pk value.
+ """
+
+ # Note the we use ModelRelatedField's implementation, as we want to get the
+ # raw database value directly, since that won't involve another
+ # database lookup.
+ #
+ # An alternative implementation would simply be this...
+ #
+ # class PrimaryKeyRelatedField(RelatedField):
+ # def to_native(self, obj):
+ # return obj.pk
+
+ def to_native(self, pk):
+ """
+ Simply returns the object's pk. You can subclass this method to
+ provide different serialization behavior of the pk.
+ (For example returning a URL based on the model's pk.)
+ """
+ return pk
+
+ def field_to_native(self, obj, field_name):
+ try:
+ obj = obj.serializable_value(field_name)
+ except AttributeError:
+ field = obj._meta.get_field_by_name(field_name)[0]
+ obj = getattr(obj, field_name)
+ if obj.__class__.__name__ == 'RelatedManager':
+ return [self.to_native(item.pk) for item in obj.all()]
+ elif isinstance(field, RelatedObject):
+ return self.to_native(obj.pk)
+ raise
+ if obj.__class__.__name__ == 'ManyRelatedManager':
+ return [self.to_native(item.pk) for item in obj.all()]
+ return self.to_native(obj)
+
+ def field_from_native(self, data, field_name, into):
+ value = data.get(field_name)
+ if hasattr(value, '__iter__'):
+ into[field_name] = [self.from_native(item) for item in value]
+ else:
+ into[field_name + '_id'] = self.from_native(value)
+
+
+class NaturalKeyRelatedField(RelatedField):
+ """
+ Serializes a model related field or related manager to a natural key value.
+ """
+ is_natural_key = True # XML renderer handles these differently
+
+ def to_native(self, obj):
+ if hasattr(obj, 'natural_key'):
+ return obj.natural_key()
+ return obj
+
+ def field_from_native(self, data, field_name, into):
+ value = data.get(field_name)
+ into[self.model_field.attname] = self.from_native(value)
+
+ def from_native(self, value):
+ # TODO: Support 'using' : db = options.pop('using', DEFAULT_DB_ALIAS)
+ manager = self.model_field.rel.to._default_manager
+ manager = manager.db_manager(DEFAULT_DB_ALIAS)
+ return manager.get_by_natural_key(*value).pk
+
+
+class BooleanField(Field):
+ default_error_messages = {
+ 'invalid': _(u"'%s' value must be either True or False."),
+ }
+
+ def from_native(self, value):
+ if value in (True, False):
+ # if value is 1 or 0 than it's equal to True or False, but we want
+ # to return a true bool for semantic reasons.
+ return bool(value)
+ if value in ('t', 'True', '1'):
+ return True
+ if value in ('f', 'False', '0'):
+ return False
+ raise ValidationError(self.error_messages['invalid'] % value)
+
+
+class CharField(Field):
+ def __init__(self, max_length=None, min_length=None, *args, **kwargs):
+ self.max_length, self.min_length = max_length, min_length
+ super(CharField, self).__init__(*args, **kwargs)
+ if min_length is not None:
+ self.validators.append(validators.MinLengthValidator(min_length))
+ if max_length is not None:
+ self.validators.append(validators.MaxLengthValidator(max_length))
+
+ def from_native(self, value):
+ if isinstance(value, basestring) or value is None:
+ return value
+ return smart_unicode(value)
+
+
+class EmailField(CharField):
+ default_error_messages = {
+ 'invalid': _('Enter a valid e-mail address.'),
+ }
+ default_validators = [validators.validate_email]
+
+ def from_native(self, value):
+ return super(EmailField, self).from_native(value).strip()
+
+ def __deepcopy__(self, memo):
+ result = copy.copy(self)
+ memo[id(self)] = result
+ #result.widget = copy.deepcopy(self.widget, memo)
+ result.validators = self.validators[:]
+ return result
+
+
+class DateField(Field):
+ default_error_messages = {
+ 'invalid': _(u"'%s' value has an invalid date format. It must be "
+ u"in YYYY-MM-DD format."),
+ 'invalid_date': _(u"'%s' value has the correct format (YYYY-MM-DD) "
+ u"but it is an invalid date."),
+ }
+ empty = None
+
+ def from_native(self, value):
+ if value is None:
+ return value
+ if isinstance(value, datetime.datetime):
+ if settings.USE_TZ and timezone.is_aware(value):
+ # Convert aware datetimes to the default time zone
+ # before casting them to dates (#17742).
+ default_timezone = timezone.get_default_timezone()
+ value = timezone.make_naive(value, default_timezone)
+ return value.date()
+ if isinstance(value, datetime.date):
+ return value
+
+ try:
+ parsed = parse_date(value)
+ if parsed is not None:
+ return parsed
+ except ValueError:
+ msg = self.error_messages['invalid_date'] % value
+ raise ValidationError(msg)
+
+ msg = self.error_messages['invalid'] % value
+ raise ValidationError(msg)
+
+
+class DateTimeField(Field):
+ default_error_messages = {
+ 'invalid': _(u"'%s' value has an invalid format. It must be in "
+ u"YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ] format."),
+ 'invalid_date': _(u"'%s' value has the correct format "
+ u"(YYYY-MM-DD) but it is an invalid date."),
+ 'invalid_datetime': _(u"'%s' value has the correct format "
+ u"(YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ]) "
+ u"but it is an invalid date/time."),
+ }
+ empty = None
+
+ def from_native(self, value):
+ if value is None:
+ return value
+ if isinstance(value, datetime.datetime):
+ return value
+ if isinstance(value, datetime.date):
+ value = datetime.datetime(value.year, value.month, value.day)
+ if settings.USE_TZ:
+ # For backwards compatibility, interpret naive datetimes in
+ # local time. This won't work during DST change, but we can't
+ # do much about it, so we let the exceptions percolate up the
+ # call stack.
+ warnings.warn(u"DateTimeField received a naive datetime (%s)"
+ u" while time zone support is active." % value,
+ RuntimeWarning)
+ default_timezone = timezone.get_default_timezone()
+ value = timezone.make_aware(value, default_timezone)
+ return value
+
+ try:
+ parsed = parse_datetime(value)
+ if parsed is not None:
+ return parsed
+ except ValueError:
+ msg = self.error_messages['invalid_datetime'] % value
+ raise ValidationError(msg)
+
+ try:
+ parsed = parse_date(value)
+ if parsed is not None:
+ return datetime.datetime(parsed.year, parsed.month, parsed.day)
+ except ValueError:
+ msg = self.error_messages['invalid_date'] % value
+ raise ValidationError(msg)
+
+ msg = self.error_messages['invalid'] % value
+ raise ValidationError(msg)
+
+
+class IntegerField(Field):
+ default_error_messages = {
+ 'invalid': _('Enter a whole number.'),
+ 'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'),
+ 'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'),
+ }
+
+ def __init__(self, max_value=None, min_value=None, *args, **kwargs):
+ self.max_value, self.min_value = max_value, min_value
+ super(IntegerField, self).__init__(*args, **kwargs)
+
+ if max_value is not None:
+ self.validators.append(validators.MaxValueValidator(max_value))
+ if min_value is not None:
+ self.validators.append(validators.MinValueValidator(min_value))
+
+ def from_native(self, value):
+ if value in validators.EMPTY_VALUES:
+ return None
+ try:
+ value = int(str(value))
+ except (ValueError, TypeError):
+ raise ValidationError(self.error_messages['invalid'])
+ return value
+
+
+class FloatField(Field):
+ default_error_messages = {
+ 'invalid': _("'%s' value must be a float."),
+ }
+
+ def from_native(self, value):
+ if value is None:
+ return value
+ try:
+ return float(value)
+ except (TypeError, ValueError):
+ msg = self.error_messages['invalid'] % value
+ raise ValidationError(msg)
+
+# field_mapping = {
+# models.AutoField: IntegerField,
+# models.BooleanField: BooleanField,
+# models.CharField: CharField,
+# models.DateTimeField: DateTimeField,
+# models.DateField: DateField,
+# models.BigIntegerField: IntegerField,
+# models.IntegerField: IntegerField,
+# models.PositiveIntegerField: IntegerField,
+# models.FloatField: FloatField
+# }
+
+
+# def modelfield_to_serializerfield(field):
+# return field_mapping.get(type(field), Field)
diff --git a/djangorestframework/parsers.py b/djangorestframework/parsers.py
index 1fff64f7..43ea0c4d 100644
--- a/djangorestframework/parsers.py
+++ b/djangorestframework/parsers.py
@@ -57,7 +57,7 @@ class BaseParser(object):
"""
return media_type_matches(self.media_type, content_type)
- def parse(self, stream, meta, upload_handlers):
+ def parse(self, stream, **opts):
"""
Given a *stream* to read from, return the deserialized output.
Should return a 2-tuple of (data, files).
@@ -72,7 +72,7 @@ class JSONParser(BaseParser):
media_type = 'application/json'
- def parse(self, stream, meta, upload_handlers):
+ def parse(self, stream, **opts):
"""
Returns a 2-tuple of `(data, files)`.
@@ -92,7 +92,7 @@ class YAMLParser(BaseParser):
media_type = 'application/yaml'
- def parse(self, stream, meta, upload_handlers):
+ def parse(self, stream, **opts):
"""
Returns a 2-tuple of `(data, files)`.
@@ -112,7 +112,7 @@ class PlainTextParser(BaseParser):
media_type = 'text/plain'
- def parse(self, stream, meta, upload_handlers):
+ def parse(self, stream, **opts):
"""
Returns a 2-tuple of `(data, files)`.
@@ -129,7 +129,7 @@ class FormParser(BaseParser):
media_type = 'application/x-www-form-urlencoded'
- def parse(self, stream, meta, upload_handlers):
+ def parse(self, stream, **opts):
"""
Returns a 2-tuple of `(data, files)`.
@@ -147,13 +147,15 @@ class MultiPartParser(BaseParser):
media_type = 'multipart/form-data'
- def parse(self, stream, meta, upload_handlers):
+ def parse(self, stream, **opts):
"""
Returns a 2-tuple of `(data, files)`.
`data` will be a :class:`QueryDict` containing all the form parameters.
`files` will be a :class:`QueryDict` containing all the form files.
"""
+ meta = opts['meta']
+ upload_handlers = opts['upload_handlers']
try:
parser = DjangoMultiPartParser(meta, stream, upload_handlers)
return parser.parse()
@@ -168,7 +170,7 @@ class XMLParser(BaseParser):
media_type = 'application/xml'
- def parse(self, stream, meta, upload_handlers):
+ def parse(self, stream, **opts):
"""
Returns a 2-tuple of `(data, files)`.
diff --git a/djangorestframework/request.py b/djangorestframework/request.py
index 684f6591..84ca0575 100644
--- a/djangorestframework/request.py
+++ b/djangorestframework/request.py
@@ -214,7 +214,8 @@ class Request(object):
for parser in self.get_parsers():
if parser.can_handle_request(self.content_type):
- return parser.parse(self.stream, self.META, self.upload_handlers)
+ return parser.parse(self.stream, meta=self.META,
+ upload_handlers=self.upload_handlers)
raise UnsupportedMediaType(self._content_type)
diff --git a/djangorestframework/serializers.py b/djangorestframework/serializers.py
new file mode 100644
index 00000000..46980ee6
--- /dev/null
+++ b/djangorestframework/serializers.py
@@ -0,0 +1,348 @@
+from decimal import Decimal
+from django.core.serializers.base import DeserializedObject
+from django.utils.datastructures import SortedDict
+import copy
+import datetime
+import types
+from djangorestframework.fields import *
+
+
+class DictWithMetadata(dict):
+ """
+ A dict-like object, that can have additional properties attached.
+ """
+ pass
+
+
+class SortedDictWithMetadata(SortedDict, DictWithMetadata):
+ """
+ A sorted dict-like object, that can have additional properties attached.
+ """
+ pass
+
+
+class RecursionOccured(BaseException):
+ pass
+
+
+def _is_protected_type(obj):
+ """
+ True if the object is a native datatype that does not need to
+ be serialized further.
+ """
+ return isinstance(obj, (
+ types.NoneType,
+ int, long,
+ datetime.datetime, datetime.date, datetime.time,
+ float, Decimal,
+ basestring)
+ )
+
+
+def _get_declared_fields(bases, attrs):
+ """
+ Create a list of serializer field instances from the passed in 'attrs',
+ plus any fields on the base classes (in 'bases').
+
+ Note that all fields from the base classes are used.
+ """
+ fields = [(field_name, attrs.pop(field_name))
+ for field_name, obj in attrs.items()
+ if isinstance(obj, Field)]
+ fields.sort(key=lambda x: x[1].creation_counter)
+
+ # If this class is subclassing another Serializer, add that Serializer's
+ # fields. Note that we loop over the bases in *reverse*. This is necessary
+ # in order to the correct order of fields.
+ for base in bases[::-1]:
+ if hasattr(base, 'base_fields'):
+ fields = base.base_fields.items() + fields
+
+ return SortedDict(fields)
+
+
+class SerializerMetaclass(type):
+ def __new__(cls, name, bases, attrs):
+ attrs['base_fields'] = _get_declared_fields(bases, attrs)
+ return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs)
+
+
+class SerializerOptions(object):
+ """
+ Meta class options for ModelSerializer
+ """
+ def __init__(self, meta):
+ self.nested = getattr(meta, 'nested', False)
+ self.fields = getattr(meta, 'fields', ())
+ self.exclude = getattr(meta, 'exclude', ())
+
+
+class BaseSerializer(Field):
+ class Meta(object):
+ pass
+
+ _options_class = SerializerOptions
+ _dict_class = SortedDictWithMetadata # Set to unsorted dict for backwards compatability with unsorted implementations.
+
+ def __init__(self, data=None, instance=None, context=None, **kwargs):
+ super(BaseSerializer, self).__init__(**kwargs)
+ self.fields = copy.deepcopy(self.base_fields)
+ self.opts = self._options_class(self.Meta)
+ self.parent = None
+ self.root = None
+
+ self.stack = []
+ self.context = context or {}
+
+ self.init_data = data
+ self.instance = instance
+
+ self._data = None
+ self._errors = None
+
+ #####
+ # Methods to determine which fields to use when (de)serializing objects.
+
+ def default_fields(self, serialize, obj=None, data=None, nested=False):
+ """
+ Return the complete set of default fields for the object, as a dict.
+ """
+ return {}
+
+ def get_fields(self, serialize, obj=None, data=None, nested=False):
+ """
+ Returns the complete set of fields for the object as a dict.
+
+ This will be the set of any explicitly declared fields,
+ plus the set of fields returned by get_default_fields().
+ """
+ ret = SortedDict()
+
+ # Get the explicitly declared fields
+ for key, field in self.fields.items():
+ ret[key] = field
+ # Determine if the declared field corrosponds to a model field.
+ try:
+ if key == 'pk':
+ model_field = obj._meta.pk
+ else:
+ model_field = obj._meta.get_field_by_name(key)[0]
+ except:
+ model_field = None
+ # Set up the field
+ field.initialize(parent=self, model_field=model_field)
+
+ # Add in the default fields
+ fields = self.default_fields(serialize, obj, data, nested)
+ for key, val in fields.items():
+ if key not in ret:
+ ret[key] = val
+
+ # If 'fields' is specified, use those fields, in that order.
+ if self.opts.fields:
+ new = SortedDict()
+ for key in self.opts.fields:
+ new[key] = ret[key]
+ ret = new
+
+ # Remove anything in 'exclude'
+ if self.opts.exclude:
+ for key in self.opts.exclude:
+ ret.pop(key, None)
+
+ return ret
+
+ #####
+ # Field methods - used when the serializer class is itself used as a field.
+
+ def initialize(self, parent, model_field=None):
+ """
+ Same behaviour as usual Field, except that we need to keep track
+ of state so that we can deal with handling maximum depth and recursion.
+ """
+ super(BaseSerializer, self).initialize(parent, model_field)
+ self.stack = parent.stack[:]
+ if parent.opts.nested and not isinstance(parent.opts.nested, bool):
+ self.opts.nested = parent.opts.nested - 1
+ else:
+ self.opts.nested = parent.opts.nested
+
+ #####
+ # Methods to convert or revert from objects <--> primative representations.
+
+ def get_field_key(self, field_name):
+ """
+ Return the key that should be used for a given field.
+ """
+ return field_name
+
+ def convert_object(self, obj):
+ """
+ Core of serialization.
+ Convert an object into a dictionary of serialized field values.
+ """
+ if obj in self.stack and not self.source == '*':
+ raise RecursionOccured()
+ self.stack.append(obj)
+
+ ret = self._dict_class()
+ ret.fields = {}
+
+ fields = self.get_fields(serialize=True, obj=obj, nested=self.opts.nested)
+ for field_name, field in fields.items():
+ key = self.get_field_key(field_name)
+ try:
+ value = field.field_to_native(obj, field_name)
+ except RecursionOccured:
+ field = self.get_fields(serialize=True, obj=obj, nested=False)[field_name]
+ value = field.field_to_native(obj, field_name)
+ ret[key] = value
+ ret.fields[key] = field
+ return ret
+
+ def restore_fields(self, data):
+ """
+ Core of deserialization, together with `restore_object`.
+ Converts a dictionary of data into a dictionary of deserialized fields.
+ """
+ fields = self.get_fields(serialize=False, data=data, nested=self.opts.nested)
+ reverted_data = {}
+ for field_name, field in fields.items():
+ try:
+ field.field_from_native(data, field_name, reverted_data)
+ except ValidationError as err:
+ self._errors[field_name] = list(err.messages)
+
+ return reverted_data
+
+ def restore_object(self, attrs, instance=None):
+ """
+ Deserialize a dictionary of attributes into an object instance.
+ You should override this method to control how deserialized objects
+ are instantiated.
+ """
+ if instance is not None:
+ instance.update(attrs)
+ return instance
+ return attrs
+
+ def to_native(self, obj):
+ """
+ Serialize objects -> primatives.
+ """
+ if isinstance(obj, dict):
+ return dict([(key, self.to_native(val))
+ for (key, val) in obj.items()])
+ elif hasattr(obj, '__iter__'):
+ return (self.to_native(item) for item in obj)
+ return self.convert_object(obj)
+
+ def from_native(self, data):
+ """
+ Deserialize primatives -> objects.
+ """
+ if hasattr(data, '__iter__') and not isinstance(data, dict):
+ # TODO: error data when deserializing lists
+ return (self.from_native(item) for item in data)
+ self._errors = {}
+ attrs = self.restore_fields(data)
+ if not self._errors:
+ return self.restore_object(attrs, instance=getattr(self, 'instance', None))
+
+ @property
+ def errors(self):
+ """
+ Run deserialization and return error data,
+ setting self.object if no errors occured.
+ """
+ if self._errors is None:
+ obj = self.from_native(self.init_data)
+ if not self._errors:
+ self.object = obj
+ return self._errors
+
+ def is_valid(self):
+ return not self.errors
+
+ @property
+ def data(self):
+ if self._data is None:
+ self._data = self.to_native(self.instance)
+ return self._data
+
+
+class Serializer(BaseSerializer):
+ __metaclass__ = SerializerMetaclass
+
+
+class ModelSerializerOptions(SerializerOptions):
+ """
+ Meta class options for ModelSerializer
+ """
+ def __init__(self, meta):
+ super(ModelSerializerOptions, self).__init__(meta)
+ self.model = getattr(meta, 'model', None)
+
+
+class ModelSerializer(RelatedField, Serializer):
+ """
+ A serializer that deals with model instances and querysets.
+ """
+ _options_class = ModelSerializerOptions
+
+ def default_fields(self, serialize, obj=None, data=None, nested=False):
+ """
+ Return all the fields that should be serialized for the model.
+ """
+ if serialize:
+ cls = obj.__class__
+ else:
+ cls = self.opts.model
+
+ opts = cls._meta.concrete_model._meta
+ pk_field = opts.pk
+ while pk_field.rel:
+ pk_field = pk_field.rel.to._meta.pk
+ fields = [pk_field]
+ fields += [field for field in opts.fields if field.serialize]
+ fields += [field for field in opts.many_to_many if field.serialize]
+
+ ret = SortedDict()
+ for model_field in fields:
+ if model_field.rel and nested:
+ field = self.get_nested_field(model_field)
+ elif model_field.rel:
+ field = self.get_related_field(model_field)
+ else:
+ field = self.get_field(model_field)
+ field.initialize(parent=self, model_field=model_field)
+ ret[model_field.name] = field
+ return ret
+
+ def get_nested_field(self, model_field):
+ """
+ Creates a default instance of a nested relational field.
+ """
+ return ModelSerializer()
+
+ def get_related_field(self, model_field):
+ """
+ Creates a default instance of a flat relational field.
+ """
+ return PrimaryKeyRelatedField()
+
+ def get_field(self, model_field):
+ """
+ Creates a default instance of a basic field.
+ """
+ return Field()
+
+ def restore_object(self, attrs, instance=None):
+ """
+ Restore the model instance.
+ """
+ m2m_data = {}
+ for field in self.opts.model._meta.many_to_many:
+ if field.name in attrs:
+ m2m_data[field.name] = attrs.pop(field.name)
+ return DeserializedObject(self.opts.model(**attrs), m2m_data)
diff --git a/djangorestframework/tests/parsers.py b/djangorestframework/tests/parsers.py
index c733d9d0..a85409dc 100644
--- a/djangorestframework/tests/parsers.py
+++ b/djangorestframework/tests/parsers.py
@@ -153,7 +153,7 @@ class TestFormParser(TestCase):
parser = FormParser()
stream = StringIO(self.string)
- (data, files) = parser.parse(stream, {}, [])
+ (data, files) = parser.parse(stream)
self.assertEqual(Form(data).is_valid(), True)
@@ -203,10 +203,10 @@ class TestXMLParser(TestCase):
def test_parse(self):
parser = XMLParser()
- (data, files) = parser.parse(self._input, {}, [])
+ (data, files) = parser.parse(self._input)
self.assertEqual(data, self._data)
def test_complex_data_parse(self):
parser = XMLParser()
- (data, files) = parser.parse(self._complex_data_input, {}, [])
+ (data, files) = parser.parse(self._complex_data_input)
self.assertEqual(data, self._complex_data)
diff --git a/djangorestframework/tests/renderers.py b/djangorestframework/tests/renderers.py
index 610457c7..1943d012 100644
--- a/djangorestframework/tests/renderers.py
+++ b/djangorestframework/tests/renderers.py
@@ -380,7 +380,7 @@ class XMLRendererTestCase(TestCase):
content = StringIO(renderer.render(self._complex_data, 'application/xml'))
parser = XMLParser()
- complex_data_out, dummy = parser.parse(content, {}, [])
+ complex_data_out, dummy = parser.parse(content)
error_msg = "complex data differs!IN:\n %s \n\n OUT:\n %s" % (repr(self._complex_data), repr(complex_data_out))
self.assertEqual(self._complex_data, complex_data_out, error_msg)