diff options
| author | Tom Christie | 2014-10-22 13:30:28 +0100 | 
|---|---|---|
| committer | Tom Christie | 2014-10-22 13:30:28 +0100 | 
| commit | ae53fdff9c6bb3e81a1ec005134462f0d629688f (patch) | |
| tree | 3dc5590e7961491605b2009322f631af7ffbc01b /rest_framework/validators.py | |
| parent | c5d1be8eac6cdb5cce000ec7c55e1847bfcf2359 (diff) | |
| download | django-rest-framework-ae53fdff9c6bb3e81a1ec005134462f0d629688f.tar.bz2 | |
First pass at unique_for_date, unique_for_month, unique_for_year
Diffstat (limited to 'rest_framework/validators.py')
| -rw-r--r-- | rest_framework/validators.py | 121 | 
1 files changed, 99 insertions, 22 deletions
| diff --git a/rest_framework/validators.py b/rest_framework/validators.py index f76faaa4..e302a0e4 100644 --- a/rest_framework/validators.py +++ b/rest_framework/validators.py @@ -6,38 +6,36 @@ This gives us better separation of concerns, allows us to use single-step  object creation, and makes it possible to switch between using the implicit  `ModelSerializer` class and an equivelent explicit `Serializer` class.  """ -from django.core.exceptions import ValidationError  from django.utils.translation import ugettext_lazy as _ +from rest_framework.exceptions import ValidationError  from rest_framework.utils.representation import smart_repr  class UniqueValidator:      """      Validator that corresponds to `unique=True` on a model field. + +    Should be applied to an individual field on the serializer.      """ -    # Validators with `requires_context` will have the field instance -    # passed to them when the field is instantiated. -    requires_context = True      message = _('This field must be unique.')      def __init__(self, queryset):          self.queryset = queryset          self.serializer_field = None -    def __call__(self, value): -        field = self.serializer_field - -        # Determine the model field name that the serializer field corresponds to. -        field_name = field.source_attrs[0] if field.source_attrs else field.field_name - +    def set_context(self, serializer_field): +        # Determine the underlying model field name. This may not be the +        # same as the serializer field name if `source=<>` is set. +        self.field_name = serializer_field.source_attrs[0]          # Determine the existing instance, if this is an update operation. -        instance = getattr(field.parent, 'instance', None) +        self.instance = getattr(serializer_field.parent, 'instance', None) +    def __call__(self, value):          # Ensure uniqueness. -        filter_kwargs = {field_name: value} +        filter_kwargs = {self.field_name: value}          queryset = self.queryset.filter(**filter_kwargs) -        if instance: -            queryset = queryset.exclude(pk=instance.pk) +        if self.instance is not None: +            queryset = queryset.exclude(pk=self.instance.pk)          if queryset.exists():              raise ValidationError(self.message) @@ -51,8 +49,9 @@ class UniqueValidator:  class UniqueTogetherValidator:      """      Validator that corresponds to `unique_together = (...)` on a model class. + +    Should be applied to the serializer class, not to an individual field.      """ -    requires_context = True      message = _('The fields {field_names} must make a unique set.')      def __init__(self, queryset, fields): @@ -60,19 +59,18 @@ class UniqueTogetherValidator:          self.fields = fields          self.serializer_field = None -    def __call__(self, value): -        serializer = self.serializer_field - +    def set_context(self, serializer):          # Determine the existing instance, if this is an update operation. -        instance = getattr(serializer, 'instance', None) +        self.instance = getattr(serializer, 'instance', None) +    def __call__(self, attrs):          # Ensure uniqueness.          filter_kwargs = dict([ -            (field_name, value[field_name]) for field_name in self.fields +            (field_name, attrs[field_name]) for field_name in self.fields          ])          queryset = self.queryset.filter(**filter_kwargs) -        if instance: -            queryset = queryset.exclude(pk=instance.pk) +        if self.instance is not None: +            queryset = queryset.exclude(pk=self.instance.pk)          if queryset.exists():              field_names = ', '.join(self.fields)              raise ValidationError(self.message.format(field_names=field_names)) @@ -83,3 +81,82 @@ class UniqueTogetherValidator:              smart_repr(self.queryset),              smart_repr(self.fields)          ) + + +class BaseUniqueForValidator: +    message = None + +    def __init__(self, queryset, field, date_field): +        self.queryset = queryset +        self.field = field +        self.date_field = date_field + +    def set_context(self, serializer): +        # Determine the underlying model field names. These may not be the +        # same as the serializer field names if `source=<>` is set. +        self.field_name = serializer.fields[self.field].source_attrs[0] +        self.date_field_name = serializer.fields[self.date_field].source_attrs[0] +        # Determine the existing instance, if this is an update operation. +        self.instance = getattr(serializer, 'instance', None) + +    def get_filter_kwargs(self, attrs): +        raise NotImplementedError('`get_filter_kwargs` must be implemented.') + +    def __call__(self, attrs): +        filter_kwargs = self.get_filter_kwargs(attrs) + +        queryset = self.queryset.filter(**filter_kwargs) +        if self.instance is not None: +            queryset = queryset.exclude(pk=self.instance.pk) +        if queryset.exists(): +            message = self.message.format(date_field=self.date_field) +            raise ValidationError({self.field: message}) + +    def __repr__(self): +        return '<%s(queryset=%s, field=%s, date_field=%s)>' % ( +            self.__class__.__name__, +            smart_repr(self.queryset), +            smart_repr(self.field), +            smart_repr(self.date_field) +        ) + + +class UniqueForDateValidator(BaseUniqueForValidator): +    message = _('This field must be unique for the "{date_field}" date.') + +    def get_filter_kwargs(self, attrs): +        value = attrs[self.field] +        date = attrs[self.date_field] + +        filter_kwargs = {} +        filter_kwargs[self.field_name] = value +        filter_kwargs['%s__day' % self.date_field_name] = date.day +        filter_kwargs['%s__month' % self.date_field_name] = date.month +        filter_kwargs['%s__year' % self.date_field_name] = date.year +        return filter_kwargs + + +class UniqueForMonthValidator(BaseUniqueForValidator): +    message = _('This field must be unique for the "{date_field}" month.') + +    def get_filter_kwargs(self, attrs): +        value = attrs[self.field] +        date = attrs[self.date_field] + +        filter_kwargs = {} +        filter_kwargs[self.field_name] = value +        filter_kwargs['%s__month' % self.date_field_name] = date.month +        return filter_kwargs + + +class UniqueForYearValidator(BaseUniqueForValidator): +    message = _('This field must be unique for the "{date_field}" year.') + +    def get_filter_kwargs(self, attrs): +        value = attrs[self.field] +        date = attrs[self.date_field] + +        filter_kwargs = {} +        filter_kwargs[self.field_name] = value +        filter_kwargs['%s__year' % self.date_field_name] = date.year +        return filter_kwargs | 
