diff options
| author | Tom Christie | 2014-11-10 12:21:27 +0000 | 
|---|---|---|
| committer | Tom Christie | 2014-11-10 12:21:27 +0000 | 
| commit | f387cd89da55ef88fcac504f5795ea9b591f3fba (patch) | |
| tree | 3c927d172e0ea18c65f6afd2360c308e286c13a3 /rest_framework/validators.py | |
| parent | 93633c297c69a1eefda5e153553c4f021cf10bd8 (diff) | |
| download | django-rest-framework-f387cd89da55ef88fcac504f5795ea9b591f3fba.tar.bz2 | |
Uniqueness constraints imply a forced 'required=True'. Refs #1945
Diffstat (limited to 'rest_framework/validators.py')
| -rw-r--r-- | rest_framework/validators.py | 95 | 
1 files changed, 79 insertions, 16 deletions
| diff --git a/rest_framework/validators.py b/rest_framework/validators.py index f3773f17..d7f847aa 100644 --- a/rest_framework/validators.py +++ b/rest_framework/validators.py @@ -25,6 +25,10 @@ class UniqueValidator:          self.message = message or self.message      def set_context(self, serializer_field): +        """ +        This hook is called by the serializer instance, +        prior to the validation call being made. +        """          # Determine the underlying model field name. This may not be the          # same as the serializer field name if `source=<>` is set.          self.field_name = serializer_field.source_attrs[0] @@ -54,6 +58,7 @@ class UniqueTogetherValidator:      Should be applied to the serializer class, not to an individual field.      """      message = _('The fields {field_names} must make a unique set.') +    missing_message = _('This field is required.')      def __init__(self, queryset, fields, message=None):          self.queryset = queryset @@ -62,17 +67,49 @@ class UniqueTogetherValidator:          self.message = message or self.message      def set_context(self, serializer): +        """ +        This hook is called by the serializer instance, +        prior to the validation call being made. +        """          # Determine the existing instance, if this is an update operation.          self.instance = getattr(serializer, 'instance', None) -    def __call__(self, attrs): -        # Ensure uniqueness. +    def enforce_required_fields(self, attrs): +        """ +        The `UniqueTogetherValidator` always forces an implied 'required' +        state on the fields it applies to. +        """ +        missing = dict([ +            (field_name, self.missing_message) +            for field_name in self.fields +            if field_name not in attrs +        ]) +        if missing: +            raise ValidationError(missing) + +    def filter_queryset(self, attrs, queryset): +        """ +        Filter the queryset to all instances matching the given attributes. +        """          filter_kwargs = dict([              (field_name, attrs[field_name]) for field_name in self.fields          ]) -        queryset = self.queryset.filter(**filter_kwargs) +        return queryset.filter(**filter_kwargs) + +    def exclude_current_instance(self, attrs, queryset): +        """ +        If an instance is being updated, then do not include +        that instance itself as a uniqueness conflict. +        """          if self.instance is not None: -            queryset = queryset.exclude(pk=self.instance.pk) +            return queryset.exclude(pk=self.instance.pk) +        return queryset + +    def __call__(self, attrs): +        self.enforce_required_fields(attrs) +        queryset = self.queryset +        queryset = self.filter_queryset(attrs, queryset) +        queryset = self.exclude_current_instance(attrs, queryset)          if queryset.exists():              field_names = ', '.join(self.fields)              raise ValidationError(self.message.format(field_names=field_names)) @@ -87,6 +124,7 @@ class UniqueTogetherValidator:  class BaseUniqueForValidator:      message = None +    missing_message = _('This field is required.')      def __init__(self, queryset, field, date_field, message=None):          self.queryset = queryset @@ -95,6 +133,10 @@ class BaseUniqueForValidator:          self.message = message or self.message      def set_context(self, serializer): +        """ +        This hook is called by the serializer instance, +        prior to the validation call being made. +        """          # Determine the underlying model field names. These may not be the          # same as the serializer field names if `source=<>` is set.          self.field_name = serializer.fields[self.field].source_attrs[0] @@ -102,15 +144,36 @@ class BaseUniqueForValidator:          # Determine the existing instance, if this is an update operation.          self.instance = getattr(serializer, 'instance', None) -    def get_filter_kwargs(self, attrs): -        raise NotImplementedError('`get_filter_kwargs` must be implemented.') +    def enforce_required_fields(self, attrs): +        """ +        The `UniqueFor<Range>Validator` classes always force an implied +        'required' state on the fields they are applied to. +        """ +        missing = dict([ +            (field_name, self.missing_message) +            for field_name in [self.field, self.date_field] +            if field_name not in attrs +        ]) +        if missing: +            raise ValidationError(missing) -    def __call__(self, attrs): -        filter_kwargs = self.get_filter_kwargs(attrs) +    def filter_queryset(self, attrs, queryset): +        raise NotImplementedError('`filter_queryset` must be implemented.') -        queryset = self.queryset.filter(**filter_kwargs) +    def exclude_current_instance(self, attrs, queryset): +        """ +        If an instance is being updated, then do not include +        that instance itself as a uniqueness conflict. +        """          if self.instance is not None: -            queryset = queryset.exclude(pk=self.instance.pk) +            return queryset.exclude(pk=self.instance.pk) +        return queryset + +    def __call__(self, attrs): +        self.enforce_required_fields(attrs) +        queryset = self.queryset +        queryset = self.filter_queryset(attrs, queryset) +        queryset = self.exclude_current_instance(attrs, queryset)          if queryset.exists():              message = self.message.format(date_field=self.date_field)              raise ValidationError({self.field: message}) @@ -127,7 +190,7 @@ class BaseUniqueForValidator:  class UniqueForDateValidator(BaseUniqueForValidator):      message = _('This field must be unique for the "{date_field}" date.') -    def get_filter_kwargs(self, attrs): +    def filter_queryset(self, attrs, queryset):          value = attrs[self.field]          date = attrs[self.date_field] @@ -136,30 +199,30 @@ class UniqueForDateValidator(BaseUniqueForValidator):          filter_kwargs['%s__day' % self.date_field_name] = date.day          filter_kwargs['%s__month' % self.date_field_name] = date.month          filter_kwargs['%s__year' % self.date_field_name] = date.year -        return filter_kwargs +        return queryset.filter(**filter_kwargs)  class UniqueForMonthValidator(BaseUniqueForValidator):      message = _('This field must be unique for the "{date_field}" month.') -    def get_filter_kwargs(self, attrs): +    def filter_queryset(self, attrs, queryset):          value = attrs[self.field]          date = attrs[self.date_field]          filter_kwargs = {}          filter_kwargs[self.field_name] = value          filter_kwargs['%s__month' % self.date_field_name] = date.month -        return filter_kwargs +        return queryset.filter(**filter_kwargs)  class UniqueForYearValidator(BaseUniqueForValidator):      message = _('This field must be unique for the "{date_field}" year.') -    def get_filter_kwargs(self, attrs): +    def filter_queryset(self, attrs, queryset):          value = attrs[self.field]          date = attrs[self.date_field]          filter_kwargs = {}          filter_kwargs[self.field_name] = value          filter_kwargs['%s__year' % self.date_field_name] = date.year -        return filter_kwargs +        return queryset.filter(**filter_kwargs) | 
