diff options
| author | Tom Christie | 2013-04-26 13:59:06 +0100 | 
|---|---|---|
| committer | Tom Christie | 2013-04-26 13:59:06 +0100 | 
| commit | d985aec3c9034ab1aa899e914a5ac0baf314cb3c (patch) | |
| tree | 4a6f47489a64114c2918947582c5d40b83160f99 /rest_framework | |
| parent | 50c6bc5762460ebd2a79b61edd534d10cb58c7e5 (diff) | |
| parent | cac669702596cdf768971267e6355fb9223a69e8 (diff) | |
| download | django-rest-framework-d985aec3c9034ab1aa899e914a5ac0baf314cb3c.tar.bz2 | |
DecimalField
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/fields.py | 70 | ||||
| -rw-r--r-- | rest_framework/serializers.py | 1 | ||||
| -rw-r--r-- | rest_framework/tests/fields.py | 165 | 
3 files changed, 236 insertions, 0 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 949f68d6..38fe025d 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -7,6 +7,7 @@ from __future__ import unicode_literals  import copy  import datetime +from decimal import Decimal, DecimalException  import inspect  import re  import warnings @@ -726,6 +727,75 @@ class FloatField(WritableField):              raise ValidationError(msg) +class DecimalField(WritableField): +    type_name = 'DecimalField' +    form_field_class = forms.DecimalField + +    default_error_messages = { +        'invalid': _('Enter a number.'), +        'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'), +        'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'), +        'max_digits': _('Ensure that there are no more than %s digits in total.'), +        'max_decimal_places': _('Ensure that there are no more than %s decimal places.'), +        'max_whole_digits': _('Ensure that there are no more than %s digits before the decimal point.') +    } + +    def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, *args, **kwargs): +        self.max_value, self.min_value = max_value, min_value +        self.max_digits, self.decimal_places = max_digits, decimal_places +        super(DecimalField, self).__init__(*args, **kwargs) + +        if max_value is not None: +            self.validators.append(validators.MaxValueValidator(max_value)) +        if min_value is not None: +            self.validators.append(validators.MinValueValidator(min_value)) + +    def from_native(self, value): +        """ +        Validates that the input is a decimal number. Returns a Decimal +        instance. Returns None for empty values. Ensures that there are no more +        than max_digits in the number, and no more than decimal_places digits +        after the decimal point. +        """ +        if value in validators.EMPTY_VALUES: +            return None +        value = smart_text(value).strip() +        try: +            value = Decimal(value) +        except DecimalException: +            raise ValidationError(self.error_messages['invalid']) +        return value + +    def validate(self, value): +        super(DecimalField, self).validate(value) +        if value in validators.EMPTY_VALUES: +            return +        # Check for NaN, Inf and -Inf values. We can't compare directly for NaN, +        # since it is never equal to itself. However, NaN is the only value that +        # isn't equal to itself, so we can use this to identify NaN +        if value != value or value == Decimal("Inf") or value == Decimal("-Inf"): +            raise ValidationError(self.error_messages['invalid']) +        sign, digittuple, exponent = value.as_tuple() +        decimals = abs(exponent) +        # digittuple doesn't include any leading zeros. +        digits = len(digittuple) +        if decimals > digits: +            # We have leading zeros up to or past the decimal point.  Count +            # everything past the decimal point as a digit.  We do not count +            # 0 before the decimal point as a digit since that would mean +            # we would not allow max_digits = decimal_places. +            digits = decimals +        whole_digits = digits - decimals + +        if self.max_digits is not None and digits > self.max_digits: +            raise ValidationError(self.error_messages['max_digits'] % self.max_digits) +        if self.decimal_places is not None and decimals > self.decimal_places: +            raise ValidationError(self.error_messages['max_decimal_places'] % self.decimal_places) +        if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places): +            raise ValidationError(self.error_messages['max_whole_digits'] % (self.max_digits - self.decimal_places)) +        return value + +  class FileField(WritableField):      use_files = True      type_name = 'FileField' diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index fb438b12..3d956e4d 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -560,6 +560,7 @@ class ModelSerializer(Serializer):          models.DateTimeField: DateTimeField,          models.DateField: DateField,          models.TimeField: TimeField, +        models.DecimalField: DecimalField,          models.EmailField: EmailField,          models.CharField: CharField,          models.URLField: URLField, diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py index 19c663d8..3cdfa0f6 100644 --- a/rest_framework/tests/fields.py +++ b/rest_framework/tests/fields.py @@ -3,12 +3,14 @@ General serializer field tests.  """  from __future__ import unicode_literals  import datetime +from decimal import Decimal  from django.db import models  from django.test import TestCase  from django.core import validators  from rest_framework import serializers +from rest_framework.serializers import Serializer  class TimestampedModel(models.Model): @@ -481,3 +483,166 @@ class TimeFieldTest(TestCase):          self.assertEqual('04 - 00 [000000]', result_1)          self.assertEqual('04 - 59 [000000]', result_2)          self.assertEqual('04 - 59 [000200]', result_3) + + +class DecimalFieldTest(TestCase): +    """ +    Tests for the DecimalField from_native() and to_native() behavior +    """ + +    def test_from_native_string(self): +        """ +        Make sure from_native() accepts string values +        """ +        f = serializers.DecimalField() +        result_1 = f.from_native('9000') +        result_2 = f.from_native('1.00000001') + +        self.assertEqual(Decimal('9000'), result_1) +        self.assertEqual(Decimal('1.00000001'), result_2) + +    def test_from_native_invalid_string(self): +        """ +        Make sure from_native() raises ValidationError on passing invalid string +        """ +        f = serializers.DecimalField() + +        try: +            f.from_native('123.45.6') +        except validators.ValidationError as e: +            self.assertEqual(e.messages, ["Enter a number."]) +        else: +            self.fail("ValidationError was not properly raised") + +    def test_from_native_integer(self): +        """ +        Make sure from_native() accepts integer values +        """ +        f = serializers.DecimalField() +        result = f.from_native(9000) + +        self.assertEqual(Decimal('9000'), result) + +    def test_from_native_float(self): +        """ +        Make sure from_native() accepts float values +        """ +        f = serializers.DecimalField() +        result = f.from_native(1.00000001) + +        self.assertEqual(Decimal('1.00000001'), result) + +    def test_from_native_empty(self): +        """ +        Make sure from_native() returns None on empty param. +        """ +        f = serializers.DecimalField() +        result = f.from_native('') + +        self.assertEqual(result, None) + +    def test_from_native_none(self): +        """ +        Make sure from_native() returns None on None param. +        """ +        f = serializers.DecimalField() +        result = f.from_native(None) + +        self.assertEqual(result, None) + +    def test_to_native(self): +        """ +        Make sure to_native() returns Decimal as string. +        """ +        f = serializers.DecimalField() + +        result_1 = f.to_native(Decimal('9000')) +        result_2 = f.to_native(Decimal('1.00000001')) + +        self.assertEqual(Decimal('9000'), result_1) +        self.assertEqual(Decimal('1.00000001'), result_2) + +    def test_to_native_none(self): +        """ +        Make sure from_native() returns None on None param. +        """ +        f = serializers.DecimalField(required=False) +        self.assertEqual(None, f.to_native(None)) + +    def test_valid_serialization(self): +        """ +        Make sure the serializer works correctly +        """ +        class DecimalSerializer(Serializer): +            decimal_field = serializers.DecimalField(max_value=9010, +                                                     min_value=9000, +                                                     max_digits=6, +                                                     decimal_places=2) + +        self.assertTrue(DecimalSerializer(data={'decimal_field': '9001'}).is_valid()) +        self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.2'}).is_valid()) +        self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.23'}).is_valid()) + +        self.assertFalse(DecimalSerializer(data={'decimal_field': '8000'}).is_valid()) +        self.assertFalse(DecimalSerializer(data={'decimal_field': '9900'}).is_valid()) +        self.assertFalse(DecimalSerializer(data={'decimal_field': '9001.234'}).is_valid()) + +    def test_raise_max_value(self): +        """ +        Make sure max_value violations raises ValidationError +        """ +        class DecimalSerializer(Serializer): +            decimal_field = serializers.DecimalField(max_value=100) + +        s = DecimalSerializer(data={'decimal_field': '123'}) + +        self.assertFalse(s.is_valid()) +        self.assertEqual(s.errors,  {'decimal_field': ['Ensure this value is less than or equal to 100.']}) + +    def test_raise_min_value(self): +        """ +        Make sure min_value violations raises ValidationError +        """ +        class DecimalSerializer(Serializer): +            decimal_field = serializers.DecimalField(min_value=100) + +        s = DecimalSerializer(data={'decimal_field': '99'}) + +        self.assertFalse(s.is_valid()) +        self.assertEqual(s.errors,  {'decimal_field': ['Ensure this value is greater than or equal to 100.']}) + +    def test_raise_max_digits(self): +        """ +        Make sure max_digits violations raises ValidationError +        """ +        class DecimalSerializer(Serializer): +            decimal_field = serializers.DecimalField(max_digits=5) + +        s = DecimalSerializer(data={'decimal_field': '123.456'}) + +        self.assertFalse(s.is_valid()) +        self.assertEqual(s.errors,  {'decimal_field': ['Ensure that there are no more than 5 digits in total.']}) + +    def test_raise_max_decimal_places(self): +        """ +        Make sure max_decimal_places violations raises ValidationError +        """ +        class DecimalSerializer(Serializer): +            decimal_field = serializers.DecimalField(decimal_places=3) + +        s = DecimalSerializer(data={'decimal_field': '123.4567'}) + +        self.assertFalse(s.is_valid()) +        self.assertEqual(s.errors,  {'decimal_field': ['Ensure that there are no more than 3 decimal places.']}) + +    def test_raise_max_whole_digits(self): +        """ +        Make sure max_whole_digits violations raises ValidationError +        """ +        class DecimalSerializer(Serializer): +            decimal_field = serializers.DecimalField(max_digits=4, decimal_places=3) + +        s = DecimalSerializer(data={'decimal_field': '12345.6'}) + +        self.assertFalse(s.is_valid()) +        self.assertEqual(s.errors,  {'decimal_field': ['Ensure that there are no more than 4 digits in total.']})
\ No newline at end of file  | 
