diff options
| -rw-r--r-- | rest_framework/serializers.py | 174 | 
1 files changed, 96 insertions, 78 deletions
| diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index d4b0926e..5e9cbe36 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -888,89 +888,19 @@ class ModelSerializer(Serializer):          # Retrieve metadata about fields & relationships on the model class.          info = model_meta.get_field_info(model) -        fields = self.get_field_names(declared_fields, info) +        field_names = self.get_field_names(declared_fields, info)          extra_kwargs = self.get_extra_kwargs() -        # Determine the set of model fields, and the fields that they map to. -        # We actually only need this to deal with the slightly awkward case -        # of supporting `unique_for_date`/`unique_for_month`/`unique_for_year`. -        model_field_mapping = {} -        for field_name in fields: -            if field_name in declared_fields: -                field = declared_fields[field_name] -                source = field.source or field_name +        model_fields = self.get_model_fields(field_names, declared_fields, extra_kwargs) +        uniqueness_extra_kwargs, hidden_fields = self.get_uniqueness_field_options(field_names, model_fields) +        for key, value in uniqueness_extra_kwargs.items(): +            if key in extra_kwargs: +                extra_kwargs[key].update(value)              else: -                try: -                    source = extra_kwargs[field_name]['source'] -                except KeyError: -                    source = field_name -            # Model fields will always have a simple source mapping, -            # they can't be nested attribute lookups. -            if '.' not in source and source != '*': -                model_field_mapping[source] = field_name - -        # Determine if we need any additional `HiddenField` or extra keyword -        # arguments to deal with `unique_for` dates that are required to -        # be in the input data in order to validate it. -        hidden_fields = {} -        unique_constraint_names = set() - -        for model_field_name, field_name in model_field_mapping.items(): -            try: -                model_field = model._meta.get_field(model_field_name) -            except FieldDoesNotExist: -                continue - -            # Include each of the `unique_for_*` field names. -            unique_constraint_names |= set([ -                model_field.unique_for_date, -                model_field.unique_for_month, -                model_field.unique_for_year -            ]) - -        unique_constraint_names -= set([None]) - -        # Include each of the `unique_together` field names, -        # so long as all the field names are included on the serializer. -        for parent_class in [model] + list(model._meta.parents.keys()): -            for unique_together_list in parent_class._meta.unique_together: -                if set(fields).issuperset(set(unique_together_list)): -                    unique_constraint_names |= set(unique_together_list) - -        # Now we have all the field names that have uniqueness constraints -        # applied, we can add the extra 'required=...' or 'default=...' -        # arguments that are appropriate to these fields, or add a `HiddenField` for it. -        for unique_constraint_name in unique_constraint_names: -            # Get the model field that is referred too. -            unique_constraint_field = model._meta.get_field(unique_constraint_name) - -            if getattr(unique_constraint_field, 'auto_now_add', None): -                default = CreateOnlyDefault(timezone.now) -            elif getattr(unique_constraint_field, 'auto_now', None): -                default = timezone.now -            elif unique_constraint_field.has_default(): -                default = unique_constraint_field.default -            else: -                default = empty - -            if unique_constraint_name in model_field_mapping: -                # The corresponding field is present in the serializer -                if unique_constraint_name not in extra_kwargs: -                    extra_kwargs[unique_constraint_name] = {} -                if default is empty: -                    if 'required' not in extra_kwargs[unique_constraint_name]: -                        extra_kwargs[unique_constraint_name]['required'] = True -                else: -                    if 'default' not in extra_kwargs[unique_constraint_name]: -                        extra_kwargs[unique_constraint_name]['default'] = default -            elif default is not empty: -                # The corresponding field is not present in the, -                # serializer. We have a default to use for it, so -                # add in a hidden field that populates it. -                hidden_fields[unique_constraint_name] = HiddenField(default=default) +                extra_kwargs[key] = value          # Now determine the fields that should be included on the serializer. -        for field_name in fields: +        for field_name in field_names:              if field_name in declared_fields:                  # Field is explicitly declared on the class, use that.                  ret[field_name] = declared_fields[field_name] @@ -1046,6 +976,94 @@ class ModelSerializer(Serializer):          return ret +    def get_model_fields(self, field_names, declared_fields, extra_kwargs): +        # Returns all the model fields that are being mapped to by fields +        # on the serializer class. +        # Returned as a dict of 'model field name' -> 'model field' +        model = getattr(self.Meta, 'model') +        model_fields = {} + +        for field_name in field_names: +            if field_name in declared_fields: +                # If the field is declared on the serializer +                field = declared_fields[field_name] +                source = field.source or field_name +            else: +                try: +                    source = extra_kwargs[field_name]['source'] +                except KeyError: +                    source = field_name + +            if '.' in source or source == '*': +                # Model fields will always have a simple source mapping, +                # they can't be nested attribute lookups. +                continue + +            try: +                model_fields[source] = model._meta.get_field(source) +            except FieldDoesNotExist: +                pass + +        return model_fields + +    def get_uniqueness_field_options(self, field_names, model_fields): +        model = getattr(self.Meta, 'model') + +        # Determine if we need any additional `HiddenField` or extra keyword +        # arguments to deal with `unique_for` dates that are required to +        # be in the input data in order to validate it. +        unique_constraint_names = set() + +        for model_field in model_fields.values(): +            # Include each of the `unique_for_*` field names. +            unique_constraint_names |= set([ +                model_field.unique_for_date, +                model_field.unique_for_month, +                model_field.unique_for_year +            ]) + +        unique_constraint_names -= set([None]) + +        # Include each of the `unique_together` field names, +        # so long as all the field names are included on the serializer. +        for parent_class in [model] + list(model._meta.parents.keys()): +            for unique_together_list in parent_class._meta.unique_together: +                if set(field_names).issuperset(set(unique_together_list)): +                    unique_constraint_names |= set(unique_together_list) + +        # Now we have all the field names that have uniqueness constraints +        # applied, we can add the extra 'required=...' or 'default=...' +        # arguments that are appropriate to these fields, or add a `HiddenField` for it. +        hidden_fields = {} +        extra_kwargs = {} + +        for unique_constraint_name in unique_constraint_names: +            # Get the model field that is referred too. +            unique_constraint_field = model._meta.get_field(unique_constraint_name) + +            if getattr(unique_constraint_field, 'auto_now_add', None): +                default = CreateOnlyDefault(timezone.now) +            elif getattr(unique_constraint_field, 'auto_now', None): +                default = timezone.now +            elif unique_constraint_field.has_default(): +                default = unique_constraint_field.default +            else: +                default = empty + +            if unique_constraint_name in model_fields: +                # The corresponding field is present in the serializer +                if default is empty: +                    extra_kwargs[unique_constraint_name] = {'required': True} +                else: +                    extra_kwargs[unique_constraint_name] = {'default': default} +            elif default is not empty: +                # The corresponding field is not present in the, +                # serializer. We have a default to use for it, so +                # add in a hidden field that populates it. +                hidden_fields[unique_constraint_name] = HiddenField(default=default) + +        return extra_kwargs, hidden_fields +      def get_extra_kwargs(self):          """          Return a dictionary mapping field names to a dictionary of | 
