diff options
| author | Tom Christie | 2014-11-19 13:55:10 +0000 | 
|---|---|---|
| committer | Tom Christie | 2014-11-19 13:55:10 +0000 | 
| commit | 8586290df80ac8448d71cdb3326bc822c399cad1 (patch) | |
| tree | 25b657c1bf29547c9380816d5fc8b4f42854ceba /rest_framework/serializers.py | |
| parent | 6cb6510132b319c96b28bea732032aaf2d495895 (diff) | |
| download | django-rest-framework-8586290df80ac8448d71cdb3326bc822c399cad1.tar.bz2 | |
Apply defaults and requiredness to unique_together fields. Closes #2092.
Diffstat (limited to 'rest_framework/serializers.py')
| -rw-r--r-- | rest_framework/serializers.py | 83 | 
1 files changed, 49 insertions, 34 deletions
| diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 84282cdb..2e34dbe7 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -720,49 +720,60 @@ class ModelSerializer(Serializer):          # Determine if we need any additional `HiddenField` or extra keyword          # arguments to deal with `unique_for` dates that are required to          # be in the input data in order to validate it. -        unique_fields = {} +        hidden_fields = {} +          for model_field_name, field_name in model_field_mapping.items():              try:                  model_field = model._meta.get_field(model_field_name)              except FieldDoesNotExist:                  continue -            # Deal with each of the `unique_for_*` cases. -            for date_field_name in ( +            # Include each of the `unique_for_*` field names. +            unique_constraint_names = set([                  model_field.unique_for_date,                  model_field.unique_for_month,                  model_field.unique_for_year -            ): -                if date_field_name is None: -                    continue - -                # Get the model field that is refered too. -                date_field = model._meta.get_field(date_field_name) - -                if date_field.auto_now_add: -                    default = CreateOnlyDefault(timezone.now) -                elif date_field.auto_now: -                    default = timezone.now -                elif date_field.has_default(): -                    default = model_field.default -                else: -                    default = empty - -                if date_field_name in model_field_mapping: -                    # The corresponding date field is present in the serializer -                    if date_field_name not in extra_kwargs: -                        extra_kwargs[date_field_name] = {} -                    if default is empty: -                        if 'required' not in extra_kwargs[date_field_name]: -                            extra_kwargs[date_field_name]['required'] = True -                    else: -                        if 'default' not in extra_kwargs[date_field_name]: -                            extra_kwargs[date_field_name]['default'] = default +            ]) +            unique_constraint_names -= set([None]) + +            # Include each of the `unique_together` field names, +            # so long as all the field names are included on the serializer. +            for parent_class in [model] + list(model._meta.parents.keys()): +                for unique_together_list in parent_class._meta.unique_together: +                    if set(fields).issuperset(set(unique_together_list)): +                        unique_constraint_names |= set(unique_together_list) + +        # Now we have all the field names that have uniqueness constraints +        # applied, we can add the extra 'required=...' or 'default=...' +        # arguments that are appropriate to these fields, or add a `HiddenField` for it. +        for unique_constraint_name in unique_constraint_names: +            # Get the model field that is refered too. +            unique_constraint_field = model._meta.get_field(unique_constraint_name) + +            if getattr(unique_constraint_field, 'auto_now_add', None): +                default = CreateOnlyDefault(timezone.now) +            elif getattr(unique_constraint_field, 'auto_now', None): +                default = timezone.now +            elif unique_constraint_field.has_default(): +                default = model_field.default +            else: +                default = empty + +            if unique_constraint_name in model_field_mapping: +                # The corresponding field is present in the serializer +                if unique_constraint_name not in extra_kwargs: +                    extra_kwargs[unique_constraint_name] = {} +                if default is empty: +                    if 'required' not in extra_kwargs[unique_constraint_name]: +                        extra_kwargs[unique_constraint_name]['required'] = True                  else: -                    # The corresponding date field is not present in the, -                    # serializer. We have a default to use for the date, so -                    # add in a hidden field that populates it. -                    unique_fields[date_field_name] = HiddenField(default=default) +                    if 'default' not in extra_kwargs[unique_constraint_name]: +                        extra_kwargs[unique_constraint_name]['default'] = default +            elif default is not empty: +                # The corresponding field is not present in the, +                # serializer. We have a default to use for it, so +                # add in a hidden field that populates it. +                hidden_fields[unique_constraint_name] = HiddenField(default=default)          # Now determine the fields that should be included on the serializer.          for field_name in fields: @@ -838,12 +849,16 @@ class ModelSerializer(Serializer):                      'validators', 'queryset'                  ]:                      kwargs.pop(attr, None) + +            if extras.get('default') and kwargs.get('required') is False: +                kwargs.pop('required') +              kwargs.update(extras)              # Create the serializer field.              ret[field_name] = field_cls(**kwargs) -        for field_name, field in unique_fields.items(): +        for field_name, field in hidden_fields.items():              ret[field_name] = field          return ret | 
