aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--rest_framework/fields.py13
-rw-r--r--tests/test_fields.py31
2 files changed, 30 insertions, 14 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 5105dfcb..5fb99a42 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -209,8 +209,10 @@ class Field(object):
"""
Validate a simple representation and return the internal value.
- The provided data may be `empty` if no representation was included.
- May return `empty` if the field should not be included in the
+ The provided data may be `empty` if no representation was included
+ in the input.
+
+ May raise `SkipField` if the field should not be included in the
validated data.
"""
if data is empty:
@@ -223,6 +225,10 @@ class Field(object):
return value
def run_validators(self, value):
+ """
+ Test the given value against all the validators on the field,
+ and either raise a `ValidationError` or simply return.
+ """
if value in (None, '', [], (), {}):
return
@@ -753,8 +759,9 @@ class MultipleChoiceField(ChoiceField):
}
def to_internal_value(self, data):
- if not hasattr(data, '__iter__'):
+ if isinstance(data, type('')) or not hasattr(data, '__iter__'):
self.fail('not_a_list', input_type=type(data).__name__)
+
return set([
super(MultipleChoiceField, self).to_internal_value(item)
for item in data
diff --git a/tests/test_fields.py b/tests/test_fields.py
index ae7f1919..e03ece54 100644
--- a/tests/test_fields.py
+++ b/tests/test_fields.py
@@ -5,22 +5,31 @@ import datetime
import pytest
+def get_items(mapping_or_list_of_two_tuples):
+ # Tests accept either lists of two tuples, or dictionaries.
+ if isinstance(mapping_or_list_of_two_tuples, dict):
+ # {value: expected}
+ return mapping_or_list_of_two_tuples.items()
+ # [(value, expected), ...]
+ return mapping_or_list_of_two_tuples
+
+
class ValidAndInvalidValues:
"""
- Base class for testing valid and invalid field values.
+ Base class for testing valid and invalid input values.
"""
def test_valid_values(self):
"""
Ensure that valid values return the expected validated data.
"""
- for input_value, expected_output in self.valid_mappings.items():
+ for input_value, expected_output in get_items(self.valid_mappings):
assert self.field.run_validation(input_value) == expected_output
def test_invalid_values(self):
"""
Ensure that invalid values raise the expected validation error.
"""
- for input_value, expected_failure in self.invalid_mappings.items():
+ for input_value, expected_failure in get_items(self.invalid_mappings):
with pytest.raises(fields.ValidationError) as exc_info:
self.field.run_validation(input_value)
assert exc_info.value.messages == expected_failure
@@ -189,14 +198,14 @@ class TestDecimalField(ValidAndInvalidValues):
12.3: Decimal('12.3'),
0.1: Decimal('0.1'),
}
- invalid_mappings = {
- 'abc': ["A valid number is required."],
- Decimal('Nan'): ["A valid number is required."],
- Decimal('Inf'): ["A valid number is required."],
- '12.345': ["Ensure that there are no more than 3 digits in total."],
- '0.01': ["Ensure that there are no more than 1 decimal places."],
- 123: ["Ensure that there are no more than 2 digits before the decimal point."]
- }
+ invalid_mappings = (
+ ('abc', ["A valid number is required."]),
+ (Decimal('Nan'), ["A valid number is required."]),
+ (Decimal('Inf'), ["A valid number is required."]),
+ ('12.345', ["Ensure that there are no more than 3 digits in total."]),
+ ('0.01', ["Ensure that there are no more than 1 decimal places."]),
+ (123, ["Ensure that there are no more than 2 digits before the decimal point."])
+ )
field = fields.DecimalField(max_digits=3, decimal_places=1)