aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--rest_framework/fields.py4
-rw-r--r--rest_framework/utils/field_mapping.py5
-rw-r--r--rest_framework/validators.py57
-rw-r--r--tests/test_validators.py35
4 files changed, 101 insertions, 0 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index f4b53279..231f693c 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -150,6 +150,10 @@ class Field(object):
messages.update(error_messages or {})
self.error_messages = messages
+ for validator in validators:
+ if getattr(validator, 'requires_context', False):
+ validator.serializer_field = self
+
def bind(self, field_name, parent):
"""
Initializes the field name and parent for the field instance.
diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py
index c3794083..cf9d910a 100644
--- a/rest_framework/utils/field_mapping.py
+++ b/rest_framework/utils/field_mapping.py
@@ -6,6 +6,7 @@ from django.core import validators
from django.db import models
from django.utils.text import capfirst
from rest_framework.compat import clean_manytomany_helptext
+from rest_framework.validators import UniqueValidator
import inspect
@@ -156,6 +157,10 @@ def get_field_kwargs(field_name, model_field):
if validator is not validators.validate_slug
]
+ if getattr(model_field, 'unique', False):
+ validator = UniqueValidator(queryset=model_field.model._default_manager)
+ validator_kwarg.append(validator)
+
max_digits = getattr(model_field, 'max_digits', None)
if max_digits is not None:
kwargs['max_digits'] = max_digits
diff --git a/rest_framework/validators.py b/rest_framework/validators.py
new file mode 100644
index 00000000..f5fbeb3c
--- /dev/null
+++ b/rest_framework/validators.py
@@ -0,0 +1,57 @@
+from django.core.exceptions import ValidationError
+
+
+class UniqueValidator:
+ # Validators with `requires_context` will have the field instance
+ # passed to them when the field is instantiated.
+ requires_context = True
+
+ def __init__(self, queryset):
+ self.queryset = queryset
+ self.serializer_field = None
+
+ def get_queryset(self):
+ return self.queryset.all()
+
+ def __call__(self, value):
+ field = self.serializer_field
+
+ # Determine the model field name that the serializer field corresponds to.
+ field_name = field.source_attrs[0] if field.source_attrs else field.field_name
+
+ # Determine the existing instance, if this is an update operation.
+ instance = getattr(field.parent, 'instance', None)
+
+ # Ensure uniqueness.
+ filter_kwargs = {field_name: value}
+ queryset = self.get_queryset().filter(**filter_kwargs)
+ if instance:
+ queryset = queryset.exclude(pk=instance.pk)
+ if queryset.exists():
+ raise ValidationError('This field must be unique.')
+
+
+class UniqueTogetherValidator:
+ requires_context = True
+
+ def __init__(self, queryset, fields):
+ self.queryset = queryset
+ self.fields = fields
+ self.serializer_field = None
+
+ def __call__(self, value):
+ serializer = self.serializer_field
+
+ # Determine the existing instance, if this is an update operation.
+ instance = getattr(serializer, 'instance', None)
+
+ # Ensure uniqueness.
+ filter_kwargs = dict([
+ (field_name, value[field_name]) for field_name in self.fields
+ ])
+ queryset = self.get_queryset().filter(**filter_kwargs)
+ if instance:
+ queryset = queryset.exclude(pk=instance.pk)
+ if queryset.exists():
+ field_names = ' and '.join(self.fields)
+ raise ValidationError('The fields %s must make a unique set.' % field_names)
diff --git a/tests/test_validators.py b/tests/test_validators.py
new file mode 100644
index 00000000..a1366a1a
--- /dev/null
+++ b/tests/test_validators.py
@@ -0,0 +1,35 @@
+from django.db import models
+from django.test import TestCase
+from rest_framework import serializers
+
+
+class ExampleModel(models.Model):
+ username = models.CharField(unique=True, max_length=100)
+
+
+class ExampleSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ExampleModel
+
+
+class TestUniquenessValidation(TestCase):
+ def setUp(self):
+ self.instance = ExampleModel.objects.create(username='existing')
+
+ def test_is_not_unique(self):
+ data = {'username': 'existing'}
+ serializer = ExampleSerializer(data=data)
+ assert not serializer.is_valid()
+ assert serializer.errors == {'username': ['This field must be unique.']}
+
+ def test_is_unique(self):
+ data = {'username': 'other'}
+ serializer = ExampleSerializer(data=data)
+ assert serializer.is_valid()
+ assert serializer.validated_data == {'username': 'other'}
+
+ def test_updated_instance_excluded(self):
+ data = {'username': 'existing'}
+ serializer = ExampleSerializer(self.instance, data=data)
+ assert serializer.is_valid()
+ assert serializer.validated_data == {'username': 'existing'}