diff options
Diffstat (limited to 'rest_framework/fields.py')
| -rw-r--r-- | rest_framework/fields.py | 85 |
1 files changed, 79 insertions, 6 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py index d18551b3..250c0579 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1,4 +1,6 @@ -from rest_framework.exceptions import ValidationError +from django.core import validators +from django.core.exceptions import ValidationError +from django.utils.encoding import is_protected_type from rest_framework.utils import html import inspect @@ -33,9 +35,14 @@ def get_attribute(instance, attrs): """ Similar to Python's built in `getattr(instance, attr)`, but takes a list of nested attributes, instead of a single attribute. + + Also accepts either attribute lookup on objects or dictionary lookups. """ for attr in attrs: - instance = getattr(instance, attr) + try: + instance = getattr(instance, attr) + except AttributeError: + return instance[attr] return instance @@ -80,9 +87,11 @@ class Field(object): 'not exist in the `MESSAGES` dictionary.' ) + default_validators = [] + def __init__(self, read_only=False, write_only=False, required=None, default=empty, initial=None, source=None, - label=None, style=None, error_messages=None): + label=None, style=None, error_messages=None, validators=[]): self._creation_counter = Field._creation_counter Field._creation_counter += 1 @@ -104,6 +113,7 @@ class Field(object): self.initial = initial self.label = label self.style = {} if style is None else style + self.validators = self.default_validators + validators def bind(self, field_name, parent, root): """ @@ -176,8 +186,21 @@ class Field(object): self.fail('required') return self.get_default() + self.run_validators(data) return self.to_native(data) + def run_validators(self, value): + if value in validators.EMPTY_VALUES: + return + errors = [] + for validator in self.validators: + try: + validator(value) + except ValidationError as exc: + errors.extend(exc.messages) + if errors: + raise ValidationError(errors) + def to_native(self, data): """ Transform the *incoming* primative data into a native value. @@ -322,9 +345,13 @@ class IntegerField(Field): } def __init__(self, **kwargs): - self.max_value = kwargs.pop('max_value') - self.min_value = kwargs.pop('min_value') - super(CharField, self).__init__(**kwargs) + max_value = kwargs.pop('max_value', None) + min_value = kwargs.pop('min_value', None) + super(IntegerField, self).__init__(**kwargs) + if max_value is not None: + self.validators.append(validators.MaxValueValidator(max_value)) + if min_value is not None: + self.validators.append(validators.MinValueValidator(min_value)) def to_native(self, data): try: @@ -392,3 +419,49 @@ class MethodField(Field): attr = 'get_{field_name}'.format(field_name=self.field_name) method = getattr(self.parent, attr) return method(value) + + +class ModelField(Field): + """ + A generic field that can be used against an arbitrary model field. + """ + def __init__(self, *args, **kwargs): + try: + self.model_field = kwargs.pop('model_field') + except KeyError: + raise ValueError("ModelField requires 'model_field' kwarg") + + self.min_length = kwargs.pop('min_length', + getattr(self.model_field, 'min_length', None)) + self.max_length = kwargs.pop('max_length', + getattr(self.model_field, 'max_length', None)) + self.min_value = kwargs.pop('min_value', + getattr(self.model_field, 'min_value', None)) + self.max_value = kwargs.pop('max_value', + getattr(self.model_field, 'max_value', None)) + + super(ModelField, self).__init__(*args, **kwargs) + + if self.min_length is not None: + self.validators.append(validators.MinLengthValidator(self.min_length)) + if self.max_length is not None: + self.validators.append(validators.MaxLengthValidator(self.max_length)) + if self.min_value is not None: + self.validators.append(validators.MinValueValidator(self.min_value)) + if self.max_value is not None: + self.validators.append(validators.MaxValueValidator(self.max_value)) + + def get_attribute(self, instance): + return get_attribute(instance, self.source_attrs[:-1]) + + def to_native(self, data): + rel = getattr(self.model_field, 'rel', None) + if rel is not None: + return rel.to._meta.get_field(rel.field_name).to_python(data) + return self.model_field.to_python(data) + + def to_primative(self, obj): + value = self.model_field._get_val_from_obj(obj) + if is_protected_type(value): + return value + return self.model_field.value_to_string(obj) |
