aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/fields.py
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework/fields.py')
-rw-r--r--rest_framework/fields.py85
1 files changed, 79 insertions, 6 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index d18551b3..250c0579 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -1,4 +1,6 @@
-from rest_framework.exceptions import ValidationError
+from django.core import validators
+from django.core.exceptions import ValidationError
+from django.utils.encoding import is_protected_type
from rest_framework.utils import html
import inspect
@@ -33,9 +35,14 @@ def get_attribute(instance, attrs):
"""
Similar to Python's built in `getattr(instance, attr)`,
but takes a list of nested attributes, instead of a single attribute.
+
+ Also accepts either attribute lookup on objects or dictionary lookups.
"""
for attr in attrs:
- instance = getattr(instance, attr)
+ try:
+ instance = getattr(instance, attr)
+ except AttributeError:
+ return instance[attr]
return instance
@@ -80,9 +87,11 @@ class Field(object):
'not exist in the `MESSAGES` dictionary.'
)
+ default_validators = []
+
def __init__(self, read_only=False, write_only=False,
required=None, default=empty, initial=None, source=None,
- label=None, style=None, error_messages=None):
+ label=None, style=None, error_messages=None, validators=[]):
self._creation_counter = Field._creation_counter
Field._creation_counter += 1
@@ -104,6 +113,7 @@ class Field(object):
self.initial = initial
self.label = label
self.style = {} if style is None else style
+ self.validators = self.default_validators + validators
def bind(self, field_name, parent, root):
"""
@@ -176,8 +186,21 @@ class Field(object):
self.fail('required')
return self.get_default()
+ self.run_validators(data)
return self.to_native(data)
+ def run_validators(self, value):
+ if value in validators.EMPTY_VALUES:
+ return
+ errors = []
+ for validator in self.validators:
+ try:
+ validator(value)
+ except ValidationError as exc:
+ errors.extend(exc.messages)
+ if errors:
+ raise ValidationError(errors)
+
def to_native(self, data):
"""
Transform the *incoming* primative data into a native value.
@@ -322,9 +345,13 @@ class IntegerField(Field):
}
def __init__(self, **kwargs):
- self.max_value = kwargs.pop('max_value')
- self.min_value = kwargs.pop('min_value')
- super(CharField, self).__init__(**kwargs)
+ max_value = kwargs.pop('max_value', None)
+ min_value = kwargs.pop('min_value', None)
+ super(IntegerField, self).__init__(**kwargs)
+ if max_value is not None:
+ self.validators.append(validators.MaxValueValidator(max_value))
+ if min_value is not None:
+ self.validators.append(validators.MinValueValidator(min_value))
def to_native(self, data):
try:
@@ -392,3 +419,49 @@ class MethodField(Field):
attr = 'get_{field_name}'.format(field_name=self.field_name)
method = getattr(self.parent, attr)
return method(value)
+
+
+class ModelField(Field):
+ """
+ A generic field that can be used against an arbitrary model field.
+ """
+ def __init__(self, *args, **kwargs):
+ try:
+ self.model_field = kwargs.pop('model_field')
+ except KeyError:
+ raise ValueError("ModelField requires 'model_field' kwarg")
+
+ self.min_length = kwargs.pop('min_length',
+ getattr(self.model_field, 'min_length', None))
+ self.max_length = kwargs.pop('max_length',
+ getattr(self.model_field, 'max_length', None))
+ self.min_value = kwargs.pop('min_value',
+ getattr(self.model_field, 'min_value', None))
+ self.max_value = kwargs.pop('max_value',
+ getattr(self.model_field, 'max_value', None))
+
+ super(ModelField, self).__init__(*args, **kwargs)
+
+ if self.min_length is not None:
+ self.validators.append(validators.MinLengthValidator(self.min_length))
+ if self.max_length is not None:
+ self.validators.append(validators.MaxLengthValidator(self.max_length))
+ if self.min_value is not None:
+ self.validators.append(validators.MinValueValidator(self.min_value))
+ if self.max_value is not None:
+ self.validators.append(validators.MaxValueValidator(self.max_value))
+
+ def get_attribute(self, instance):
+ return get_attribute(instance, self.source_attrs[:-1])
+
+ def to_native(self, data):
+ rel = getattr(self.model_field, 'rel', None)
+ if rel is not None:
+ return rel.to._meta.get_field(rel.field_name).to_python(data)
+ return self.model_field.to_python(data)
+
+ def to_primative(self, obj):
+ value = self.model_field._get_val_from_obj(obj)
+ if is_protected_type(value):
+ return value
+ return self.model_field.value_to_string(obj)