diff options
| -rw-r--r-- | rest_framework/fields.py | 10 | ||||
| -rw-r--r-- | rest_framework/serializers.py | 26 | ||||
| -rw-r--r-- | rest_framework/validators.py | 121 | ||||
| -rw-r--r-- | tests/test_validators.py | 68 | 
4 files changed, 191 insertions, 34 deletions
| diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 2da4aa8b..e939b2f2 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -268,11 +268,17 @@ class Field(object):          """          errors = []          for validator in self.validators: -            if getattr(validator, 'requires_context', False): -                validator.serializer_field = self +            if hasattr(validator, 'set_context'): +                validator.set_context(self) +              try:                  validator(value)              except ValidationError as exc: +                # If the validation error contains a mapping of fields to +                # errors then simply raise it immediately rather than +                # attempting to accumulate a list of errors. +                if isinstance(exc.detail, dict): +                    raise                  errors.extend(exc.detail)              except DjangoValidationError as exc:                  errors.extend(exc.messages) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 59f38a73..5770fcf6 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -23,7 +23,9 @@ from rest_framework.utils.field_mapping import (      get_relation_kwargs, get_nested_relation_kwargs,      ClassLookupDict  ) -from rest_framework.validators import UniqueTogetherValidator +from rest_framework.validators import ( +    UniqueForDateValidator, UniqueTogetherValidator +)  import copy  import inspect  import warnings @@ -578,15 +580,9 @@ class ModelSerializer(Serializer):          validators = []          model_class = self.Meta.model -        for unique_together in model_class._meta.unique_together: -            if field_names.issuperset(set(unique_together)): -                validator = UniqueTogetherValidator( -                    queryset=model_class._default_manager, -                    fields=unique_together -                ) -                validators.append(validator) - -        for parent_class in model_class._meta.parents.keys(): +        # Note that we make sure to check `unique_together` both on the +        # base model class, but also on any parent classes. +        for parent_class in [model_class] + list(model_class._meta.parents.keys()):              for unique_together in parent_class._meta.unique_together:                  if field_names.issuperset(set(unique_together)):                      validator = UniqueTogetherValidator( @@ -595,6 +591,16 @@ class ModelSerializer(Serializer):                      )                      validators.append(validator) +        info = model_meta.get_field_info(model_class) +        for field_name, field in info.fields_and_pk.items(): +            if field.unique_for_date and field_name in field_names: +                validator = UniqueForDateValidator( +                    queryset=model_class._default_manager, +                    field=field_name, +                    date_field=field.unique_for_date +                ) +                validators.append(validator) +          return validators      def _get_base_fields(self): diff --git a/rest_framework/validators.py b/rest_framework/validators.py index f76faaa4..e302a0e4 100644 --- a/rest_framework/validators.py +++ b/rest_framework/validators.py @@ -6,38 +6,36 @@ This gives us better separation of concerns, allows us to use single-step  object creation, and makes it possible to switch between using the implicit  `ModelSerializer` class and an equivelent explicit `Serializer` class.  """ -from django.core.exceptions import ValidationError  from django.utils.translation import ugettext_lazy as _ +from rest_framework.exceptions import ValidationError  from rest_framework.utils.representation import smart_repr  class UniqueValidator:      """      Validator that corresponds to `unique=True` on a model field. + +    Should be applied to an individual field on the serializer.      """ -    # Validators with `requires_context` will have the field instance -    # passed to them when the field is instantiated. -    requires_context = True      message = _('This field must be unique.')      def __init__(self, queryset):          self.queryset = queryset          self.serializer_field = None -    def __call__(self, value): -        field = self.serializer_field - -        # Determine the model field name that the serializer field corresponds to. -        field_name = field.source_attrs[0] if field.source_attrs else field.field_name - +    def set_context(self, serializer_field): +        # Determine the underlying model field name. This may not be the +        # same as the serializer field name if `source=<>` is set. +        self.field_name = serializer_field.source_attrs[0]          # Determine the existing instance, if this is an update operation. -        instance = getattr(field.parent, 'instance', None) +        self.instance = getattr(serializer_field.parent, 'instance', None) +    def __call__(self, value):          # Ensure uniqueness. -        filter_kwargs = {field_name: value} +        filter_kwargs = {self.field_name: value}          queryset = self.queryset.filter(**filter_kwargs) -        if instance: -            queryset = queryset.exclude(pk=instance.pk) +        if self.instance is not None: +            queryset = queryset.exclude(pk=self.instance.pk)          if queryset.exists():              raise ValidationError(self.message) @@ -51,8 +49,9 @@ class UniqueValidator:  class UniqueTogetherValidator:      """      Validator that corresponds to `unique_together = (...)` on a model class. + +    Should be applied to the serializer class, not to an individual field.      """ -    requires_context = True      message = _('The fields {field_names} must make a unique set.')      def __init__(self, queryset, fields): @@ -60,19 +59,18 @@ class UniqueTogetherValidator:          self.fields = fields          self.serializer_field = None -    def __call__(self, value): -        serializer = self.serializer_field - +    def set_context(self, serializer):          # Determine the existing instance, if this is an update operation. -        instance = getattr(serializer, 'instance', None) +        self.instance = getattr(serializer, 'instance', None) +    def __call__(self, attrs):          # Ensure uniqueness.          filter_kwargs = dict([ -            (field_name, value[field_name]) for field_name in self.fields +            (field_name, attrs[field_name]) for field_name in self.fields          ])          queryset = self.queryset.filter(**filter_kwargs) -        if instance: -            queryset = queryset.exclude(pk=instance.pk) +        if self.instance is not None: +            queryset = queryset.exclude(pk=self.instance.pk)          if queryset.exists():              field_names = ', '.join(self.fields)              raise ValidationError(self.message.format(field_names=field_names)) @@ -83,3 +81,82 @@ class UniqueTogetherValidator:              smart_repr(self.queryset),              smart_repr(self.fields)          ) + + +class BaseUniqueForValidator: +    message = None + +    def __init__(self, queryset, field, date_field): +        self.queryset = queryset +        self.field = field +        self.date_field = date_field + +    def set_context(self, serializer): +        # Determine the underlying model field names. These may not be the +        # same as the serializer field names if `source=<>` is set. +        self.field_name = serializer.fields[self.field].source_attrs[0] +        self.date_field_name = serializer.fields[self.date_field].source_attrs[0] +        # Determine the existing instance, if this is an update operation. +        self.instance = getattr(serializer, 'instance', None) + +    def get_filter_kwargs(self, attrs): +        raise NotImplementedError('`get_filter_kwargs` must be implemented.') + +    def __call__(self, attrs): +        filter_kwargs = self.get_filter_kwargs(attrs) + +        queryset = self.queryset.filter(**filter_kwargs) +        if self.instance is not None: +            queryset = queryset.exclude(pk=self.instance.pk) +        if queryset.exists(): +            message = self.message.format(date_field=self.date_field) +            raise ValidationError({self.field: message}) + +    def __repr__(self): +        return '<%s(queryset=%s, field=%s, date_field=%s)>' % ( +            self.__class__.__name__, +            smart_repr(self.queryset), +            smart_repr(self.field), +            smart_repr(self.date_field) +        ) + + +class UniqueForDateValidator(BaseUniqueForValidator): +    message = _('This field must be unique for the "{date_field}" date.') + +    def get_filter_kwargs(self, attrs): +        value = attrs[self.field] +        date = attrs[self.date_field] + +        filter_kwargs = {} +        filter_kwargs[self.field_name] = value +        filter_kwargs['%s__day' % self.date_field_name] = date.day +        filter_kwargs['%s__month' % self.date_field_name] = date.month +        filter_kwargs['%s__year' % self.date_field_name] = date.year +        return filter_kwargs + + +class UniqueForMonthValidator(BaseUniqueForValidator): +    message = _('This field must be unique for the "{date_field}" month.') + +    def get_filter_kwargs(self, attrs): +        value = attrs[self.field] +        date = attrs[self.date_field] + +        filter_kwargs = {} +        filter_kwargs[self.field_name] = value +        filter_kwargs['%s__month' % self.date_field_name] = date.month +        return filter_kwargs + + +class UniqueForYearValidator(BaseUniqueForValidator): +    message = _('This field must be unique for the "{date_field}" year.') + +    def get_filter_kwargs(self, attrs): +        value = attrs[self.field] +        date = attrs[self.date_field] + +        filter_kwargs = {} +        filter_kwargs[self.field_name] = value +        filter_kwargs['%s__year' % self.date_field_name] = date.year +        return filter_kwargs diff --git a/tests/test_validators.py b/tests/test_validators.py index 1d081411..5adb7678 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -1,6 +1,7 @@  from django.db import models  from django.test import TestCase  from rest_framework import serializers +import datetime  def dedent(blocktext): @@ -147,3 +148,70 @@ class TestUniquenessTogetherValidation(TestCase):                  race_name = CharField(max_length=100)          """)          assert repr(serializer) == expected + + +# Tests for `UniqueForDateValidator` +# ---------------------------------- + +class UniqueForDateModel(models.Model): +    slug = models.CharField(max_length=100, unique_for_date='published') +    published = models.DateField() + + +class UniqueForDateSerializer(serializers.ModelSerializer): +    class Meta: +        model = UniqueForDateModel + + +class TestUniquenessForDateValidation(TestCase): +    def setUp(self): +        self.instance = UniqueForDateModel.objects.create( +            slug='existing', +            published='2000-01-01' +        ) + +    def test_repr(self): +        serializer = UniqueForDateSerializer() +        expected = dedent(""" +            UniqueForDateSerializer(validators=[<UniqueForDateValidator(queryset=UniqueForDateModel.objects.all(), field='slug', date_field='published')>]): +                id = IntegerField(label='ID', read_only=True) +                slug = CharField(max_length=100) +                published = DateField() +        """) +        assert repr(serializer) == expected + +    def test_is_not_unique_for_date(self): +        """ +        Failing unique for date validation should result in field error. +        """ +        data = {'slug': 'existing', 'published': '2000-01-01'} +        serializer = UniqueForDateSerializer(data=data) +        assert not serializer.is_valid() +        assert serializer.errors == { +            'slug': ['This field must be unique for the "published" date.'] +        } + +    def test_is_unique_for_date(self): +        """ +        Passing unique for date validation. +        """ +        data = {'slug': 'existing', 'published': '2000-01-02'} +        serializer = UniqueForDateSerializer(data=data) +        assert serializer.is_valid() +        assert serializer.validated_data == { +            'slug': 'existing', +            'published': datetime.date(2000, 1, 2) +        } + +    def test_updated_instance_excluded_from_unique_for_date(self): +        """ +        When performing an update, the existing instance does not count +        as a match against unique_for_date. +        """ +        data = {'slug': 'existing', 'published': '2000-01-01'} +        serializer = UniqueForDateSerializer(instance=self.instance, data=data) +        assert serializer.is_valid() +        assert serializer.validated_data == { +            'slug': 'existing', +            'published': datetime.date(2000, 1, 1) +        } | 
