diff options
Diffstat (limited to 'rest_framework/validators.py')
| -rw-r--r-- | rest_framework/validators.py | 95 |
1 files changed, 79 insertions, 16 deletions
diff --git a/rest_framework/validators.py b/rest_framework/validators.py index f3773f17..d7f847aa 100644 --- a/rest_framework/validators.py +++ b/rest_framework/validators.py @@ -25,6 +25,10 @@ class UniqueValidator: self.message = message or self.message def set_context(self, serializer_field): + """ + This hook is called by the serializer instance, + prior to the validation call being made. + """ # Determine the underlying model field name. This may not be the # same as the serializer field name if `source=<>` is set. self.field_name = serializer_field.source_attrs[0] @@ -54,6 +58,7 @@ class UniqueTogetherValidator: Should be applied to the serializer class, not to an individual field. """ message = _('The fields {field_names} must make a unique set.') + missing_message = _('This field is required.') def __init__(self, queryset, fields, message=None): self.queryset = queryset @@ -62,17 +67,49 @@ class UniqueTogetherValidator: self.message = message or self.message def set_context(self, serializer): + """ + This hook is called by the serializer instance, + prior to the validation call being made. + """ # Determine the existing instance, if this is an update operation. self.instance = getattr(serializer, 'instance', None) - def __call__(self, attrs): - # Ensure uniqueness. + def enforce_required_fields(self, attrs): + """ + The `UniqueTogetherValidator` always forces an implied 'required' + state on the fields it applies to. + """ + missing = dict([ + (field_name, self.missing_message) + for field_name in self.fields + if field_name not in attrs + ]) + if missing: + raise ValidationError(missing) + + def filter_queryset(self, attrs, queryset): + """ + Filter the queryset to all instances matching the given attributes. + """ filter_kwargs = dict([ (field_name, attrs[field_name]) for field_name in self.fields ]) - queryset = self.queryset.filter(**filter_kwargs) + return queryset.filter(**filter_kwargs) + + def exclude_current_instance(self, attrs, queryset): + """ + If an instance is being updated, then do not include + that instance itself as a uniqueness conflict. + """ if self.instance is not None: - queryset = queryset.exclude(pk=self.instance.pk) + return queryset.exclude(pk=self.instance.pk) + return queryset + + def __call__(self, attrs): + self.enforce_required_fields(attrs) + queryset = self.queryset + queryset = self.filter_queryset(attrs, queryset) + queryset = self.exclude_current_instance(attrs, queryset) if queryset.exists(): field_names = ', '.join(self.fields) raise ValidationError(self.message.format(field_names=field_names)) @@ -87,6 +124,7 @@ class UniqueTogetherValidator: class BaseUniqueForValidator: message = None + missing_message = _('This field is required.') def __init__(self, queryset, field, date_field, message=None): self.queryset = queryset @@ -95,6 +133,10 @@ class BaseUniqueForValidator: self.message = message or self.message def set_context(self, serializer): + """ + This hook is called by the serializer instance, + prior to the validation call being made. + """ # Determine the underlying model field names. These may not be the # same as the serializer field names if `source=<>` is set. self.field_name = serializer.fields[self.field].source_attrs[0] @@ -102,15 +144,36 @@ class BaseUniqueForValidator: # Determine the existing instance, if this is an update operation. self.instance = getattr(serializer, 'instance', None) - def get_filter_kwargs(self, attrs): - raise NotImplementedError('`get_filter_kwargs` must be implemented.') + def enforce_required_fields(self, attrs): + """ + The `UniqueFor<Range>Validator` classes always force an implied + 'required' state on the fields they are applied to. + """ + missing = dict([ + (field_name, self.missing_message) + for field_name in [self.field, self.date_field] + if field_name not in attrs + ]) + if missing: + raise ValidationError(missing) - def __call__(self, attrs): - filter_kwargs = self.get_filter_kwargs(attrs) + def filter_queryset(self, attrs, queryset): + raise NotImplementedError('`filter_queryset` must be implemented.') - queryset = self.queryset.filter(**filter_kwargs) + def exclude_current_instance(self, attrs, queryset): + """ + If an instance is being updated, then do not include + that instance itself as a uniqueness conflict. + """ if self.instance is not None: - queryset = queryset.exclude(pk=self.instance.pk) + return queryset.exclude(pk=self.instance.pk) + return queryset + + def __call__(self, attrs): + self.enforce_required_fields(attrs) + queryset = self.queryset + queryset = self.filter_queryset(attrs, queryset) + queryset = self.exclude_current_instance(attrs, queryset) if queryset.exists(): message = self.message.format(date_field=self.date_field) raise ValidationError({self.field: message}) @@ -127,7 +190,7 @@ class BaseUniqueForValidator: class UniqueForDateValidator(BaseUniqueForValidator): message = _('This field must be unique for the "{date_field}" date.') - def get_filter_kwargs(self, attrs): + def filter_queryset(self, attrs, queryset): value = attrs[self.field] date = attrs[self.date_field] @@ -136,30 +199,30 @@ class UniqueForDateValidator(BaseUniqueForValidator): filter_kwargs['%s__day' % self.date_field_name] = date.day filter_kwargs['%s__month' % self.date_field_name] = date.month filter_kwargs['%s__year' % self.date_field_name] = date.year - return filter_kwargs + return queryset.filter(**filter_kwargs) class UniqueForMonthValidator(BaseUniqueForValidator): message = _('This field must be unique for the "{date_field}" month.') - def get_filter_kwargs(self, attrs): + def filter_queryset(self, attrs, queryset): value = attrs[self.field] date = attrs[self.date_field] filter_kwargs = {} filter_kwargs[self.field_name] = value filter_kwargs['%s__month' % self.date_field_name] = date.month - return filter_kwargs + return queryset.filter(**filter_kwargs) class UniqueForYearValidator(BaseUniqueForValidator): message = _('This field must be unique for the "{date_field}" year.') - def get_filter_kwargs(self, attrs): + def filter_queryset(self, attrs, queryset): value = attrs[self.field] date = attrs[self.date_field] filter_kwargs = {} filter_kwargs[self.field_name] = value filter_kwargs['%s__year' % self.date_field_name] = date.year - return filter_kwargs + return queryset.filter(**filter_kwargs) |
