diff options
| author | Tom Christie | 2013-03-22 21:57:37 +0000 | 
|---|---|---|
| committer | Tom Christie | 2013-03-22 21:57:37 +0000 | 
| commit | 9bf7c9b714713f7b2fe84074cfd05a8bc3ef4022 (patch) | |
| tree | 932d97342b9adc7a25b6620fde35d07184ed3c58 /rest_framework | |
| parent | deb5e653e441bf31f3b183b575f72e6b4cf537ea (diff) | |
| parent | 870d5c7d7810ecd7f187e13b5fe3a3bcba6b18c3 (diff) | |
| download | django-rest-framework-9bf7c9b714713f7b2fe84074cfd05a8bc3ef4022.tar.bz2 | |
Merge master
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/authentication.py | 3 | ||||
| -rw-r--r-- | rest_framework/fields.py | 23 | ||||
| -rw-r--r-- | rest_framework/serializers.py | 73 | ||||
| -rw-r--r-- | rest_framework/tests/fields.py | 43 | ||||
| -rw-r--r-- | rest_framework/tests/filterset.py | 9 | ||||
| -rw-r--r-- | rest_framework/tests/pagination.py | 2 | ||||
| -rw-r--r-- | rest_framework/tests/serializer.py | 30 | ||||
| -rw-r--r-- | rest_framework/tests/serializer_bulk_update.py | 181 | 
8 files changed, 312 insertions, 52 deletions
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index b4b73699..8f4ec536 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -204,6 +204,9 @@ class OAuthAuthentication(BaseAuthentication):          except oauth.Error as err:              raise exceptions.AuthenticationFailed(err.message) +        if not oauth_request: +            return None +          oauth_params = oauth_provider.consts.OAUTH_PARAMETERS_NAMES          found = any(param for param in oauth_params if param in oauth_request) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 4b6931ad..f3496b53 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -494,7 +494,7 @@ class DateField(WritableField):      }      empty = None      input_formats = api_settings.DATE_INPUT_FORMATS -    format = api_settings.DATE_FORMAT +    format = None      def __init__(self, input_formats=None, format=None, *args, **kwargs):          self.input_formats = input_formats if input_formats is not None else self.input_formats @@ -536,8 +536,8 @@ class DateField(WritableField):          raise ValidationError(msg)      def to_native(self, value): -        if value is None: -            return None +        if value is None or self.format is None: +            return value          if isinstance(value, datetime.datetime):              value = value.date() @@ -557,7 +557,7 @@ class DateTimeField(WritableField):      }      empty = None      input_formats = api_settings.DATETIME_INPUT_FORMATS -    format = api_settings.DATETIME_FORMAT +    format = None      def __init__(self, input_formats=None, format=None, *args, **kwargs):          self.input_formats = input_formats if input_formats is not None else self.input_formats @@ -605,11 +605,14 @@ class DateTimeField(WritableField):          raise ValidationError(msg)      def to_native(self, value): -        if value is None: -            return None +        if value is None or self.format is None: +            return value          if self.format.lower() == ISO_8601: -            return value.isoformat() +            ret = value.isoformat() +            if ret.endswith('+00:00'): +                ret = ret[:-6] + 'Z' +            return ret          return value.strftime(self.format) @@ -623,7 +626,7 @@ class TimeField(WritableField):      }      empty = None      input_formats = api_settings.TIME_INPUT_FORMATS -    format = api_settings.TIME_FORMAT +    format = None      def __init__(self, input_formats=None, format=None, *args, **kwargs):          self.input_formats = input_formats if input_formats is not None else self.input_formats @@ -658,8 +661,8 @@ class TimeField(WritableField):          raise ValidationError(msg)      def to_native(self, value): -        if value is None: -            return None +        if value is None or self.format is None: +            return value          if isinstance(value, datetime.datetime):              value = value.time() diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index a81cbc29..26c34044 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -129,13 +129,15 @@ class BaseSerializer(WritableField):      _dict_class = SortedDictWithMetadata      def __init__(self, instance=None, data=None, files=None, -                 context=None, partial=False, many=None, **kwargs): +                 context=None, partial=False, many=None, +                 allow_delete=False, **kwargs):          super(BaseSerializer, self).__init__(**kwargs)          self.opts = self._options_class(self.Meta)          self.parent = None          self.root = None          self.partial = partial          self.many = many +        self.allow_delete = allow_delete          self.context = context or {} @@ -147,6 +149,13 @@ class BaseSerializer(WritableField):          self._data = None          self._files = None          self._errors = None +        self._deleted = None + +        if many and instance is not None and not hasattr(instance, '__iter__'): +            raise ValueError('instance should be a queryset or other iterable with many=True') + +        if allow_delete and not many: +            raise ValueError('allow_delete should only be used for bulk updates, but you have not set many=True')      #####      # Methods to determine which fields to use when (de)serializing objects. @@ -387,6 +396,20 @@ class BaseSerializer(WritableField):                  # Propagate errors up to our parent                  raise NestedValidationError(serializer.errors) +    def get_identity(self, data): +        """ +        This hook is required for bulk update. +        It is used to determine the canonical identity of a given object. + +        Note that the data has not been validated at this point, so we need +        to make sure that we catch any cases of incorrect datatypes being +        passed to this method. +        """ +        try: +            return data.get('id', None) +        except AttributeError: +            return None +      @property      def errors(self):          """ @@ -408,10 +431,33 @@ class BaseSerializer(WritableField):              if many:                  ret = []                  errors = [] -                for item in data: -                    ret.append(self.from_native(item, None)) -                    errors.append(self._errors) -                self._errors = any(errors) and errors or [] +                update = self.object is not None + +                if update: +                    # If this is a bulk update we need to map all the objects +                    # to a canonical identity so we can determine which +                    # individual object is being updated for each item in the +                    # incoming data +                    objects = self.object +                    identities = [self.get_identity(self.to_native(obj)) for obj in objects] +                    identity_to_objects = dict(zip(identities, objects)) + +                if hasattr(data, '__iter__') and not isinstance(data, (dict, six.text_type)): +                    for item in data: +                        if update: +                            # Determine which object we're updating +                            identity = self.get_identity(item) +                            self.object = identity_to_objects.pop(identity, None) + +                        ret.append(self.from_native(item, None)) +                        errors.append(self._errors) + +                    if update: +                        self._deleted = identity_to_objects.values() + +                    self._errors = any(errors) and errors or [] +                else: +                    self._errors = {'non_field_errors': ['Expected a list of items']}              else:                  ret = self.from_native(data, files) @@ -450,6 +496,9 @@ class BaseSerializer(WritableField):      def save_object(self, obj, **kwargs):          obj.save(**kwargs) +    def delete_object(self, obj): +        obj.delete() +      def save(self, **kwargs):          """          Save the deserialized object and return it. @@ -458,6 +507,10 @@ class BaseSerializer(WritableField):              [self.save_object(item, **kwargs) for item in self.object]          else:              self.save_object(self.object, **kwargs) + +        if self.allow_delete and self._deleted: +            [self.delete_object(item) for item in self._deleted] +          return self.object @@ -765,3 +818,13 @@ class HyperlinkedModelSerializer(ModelSerializer):              'many': to_many          }          return HyperlinkedRelatedField(**kwargs) + +    def get_identity(self, data): +        """ +        This hook is required for bulk update. +        We need to override the default, to use the url as the identity. +        """ +        try: +            return data.get('url', None) +        except AttributeError: +            return None diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py index fd6de779..19c663d8 100644 --- a/rest_framework/tests/fields.py +++ b/rest_framework/tests/fields.py @@ -153,12 +153,22 @@ class DateFieldTest(TestCase):      def test_to_native(self):          """ -        Make sure to_native() returns isoformat as default. +        Make sure to_native() returns datetime as default.          """          f = serializers.DateField()          result_1 = f.to_native(datetime.date(1984, 7, 31)) +        self.assertEqual(datetime.date(1984, 7, 31), result_1) + +    def test_to_native_iso(self): +        """ +        Make sure to_native() with 'iso-8601' returns iso formated date. +        """ +        f = serializers.DateField(format='iso-8601') + +        result_1 = f.to_native(datetime.date(1984, 7, 31)) +          self.assertEqual('1984-07-31', result_1)      def test_to_native_custom_format(self): @@ -289,6 +299,22 @@ class DateTimeFieldTest(TestCase):          result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59))          result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) +        self.assertEqual(datetime.datetime(1984, 7, 31), result_1) +        self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_2) +        self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_3) +        self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_4) + +    def test_to_native_iso(self): +        """ +        Make sure to_native() with format=iso-8601 returns iso formatted datetime. +        """ +        f = serializers.DateTimeField(format='iso-8601') + +        result_1 = f.to_native(datetime.datetime(1984, 7, 31)) +        result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31)) +        result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59)) +        result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) +          self.assertEqual('1984-07-31T00:00:00', result_1)          self.assertEqual('1984-07-31T04:31:00', result_2)          self.assertEqual('1984-07-31T04:31:59', result_3) @@ -419,13 +445,26 @@ class TimeFieldTest(TestCase):      def test_to_native(self):          """ -        Make sure to_native() returns isoformat as default. +        Make sure to_native() returns time object as default.          """          f = serializers.TimeField()          result_1 = f.to_native(datetime.time(4, 31))          result_2 = f.to_native(datetime.time(4, 31, 59))          result_3 = f.to_native(datetime.time(4, 31, 59, 200)) +        self.assertEqual(datetime.time(4, 31), result_1) +        self.assertEqual(datetime.time(4, 31, 59), result_2) +        self.assertEqual(datetime.time(4, 31, 59, 200), result_3) + +    def test_to_native_iso(self): +        """ +        Make sure to_native() with format='iso-8601' returns iso formatted time. +        """ +        f = serializers.TimeField(format='iso-8601') +        result_1 = f.to_native(datetime.time(4, 31)) +        result_2 = f.to_native(datetime.time(4, 31, 59)) +        result_3 = f.to_native(datetime.time(4, 31, 59, 200)) +          self.assertEqual('04:31:00', result_1)          self.assertEqual('04:31:59', result_2)          self.assertEqual('04:31:59.000200', result_3) diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py index fe92e0bc..238da56e 100644 --- a/rest_framework/tests/filterset.py +++ b/rest_framework/tests/filterset.py @@ -65,7 +65,7 @@ class IntegrationTestFiltering(TestCase):          self.objects = FilterableItem.objects          self.data = [ -            {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date.isoformat()} +            {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}              for obj in self.objects.all()          ] @@ -95,7 +95,7 @@ class IntegrationTestFiltering(TestCase):          request = factory.get('/?date=%s' % search_date)  # search_date str: '2012-09-22'          response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK) -        expected_data = [f for f in self.data if datetime.datetime.strptime(f['date'], '%Y-%m-%d').date() == search_date] +        expected_data = [f for f in self.data if f['date'] == search_date]          self.assertEqual(response.data, expected_data)      @unittest.skipUnless(django_filters, 'django-filters not installed') @@ -125,7 +125,7 @@ class IntegrationTestFiltering(TestCase):          request = factory.get('/?date=%s' % search_date)  # search_date str: '2012-10-02'          response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK) -        expected_data = [f for f in self.data if datetime.datetime.strptime(f['date'], '%Y-%m-%d').date() > search_date] +        expected_data = [f for f in self.data if f['date'] > search_date]          self.assertEqual(response.data, expected_data)          # Tests that the text filter set with 'icontains' in the filter class works. @@ -142,8 +142,7 @@ class IntegrationTestFiltering(TestCase):          request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date))          response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK) -        expected_data = [f for f in self.data if -                         datetime.datetime.strptime(f['date'], '%Y-%m-%d').date() > search_date and +        expected_data = [f for f in self.data if f['date'] > search_date and                           f['decimal'] < search_decimal]          self.assertEqual(response.data, expected_data) diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py index 1a2d68a6..d2c9b051 100644 --- a/rest_framework/tests/pagination.py +++ b/rest_framework/tests/pagination.py @@ -102,7 +102,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):          self.objects = FilterableItem.objects          self.data = [ -            {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date.isoformat()} +            {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}              for obj in self.objects.all()          ] diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py index beb372c2..05217f35 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/serializer.py @@ -112,7 +112,7 @@ class BasicTests(TestCase):          self.expected = {              'email': 'tom@example.com',              'content': 'Happy new year!', -            'created': '2012-01-01T00:00:00', +            'created': datetime.datetime(2012, 1, 1),              'sub_comment': 'And Merry Christmas!'          }          self.person_data = {'name': 'dwight', 'age': 35} @@ -261,34 +261,6 @@ class ValidationTests(TestCase):          self.assertEqual(serializer.is_valid(), True)          self.assertEqual(serializer.errors, {}) -    def test_bad_type_data_is_false(self): -        """ -        Data of the wrong type is not valid. -        """ -        data = ['i am', 'a', 'list'] -        serializer = CommentSerializer(self.comment, data=data, many=True) -        self.assertEqual(serializer.is_valid(), False) -        self.assertTrue(isinstance(serializer.errors, list)) - -        self.assertEqual( -            serializer.errors, -            [ -                {'non_field_errors': ['Invalid data']}, -                {'non_field_errors': ['Invalid data']}, -                {'non_field_errors': ['Invalid data']} -            ] -        ) - -        data = 'and i am a string' -        serializer = CommentSerializer(self.comment, data=data) -        self.assertEqual(serializer.is_valid(), False) -        self.assertEqual(serializer.errors, {'non_field_errors': ['Invalid data']}) - -        data = 42 -        serializer = CommentSerializer(self.comment, data=data) -        self.assertEqual(serializer.is_valid(), False) -        self.assertEqual(serializer.errors, {'non_field_errors': ['Invalid data']}) -      def test_cross_field_validation(self):          class CommentSerializerWithCrossFieldValidator(CommentSerializer): diff --git a/rest_framework/tests/serializer_bulk_update.py b/rest_framework/tests/serializer_bulk_update.py index 3ecb23ed..afc1a1a9 100644 --- a/rest_framework/tests/serializer_bulk_update.py +++ b/rest_framework/tests/serializer_bulk_update.py @@ -7,6 +7,9 @@ from rest_framework import serializers  class BulkCreateSerializerTests(TestCase): +    """ +    Creating multiple instances using serializers. +    """      def setUp(self):          class BookSerializer(serializers.Serializer): @@ -71,3 +74,181 @@ class BulkCreateSerializerTests(TestCase):          self.assertEqual(serializer.is_valid(), False)          self.assertEqual(serializer.errors, expected_errors) +    def test_invalid_list_datatype(self): +        """ +        Data containing list of incorrect data type should return errors. +        """ +        data = ['foo', 'bar', 'baz'] +        serializer = self.BookSerializer(data=data, many=True) +        self.assertEqual(serializer.is_valid(), False) + +        expected_errors = [ +                {'non_field_errors': ['Invalid data']}, +                {'non_field_errors': ['Invalid data']}, +                {'non_field_errors': ['Invalid data']} +        ] + +        self.assertEqual(serializer.errors, expected_errors) + +    def test_invalid_single_datatype(self): +        """ +        Data containing a single incorrect data type should return errors. +        """ +        data = 123 +        serializer = self.BookSerializer(data=data, many=True) +        self.assertEqual(serializer.is_valid(), False) + +        expected_errors = {'non_field_errors': ['Expected a list of items']} + +        self.assertEqual(serializer.errors, expected_errors) + +    def test_invalid_single_object(self): +        """ +        Data containing only a single object, instead of a list of objects +        should return errors. +        """ +        data = { +            'id': 0, +            'title': 'The electric kool-aid acid test', +            'author': 'Tom Wolfe' +        } +        serializer = self.BookSerializer(data=data, many=True) +        self.assertEqual(serializer.is_valid(), False) + +        expected_errors = {'non_field_errors': ['Expected a list of items']} + +        self.assertEqual(serializer.errors, expected_errors) + + +class BulkUpdateSerializerTests(TestCase): +    """ +    Updating multiple instances using serializers. +    """ + +    def setUp(self): +        class Book(object): +            """ +            A data type that can be persisted to a mock storage backend +            with `.save()` and `.delete()`. +            """ +            object_map = {} + +            def __init__(self, id, title, author): +                self.id = id +                self.title = title +                self.author = author + +            def save(self): +                Book.object_map[self.id] = self + +            def delete(self): +                del Book.object_map[self.id] + +        class BookSerializer(serializers.Serializer): +            id = serializers.IntegerField() +            title = serializers.CharField(max_length=100) +            author = serializers.CharField(max_length=100) + +            def restore_object(self, attrs, instance=None): +                if instance: +                    instance.id = attrs['id'] +                    instance.title = attrs['title'] +                    instance.author = attrs['author'] +                    return instance +                return Book(**attrs) + +        self.Book = Book +        self.BookSerializer = BookSerializer + +        data = [ +            { +                'id': 0, +                'title': 'The electric kool-aid acid test', +                'author': 'Tom Wolfe' +            }, { +                'id': 1, +                'title': 'If this is a man', +                'author': 'Primo Levi' +            }, { +                'id': 2, +                'title': 'The wind-up bird chronicle', +                'author': 'Haruki Murakami' +            } +        ] + +        for item in data: +            book = Book(item['id'], item['title'], item['author']) +            book.save() + +    def books(self): +        """ +        Return all the objects in the mock storage backend. +        """ +        return self.Book.object_map.values() + +    def test_bulk_update_success(self): +        """ +        Correct bulk update serialization should return the input data. +        """ +        data = [ +            { +                'id': 0, +                'title': 'The electric kool-aid acid test', +                'author': 'Tom Wolfe' +            }, { +                'id': 2, +                'title': 'Kafka on the shore', +                'author': 'Haruki Murakami' +            } +        ] +        serializer = self.BookSerializer(self.books(), data=data, many=True, allow_delete=True) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(serializer.data, data) +        serializer.save() +        new_data = self.BookSerializer(self.books(), many=True).data +        self.assertEqual(data, new_data) + +    def test_bulk_update_and_create(self): +        """ +        Bulk update serialization may also include created items. +        """ +        data = [ +            { +                'id': 0, +                'title': 'The electric kool-aid acid test', +                'author': 'Tom Wolfe' +            }, { +                'id': 3, +                'title': 'Kafka on the shore', +                'author': 'Haruki Murakami' +            } +        ] +        serializer = self.BookSerializer(self.books(), data=data, many=True, allow_delete=True) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(serializer.data, data) +        serializer.save() +        new_data = self.BookSerializer(self.books(), many=True).data +        self.assertEqual(data, new_data) + +    def test_bulk_update_error(self): +        """ +        Incorrect bulk update serialization should return error data. +        """ +        data = [ +            { +                'id': 0, +                'title': 'The electric kool-aid acid test', +                'author': 'Tom Wolfe' +            }, { +                'id': 'foo', +                'title': 'Kafka on the shore', +                'author': 'Haruki Murakami' +            } +        ] +        expected_errors = [ +            {}, +            {'id': ['Enter a whole number.']} +        ] +        serializer = self.BookSerializer(self.books(), data=data, many=True, allow_delete=True) +        self.assertEqual(serializer.is_valid(), False) +        self.assertEqual(serializer.errors, expected_errors)  | 
