aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
authorTom Christie2013-03-22 21:57:37 +0000
committerTom Christie2013-03-22 21:57:37 +0000
commit9bf7c9b714713f7b2fe84074cfd05a8bc3ef4022 (patch)
tree932d97342b9adc7a25b6620fde35d07184ed3c58 /rest_framework
parentdeb5e653e441bf31f3b183b575f72e6b4cf537ea (diff)
parent870d5c7d7810ecd7f187e13b5fe3a3bcba6b18c3 (diff)
downloaddjango-rest-framework-9bf7c9b714713f7b2fe84074cfd05a8bc3ef4022.tar.bz2
Merge master
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/authentication.py3
-rw-r--r--rest_framework/fields.py23
-rw-r--r--rest_framework/serializers.py73
-rw-r--r--rest_framework/tests/fields.py43
-rw-r--r--rest_framework/tests/filterset.py9
-rw-r--r--rest_framework/tests/pagination.py2
-rw-r--r--rest_framework/tests/serializer.py30
-rw-r--r--rest_framework/tests/serializer_bulk_update.py181
8 files changed, 312 insertions, 52 deletions
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py
index b4b73699..8f4ec536 100644
--- a/rest_framework/authentication.py
+++ b/rest_framework/authentication.py
@@ -204,6 +204,9 @@ class OAuthAuthentication(BaseAuthentication):
except oauth.Error as err:
raise exceptions.AuthenticationFailed(err.message)
+ if not oauth_request:
+ return None
+
oauth_params = oauth_provider.consts.OAUTH_PARAMETERS_NAMES
found = any(param for param in oauth_params if param in oauth_request)
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 4b6931ad..f3496b53 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -494,7 +494,7 @@ class DateField(WritableField):
}
empty = None
input_formats = api_settings.DATE_INPUT_FORMATS
- format = api_settings.DATE_FORMAT
+ format = None
def __init__(self, input_formats=None, format=None, *args, **kwargs):
self.input_formats = input_formats if input_formats is not None else self.input_formats
@@ -536,8 +536,8 @@ class DateField(WritableField):
raise ValidationError(msg)
def to_native(self, value):
- if value is None:
- return None
+ if value is None or self.format is None:
+ return value
if isinstance(value, datetime.datetime):
value = value.date()
@@ -557,7 +557,7 @@ class DateTimeField(WritableField):
}
empty = None
input_formats = api_settings.DATETIME_INPUT_FORMATS
- format = api_settings.DATETIME_FORMAT
+ format = None
def __init__(self, input_formats=None, format=None, *args, **kwargs):
self.input_formats = input_formats if input_formats is not None else self.input_formats
@@ -605,11 +605,14 @@ class DateTimeField(WritableField):
raise ValidationError(msg)
def to_native(self, value):
- if value is None:
- return None
+ if value is None or self.format is None:
+ return value
if self.format.lower() == ISO_8601:
- return value.isoformat()
+ ret = value.isoformat()
+ if ret.endswith('+00:00'):
+ ret = ret[:-6] + 'Z'
+ return ret
return value.strftime(self.format)
@@ -623,7 +626,7 @@ class TimeField(WritableField):
}
empty = None
input_formats = api_settings.TIME_INPUT_FORMATS
- format = api_settings.TIME_FORMAT
+ format = None
def __init__(self, input_formats=None, format=None, *args, **kwargs):
self.input_formats = input_formats if input_formats is not None else self.input_formats
@@ -658,8 +661,8 @@ class TimeField(WritableField):
raise ValidationError(msg)
def to_native(self, value):
- if value is None:
- return None
+ if value is None or self.format is None:
+ return value
if isinstance(value, datetime.datetime):
value = value.time()
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index a81cbc29..26c34044 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -129,13 +129,15 @@ class BaseSerializer(WritableField):
_dict_class = SortedDictWithMetadata
def __init__(self, instance=None, data=None, files=None,
- context=None, partial=False, many=None, **kwargs):
+ context=None, partial=False, many=None,
+ allow_delete=False, **kwargs):
super(BaseSerializer, self).__init__(**kwargs)
self.opts = self._options_class(self.Meta)
self.parent = None
self.root = None
self.partial = partial
self.many = many
+ self.allow_delete = allow_delete
self.context = context or {}
@@ -147,6 +149,13 @@ class BaseSerializer(WritableField):
self._data = None
self._files = None
self._errors = None
+ self._deleted = None
+
+ if many and instance is not None and not hasattr(instance, '__iter__'):
+ raise ValueError('instance should be a queryset or other iterable with many=True')
+
+ if allow_delete and not many:
+ raise ValueError('allow_delete should only be used for bulk updates, but you have not set many=True')
#####
# Methods to determine which fields to use when (de)serializing objects.
@@ -387,6 +396,20 @@ class BaseSerializer(WritableField):
# Propagate errors up to our parent
raise NestedValidationError(serializer.errors)
+ def get_identity(self, data):
+ """
+ This hook is required for bulk update.
+ It is used to determine the canonical identity of a given object.
+
+ Note that the data has not been validated at this point, so we need
+ to make sure that we catch any cases of incorrect datatypes being
+ passed to this method.
+ """
+ try:
+ return data.get('id', None)
+ except AttributeError:
+ return None
+
@property
def errors(self):
"""
@@ -408,10 +431,33 @@ class BaseSerializer(WritableField):
if many:
ret = []
errors = []
- for item in data:
- ret.append(self.from_native(item, None))
- errors.append(self._errors)
- self._errors = any(errors) and errors or []
+ update = self.object is not None
+
+ if update:
+ # If this is a bulk update we need to map all the objects
+ # to a canonical identity so we can determine which
+ # individual object is being updated for each item in the
+ # incoming data
+ objects = self.object
+ identities = [self.get_identity(self.to_native(obj)) for obj in objects]
+ identity_to_objects = dict(zip(identities, objects))
+
+ if hasattr(data, '__iter__') and not isinstance(data, (dict, six.text_type)):
+ for item in data:
+ if update:
+ # Determine which object we're updating
+ identity = self.get_identity(item)
+ self.object = identity_to_objects.pop(identity, None)
+
+ ret.append(self.from_native(item, None))
+ errors.append(self._errors)
+
+ if update:
+ self._deleted = identity_to_objects.values()
+
+ self._errors = any(errors) and errors or []
+ else:
+ self._errors = {'non_field_errors': ['Expected a list of items']}
else:
ret = self.from_native(data, files)
@@ -450,6 +496,9 @@ class BaseSerializer(WritableField):
def save_object(self, obj, **kwargs):
obj.save(**kwargs)
+ def delete_object(self, obj):
+ obj.delete()
+
def save(self, **kwargs):
"""
Save the deserialized object and return it.
@@ -458,6 +507,10 @@ class BaseSerializer(WritableField):
[self.save_object(item, **kwargs) for item in self.object]
else:
self.save_object(self.object, **kwargs)
+
+ if self.allow_delete and self._deleted:
+ [self.delete_object(item) for item in self._deleted]
+
return self.object
@@ -765,3 +818,13 @@ class HyperlinkedModelSerializer(ModelSerializer):
'many': to_many
}
return HyperlinkedRelatedField(**kwargs)
+
+ def get_identity(self, data):
+ """
+ This hook is required for bulk update.
+ We need to override the default, to use the url as the identity.
+ """
+ try:
+ return data.get('url', None)
+ except AttributeError:
+ return None
diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py
index fd6de779..19c663d8 100644
--- a/rest_framework/tests/fields.py
+++ b/rest_framework/tests/fields.py
@@ -153,12 +153,22 @@ class DateFieldTest(TestCase):
def test_to_native(self):
"""
- Make sure to_native() returns isoformat as default.
+ Make sure to_native() returns datetime as default.
"""
f = serializers.DateField()
result_1 = f.to_native(datetime.date(1984, 7, 31))
+ self.assertEqual(datetime.date(1984, 7, 31), result_1)
+
+ def test_to_native_iso(self):
+ """
+ Make sure to_native() with 'iso-8601' returns iso formated date.
+ """
+ f = serializers.DateField(format='iso-8601')
+
+ result_1 = f.to_native(datetime.date(1984, 7, 31))
+
self.assertEqual('1984-07-31', result_1)
def test_to_native_custom_format(self):
@@ -289,6 +299,22 @@ class DateTimeFieldTest(TestCase):
result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+ self.assertEqual(datetime.datetime(1984, 7, 31), result_1)
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_2)
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_3)
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_4)
+
+ def test_to_native_iso(self):
+ """
+ Make sure to_native() with format=iso-8601 returns iso formatted datetime.
+ """
+ f = serializers.DateTimeField(format='iso-8601')
+
+ result_1 = f.to_native(datetime.datetime(1984, 7, 31))
+ result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31))
+ result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
+ result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+
self.assertEqual('1984-07-31T00:00:00', result_1)
self.assertEqual('1984-07-31T04:31:00', result_2)
self.assertEqual('1984-07-31T04:31:59', result_3)
@@ -419,13 +445,26 @@ class TimeFieldTest(TestCase):
def test_to_native(self):
"""
- Make sure to_native() returns isoformat as default.
+ Make sure to_native() returns time object as default.
"""
f = serializers.TimeField()
result_1 = f.to_native(datetime.time(4, 31))
result_2 = f.to_native(datetime.time(4, 31, 59))
result_3 = f.to_native(datetime.time(4, 31, 59, 200))
+ self.assertEqual(datetime.time(4, 31), result_1)
+ self.assertEqual(datetime.time(4, 31, 59), result_2)
+ self.assertEqual(datetime.time(4, 31, 59, 200), result_3)
+
+ def test_to_native_iso(self):
+ """
+ Make sure to_native() with format='iso-8601' returns iso formatted time.
+ """
+ f = serializers.TimeField(format='iso-8601')
+ result_1 = f.to_native(datetime.time(4, 31))
+ result_2 = f.to_native(datetime.time(4, 31, 59))
+ result_3 = f.to_native(datetime.time(4, 31, 59, 200))
+
self.assertEqual('04:31:00', result_1)
self.assertEqual('04:31:59', result_2)
self.assertEqual('04:31:59.000200', result_3)
diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py
index fe92e0bc..238da56e 100644
--- a/rest_framework/tests/filterset.py
+++ b/rest_framework/tests/filterset.py
@@ -65,7 +65,7 @@ class IntegrationTestFiltering(TestCase):
self.objects = FilterableItem.objects
self.data = [
- {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date.isoformat()}
+ {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
for obj in self.objects.all()
]
@@ -95,7 +95,7 @@ class IntegrationTestFiltering(TestCase):
request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22'
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if datetime.datetime.strptime(f['date'], '%Y-%m-%d').date() == search_date]
+ expected_data = [f for f in self.data if f['date'] == search_date]
self.assertEqual(response.data, expected_data)
@unittest.skipUnless(django_filters, 'django-filters not installed')
@@ -125,7 +125,7 @@ class IntegrationTestFiltering(TestCase):
request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02'
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if datetime.datetime.strptime(f['date'], '%Y-%m-%d').date() > search_date]
+ expected_data = [f for f in self.data if f['date'] > search_date]
self.assertEqual(response.data, expected_data)
# Tests that the text filter set with 'icontains' in the filter class works.
@@ -142,8 +142,7 @@ class IntegrationTestFiltering(TestCase):
request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date))
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if
- datetime.datetime.strptime(f['date'], '%Y-%m-%d').date() > search_date and
+ expected_data = [f for f in self.data if f['date'] > search_date and
f['decimal'] < search_decimal]
self.assertEqual(response.data, expected_data)
diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py
index 1a2d68a6..d2c9b051 100644
--- a/rest_framework/tests/pagination.py
+++ b/rest_framework/tests/pagination.py
@@ -102,7 +102,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
self.objects = FilterableItem.objects
self.data = [
- {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date.isoformat()}
+ {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
for obj in self.objects.all()
]
diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py
index beb372c2..05217f35 100644
--- a/rest_framework/tests/serializer.py
+++ b/rest_framework/tests/serializer.py
@@ -112,7 +112,7 @@ class BasicTests(TestCase):
self.expected = {
'email': 'tom@example.com',
'content': 'Happy new year!',
- 'created': '2012-01-01T00:00:00',
+ 'created': datetime.datetime(2012, 1, 1),
'sub_comment': 'And Merry Christmas!'
}
self.person_data = {'name': 'dwight', 'age': 35}
@@ -261,34 +261,6 @@ class ValidationTests(TestCase):
self.assertEqual(serializer.is_valid(), True)
self.assertEqual(serializer.errors, {})
- def test_bad_type_data_is_false(self):
- """
- Data of the wrong type is not valid.
- """
- data = ['i am', 'a', 'list']
- serializer = CommentSerializer(self.comment, data=data, many=True)
- self.assertEqual(serializer.is_valid(), False)
- self.assertTrue(isinstance(serializer.errors, list))
-
- self.assertEqual(
- serializer.errors,
- [
- {'non_field_errors': ['Invalid data']},
- {'non_field_errors': ['Invalid data']},
- {'non_field_errors': ['Invalid data']}
- ]
- )
-
- data = 'and i am a string'
- serializer = CommentSerializer(self.comment, data=data)
- self.assertEqual(serializer.is_valid(), False)
- self.assertEqual(serializer.errors, {'non_field_errors': ['Invalid data']})
-
- data = 42
- serializer = CommentSerializer(self.comment, data=data)
- self.assertEqual(serializer.is_valid(), False)
- self.assertEqual(serializer.errors, {'non_field_errors': ['Invalid data']})
-
def test_cross_field_validation(self):
class CommentSerializerWithCrossFieldValidator(CommentSerializer):
diff --git a/rest_framework/tests/serializer_bulk_update.py b/rest_framework/tests/serializer_bulk_update.py
index 3ecb23ed..afc1a1a9 100644
--- a/rest_framework/tests/serializer_bulk_update.py
+++ b/rest_framework/tests/serializer_bulk_update.py
@@ -7,6 +7,9 @@ from rest_framework import serializers
class BulkCreateSerializerTests(TestCase):
+ """
+ Creating multiple instances using serializers.
+ """
def setUp(self):
class BookSerializer(serializers.Serializer):
@@ -71,3 +74,181 @@ class BulkCreateSerializerTests(TestCase):
self.assertEqual(serializer.is_valid(), False)
self.assertEqual(serializer.errors, expected_errors)
+ def test_invalid_list_datatype(self):
+ """
+ Data containing list of incorrect data type should return errors.
+ """
+ data = ['foo', 'bar', 'baz']
+ serializer = self.BookSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+
+ expected_errors = [
+ {'non_field_errors': ['Invalid data']},
+ {'non_field_errors': ['Invalid data']},
+ {'non_field_errors': ['Invalid data']}
+ ]
+
+ self.assertEqual(serializer.errors, expected_errors)
+
+ def test_invalid_single_datatype(self):
+ """
+ Data containing a single incorrect data type should return errors.
+ """
+ data = 123
+ serializer = self.BookSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+
+ expected_errors = {'non_field_errors': ['Expected a list of items']}
+
+ self.assertEqual(serializer.errors, expected_errors)
+
+ def test_invalid_single_object(self):
+ """
+ Data containing only a single object, instead of a list of objects
+ should return errors.
+ """
+ data = {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }
+ serializer = self.BookSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+
+ expected_errors = {'non_field_errors': ['Expected a list of items']}
+
+ self.assertEqual(serializer.errors, expected_errors)
+
+
+class BulkUpdateSerializerTests(TestCase):
+ """
+ Updating multiple instances using serializers.
+ """
+
+ def setUp(self):
+ class Book(object):
+ """
+ A data type that can be persisted to a mock storage backend
+ with `.save()` and `.delete()`.
+ """
+ object_map = {}
+
+ def __init__(self, id, title, author):
+ self.id = id
+ self.title = title
+ self.author = author
+
+ def save(self):
+ Book.object_map[self.id] = self
+
+ def delete(self):
+ del Book.object_map[self.id]
+
+ class BookSerializer(serializers.Serializer):
+ id = serializers.IntegerField()
+ title = serializers.CharField(max_length=100)
+ author = serializers.CharField(max_length=100)
+
+ def restore_object(self, attrs, instance=None):
+ if instance:
+ instance.id = attrs['id']
+ instance.title = attrs['title']
+ instance.author = attrs['author']
+ return instance
+ return Book(**attrs)
+
+ self.Book = Book
+ self.BookSerializer = BookSerializer
+
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 1,
+ 'title': 'If this is a man',
+ 'author': 'Primo Levi'
+ }, {
+ 'id': 2,
+ 'title': 'The wind-up bird chronicle',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+
+ for item in data:
+ book = Book(item['id'], item['title'], item['author'])
+ book.save()
+
+ def books(self):
+ """
+ Return all the objects in the mock storage backend.
+ """
+ return self.Book.object_map.values()
+
+ def test_bulk_update_success(self):
+ """
+ Correct bulk update serialization should return the input data.
+ """
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 2,
+ 'title': 'Kafka on the shore',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+ serializer = self.BookSerializer(self.books(), data=data, many=True, allow_delete=True)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.data, data)
+ serializer.save()
+ new_data = self.BookSerializer(self.books(), many=True).data
+ self.assertEqual(data, new_data)
+
+ def test_bulk_update_and_create(self):
+ """
+ Bulk update serialization may also include created items.
+ """
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 3,
+ 'title': 'Kafka on the shore',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+ serializer = self.BookSerializer(self.books(), data=data, many=True, allow_delete=True)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.data, data)
+ serializer.save()
+ new_data = self.BookSerializer(self.books(), many=True).data
+ self.assertEqual(data, new_data)
+
+ def test_bulk_update_error(self):
+ """
+ Incorrect bulk update serialization should return error data.
+ """
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 'foo',
+ 'title': 'Kafka on the shore',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+ expected_errors = [
+ {},
+ {'id': ['Enter a whole number.']}
+ ]
+ serializer = self.BookSerializer(self.books(), data=data, many=True, allow_delete=True)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, expected_errors)