aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/validators.py
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework/validators.py')
-rw-r--r--rest_framework/validators.py254
1 files changed, 254 insertions, 0 deletions
diff --git a/rest_framework/validators.py b/rest_framework/validators.py
new file mode 100644
index 00000000..7ca4e6a9
--- /dev/null
+++ b/rest_framework/validators.py
@@ -0,0 +1,254 @@
+"""
+We perform uniqueness checks explicitly on the serializer class, rather
+the using Django's `.full_clean()`.
+
+This gives us better separation of concerns, allows us to use single-step
+object creation, and makes it possible to switch between using the implicit
+`ModelSerializer` class and an equivelent explicit `Serializer` class.
+"""
+from django.utils.translation import ugettext_lazy as _
+from rest_framework.exceptions import ValidationError
+from rest_framework.utils.representation import smart_repr
+
+
+class UniqueValidator:
+ """
+ Validator that corresponds to `unique=True` on a model field.
+
+ Should be applied to an individual field on the serializer.
+ """
+ message = _('This field must be unique.')
+
+ def __init__(self, queryset, message=None):
+ self.queryset = queryset
+ self.serializer_field = None
+ self.message = message or self.message
+
+ def set_context(self, serializer_field):
+ """
+ This hook is called by the serializer instance,
+ prior to the validation call being made.
+ """
+ # Determine the underlying model field name. This may not be the
+ # same as the serializer field name if `source=<>` is set.
+ self.field_name = serializer_field.source_attrs[0]
+ # Determine the existing instance, if this is an update operation.
+ self.instance = getattr(serializer_field.parent, 'instance', None)
+
+ def filter_queryset(self, value, queryset):
+ """
+ Filter the queryset to all instances matching the given attribute.
+ """
+ filter_kwargs = {self.field_name: value}
+ return queryset.filter(**filter_kwargs)
+
+ def exclude_current_instance(self, queryset):
+ """
+ If an instance is being updated, then do not include
+ that instance itself as a uniqueness conflict.
+ """
+ if self.instance is not None:
+ return queryset.exclude(pk=self.instance.pk)
+ return queryset
+
+ def __call__(self, value):
+ queryset = self.queryset
+ queryset = self.filter_queryset(value, queryset)
+ queryset = self.exclude_current_instance(queryset)
+ if queryset.exists():
+ raise ValidationError(self.message)
+
+ def __repr__(self):
+ return '<%s(queryset=%s)>' % (
+ self.__class__.__name__,
+ smart_repr(self.queryset)
+ )
+
+
+class UniqueTogetherValidator:
+ """
+ Validator that corresponds to `unique_together = (...)` on a model class.
+
+ Should be applied to the serializer class, not to an individual field.
+ """
+ message = _('The fields {field_names} must make a unique set.')
+ missing_message = _('This field is required.')
+
+ def __init__(self, queryset, fields, message=None):
+ self.queryset = queryset
+ self.fields = fields
+ self.serializer_field = None
+ self.message = message or self.message
+
+ def set_context(self, serializer):
+ """
+ This hook is called by the serializer instance,
+ prior to the validation call being made.
+ """
+ # Determine the existing instance, if this is an update operation.
+ self.instance = getattr(serializer, 'instance', None)
+
+ def enforce_required_fields(self, attrs):
+ """
+ The `UniqueTogetherValidator` always forces an implied 'required'
+ state on the fields it applies to.
+ """
+ if self.instance is not None:
+ return
+
+ missing = dict([
+ (field_name, self.missing_message)
+ for field_name in self.fields
+ if field_name not in attrs
+ ])
+ if missing:
+ raise ValidationError(missing)
+
+ def filter_queryset(self, attrs, queryset):
+ """
+ Filter the queryset to all instances matching the given attributes.
+ """
+ # If this is an update, then any unprovided field should
+ # have it's value set based on the existing instance attribute.
+ if self.instance is not None:
+ for field_name in self.fields:
+ if field_name not in attrs:
+ attrs[field_name] = getattr(self.instance, field_name)
+
+ # Determine the filter keyword arguments and filter the queryset.
+ filter_kwargs = dict([
+ (field_name, attrs[field_name])
+ for field_name in self.fields
+ ])
+ return queryset.filter(**filter_kwargs)
+
+ def exclude_current_instance(self, attrs, queryset):
+ """
+ If an instance is being updated, then do not include
+ that instance itself as a uniqueness conflict.
+ """
+ if self.instance is not None:
+ return queryset.exclude(pk=self.instance.pk)
+ return queryset
+
+ def __call__(self, attrs):
+ self.enforce_required_fields(attrs)
+ queryset = self.queryset
+ queryset = self.filter_queryset(attrs, queryset)
+ queryset = self.exclude_current_instance(attrs, queryset)
+ if queryset.exists():
+ field_names = ', '.join(self.fields)
+ raise ValidationError(self.message.format(field_names=field_names))
+
+ def __repr__(self):
+ return '<%s(queryset=%s, fields=%s)>' % (
+ self.__class__.__name__,
+ smart_repr(self.queryset),
+ smart_repr(self.fields)
+ )
+
+
+class BaseUniqueForValidator:
+ message = None
+ missing_message = _('This field is required.')
+
+ def __init__(self, queryset, field, date_field, message=None):
+ self.queryset = queryset
+ self.field = field
+ self.date_field = date_field
+ self.message = message or self.message
+
+ def set_context(self, serializer):
+ """
+ This hook is called by the serializer instance,
+ prior to the validation call being made.
+ """
+ # Determine the underlying model field names. These may not be the
+ # same as the serializer field names if `source=<>` is set.
+ self.field_name = serializer.fields[self.field].source_attrs[0]
+ self.date_field_name = serializer.fields[self.date_field].source_attrs[0]
+ # Determine the existing instance, if this is an update operation.
+ self.instance = getattr(serializer, 'instance', None)
+
+ def enforce_required_fields(self, attrs):
+ """
+ The `UniqueFor<Range>Validator` classes always force an implied
+ 'required' state on the fields they are applied to.
+ """
+ missing = dict([
+ (field_name, self.missing_message)
+ for field_name in [self.field, self.date_field]
+ if field_name not in attrs
+ ])
+ if missing:
+ raise ValidationError(missing)
+
+ def filter_queryset(self, attrs, queryset):
+ raise NotImplementedError('`filter_queryset` must be implemented.')
+
+ def exclude_current_instance(self, attrs, queryset):
+ """
+ If an instance is being updated, then do not include
+ that instance itself as a uniqueness conflict.
+ """
+ if self.instance is not None:
+ return queryset.exclude(pk=self.instance.pk)
+ return queryset
+
+ def __call__(self, attrs):
+ self.enforce_required_fields(attrs)
+ queryset = self.queryset
+ queryset = self.filter_queryset(attrs, queryset)
+ queryset = self.exclude_current_instance(attrs, queryset)
+ if queryset.exists():
+ message = self.message.format(date_field=self.date_field)
+ raise ValidationError({self.field: message})
+
+ def __repr__(self):
+ return '<%s(queryset=%s, field=%s, date_field=%s)>' % (
+ self.__class__.__name__,
+ smart_repr(self.queryset),
+ smart_repr(self.field),
+ smart_repr(self.date_field)
+ )
+
+
+class UniqueForDateValidator(BaseUniqueForValidator):
+ message = _('This field must be unique for the "{date_field}" date.')
+
+ def filter_queryset(self, attrs, queryset):
+ value = attrs[self.field]
+ date = attrs[self.date_field]
+
+ filter_kwargs = {}
+ filter_kwargs[self.field_name] = value
+ filter_kwargs['%s__day' % self.date_field_name] = date.day
+ filter_kwargs['%s__month' % self.date_field_name] = date.month
+ filter_kwargs['%s__year' % self.date_field_name] = date.year
+ return queryset.filter(**filter_kwargs)
+
+
+class UniqueForMonthValidator(BaseUniqueForValidator):
+ message = _('This field must be unique for the "{date_field}" month.')
+
+ def filter_queryset(self, attrs, queryset):
+ value = attrs[self.field]
+ date = attrs[self.date_field]
+
+ filter_kwargs = {}
+ filter_kwargs[self.field_name] = value
+ filter_kwargs['%s__month' % self.date_field_name] = date.month
+ return queryset.filter(**filter_kwargs)
+
+
+class UniqueForYearValidator(BaseUniqueForValidator):
+ message = _('This field must be unique for the "{date_field}" year.')
+
+ def filter_queryset(self, attrs, queryset):
+ value = attrs[self.field]
+ date = attrs[self.date_field]
+
+ filter_kwargs = {}
+ filter_kwargs[self.field_name] = value
+ filter_kwargs['%s__year' % self.date_field_name] = date.year
+ return queryset.filter(**filter_kwargs)