diff options
| author | Stephan Groß | 2013-04-15 12:40:18 +0200 | 
|---|---|---|
| committer | Stephan Groß | 2013-04-15 12:40:18 +0200 | 
| commit | 1b5382c146c1de902cc83b11a66a5f9909149691 (patch) | |
| tree | f3bd9a456294464a2d6c82606a46b71138d80af8 | |
| parent | 5b56639e7a26ba31ffe472b69408c427346df85b (diff) | |
| download | django-rest-framework-1b5382c146c1de902cc83b11a66a5f9909149691.tar.bz2 | |
Add DecimalField support
| -rw-r--r-- | docs/api-guide/fields.md | 6 | ||||
| -rw-r--r-- | docs/topics/release-notes.md | 1 | ||||
| -rw-r--r-- | rest_framework/fields.py | 75 | ||||
| -rw-r--r-- | rest_framework/tests/fields.py | 165 | 
4 files changed, 247 insertions, 0 deletions
diff --git a/docs/api-guide/fields.md b/docs/api-guide/fields.md index 42f89f46..e117c370 100644 --- a/docs/api-guide/fields.md +++ b/docs/api-guide/fields.md @@ -248,6 +248,12 @@ A floating point representation.  Corresponds to `django.db.models.fields.FloatField`. +## DecimalField + +A decimal representation. + +Corresponds to `django.db.models.fields.DecimalField`. +  ## FileField  A file representation. Performs Django's standard FileField validation.  diff --git a/docs/topics/release-notes.md b/docs/topics/release-notes.md index 609b4504..66022959 100644 --- a/docs/topics/release-notes.md +++ b/docs/topics/release-notes.md @@ -42,6 +42,7 @@ You can determine your currently installed version using `pip freeze`:  ### Master +* DecimalField support.  * OAuth2 authentication no longer requires unneccessary URL parameters in addition to the token.  * URL hyperlinking in browseable API now handles more cases correctly.  * Long HTTP headers in browsable API are broken in multiple lines when possible. diff --git a/rest_framework/fields.py b/rest_framework/fields.py index f3496b53..a1b9f546 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -2,6 +2,7 @@ from __future__ import unicode_literals  import copy  import datetime +from decimal import Decimal, DecimalException  import inspect  import re  import warnings @@ -721,6 +722,80 @@ class FloatField(WritableField):              raise ValidationError(msg) +class DecimalField(WritableField): +    type_name = 'DecimalField' +    form_field_class = forms.DecimalField + +    default_error_messages = { +        'invalid': _('Enter a number.'), +        'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'), +        'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'), +        'max_digits': _('Ensure that there are no more than %s digits in total.'), +        'max_decimal_places': _('Ensure that there are no more than %s decimal places.'), +        'max_whole_digits': _('Ensure that there are no more than %s digits before the decimal point.') +    } + +    def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, *args, **kwargs): +        self.max_value, self.min_value = max_value, min_value +        self.max_digits, self.decimal_places = max_digits, decimal_places +        super(DecimalField, self).__init__(self, *args, **kwargs) + +        if max_value is not None: +            self.validators.append(validators.MaxValueValidator(max_value)) +        if min_value is not None: +            self.validators.append(validators.MinValueValidator(min_value)) + +    def from_native(self, value): +        """ +        Validates that the input is a decimal number. Returns a Decimal +        instance. Returns None for empty values. Ensures that there are no more +        than max_digits in the number, and no more than decimal_places digits +        after the decimal point. +        """ +        if value in validators.EMPTY_VALUES: +            return None +        value = smart_text(value).strip() +        try: +            value = Decimal(value) +        except DecimalException: +            raise ValidationError(self.error_messages['invalid']) +        return value + +    def to_native(self, value): +        if value is not None: +            return str(value) +        return value + +    def validate(self, value): +        super(DecimalField, self).validate(value) +        if value in validators.EMPTY_VALUES: +            return +        # Check for NaN, Inf and -Inf values. We can't compare directly for NaN, +        # since it is never equal to itself. However, NaN is the only value that +        # isn't equal to itself, so we can use this to identify NaN +        if value != value or value == Decimal("Inf") or value == Decimal("-Inf"): +            raise ValidationError(self.error_messages['invalid']) +        sign, digittuple, exponent = value.as_tuple() +        decimals = abs(exponent) +        # digittuple doesn't include any leading zeros. +        digits = len(digittuple) +        if decimals > digits: +            # We have leading zeros up to or past the decimal point.  Count +            # everything past the decimal point as a digit.  We do not count +            # 0 before the decimal point as a digit since that would mean +            # we would not allow max_digits = decimal_places. +            digits = decimals +        whole_digits = digits - decimals + +        if self.max_digits is not None and digits > self.max_digits: +            raise ValidationError(self.error_messages['max_digits'] % self.max_digits) +        if self.decimal_places is not None and decimals > self.decimal_places: +            raise ValidationError(self.error_messages['max_decimal_places'] % self.decimal_places) +        if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places): +            raise ValidationError(self.error_messages['max_whole_digits'] % (self.max_digits - self.decimal_places)) +        return value + +  class FileField(WritableField):      use_files = True      type_name = 'FileField' diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py index 19c663d8..f833aa32 100644 --- a/rest_framework/tests/fields.py +++ b/rest_framework/tests/fields.py @@ -3,12 +3,14 @@ General serializer field tests.  """  from __future__ import unicode_literals  import datetime +from decimal import Decimal  from django.db import models  from django.test import TestCase  from django.core import validators  from rest_framework import serializers +from rest_framework.serializers import Serializer  class TimestampedModel(models.Model): @@ -481,3 +483,166 @@ class TimeFieldTest(TestCase):          self.assertEqual('04 - 00 [000000]', result_1)          self.assertEqual('04 - 59 [000000]', result_2)          self.assertEqual('04 - 59 [000200]', result_3) + + +class DecimalFieldTest(TestCase): +    """ +    Tests for the DecimalField from_native() and to_native() behavior +    """ + +    def test_from_native_string(self): +        """ +        Make sure from_native() accepts string values +        """ +        f = serializers.DecimalField() +        result_1 = f.from_native('9000') +        result_2 = f.from_native('1.00000001') + +        self.assertEqual(Decimal('9000'), result_1) +        self.assertEqual(Decimal('1.00000001'), result_2) + +    def test_from_native_invalid_string(self): +        """ +        Make sure from_native() raises ValidationError on passing invalid string +        """ +        f = serializers.DecimalField() + +        try: +            f.from_native('123.45.6') +        except validators.ValidationError as e: +            self.assertEqual(e.messages, ["Enter a number."]) +        else: +            self.fail("ValidationError was not properly raised") + +    def test_from_native_integer(self): +        """ +        Make sure from_native() accepts integer values +        """ +        f = serializers.DecimalField() +        result = f.from_native(9000) + +        self.assertEqual(Decimal('9000'), result) + +    def test_from_native_float(self): +        """ +        Make sure from_native() accepts float values +        """ +        f = serializers.DecimalField() +        result = f.from_native(1.00000001) + +        self.assertEqual(Decimal('1.00000001'), result) + +    def test_from_native_empty(self): +        """ +        Make sure from_native() returns None on empty param. +        """ +        f = serializers.DecimalField() +        result = f.from_native('') + +        self.assertEqual(result, None) + +    def test_from_native_none(self): +        """ +        Make sure from_native() returns None on None param. +        """ +        f = serializers.DecimalField() +        result = f.from_native(None) + +        self.assertEqual(result, None) + +    def test_to_native(self): +        """ +        Make sure to_native() returns Decimal as string. +        """ +        f = serializers.DecimalField() + +        result_1 = f.to_native(Decimal('9000')) +        result_2 = f.to_native(Decimal('1.00000001')) + +        self.assertEqual('9000', result_1) +        self.assertEqual('1.00000001', result_2) + +    def test_to_native_none(self): +        """ +        Make sure from_native() returns None on None param. +        """ +        f = serializers.DecimalField(required=False) +        self.assertEqual(None, f.to_native(None)) + +    def test_valid_serialization(self): +        """ +        Make sure the serializer works correctly +        """ +        class DecimalSerializer(Serializer): +            decimal_field = serializers.DecimalField(max_value=9010, +                                                     min_value=9000, +                                                     max_digits=6, +                                                     decimal_places=2) + +        self.assertTrue(DecimalSerializer(data={'decimal_field': '9001'}).is_valid()) +        self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.2'}).is_valid()) +        self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.23'}).is_valid()) + +        self.assertFalse(DecimalSerializer(data={'decimal_field': '8000'}).is_valid()) +        self.assertFalse(DecimalSerializer(data={'decimal_field': '9900'}).is_valid()) +        self.assertFalse(DecimalSerializer(data={'decimal_field': '9001.234'}).is_valid()) + +    def test_raise_max_value(self): +        """ +        Make sure max_value violations raises ValidationError +        """ +        class DecimalSerializer(Serializer): +            decimal_field = serializers.DecimalField(max_value=100) + +        s = DecimalSerializer(data={'decimal_field': '123'}) + +        self.assertFalse(s.is_valid()) +        self.assertEqual(s.errors,  {'decimal_field': [u'Ensure this value is less than or equal to 100.']}) + +    def test_raise_min_value(self): +        """ +        Make sure min_value violations raises ValidationError +        """ +        class DecimalSerializer(Serializer): +            decimal_field = serializers.DecimalField(min_value=100) + +        s = DecimalSerializer(data={'decimal_field': '99'}) + +        self.assertFalse(s.is_valid()) +        self.assertEqual(s.errors,  {'decimal_field': [u'Ensure this value is greater than or equal to 100.']}) + +    def test_raise_max_digits(self): +        """ +        Make sure max_digits violations raises ValidationError +        """ +        class DecimalSerializer(Serializer): +            decimal_field = serializers.DecimalField(max_digits=5) + +        s = DecimalSerializer(data={'decimal_field': '123.456'}) + +        self.assertFalse(s.is_valid()) +        self.assertEqual(s.errors,  {'decimal_field': [u'Ensure that there are no more than 5 digits in total.']}) + +    def test_raise_max_decimal_places(self): +        """ +        Make sure max_decimal_places violations raises ValidationError +        """ +        class DecimalSerializer(Serializer): +            decimal_field = serializers.DecimalField(decimal_places=3) + +        s = DecimalSerializer(data={'decimal_field': '123.4567'}) + +        self.assertFalse(s.is_valid()) +        self.assertEqual(s.errors,  {'decimal_field': [u'Ensure that there are no more than 3 decimal places.']}) + +    def test_raise_max_whole_digits(self): +        """ +        Make sure max_whole_digits violations raises ValidationError +        """ +        class DecimalSerializer(Serializer): +            decimal_field = serializers.DecimalField(max_digits=4, decimal_places=3) + +        s = DecimalSerializer(data={'decimal_field': '12345.6'}) + +        self.assertFalse(s.is_valid()) +        self.assertEqual(s.errors,  {'decimal_field': [u'Ensure that there are no more than 4 digits in total.']})
\ No newline at end of file  | 
