diff options
Diffstat (limited to 'rest_framework/serializers.py')
| -rw-r--r-- | rest_framework/serializers.py | 320 |
1 files changed, 237 insertions, 83 deletions
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index d83367f4..af8aeb48 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -11,6 +11,7 @@ python primitives. response content is handled by parsers and renderers. """ from django.core.exceptions import ImproperlyConfigured +from django.core.exceptions import ValidationError as DjangoValidationError from django.db import models from django.db.models.fields import FieldDoesNotExist from django.utils import six @@ -46,6 +47,9 @@ import warnings from rest_framework.relations import * # NOQA from rest_framework.fields import * # NOQA + +# We assume that 'validators' are intended for the child serializer, +# rather than the parent serializer. LIST_SERIALIZER_KWARGS = ( 'read_only', 'write_only', 'required', 'default', 'initial', 'source', 'label', 'help_text', 'style', 'error_messages', @@ -73,13 +77,36 @@ class BaseSerializer(Field): # We override this method in order to automagically create # `ListSerializer` classes instead when `many=True` is set. if kwargs.pop('many', False): - list_kwargs = {'child': cls(*args, **kwargs)} - for key in kwargs.keys(): - if key in LIST_SERIALIZER_KWARGS: - list_kwargs[key] = kwargs[key] - return ListSerializer(*args, **list_kwargs) + return cls.many_init(*args, **kwargs) return super(BaseSerializer, cls).__new__(cls, *args, **kwargs) + @classmethod + def many_init(cls, *args, **kwargs): + """ + This method implements the creation of a `ListSerializer` parent + class when `many=True` is used. You can customize it if you need to + control which keyword arguments are passed to the parent, and + which are passed to the child. + + Note that we're over-cautious in passing most arguments to both parent + and child classes in order to try to cover the general case. If you're + overriding this method you'll probably want something much simpler, eg: + + @classmethod + def many_init(cls, *args, **kwargs): + kwargs['child'] = cls() + return CustomListSerializer(*args, **kwargs) + """ + child_serializer = cls(*args, **kwargs) + list_kwargs = {'child': child_serializer} + list_kwargs.update(dict([ + (key, value) for key, value in kwargs.items() + if key in LIST_SERIALIZER_KWARGS + ])) + meta = getattr(cls, 'Meta', None) + list_serializer_class = getattr(meta, 'list_serializer_class', ListSerializer) + return list_serializer_class(*args, **list_kwargs) + def to_internal_value(self, data): raise NotImplementedError('`to_internal_value()` must be implemented.') @@ -93,6 +120,21 @@ class BaseSerializer(Field): raise NotImplementedError('`create()` must be implemented.') def save(self, **kwargs): + assert not hasattr(self, 'save_object'), ( + 'Serializer `%s.%s` has old-style version 2 `.save_object()` ' + 'that is no longer compatible with REST framework 3. ' + 'Use the new-style `.create()` and `.update()` methods instead.' % + (self.__class__.__module__, self.__class__.__name__) + ) + + assert hasattr(self, '_errors'), ( + 'You must call `.is_valid()` before calling `.save()`.' + ) + + assert not self.errors, ( + 'You cannot call `.save()` on a serializer with invalid data.' + ) + validated_data = dict( list(self.validated_data.items()) + list(kwargs.items()) @@ -230,18 +272,18 @@ class Serializer(BaseSerializer): def get_initial(self): if self._initial_data is not None: - return ReturnDict([ + return OrderedDict([ (field_name, field.get_value(self._initial_data)) for field_name, field in self.fields.items() if field.get_value(self._initial_data) is not empty and not field.read_only - ], serializer=self) + ]) - return ReturnDict([ + return OrderedDict([ (field.field_name, field.get_initial()) for field in self.fields.values() if not field.read_only - ], serializer=self) + ]) def get_value(self, dictionary): # We override the default field access in order to support @@ -297,6 +339,14 @@ class Serializer(BaseSerializer): raise ValidationError({ api_settings.NON_FIELD_ERRORS_KEY: [exc.detail] }) + except DjangoValidationError as exc: + # Normally you should raise `serializers.ValidationError` + # inside your codebase, but we handle Django's validation + # exception class as well for simpler compat. + # Eg. Calling Model.clean() explictily inside Serializer.validate() + raise ValidationError({ + api_settings.NON_FIELD_ERRORS_KEY: list(exc.messages) + }) return value @@ -304,8 +354,8 @@ class Serializer(BaseSerializer): """ Dict of native values <- Dict of primitive datatypes. """ - ret = {} - errors = ReturnDict(serializer=self) + ret = OrderedDict() + errors = OrderedDict() fields = [ field for field in self.fields.values() if (not field.read_only) or (field.default is not empty) @@ -320,6 +370,8 @@ class Serializer(BaseSerializer): validated_value = validate_method(validated_value) except ValidationError as exc: errors[field.field_name] = exc.detail + except DjangoValidationError as exc: + errors[field.field_name] = list(exc.messages) except SkipField: pass else: @@ -334,20 +386,15 @@ class Serializer(BaseSerializer): """ Object instance -> Dict of primitive datatypes. """ - ret = ReturnDict(serializer=self) + ret = OrderedDict() fields = [field for field in self.fields.values() if not field.write_only] for field in fields: attribute = field.get_attribute(instance) if attribute is None: - value = None + ret[field.field_name] = None else: - value = field.to_representation(attribute) - transform_method = getattr(self, 'transform_' + field.field_name, None) - if transform_method is not None: - value = transform_method(value) - - ret[field.field_name] = value + ret[field.field_name] = field.to_representation(attribute) return ret @@ -373,6 +420,19 @@ class Serializer(BaseSerializer): return NestedBoundField(field, value, error) return BoundField(field, value, error) + # Include a backlink to the serializer class on return objects. + # Allows renderers such as HTMLFormRenderer to get the full field info. + + @property + def data(self): + ret = super(Serializer, self).data + return ReturnDict(ret, serializer=self) + + @property + def errors(self): + ret = super(Serializer, self).errors + return ReturnDict(ret, serializer=self) + # There's some replication of `ListField` here, # but that's probably better than obfuscating the call hierarchy. @@ -395,7 +455,7 @@ class ListSerializer(BaseSerializer): def get_initial(self): if self._initial_data is not None: return self.to_representation(self._initial_data) - return ReturnList(serializer=self) + return [] def get_value(self, dictionary): """ @@ -423,7 +483,7 @@ class ListSerializer(BaseSerializer): }) ret = [] - errors = ReturnList(serializer=self) + errors = [] for item in data: try: @@ -444,37 +504,64 @@ class ListSerializer(BaseSerializer): List of object instances -> List of dicts of primitive datatypes. """ iterable = data.all() if (hasattr(data, 'all')) else data - return ReturnList( - [self.child.to_representation(item) for item in iterable], - serializer=self + return [ + self.child.to_representation(item) for item in iterable + ] + + def update(self, instance, validated_data): + raise NotImplementedError( + "Serializers with many=True do not support multiple update by " + "default, only multiple create. For updates it is unclear how to " + "deal with insertions and deletions. If you need to support " + "multiple update, use a `ListSerializer` class and override " + "`.update()` so you can specify the behavior exactly." ) + def create(self, validated_data): + return [ + self.child.create(attrs) for attrs in validated_data + ] + def save(self, **kwargs): """ Save and return a list of object instances. """ - assert self.instance is None, ( - "Serializers do not support multiple update by default, only " - "multiple create. For updates it is unclear how to deal with " - "insertions and deletions. If you need to support multiple update, " - "use a `ListSerializer` class and override `.save()` so you can " - "specify the behavior exactly." - ) - validated_data = [ dict(list(attrs.items()) + list(kwargs.items())) for attrs in self.validated_data ] - self.instance = [ - self.child.create(attrs) for attrs in validated_data - ] + if self.instance is not None: + self.instance = self.update(self.instance, validated_data) + assert self.instance is not None, ( + '`update()` did not return an object instance.' + ) + else: + self.instance = self.create(validated_data) + assert self.instance is not None, ( + '`create()` did not return an object instance.' + ) return self.instance def __repr__(self): return representation.list_repr(self, indent=1) + # Include a backlink to the serializer class on return objects. + # Allows renderers such as HTMLFormRenderer to get the full field info. + + @property + def data(self): + ret = super(ListSerializer, self).data + return ReturnList(ret, serializer=self) + + @property + def errors(self): + ret = super(ListSerializer, self).errors + if isinstance(ret, dict): + return ReturnDict(ret, serializer=self) + return ReturnList(ret, serializer=self) + # ModelSerializer & HyperlinkedModelSerializer # -------------------------------------------- @@ -486,6 +573,14 @@ class ModelSerializer(Serializer): * A set of default fields are automatically populated. * A set of default validators are automatically populated. * Default `.create()` and `.update()` implementations are provided. + + The process of automatically determining a set of serializer fields + based on the model fields is reasonably complex, but you almost certainly + don't need to dig into the implemention. + + If the `ModelSerializer` class *doesn't* generate the set of fields that + you need you should either declare the extra/differing fields explicitly on + the serializer class, or simply use a `Serializer` class. """ _field_mapping = ClassLookupDict({ models.AutoField: IntegerField, @@ -513,13 +608,33 @@ class ModelSerializer(Serializer): }) _related_class = PrimaryKeyRelatedField - def create(self, validated_attrs): + def create(self, validated_data): + """ + We have a bit of extra checking around this in order to provide + descriptive messages when something goes wrong, but this method is + essentially just: + + return ExampleModel.objects.create(**validated_data) + + If there are many to many fields present on the instance then they + cannot be set until the model is instantiated, in which case the + implementation is like so: + + example_relationship = validated_data.pop('example_relationship') + instance = ExampleModel.objects.create(**validated_data) + instance.example_relationship = example_relationship + return instance + + The default implementation also does not handle nested relationships. + If you want to support writable nested relationships you'll need + to write an explicit `.create()` method. + """ # Check that the user isn't trying to handle a writable nested field. # If we don't do this explicitly they'd likely get a confusing # error at the point of calling `Model.objects.create()`. assert not any( - isinstance(field, BaseSerializer) and not field.read_only - for field in self.fields.values() + isinstance(field, BaseSerializer) and (key in validated_attrs) + for key, field in self.fields.items() ), ( 'The `.create()` method does not suport nested writable fields ' 'by default. Write an explicit `.create()` method for serializer ' @@ -529,16 +644,33 @@ class ModelSerializer(Serializer): ModelClass = self.Meta.model - # Remove many-to-many relationships from validated_attrs. + # Remove many-to-many relationships from validated_data. # They are not valid arguments to the default `.create()` method, # as they require that the instance has already been saved. info = model_meta.get_field_info(ModelClass) many_to_many = {} for field_name, relation_info in info.relations.items(): - if relation_info.to_many and (field_name in validated_attrs): - many_to_many[field_name] = validated_attrs.pop(field_name) + if relation_info.to_many and (field_name in validated_data): + many_to_many[field_name] = validated_data.pop(field_name) - instance = ModelClass.objects.create(**validated_attrs) + try: + instance = ModelClass.objects.create(**validated_data) + except TypeError as exc: + msg = ( + 'Got a `TypeError` when calling `%s.objects.create()`. ' + 'This may be because you have a writable field on the ' + 'serializer class that is not a valid argument to ' + '`%s.objects.create()`. You may need to make the field ' + 'read-only, or override the %s.create() method to handle ' + 'this correctly.\nOriginal exception text was: %s.' % + ( + ModelClass.__name__, + ModelClass.__name__, + self.__class__.__name__, + exc + ) + ) + raise TypeError(msg) # Save many-to-many relationships after the instance is created. if many_to_many: @@ -547,10 +679,10 @@ class ModelSerializer(Serializer): return instance - def update(self, instance, validated_attrs): + def update(self, instance, validated_data): assert not any( - isinstance(field, BaseSerializer) and not field.read_only - for field in self.fields.values() + isinstance(field, BaseSerializer) and (key in validated_attrs) + for key, field in self.fields.items() ), ( 'The `.update()` method does not suport nested writable fields ' 'by default. Write an explicit `.update()` method for serializer ' @@ -558,20 +690,25 @@ class ModelSerializer(Serializer): (self.__class__.__module__, self.__class__.__name__) ) - for attr, value in validated_attrs.items(): + for attr, value in validated_data.items(): setattr(instance, attr, value) instance.save() return instance def get_validators(self): + # If the validators have been declared explicitly then use that. + validators = getattr(getattr(self, 'Meta', None), 'validators', None) + if validators is not None: + return validators + + # Determine the default set of validators. + validators = [] + model_class = self.Meta.model field_names = set([ field.source for field in self.fields.values() if (field.source != '*') and ('.' not in field.source) ]) - validators = getattr(getattr(self, 'Meta', None), 'validators', []) - model_class = self.Meta.model - # Note that we make sure to check `unique_together` both on the # base model class, but also on any parent classes. for parent_class in [model_class] + list(model_class._meta.parents.keys()): @@ -658,49 +795,62 @@ class ModelSerializer(Serializer): # Determine if we need any additional `HiddenField` or extra keyword # arguments to deal with `unique_for` dates that are required to # be in the input data in order to validate it. - unique_fields = {} + hidden_fields = {} + unique_constraint_names = set() + for model_field_name, field_name in model_field_mapping.items(): try: model_field = model._meta.get_field(model_field_name) except FieldDoesNotExist: continue - # Deal with each of the `unique_for_*` cases. - for date_field_name in ( + # Include each of the `unique_for_*` field names. + unique_constraint_names |= set([ model_field.unique_for_date, model_field.unique_for_month, model_field.unique_for_year - ): - if date_field_name is None: - continue - - # Get the model field that is refered too. - date_field = model._meta.get_field(date_field_name) - - if date_field.auto_now_add: - default = CreateOnlyDefault(timezone.now) - elif date_field.auto_now: - default = timezone.now - elif date_field.has_default(): - default = model_field.default - else: - default = empty - - if date_field_name in model_field_mapping: - # The corresponding date field is present in the serializer - if date_field_name not in extra_kwargs: - extra_kwargs[date_field_name] = {} - if default is empty: - if 'required' not in extra_kwargs[date_field_name]: - extra_kwargs[date_field_name]['required'] = True - else: - if 'default' not in extra_kwargs[date_field_name]: - extra_kwargs[date_field_name]['default'] = default + ]) + + unique_constraint_names -= set([None]) + + # Include each of the `unique_together` field names, + # so long as all the field names are included on the serializer. + for parent_class in [model] + list(model._meta.parents.keys()): + for unique_together_list in parent_class._meta.unique_together: + if set(fields).issuperset(set(unique_together_list)): + unique_constraint_names |= set(unique_together_list) + + # Now we have all the field names that have uniqueness constraints + # applied, we can add the extra 'required=...' or 'default=...' + # arguments that are appropriate to these fields, or add a `HiddenField` for it. + for unique_constraint_name in unique_constraint_names: + # Get the model field that is refered too. + unique_constraint_field = model._meta.get_field(unique_constraint_name) + + if getattr(unique_constraint_field, 'auto_now_add', None): + default = CreateOnlyDefault(timezone.now) + elif getattr(unique_constraint_field, 'auto_now', None): + default = timezone.now + elif unique_constraint_field.has_default(): + default = unique_constraint_field.default + else: + default = empty + + if unique_constraint_name in model_field_mapping: + # The corresponding field is present in the serializer + if unique_constraint_name not in extra_kwargs: + extra_kwargs[unique_constraint_name] = {} + if default is empty: + if 'required' not in extra_kwargs[unique_constraint_name]: + extra_kwargs[unique_constraint_name]['required'] = True else: - # The corresponding date field is not present in the, - # serializer. We have a default to use for the date, so - # add in a hidden field that populates it. - unique_fields[date_field_name] = HiddenField(default=default) + if 'default' not in extra_kwargs[unique_constraint_name]: + extra_kwargs[unique_constraint_name]['default'] = default + elif default is not empty: + # The corresponding field is not present in the, + # serializer. We have a default to use for it, so + # add in a hidden field that populates it. + hidden_fields[unique_constraint_name] = HiddenField(default=default) # Now determine the fields that should be included on the serializer. for field_name in fields: @@ -776,12 +926,16 @@ class ModelSerializer(Serializer): 'validators', 'queryset' ]: kwargs.pop(attr, None) + + if extras.get('default') and kwargs.get('required') is False: + kwargs.pop('required') + kwargs.update(extras) # Create the serializer field. ret[field_name] = field_cls(**kwargs) - for field_name, field in unique_fields.items(): + for field_name, field in hidden_fields.items(): ret[field_name] = field return ret |
