diff options
| author | Tom Christie | 2014-10-28 16:21:49 +0000 | 
|---|---|---|
| committer | Tom Christie | 2014-10-28 16:21:49 +0000 | 
| commit | 9ebaabd6eb31e18cf0bb1c70893f719f18ecb0f9 (patch) | |
| tree | b3718c7e155e3d3c97666ce2cfd832e0a8381de0 /rest_framework | |
| parent | 702f47700de2c10f26f06b23099740c408ffe797 (diff) | |
| download | django-rest-framework-9ebaabd6eb31e18cf0bb1c70893f719f18ecb0f9.tar.bz2 | |
unique_for_date/unique_for_month/unique_for_year
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/fields.py | 54 | ||||
| -rw-r--r-- | rest_framework/serializers.py | 82 | 
2 files changed, 129 insertions, 7 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py index e939b2f2..82b7eb37 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -92,13 +92,35 @@ def set_value(dictionary, keys, value):      dictionary[keys[-1]] = value +class CreateOnlyDefault: +    """ +    This class may be used to provide default values that are only used +    for create operations, but that do not return any value for update +    operations. +    """ +    def __init__(self, default): +        self.default = default + +    def set_context(self, serializer_field): +        self.is_update = serializer_field.parent.instance is not None + +    def __call__(self): +        if self.is_update: +            raise SkipField() +        if callable(self.default): +            return self.default() +        return self.default + +    def __repr__(self): +        return '%s(%s)' % (self.__class__.__name__, repr(self.default)) + +  class SkipField(Exception):      pass  NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`'  NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`' -NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`'  NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`'  USE_READONLYFIELD = 'Field(read_only=True) should be ReadOnlyField'  MISSING_ERROR_MESSAGE = ( @@ -132,7 +154,6 @@ class Field(object):          # Some combinations of keyword arguments do not make sense.          assert not (read_only and write_only), NOT_READ_ONLY_WRITE_ONLY          assert not (read_only and required), NOT_READ_ONLY_REQUIRED -        assert not (read_only and default is not empty), NOT_READ_ONLY_DEFAULT          assert not (required and default is not empty), NOT_REQUIRED_DEFAULT          assert not (read_only and self.__class__ == Field), USE_READONLYFIELD @@ -230,7 +251,9 @@ class Field(object):          """          if self.default is empty:              raise SkipField() -        if is_simple_callable(self.default): +        if callable(self.default): +            if hasattr(self.default, 'set_context'): +                self.default.set_context(self)              return self.default()          return self.default @@ -244,6 +267,9 @@ class Field(object):          May raise `SkipField` if the field should not be included in the          validated data.          """ +        if self.read_only: +            return self.get_default() +          if data is empty:              if getattr(self.root, 'partial', False):                  raise SkipField() @@ -1033,6 +1059,28 @@ class ReadOnlyField(Field):          return value +class HiddenField(Field): +    """ +    A hidden field does not take input from the user, or present any output, +    but it does populate a field in `validated_data`, based on its default +    value. This is particularly useful when we have a `unique_for_date` +    constrain on a pair of fields, as we need some way to include the date in +    the validated data. +    """ +    def __init__(self, **kwargs): +        assert 'default' in kwargs, 'default is a required argument.' +        kwargs['write_only'] = True +        super(HiddenField, self).__init__(**kwargs) + +    def get_value(self, dictionary): +        # We always use the default value for `HiddenField`. +        # User input is never provided or accepted. +        return empty + +    def to_internal_value(self, data): +        return data + +  class SerializerMethodField(Field):      """      A read-only field that get its representation from calling a method on the diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index b45f343a..6aab020e 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -12,6 +12,7 @@ response content is handled by parsers and renderers.  """  from django.core.exceptions import ImproperlyConfigured  from django.db import models +from django.db.models.fields import FieldDoesNotExist  from django.utils import six  from django.utils.datastructures import SortedDict  from rest_framework.exceptions import ValidationError @@ -368,7 +369,10 @@ class Serializer(BaseSerializer):          """          ret = {}          errors = ReturnDict(serializer=self) -        fields = [field for field in self.fields.values() if not field.read_only] +        fields = [ +            field for field in self.fields.values() +            if (not field.read_only) or (field.default is not empty) +        ]          for field in fields:              validate_method = getattr(self, 'validate_' + field.field_name, None) @@ -517,7 +521,7 @@ class ModelSerializer(Serializer):      def __init__(self, *args, **kwargs):          super(ModelSerializer, self).__init__(*args, **kwargs)          if 'validators' not in kwargs: -            validators = self.get_unique_together_validators() +            validators = self.get_default_validators()              if validators:                  self.validators.extend(validators)                  self._kwargs['validators'] = validators @@ -572,7 +576,7 @@ class ModelSerializer(Serializer):          instance.save()          return instance -    def get_unique_together_validators(self): +    def get_default_validators(self):          field_names = set([              field.source for field in self.fields.values()              if (field.source != '*') and ('.' not in field.source) @@ -592,6 +596,7 @@ class ModelSerializer(Serializer):                      )                      validators.append(validator) +        # Add any unique_for_date/unique_for_month/unique_for_year constraints.          info = model_meta.get_field_info(model_class)          for field_name, field in info.fields_and_pk.items():              if field.unique_for_date and field_name in field_names: @@ -637,7 +642,7 @@ class ModelSerializer(Serializer):          # Retrieve metadata about fields & relationships on the model class.          info = model_meta.get_field_info(model) -        # Use the default set of fields if none is supplied explicitly. +        # Use the default set of field names if none is supplied explicitly.          if fields is None:              fields = self._get_default_field_names(declared_fields, info)              exclude = getattr(self.Meta, 'exclude', None) @@ -645,6 +650,72 @@ class ModelSerializer(Serializer):                  for field_name in exclude:                      fields.remove(field_name) +        # Determine the set of model fields, and the fields that they map to. +        # We actually only need this to deal with the slightly awkward case +        # of supporting `unique_for_date`/`unique_for_month`/`unique_for_year`. +        model_field_mapping = {} +        for field_name in fields: +            if field_name in declared_fields: +                field = declared_fields[field_name] +                source = field.source or field_name +            else: +                try: +                    source = extra_kwargs[field_name]['source'] +                except KeyError: +                    source = field_name +            # Model fields will always have a simple source mapping, +            # they can't be nested attribute lookups. +            if '.' not in source and source != '*': +                model_field_mapping[source] = field_name + +        # Determine if we need any additional `HiddenField` or extra keyword +        # arguments to deal with `unique_for` dates that are required to +        # be in the input data in order to validate it. +        unique_fields = {} +        for model_field_name, field_name in model_field_mapping.items(): +            try: +                model_field = model._meta.get_field(model_field_name) +            except FieldDoesNotExist: +                continue + +            # Deal with each of the `unique_for_*` cases. +            for date_field_name in ( +                model_field.unique_for_date, +                model_field.unique_for_month, +                model_field.unique_for_year +            ): +                if date_field_name is None: +                    continue + +                # Get the model field that is refered too. +                date_field = model._meta.get_field(date_field_name) + +                if date_field.auto_now_add: +                    default = CreateOnlyDefault(timezone.now) +                elif date_field.auto_now: +                    default = timezone.now +                elif date_field.has_default(): +                    default = model_field.default +                else: +                    default = empty + +                if date_field_name in model_field_mapping: +                    # The corresponding date field is present in the serializer +                    if date_field_name not in extra_kwargs: +                        extra_kwargs[date_field_name] = {} +                    if default is empty: +                        if 'required' not in extra_kwargs[date_field_name]: +                            extra_kwargs[date_field_name]['required'] = True +                    else: +                        if 'default' not in extra_kwargs[date_field_name]: +                            extra_kwargs[date_field_name]['default'] = default +                else: +                    # The corresponding date field is not present in the, +                    # serializer. We have a default to use for the date, so +                    # add in a hidden field that populates it. +                    unique_fields[date_field_name] = HiddenField(default=default) + +        # Now determine the fields that should be included on the serializer.          for field_name in fields:              if field_name in declared_fields:                  # Field is explicitly declared on the class, use that. @@ -723,6 +794,9 @@ class ModelSerializer(Serializer):              # Create the serializer field.              ret[field_name] = field_cls(**kwargs) +        for field_name, field in unique_fields.items(): +            ret[field_name] = field +          return ret      def _include_additional_options(self, extra_kwargs):  | 
