diff options
| author | Tom Christie | 2014-09-22 14:54:33 +0100 | 
|---|---|---|
| committer | Tom Christie | 2014-09-22 14:54:33 +0100 | 
| commit | 249253a144ba4381581809fb3f27959c7bd6e577 (patch) | |
| tree | 2af20e6b13a344a1af960b2453e1678dcf63b40f | |
| parent | c54f394904c3f93211b8aa073de4e9e50110f831 (diff) | |
| download | django-rest-framework-249253a144ba4381581809fb3f27959c7bd6e577.tar.bz2 | |
Fix compat issues
| -rw-r--r-- | rest_framework/fields.py | 13 | ||||
| -rw-r--r-- | tests/test_fields.py | 31 | 
2 files changed, 30 insertions, 14 deletions
| diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 5105dfcb..5fb99a42 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -209,8 +209,10 @@ class Field(object):          """          Validate a simple representation and return the internal value. -        The provided data may be `empty` if no representation was included. -        May return `empty` if the field should not be included in the +        The provided data may be `empty` if no representation was included +        in the input. + +        May raise `SkipField` if the field should not be included in the          validated data.          """          if data is empty: @@ -223,6 +225,10 @@ class Field(object):          return value      def run_validators(self, value): +        """ +        Test the given value against all the validators on the field, +        and either raise a `ValidationError` or simply return. +        """          if value in (None, '', [], (), {}):              return @@ -753,8 +759,9 @@ class MultipleChoiceField(ChoiceField):      }      def to_internal_value(self, data): -        if not hasattr(data, '__iter__'): +        if isinstance(data, type('')) or not hasattr(data, '__iter__'):              self.fail('not_a_list', input_type=type(data).__name__) +          return set([              super(MultipleChoiceField, self).to_internal_value(item)              for item in data diff --git a/tests/test_fields.py b/tests/test_fields.py index ae7f1919..e03ece54 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -5,22 +5,31 @@ import datetime  import pytest +def get_items(mapping_or_list_of_two_tuples): +    # Tests accept either lists of two tuples, or dictionaries. +    if isinstance(mapping_or_list_of_two_tuples, dict): +        # {value: expected} +        return mapping_or_list_of_two_tuples.items() +    # [(value, expected), ...] +    return mapping_or_list_of_two_tuples + +  class ValidAndInvalidValues:      """ -    Base class for testing valid and invalid field values. +    Base class for testing valid and invalid input values.      """      def test_valid_values(self):          """          Ensure that valid values return the expected validated data.          """ -        for input_value, expected_output in self.valid_mappings.items(): +        for input_value, expected_output in get_items(self.valid_mappings):              assert self.field.run_validation(input_value) == expected_output      def test_invalid_values(self):          """          Ensure that invalid values raise the expected validation error.          """ -        for input_value, expected_failure in self.invalid_mappings.items(): +        for input_value, expected_failure in get_items(self.invalid_mappings):              with pytest.raises(fields.ValidationError) as exc_info:                  self.field.run_validation(input_value)              assert exc_info.value.messages == expected_failure @@ -189,14 +198,14 @@ class TestDecimalField(ValidAndInvalidValues):          12.3: Decimal('12.3'),          0.1: Decimal('0.1'),      } -    invalid_mappings = { -        'abc': ["A valid number is required."], -        Decimal('Nan'): ["A valid number is required."], -        Decimal('Inf'): ["A valid number is required."], -        '12.345': ["Ensure that there are no more than 3 digits in total."], -        '0.01': ["Ensure that there are no more than 1 decimal places."], -        123: ["Ensure that there are no more than 2 digits before the decimal point."] -    } +    invalid_mappings = ( +        ('abc', ["A valid number is required."]), +        (Decimal('Nan'), ["A valid number is required."]), +        (Decimal('Inf'), ["A valid number is required."]), +        ('12.345', ["Ensure that there are no more than 3 digits in total."]), +        ('0.01', ["Ensure that there are no more than 1 decimal places."]), +        (123, ["Ensure that there are no more than 2 digits before the decimal point."]) +    )      field = fields.DecimalField(max_digits=3, decimal_places=1) | 
