aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTom Christie2013-04-26 13:59:06 +0100
committerTom Christie2013-04-26 13:59:06 +0100
commitd985aec3c9034ab1aa899e914a5ac0baf314cb3c (patch)
tree4a6f47489a64114c2918947582c5d40b83160f99
parent50c6bc5762460ebd2a79b61edd534d10cb58c7e5 (diff)
parentcac669702596cdf768971267e6355fb9223a69e8 (diff)
downloaddjango-rest-framework-d985aec3c9034ab1aa899e914a5ac0baf314cb3c.tar.bz2
DecimalField
-rw-r--r--docs/api-guide/fields.md6
-rw-r--r--docs/topics/2.3-announcement.md6
-rw-r--r--docs/topics/release-notes.md15
-rw-r--r--rest_framework/fields.py70
-rw-r--r--rest_framework/serializers.py1
-rw-r--r--rest_framework/tests/fields.py165
6 files changed, 263 insertions, 0 deletions
diff --git a/docs/api-guide/fields.md b/docs/api-guide/fields.md
index 42f89f46..e117c370 100644
--- a/docs/api-guide/fields.md
+++ b/docs/api-guide/fields.md
@@ -248,6 +248,12 @@ A floating point representation.
Corresponds to `django.db.models.fields.FloatField`.
+## DecimalField
+
+A decimal representation.
+
+Corresponds to `django.db.models.fields.DecimalField`.
+
## FileField
A file representation. Performs Django's standard FileField validation.
diff --git a/docs/topics/2.3-announcement.md b/docs/topics/2.3-announcement.md
index 0b80f5e2..554728ae 100644
--- a/docs/topics/2.3-announcement.md
+++ b/docs/topics/2.3-announcement.md
@@ -118,6 +118,12 @@ And would have the following entry in the urlconf:
Usage of the old-style attributes continues to be supported, but will raise a `PendingDeprecationWarning`.
+## DecimalField
+
+2.3 introduces a `DecimalField` serializer field, which returns `Decimal` instances.
+
+For most cases APIs using model fields will behave as previously, however if you are using a custom renderer, not provided by REST framework, then you may now need to add support for rendering `Decimal` instances to your renderer implmentation.
+
---
# Other notes
diff --git a/docs/topics/release-notes.md b/docs/topics/release-notes.md
index 5e0aa098..8094cc4a 100644
--- a/docs/topics/release-notes.md
+++ b/docs/topics/release-notes.md
@@ -38,6 +38,20 @@ You can determine your currently installed version using `pip freeze`:
---
+## 2.3.x series
+
+### 2.3.0
+
+* ViewSets and Routers.
+* ModelSerializers support reverse relations in 'fields' option.
+* HyperLinkedModelSerializers support 'id' field in 'fields' option.
+* Cleaner generic views.
+* DecimalField support.
+
+**Note**: See the [2.3 announcement][2.3-announcement] for full details.
+
+---
+
## 2.2.x series
### 2.2.7
@@ -458,6 +472,7 @@ This change will not affect user code, so long as it's following the recommended
[django-deprecation-policy]: https://docs.djangoproject.com/en/dev/internals/release-process/#internal-release-deprecation-policy
[defusedxml-announce]: http://blog.python.org/2013/02/announcing-defusedxml-fixes-for-xml.html
[2.2-announcement]: 2.2-announcement.md
+[2.3-announcement]: 2.3-announcement.md
[743]: https://github.com/tomchristie/django-rest-framework/pull/743
[staticfiles14]: https://docs.djangoproject.com/en/1.4/howto/static-files/#with-a-template-tag
[staticfiles13]: https://docs.djangoproject.com/en/1.3/howto/static-files/#with-a-template-tag
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 949f68d6..38fe025d 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -7,6 +7,7 @@ from __future__ import unicode_literals
import copy
import datetime
+from decimal import Decimal, DecimalException
import inspect
import re
import warnings
@@ -726,6 +727,75 @@ class FloatField(WritableField):
raise ValidationError(msg)
+class DecimalField(WritableField):
+ type_name = 'DecimalField'
+ form_field_class = forms.DecimalField
+
+ default_error_messages = {
+ 'invalid': _('Enter a number.'),
+ 'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'),
+ 'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'),
+ 'max_digits': _('Ensure that there are no more than %s digits in total.'),
+ 'max_decimal_places': _('Ensure that there are no more than %s decimal places.'),
+ 'max_whole_digits': _('Ensure that there are no more than %s digits before the decimal point.')
+ }
+
+ def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, *args, **kwargs):
+ self.max_value, self.min_value = max_value, min_value
+ self.max_digits, self.decimal_places = max_digits, decimal_places
+ super(DecimalField, self).__init__(*args, **kwargs)
+
+ if max_value is not None:
+ self.validators.append(validators.MaxValueValidator(max_value))
+ if min_value is not None:
+ self.validators.append(validators.MinValueValidator(min_value))
+
+ def from_native(self, value):
+ """
+ Validates that the input is a decimal number. Returns a Decimal
+ instance. Returns None for empty values. Ensures that there are no more
+ than max_digits in the number, and no more than decimal_places digits
+ after the decimal point.
+ """
+ if value in validators.EMPTY_VALUES:
+ return None
+ value = smart_text(value).strip()
+ try:
+ value = Decimal(value)
+ except DecimalException:
+ raise ValidationError(self.error_messages['invalid'])
+ return value
+
+ def validate(self, value):
+ super(DecimalField, self).validate(value)
+ if value in validators.EMPTY_VALUES:
+ return
+ # Check for NaN, Inf and -Inf values. We can't compare directly for NaN,
+ # since it is never equal to itself. However, NaN is the only value that
+ # isn't equal to itself, so we can use this to identify NaN
+ if value != value or value == Decimal("Inf") or value == Decimal("-Inf"):
+ raise ValidationError(self.error_messages['invalid'])
+ sign, digittuple, exponent = value.as_tuple()
+ decimals = abs(exponent)
+ # digittuple doesn't include any leading zeros.
+ digits = len(digittuple)
+ if decimals > digits:
+ # We have leading zeros up to or past the decimal point. Count
+ # everything past the decimal point as a digit. We do not count
+ # 0 before the decimal point as a digit since that would mean
+ # we would not allow max_digits = decimal_places.
+ digits = decimals
+ whole_digits = digits - decimals
+
+ if self.max_digits is not None and digits > self.max_digits:
+ raise ValidationError(self.error_messages['max_digits'] % self.max_digits)
+ if self.decimal_places is not None and decimals > self.decimal_places:
+ raise ValidationError(self.error_messages['max_decimal_places'] % self.decimal_places)
+ if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places):
+ raise ValidationError(self.error_messages['max_whole_digits'] % (self.max_digits - self.decimal_places))
+ return value
+
+
class FileField(WritableField):
use_files = True
type_name = 'FileField'
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index fb438b12..3d956e4d 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -560,6 +560,7 @@ class ModelSerializer(Serializer):
models.DateTimeField: DateTimeField,
models.DateField: DateField,
models.TimeField: TimeField,
+ models.DecimalField: DecimalField,
models.EmailField: EmailField,
models.CharField: CharField,
models.URLField: URLField,
diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py
index 19c663d8..3cdfa0f6 100644
--- a/rest_framework/tests/fields.py
+++ b/rest_framework/tests/fields.py
@@ -3,12 +3,14 @@ General serializer field tests.
"""
from __future__ import unicode_literals
import datetime
+from decimal import Decimal
from django.db import models
from django.test import TestCase
from django.core import validators
from rest_framework import serializers
+from rest_framework.serializers import Serializer
class TimestampedModel(models.Model):
@@ -481,3 +483,166 @@ class TimeFieldTest(TestCase):
self.assertEqual('04 - 00 [000000]', result_1)
self.assertEqual('04 - 59 [000000]', result_2)
self.assertEqual('04 - 59 [000200]', result_3)
+
+
+class DecimalFieldTest(TestCase):
+ """
+ Tests for the DecimalField from_native() and to_native() behavior
+ """
+
+ def test_from_native_string(self):
+ """
+ Make sure from_native() accepts string values
+ """
+ f = serializers.DecimalField()
+ result_1 = f.from_native('9000')
+ result_2 = f.from_native('1.00000001')
+
+ self.assertEqual(Decimal('9000'), result_1)
+ self.assertEqual(Decimal('1.00000001'), result_2)
+
+ def test_from_native_invalid_string(self):
+ """
+ Make sure from_native() raises ValidationError on passing invalid string
+ """
+ f = serializers.DecimalField()
+
+ try:
+ f.from_native('123.45.6')
+ except validators.ValidationError as e:
+ self.assertEqual(e.messages, ["Enter a number."])
+ else:
+ self.fail("ValidationError was not properly raised")
+
+ def test_from_native_integer(self):
+ """
+ Make sure from_native() accepts integer values
+ """
+ f = serializers.DecimalField()
+ result = f.from_native(9000)
+
+ self.assertEqual(Decimal('9000'), result)
+
+ def test_from_native_float(self):
+ """
+ Make sure from_native() accepts float values
+ """
+ f = serializers.DecimalField()
+ result = f.from_native(1.00000001)
+
+ self.assertEqual(Decimal('1.00000001'), result)
+
+ def test_from_native_empty(self):
+ """
+ Make sure from_native() returns None on empty param.
+ """
+ f = serializers.DecimalField()
+ result = f.from_native('')
+
+ self.assertEqual(result, None)
+
+ def test_from_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DecimalField()
+ result = f.from_native(None)
+
+ self.assertEqual(result, None)
+
+ def test_to_native(self):
+ """
+ Make sure to_native() returns Decimal as string.
+ """
+ f = serializers.DecimalField()
+
+ result_1 = f.to_native(Decimal('9000'))
+ result_2 = f.to_native(Decimal('1.00000001'))
+
+ self.assertEqual(Decimal('9000'), result_1)
+ self.assertEqual(Decimal('1.00000001'), result_2)
+
+ def test_to_native_none(self):
+ """
+ Make sure from_native() returns None on None param.
+ """
+ f = serializers.DecimalField(required=False)
+ self.assertEqual(None, f.to_native(None))
+
+ def test_valid_serialization(self):
+ """
+ Make sure the serializer works correctly
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(max_value=9010,
+ min_value=9000,
+ max_digits=6,
+ decimal_places=2)
+
+ self.assertTrue(DecimalSerializer(data={'decimal_field': '9001'}).is_valid())
+ self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.2'}).is_valid())
+ self.assertTrue(DecimalSerializer(data={'decimal_field': '9001.23'}).is_valid())
+
+ self.assertFalse(DecimalSerializer(data={'decimal_field': '8000'}).is_valid())
+ self.assertFalse(DecimalSerializer(data={'decimal_field': '9900'}).is_valid())
+ self.assertFalse(DecimalSerializer(data={'decimal_field': '9001.234'}).is_valid())
+
+ def test_raise_max_value(self):
+ """
+ Make sure max_value violations raises ValidationError
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(max_value=100)
+
+ s = DecimalSerializer(data={'decimal_field': '123'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is less than or equal to 100.']})
+
+ def test_raise_min_value(self):
+ """
+ Make sure min_value violations raises ValidationError
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(min_value=100)
+
+ s = DecimalSerializer(data={'decimal_field': '99'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure this value is greater than or equal to 100.']})
+
+ def test_raise_max_digits(self):
+ """
+ Make sure max_digits violations raises ValidationError
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(max_digits=5)
+
+ s = DecimalSerializer(data={'decimal_field': '123.456'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 5 digits in total.']})
+
+ def test_raise_max_decimal_places(self):
+ """
+ Make sure max_decimal_places violations raises ValidationError
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(decimal_places=3)
+
+ s = DecimalSerializer(data={'decimal_field': '123.4567'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 3 decimal places.']})
+
+ def test_raise_max_whole_digits(self):
+ """
+ Make sure max_whole_digits violations raises ValidationError
+ """
+ class DecimalSerializer(Serializer):
+ decimal_field = serializers.DecimalField(max_digits=4, decimal_places=3)
+
+ s = DecimalSerializer(data={'decimal_field': '12345.6'})
+
+ self.assertFalse(s.is_valid())
+ self.assertEqual(s.errors, {'decimal_field': ['Ensure that there are no more than 4 digits in total.']}) \ No newline at end of file