diff options
| -rw-r--r-- | docs/api-guide/validators.md | 10 | ||||
| -rw-r--r-- | rest_framework/validators.py | 95 | ||||
| -rw-r--r-- | tests/test_validators.py | 11 | 
3 files changed, 100 insertions, 16 deletions
| diff --git a/docs/api-guide/validators.md b/docs/api-guide/validators.md index 6a0ef4ff..bb073f57 100644 --- a/docs/api-guide/validators.md +++ b/docs/api-guide/validators.md @@ -93,6 +93,12 @@ The validator should be applied to *serializer classes*, like so:                  )              ] +--- + +**Note**: The `UniqueTogetherValidation` class always imposes an implicit constraint that all the fields it applies to are always treated as required. Fields with `default` values are an exception to this as they always supply a value even when omitted from user input. + +--- +  ## UniqueForDateValidator  ## UniqueForMonthValidator @@ -146,6 +152,10 @@ If you want the date field to be entirely hidden from the user, then use `Hidden  --- +**Note**: The `UniqueFor<Range>Validation` classes always imposes an implicit constraint that the fields they are applied to are always treated as required. Fields with `default` values are an exception to this as they always supply a value even when omitted from user input. + +--- +  # Writing custom validators  You can use any of Django's existing validators, or write your own custom validators. diff --git a/rest_framework/validators.py b/rest_framework/validators.py index f3773f17..d7f847aa 100644 --- a/rest_framework/validators.py +++ b/rest_framework/validators.py @@ -25,6 +25,10 @@ class UniqueValidator:          self.message = message or self.message      def set_context(self, serializer_field): +        """ +        This hook is called by the serializer instance, +        prior to the validation call being made. +        """          # Determine the underlying model field name. This may not be the          # same as the serializer field name if `source=<>` is set.          self.field_name = serializer_field.source_attrs[0] @@ -54,6 +58,7 @@ class UniqueTogetherValidator:      Should be applied to the serializer class, not to an individual field.      """      message = _('The fields {field_names} must make a unique set.') +    missing_message = _('This field is required.')      def __init__(self, queryset, fields, message=None):          self.queryset = queryset @@ -62,17 +67,49 @@ class UniqueTogetherValidator:          self.message = message or self.message      def set_context(self, serializer): +        """ +        This hook is called by the serializer instance, +        prior to the validation call being made. +        """          # Determine the existing instance, if this is an update operation.          self.instance = getattr(serializer, 'instance', None) -    def __call__(self, attrs): -        # Ensure uniqueness. +    def enforce_required_fields(self, attrs): +        """ +        The `UniqueTogetherValidator` always forces an implied 'required' +        state on the fields it applies to. +        """ +        missing = dict([ +            (field_name, self.missing_message) +            for field_name in self.fields +            if field_name not in attrs +        ]) +        if missing: +            raise ValidationError(missing) + +    def filter_queryset(self, attrs, queryset): +        """ +        Filter the queryset to all instances matching the given attributes. +        """          filter_kwargs = dict([              (field_name, attrs[field_name]) for field_name in self.fields          ]) -        queryset = self.queryset.filter(**filter_kwargs) +        return queryset.filter(**filter_kwargs) + +    def exclude_current_instance(self, attrs, queryset): +        """ +        If an instance is being updated, then do not include +        that instance itself as a uniqueness conflict. +        """          if self.instance is not None: -            queryset = queryset.exclude(pk=self.instance.pk) +            return queryset.exclude(pk=self.instance.pk) +        return queryset + +    def __call__(self, attrs): +        self.enforce_required_fields(attrs) +        queryset = self.queryset +        queryset = self.filter_queryset(attrs, queryset) +        queryset = self.exclude_current_instance(attrs, queryset)          if queryset.exists():              field_names = ', '.join(self.fields)              raise ValidationError(self.message.format(field_names=field_names)) @@ -87,6 +124,7 @@ class UniqueTogetherValidator:  class BaseUniqueForValidator:      message = None +    missing_message = _('This field is required.')      def __init__(self, queryset, field, date_field, message=None):          self.queryset = queryset @@ -95,6 +133,10 @@ class BaseUniqueForValidator:          self.message = message or self.message      def set_context(self, serializer): +        """ +        This hook is called by the serializer instance, +        prior to the validation call being made. +        """          # Determine the underlying model field names. These may not be the          # same as the serializer field names if `source=<>` is set.          self.field_name = serializer.fields[self.field].source_attrs[0] @@ -102,15 +144,36 @@ class BaseUniqueForValidator:          # Determine the existing instance, if this is an update operation.          self.instance = getattr(serializer, 'instance', None) -    def get_filter_kwargs(self, attrs): -        raise NotImplementedError('`get_filter_kwargs` must be implemented.') +    def enforce_required_fields(self, attrs): +        """ +        The `UniqueFor<Range>Validator` classes always force an implied +        'required' state on the fields they are applied to. +        """ +        missing = dict([ +            (field_name, self.missing_message) +            for field_name in [self.field, self.date_field] +            if field_name not in attrs +        ]) +        if missing: +            raise ValidationError(missing) -    def __call__(self, attrs): -        filter_kwargs = self.get_filter_kwargs(attrs) +    def filter_queryset(self, attrs, queryset): +        raise NotImplementedError('`filter_queryset` must be implemented.') -        queryset = self.queryset.filter(**filter_kwargs) +    def exclude_current_instance(self, attrs, queryset): +        """ +        If an instance is being updated, then do not include +        that instance itself as a uniqueness conflict. +        """          if self.instance is not None: -            queryset = queryset.exclude(pk=self.instance.pk) +            return queryset.exclude(pk=self.instance.pk) +        return queryset + +    def __call__(self, attrs): +        self.enforce_required_fields(attrs) +        queryset = self.queryset +        queryset = self.filter_queryset(attrs, queryset) +        queryset = self.exclude_current_instance(attrs, queryset)          if queryset.exists():              message = self.message.format(date_field=self.date_field)              raise ValidationError({self.field: message}) @@ -127,7 +190,7 @@ class BaseUniqueForValidator:  class UniqueForDateValidator(BaseUniqueForValidator):      message = _('This field must be unique for the "{date_field}" date.') -    def get_filter_kwargs(self, attrs): +    def filter_queryset(self, attrs, queryset):          value = attrs[self.field]          date = attrs[self.date_field] @@ -136,30 +199,30 @@ class UniqueForDateValidator(BaseUniqueForValidator):          filter_kwargs['%s__day' % self.date_field_name] = date.day          filter_kwargs['%s__month' % self.date_field_name] = date.month          filter_kwargs['%s__year' % self.date_field_name] = date.year -        return filter_kwargs +        return queryset.filter(**filter_kwargs)  class UniqueForMonthValidator(BaseUniqueForValidator):      message = _('This field must be unique for the "{date_field}" month.') -    def get_filter_kwargs(self, attrs): +    def filter_queryset(self, attrs, queryset):          value = attrs[self.field]          date = attrs[self.date_field]          filter_kwargs = {}          filter_kwargs[self.field_name] = value          filter_kwargs['%s__month' % self.date_field_name] = date.month -        return filter_kwargs +        return queryset.filter(**filter_kwargs)  class UniqueForYearValidator(BaseUniqueForValidator):      message = _('This field must be unique for the "{date_field}" year.') -    def get_filter_kwargs(self, attrs): +    def filter_queryset(self, attrs, queryset):          value = attrs[self.field]          date = attrs[self.date_field]          filter_kwargs = {}          filter_kwargs[self.field_name] = value          filter_kwargs['%s__year' % self.date_field_name] = date.year -        return filter_kwargs +        return queryset.filter(**filter_kwargs) diff --git a/tests/test_validators.py b/tests/test_validators.py index e6e0b23a..86614b10 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -134,6 +134,17 @@ class TestUniquenessTogetherValidation(TestCase):              'position': 1          } +    def test_unique_together_is_required(self): +        """ +        In a unique together validation, all fields are required. +        """ +        data = {'position': 2} +        serializer = UniquenessTogetherSerializer(data=data, partial=True) +        assert not serializer.is_valid() +        assert serializer.errors == { +            'race_name': ['This field is required.'] +        } +      def test_ignore_excluded_fields(self):          """          When model fields are not included in a serializer, then uniqueness | 
