aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/serializers.py
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework/serializers.py')
-rw-r--r--rest_framework/serializers.py320
1 files changed, 237 insertions, 83 deletions
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index d83367f4..af8aeb48 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -11,6 +11,7 @@ python primitives.
response content is handled by parsers and renderers.
"""
from django.core.exceptions import ImproperlyConfigured
+from django.core.exceptions import ValidationError as DjangoValidationError
from django.db import models
from django.db.models.fields import FieldDoesNotExist
from django.utils import six
@@ -46,6 +47,9 @@ import warnings
from rest_framework.relations import * # NOQA
from rest_framework.fields import * # NOQA
+
+# We assume that 'validators' are intended for the child serializer,
+# rather than the parent serializer.
LIST_SERIALIZER_KWARGS = (
'read_only', 'write_only', 'required', 'default', 'initial', 'source',
'label', 'help_text', 'style', 'error_messages',
@@ -73,13 +77,36 @@ class BaseSerializer(Field):
# We override this method in order to automagically create
# `ListSerializer` classes instead when `many=True` is set.
if kwargs.pop('many', False):
- list_kwargs = {'child': cls(*args, **kwargs)}
- for key in kwargs.keys():
- if key in LIST_SERIALIZER_KWARGS:
- list_kwargs[key] = kwargs[key]
- return ListSerializer(*args, **list_kwargs)
+ return cls.many_init(*args, **kwargs)
return super(BaseSerializer, cls).__new__(cls, *args, **kwargs)
+ @classmethod
+ def many_init(cls, *args, **kwargs):
+ """
+ This method implements the creation of a `ListSerializer` parent
+ class when `many=True` is used. You can customize it if you need to
+ control which keyword arguments are passed to the parent, and
+ which are passed to the child.
+
+ Note that we're over-cautious in passing most arguments to both parent
+ and child classes in order to try to cover the general case. If you're
+ overriding this method you'll probably want something much simpler, eg:
+
+ @classmethod
+ def many_init(cls, *args, **kwargs):
+ kwargs['child'] = cls()
+ return CustomListSerializer(*args, **kwargs)
+ """
+ child_serializer = cls(*args, **kwargs)
+ list_kwargs = {'child': child_serializer}
+ list_kwargs.update(dict([
+ (key, value) for key, value in kwargs.items()
+ if key in LIST_SERIALIZER_KWARGS
+ ]))
+ meta = getattr(cls, 'Meta', None)
+ list_serializer_class = getattr(meta, 'list_serializer_class', ListSerializer)
+ return list_serializer_class(*args, **list_kwargs)
+
def to_internal_value(self, data):
raise NotImplementedError('`to_internal_value()` must be implemented.')
@@ -93,6 +120,21 @@ class BaseSerializer(Field):
raise NotImplementedError('`create()` must be implemented.')
def save(self, **kwargs):
+ assert not hasattr(self, 'save_object'), (
+ 'Serializer `%s.%s` has old-style version 2 `.save_object()` '
+ 'that is no longer compatible with REST framework 3. '
+ 'Use the new-style `.create()` and `.update()` methods instead.' %
+ (self.__class__.__module__, self.__class__.__name__)
+ )
+
+ assert hasattr(self, '_errors'), (
+ 'You must call `.is_valid()` before calling `.save()`.'
+ )
+
+ assert not self.errors, (
+ 'You cannot call `.save()` on a serializer with invalid data.'
+ )
+
validated_data = dict(
list(self.validated_data.items()) +
list(kwargs.items())
@@ -230,18 +272,18 @@ class Serializer(BaseSerializer):
def get_initial(self):
if self._initial_data is not None:
- return ReturnDict([
+ return OrderedDict([
(field_name, field.get_value(self._initial_data))
for field_name, field in self.fields.items()
if field.get_value(self._initial_data) is not empty
and not field.read_only
- ], serializer=self)
+ ])
- return ReturnDict([
+ return OrderedDict([
(field.field_name, field.get_initial())
for field in self.fields.values()
if not field.read_only
- ], serializer=self)
+ ])
def get_value(self, dictionary):
# We override the default field access in order to support
@@ -297,6 +339,14 @@ class Serializer(BaseSerializer):
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [exc.detail]
})
+ except DjangoValidationError as exc:
+ # Normally you should raise `serializers.ValidationError`
+ # inside your codebase, but we handle Django's validation
+ # exception class as well for simpler compat.
+ # Eg. Calling Model.clean() explictily inside Serializer.validate()
+ raise ValidationError({
+ api_settings.NON_FIELD_ERRORS_KEY: list(exc.messages)
+ })
return value
@@ -304,8 +354,8 @@ class Serializer(BaseSerializer):
"""
Dict of native values <- Dict of primitive datatypes.
"""
- ret = {}
- errors = ReturnDict(serializer=self)
+ ret = OrderedDict()
+ errors = OrderedDict()
fields = [
field for field in self.fields.values()
if (not field.read_only) or (field.default is not empty)
@@ -320,6 +370,8 @@ class Serializer(BaseSerializer):
validated_value = validate_method(validated_value)
except ValidationError as exc:
errors[field.field_name] = exc.detail
+ except DjangoValidationError as exc:
+ errors[field.field_name] = list(exc.messages)
except SkipField:
pass
else:
@@ -334,20 +386,15 @@ class Serializer(BaseSerializer):
"""
Object instance -> Dict of primitive datatypes.
"""
- ret = ReturnDict(serializer=self)
+ ret = OrderedDict()
fields = [field for field in self.fields.values() if not field.write_only]
for field in fields:
attribute = field.get_attribute(instance)
if attribute is None:
- value = None
+ ret[field.field_name] = None
else:
- value = field.to_representation(attribute)
- transform_method = getattr(self, 'transform_' + field.field_name, None)
- if transform_method is not None:
- value = transform_method(value)
-
- ret[field.field_name] = value
+ ret[field.field_name] = field.to_representation(attribute)
return ret
@@ -373,6 +420,19 @@ class Serializer(BaseSerializer):
return NestedBoundField(field, value, error)
return BoundField(field, value, error)
+ # Include a backlink to the serializer class on return objects.
+ # Allows renderers such as HTMLFormRenderer to get the full field info.
+
+ @property
+ def data(self):
+ ret = super(Serializer, self).data
+ return ReturnDict(ret, serializer=self)
+
+ @property
+ def errors(self):
+ ret = super(Serializer, self).errors
+ return ReturnDict(ret, serializer=self)
+
# There's some replication of `ListField` here,
# but that's probably better than obfuscating the call hierarchy.
@@ -395,7 +455,7 @@ class ListSerializer(BaseSerializer):
def get_initial(self):
if self._initial_data is not None:
return self.to_representation(self._initial_data)
- return ReturnList(serializer=self)
+ return []
def get_value(self, dictionary):
"""
@@ -423,7 +483,7 @@ class ListSerializer(BaseSerializer):
})
ret = []
- errors = ReturnList(serializer=self)
+ errors = []
for item in data:
try:
@@ -444,37 +504,64 @@ class ListSerializer(BaseSerializer):
List of object instances -> List of dicts of primitive datatypes.
"""
iterable = data.all() if (hasattr(data, 'all')) else data
- return ReturnList(
- [self.child.to_representation(item) for item in iterable],
- serializer=self
+ return [
+ self.child.to_representation(item) for item in iterable
+ ]
+
+ def update(self, instance, validated_data):
+ raise NotImplementedError(
+ "Serializers with many=True do not support multiple update by "
+ "default, only multiple create. For updates it is unclear how to "
+ "deal with insertions and deletions. If you need to support "
+ "multiple update, use a `ListSerializer` class and override "
+ "`.update()` so you can specify the behavior exactly."
)
+ def create(self, validated_data):
+ return [
+ self.child.create(attrs) for attrs in validated_data
+ ]
+
def save(self, **kwargs):
"""
Save and return a list of object instances.
"""
- assert self.instance is None, (
- "Serializers do not support multiple update by default, only "
- "multiple create. For updates it is unclear how to deal with "
- "insertions and deletions. If you need to support multiple update, "
- "use a `ListSerializer` class and override `.save()` so you can "
- "specify the behavior exactly."
- )
-
validated_data = [
dict(list(attrs.items()) + list(kwargs.items()))
for attrs in self.validated_data
]
- self.instance = [
- self.child.create(attrs) for attrs in validated_data
- ]
+ if self.instance is not None:
+ self.instance = self.update(self.instance, validated_data)
+ assert self.instance is not None, (
+ '`update()` did not return an object instance.'
+ )
+ else:
+ self.instance = self.create(validated_data)
+ assert self.instance is not None, (
+ '`create()` did not return an object instance.'
+ )
return self.instance
def __repr__(self):
return representation.list_repr(self, indent=1)
+ # Include a backlink to the serializer class on return objects.
+ # Allows renderers such as HTMLFormRenderer to get the full field info.
+
+ @property
+ def data(self):
+ ret = super(ListSerializer, self).data
+ return ReturnList(ret, serializer=self)
+
+ @property
+ def errors(self):
+ ret = super(ListSerializer, self).errors
+ if isinstance(ret, dict):
+ return ReturnDict(ret, serializer=self)
+ return ReturnList(ret, serializer=self)
+
# ModelSerializer & HyperlinkedModelSerializer
# --------------------------------------------
@@ -486,6 +573,14 @@ class ModelSerializer(Serializer):
* A set of default fields are automatically populated.
* A set of default validators are automatically populated.
* Default `.create()` and `.update()` implementations are provided.
+
+ The process of automatically determining a set of serializer fields
+ based on the model fields is reasonably complex, but you almost certainly
+ don't need to dig into the implemention.
+
+ If the `ModelSerializer` class *doesn't* generate the set of fields that
+ you need you should either declare the extra/differing fields explicitly on
+ the serializer class, or simply use a `Serializer` class.
"""
_field_mapping = ClassLookupDict({
models.AutoField: IntegerField,
@@ -513,13 +608,33 @@ class ModelSerializer(Serializer):
})
_related_class = PrimaryKeyRelatedField
- def create(self, validated_attrs):
+ def create(self, validated_data):
+ """
+ We have a bit of extra checking around this in order to provide
+ descriptive messages when something goes wrong, but this method is
+ essentially just:
+
+ return ExampleModel.objects.create(**validated_data)
+
+ If there are many to many fields present on the instance then they
+ cannot be set until the model is instantiated, in which case the
+ implementation is like so:
+
+ example_relationship = validated_data.pop('example_relationship')
+ instance = ExampleModel.objects.create(**validated_data)
+ instance.example_relationship = example_relationship
+ return instance
+
+ The default implementation also does not handle nested relationships.
+ If you want to support writable nested relationships you'll need
+ to write an explicit `.create()` method.
+ """
# Check that the user isn't trying to handle a writable nested field.
# If we don't do this explicitly they'd likely get a confusing
# error at the point of calling `Model.objects.create()`.
assert not any(
- isinstance(field, BaseSerializer) and not field.read_only
- for field in self.fields.values()
+ isinstance(field, BaseSerializer) and (key in validated_attrs)
+ for key, field in self.fields.items()
), (
'The `.create()` method does not suport nested writable fields '
'by default. Write an explicit `.create()` method for serializer '
@@ -529,16 +644,33 @@ class ModelSerializer(Serializer):
ModelClass = self.Meta.model
- # Remove many-to-many relationships from validated_attrs.
+ # Remove many-to-many relationships from validated_data.
# They are not valid arguments to the default `.create()` method,
# as they require that the instance has already been saved.
info = model_meta.get_field_info(ModelClass)
many_to_many = {}
for field_name, relation_info in info.relations.items():
- if relation_info.to_many and (field_name in validated_attrs):
- many_to_many[field_name] = validated_attrs.pop(field_name)
+ if relation_info.to_many and (field_name in validated_data):
+ many_to_many[field_name] = validated_data.pop(field_name)
- instance = ModelClass.objects.create(**validated_attrs)
+ try:
+ instance = ModelClass.objects.create(**validated_data)
+ except TypeError as exc:
+ msg = (
+ 'Got a `TypeError` when calling `%s.objects.create()`. '
+ 'This may be because you have a writable field on the '
+ 'serializer class that is not a valid argument to '
+ '`%s.objects.create()`. You may need to make the field '
+ 'read-only, or override the %s.create() method to handle '
+ 'this correctly.\nOriginal exception text was: %s.' %
+ (
+ ModelClass.__name__,
+ ModelClass.__name__,
+ self.__class__.__name__,
+ exc
+ )
+ )
+ raise TypeError(msg)
# Save many-to-many relationships after the instance is created.
if many_to_many:
@@ -547,10 +679,10 @@ class ModelSerializer(Serializer):
return instance
- def update(self, instance, validated_attrs):
+ def update(self, instance, validated_data):
assert not any(
- isinstance(field, BaseSerializer) and not field.read_only
- for field in self.fields.values()
+ isinstance(field, BaseSerializer) and (key in validated_attrs)
+ for key, field in self.fields.items()
), (
'The `.update()` method does not suport nested writable fields '
'by default. Write an explicit `.update()` method for serializer '
@@ -558,20 +690,25 @@ class ModelSerializer(Serializer):
(self.__class__.__module__, self.__class__.__name__)
)
- for attr, value in validated_attrs.items():
+ for attr, value in validated_data.items():
setattr(instance, attr, value)
instance.save()
return instance
def get_validators(self):
+ # If the validators have been declared explicitly then use that.
+ validators = getattr(getattr(self, 'Meta', None), 'validators', None)
+ if validators is not None:
+ return validators
+
+ # Determine the default set of validators.
+ validators = []
+ model_class = self.Meta.model
field_names = set([
field.source for field in self.fields.values()
if (field.source != '*') and ('.' not in field.source)
])
- validators = getattr(getattr(self, 'Meta', None), 'validators', [])
- model_class = self.Meta.model
-
# Note that we make sure to check `unique_together` both on the
# base model class, but also on any parent classes.
for parent_class in [model_class] + list(model_class._meta.parents.keys()):
@@ -658,49 +795,62 @@ class ModelSerializer(Serializer):
# Determine if we need any additional `HiddenField` or extra keyword
# arguments to deal with `unique_for` dates that are required to
# be in the input data in order to validate it.
- unique_fields = {}
+ hidden_fields = {}
+ unique_constraint_names = set()
+
for model_field_name, field_name in model_field_mapping.items():
try:
model_field = model._meta.get_field(model_field_name)
except FieldDoesNotExist:
continue
- # Deal with each of the `unique_for_*` cases.
- for date_field_name in (
+ # Include each of the `unique_for_*` field names.
+ unique_constraint_names |= set([
model_field.unique_for_date,
model_field.unique_for_month,
model_field.unique_for_year
- ):
- if date_field_name is None:
- continue
-
- # Get the model field that is refered too.
- date_field = model._meta.get_field(date_field_name)
-
- if date_field.auto_now_add:
- default = CreateOnlyDefault(timezone.now)
- elif date_field.auto_now:
- default = timezone.now
- elif date_field.has_default():
- default = model_field.default
- else:
- default = empty
-
- if date_field_name in model_field_mapping:
- # The corresponding date field is present in the serializer
- if date_field_name not in extra_kwargs:
- extra_kwargs[date_field_name] = {}
- if default is empty:
- if 'required' not in extra_kwargs[date_field_name]:
- extra_kwargs[date_field_name]['required'] = True
- else:
- if 'default' not in extra_kwargs[date_field_name]:
- extra_kwargs[date_field_name]['default'] = default
+ ])
+
+ unique_constraint_names -= set([None])
+
+ # Include each of the `unique_together` field names,
+ # so long as all the field names are included on the serializer.
+ for parent_class in [model] + list(model._meta.parents.keys()):
+ for unique_together_list in parent_class._meta.unique_together:
+ if set(fields).issuperset(set(unique_together_list)):
+ unique_constraint_names |= set(unique_together_list)
+
+ # Now we have all the field names that have uniqueness constraints
+ # applied, we can add the extra 'required=...' or 'default=...'
+ # arguments that are appropriate to these fields, or add a `HiddenField` for it.
+ for unique_constraint_name in unique_constraint_names:
+ # Get the model field that is refered too.
+ unique_constraint_field = model._meta.get_field(unique_constraint_name)
+
+ if getattr(unique_constraint_field, 'auto_now_add', None):
+ default = CreateOnlyDefault(timezone.now)
+ elif getattr(unique_constraint_field, 'auto_now', None):
+ default = timezone.now
+ elif unique_constraint_field.has_default():
+ default = unique_constraint_field.default
+ else:
+ default = empty
+
+ if unique_constraint_name in model_field_mapping:
+ # The corresponding field is present in the serializer
+ if unique_constraint_name not in extra_kwargs:
+ extra_kwargs[unique_constraint_name] = {}
+ if default is empty:
+ if 'required' not in extra_kwargs[unique_constraint_name]:
+ extra_kwargs[unique_constraint_name]['required'] = True
else:
- # The corresponding date field is not present in the,
- # serializer. We have a default to use for the date, so
- # add in a hidden field that populates it.
- unique_fields[date_field_name] = HiddenField(default=default)
+ if 'default' not in extra_kwargs[unique_constraint_name]:
+ extra_kwargs[unique_constraint_name]['default'] = default
+ elif default is not empty:
+ # The corresponding field is not present in the,
+ # serializer. We have a default to use for it, so
+ # add in a hidden field that populates it.
+ hidden_fields[unique_constraint_name] = HiddenField(default=default)
# Now determine the fields that should be included on the serializer.
for field_name in fields:
@@ -776,12 +926,16 @@ class ModelSerializer(Serializer):
'validators', 'queryset'
]:
kwargs.pop(attr, None)
+
+ if extras.get('default') and kwargs.get('required') is False:
+ kwargs.pop('required')
+
kwargs.update(extras)
# Create the serializer field.
ret[field_name] = field_cls(**kwargs)
- for field_name, field in unique_fields.items():
+ for field_name, field in hidden_fields.items():
ret[field_name] = field
return ret