diff options
| author | José Padilla | 2014-11-28 12:14:40 -0400 |
|---|---|---|
| committer | José Padilla | 2014-11-28 12:14:40 -0400 |
| commit | 0cc990792c63caa8fa8fea62cea53b0d28157b55 (patch) | |
| tree | 7ea80a203cc8718150cd55e4403f3f4771160281 /rest_framework/validators.py | |
| parent | 1aa77830955dcdf829f65a9001b6b8900dfc8755 (diff) | |
| parent | 3a5b3772fefc3c2f2c0899947cbc07bfe6e6b5d2 (diff) | |
| download | django-rest-framework-0cc990792c63caa8fa8fea62cea53b0d28157b55.tar.bz2 | |
Merge branch 'version-3.1' into oauth_as_package
Conflicts:
requirements-test.txt
rest_framework/compat.py
tests/settings.py
tox.ini
Diffstat (limited to 'rest_framework/validators.py')
| -rw-r--r-- | rest_framework/validators.py | 254 |
1 files changed, 254 insertions, 0 deletions
diff --git a/rest_framework/validators.py b/rest_framework/validators.py new file mode 100644 index 00000000..7ca4e6a9 --- /dev/null +++ b/rest_framework/validators.py @@ -0,0 +1,254 @@ +""" +We perform uniqueness checks explicitly on the serializer class, rather +the using Django's `.full_clean()`. + +This gives us better separation of concerns, allows us to use single-step +object creation, and makes it possible to switch between using the implicit +`ModelSerializer` class and an equivelent explicit `Serializer` class. +""" +from django.utils.translation import ugettext_lazy as _ +from rest_framework.exceptions import ValidationError +from rest_framework.utils.representation import smart_repr + + +class UniqueValidator: + """ + Validator that corresponds to `unique=True` on a model field. + + Should be applied to an individual field on the serializer. + """ + message = _('This field must be unique.') + + def __init__(self, queryset, message=None): + self.queryset = queryset + self.serializer_field = None + self.message = message or self.message + + def set_context(self, serializer_field): + """ + This hook is called by the serializer instance, + prior to the validation call being made. + """ + # Determine the underlying model field name. This may not be the + # same as the serializer field name if `source=<>` is set. + self.field_name = serializer_field.source_attrs[0] + # Determine the existing instance, if this is an update operation. + self.instance = getattr(serializer_field.parent, 'instance', None) + + def filter_queryset(self, value, queryset): + """ + Filter the queryset to all instances matching the given attribute. + """ + filter_kwargs = {self.field_name: value} + return queryset.filter(**filter_kwargs) + + def exclude_current_instance(self, queryset): + """ + If an instance is being updated, then do not include + that instance itself as a uniqueness conflict. + """ + if self.instance is not None: + return queryset.exclude(pk=self.instance.pk) + return queryset + + def __call__(self, value): + queryset = self.queryset + queryset = self.filter_queryset(value, queryset) + queryset = self.exclude_current_instance(queryset) + if queryset.exists(): + raise ValidationError(self.message) + + def __repr__(self): + return '<%s(queryset=%s)>' % ( + self.__class__.__name__, + smart_repr(self.queryset) + ) + + +class UniqueTogetherValidator: + """ + Validator that corresponds to `unique_together = (...)` on a model class. + + Should be applied to the serializer class, not to an individual field. + """ + message = _('The fields {field_names} must make a unique set.') + missing_message = _('This field is required.') + + def __init__(self, queryset, fields, message=None): + self.queryset = queryset + self.fields = fields + self.serializer_field = None + self.message = message or self.message + + def set_context(self, serializer): + """ + This hook is called by the serializer instance, + prior to the validation call being made. + """ + # Determine the existing instance, if this is an update operation. + self.instance = getattr(serializer, 'instance', None) + + def enforce_required_fields(self, attrs): + """ + The `UniqueTogetherValidator` always forces an implied 'required' + state on the fields it applies to. + """ + if self.instance is not None: + return + + missing = dict([ + (field_name, self.missing_message) + for field_name in self.fields + if field_name not in attrs + ]) + if missing: + raise ValidationError(missing) + + def filter_queryset(self, attrs, queryset): + """ + Filter the queryset to all instances matching the given attributes. + """ + # If this is an update, then any unprovided field should + # have it's value set based on the existing instance attribute. + if self.instance is not None: + for field_name in self.fields: + if field_name not in attrs: + attrs[field_name] = getattr(self.instance, field_name) + + # Determine the filter keyword arguments and filter the queryset. + filter_kwargs = dict([ + (field_name, attrs[field_name]) + for field_name in self.fields + ]) + return queryset.filter(**filter_kwargs) + + def exclude_current_instance(self, attrs, queryset): + """ + If an instance is being updated, then do not include + that instance itself as a uniqueness conflict. + """ + if self.instance is not None: + return queryset.exclude(pk=self.instance.pk) + return queryset + + def __call__(self, attrs): + self.enforce_required_fields(attrs) + queryset = self.queryset + queryset = self.filter_queryset(attrs, queryset) + queryset = self.exclude_current_instance(attrs, queryset) + if queryset.exists(): + field_names = ', '.join(self.fields) + raise ValidationError(self.message.format(field_names=field_names)) + + def __repr__(self): + return '<%s(queryset=%s, fields=%s)>' % ( + self.__class__.__name__, + smart_repr(self.queryset), + smart_repr(self.fields) + ) + + +class BaseUniqueForValidator: + message = None + missing_message = _('This field is required.') + + def __init__(self, queryset, field, date_field, message=None): + self.queryset = queryset + self.field = field + self.date_field = date_field + self.message = message or self.message + + def set_context(self, serializer): + """ + This hook is called by the serializer instance, + prior to the validation call being made. + """ + # Determine the underlying model field names. These may not be the + # same as the serializer field names if `source=<>` is set. + self.field_name = serializer.fields[self.field].source_attrs[0] + self.date_field_name = serializer.fields[self.date_field].source_attrs[0] + # Determine the existing instance, if this is an update operation. + self.instance = getattr(serializer, 'instance', None) + + def enforce_required_fields(self, attrs): + """ + The `UniqueFor<Range>Validator` classes always force an implied + 'required' state on the fields they are applied to. + """ + missing = dict([ + (field_name, self.missing_message) + for field_name in [self.field, self.date_field] + if field_name not in attrs + ]) + if missing: + raise ValidationError(missing) + + def filter_queryset(self, attrs, queryset): + raise NotImplementedError('`filter_queryset` must be implemented.') + + def exclude_current_instance(self, attrs, queryset): + """ + If an instance is being updated, then do not include + that instance itself as a uniqueness conflict. + """ + if self.instance is not None: + return queryset.exclude(pk=self.instance.pk) + return queryset + + def __call__(self, attrs): + self.enforce_required_fields(attrs) + queryset = self.queryset + queryset = self.filter_queryset(attrs, queryset) + queryset = self.exclude_current_instance(attrs, queryset) + if queryset.exists(): + message = self.message.format(date_field=self.date_field) + raise ValidationError({self.field: message}) + + def __repr__(self): + return '<%s(queryset=%s, field=%s, date_field=%s)>' % ( + self.__class__.__name__, + smart_repr(self.queryset), + smart_repr(self.field), + smart_repr(self.date_field) + ) + + +class UniqueForDateValidator(BaseUniqueForValidator): + message = _('This field must be unique for the "{date_field}" date.') + + def filter_queryset(self, attrs, queryset): + value = attrs[self.field] + date = attrs[self.date_field] + + filter_kwargs = {} + filter_kwargs[self.field_name] = value + filter_kwargs['%s__day' % self.date_field_name] = date.day + filter_kwargs['%s__month' % self.date_field_name] = date.month + filter_kwargs['%s__year' % self.date_field_name] = date.year + return queryset.filter(**filter_kwargs) + + +class UniqueForMonthValidator(BaseUniqueForValidator): + message = _('This field must be unique for the "{date_field}" month.') + + def filter_queryset(self, attrs, queryset): + value = attrs[self.field] + date = attrs[self.date_field] + + filter_kwargs = {} + filter_kwargs[self.field_name] = value + filter_kwargs['%s__month' % self.date_field_name] = date.month + return queryset.filter(**filter_kwargs) + + +class UniqueForYearValidator(BaseUniqueForValidator): + message = _('This field must be unique for the "{date_field}" year.') + + def filter_queryset(self, attrs, queryset): + value = attrs[self.field] + date = attrs[self.date_field] + + filter_kwargs = {} + filter_kwargs[self.field_name] = value + filter_kwargs['%s__year' % self.date_field_name] = date.year + return queryset.filter(**filter_kwargs) |
