diff options
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/__init__.py | 2 | ||||
| -rw-r--r-- | rest_framework/authentication.py | 29 | ||||
| -rw-r--r-- | rest_framework/compat.py | 33 | ||||
| -rw-r--r-- | rest_framework/fields.py | 23 | ||||
| -rw-r--r-- | rest_framework/filters.py | 2 | ||||
| -rw-r--r-- | rest_framework/serializers.py | 217 | ||||
| -rw-r--r-- | rest_framework/templatetags/rest_framework.py | 84 | ||||
| -rw-r--r-- | rest_framework/tests/authentication.py | 44 | ||||
| -rw-r--r-- | rest_framework/tests/fields.py | 43 | ||||
| -rw-r--r-- | rest_framework/tests/filterset.py | 82 | ||||
| -rw-r--r-- | rest_framework/tests/pagination.py | 12 | ||||
| -rw-r--r-- | rest_framework/tests/serializer.py | 30 | ||||
| -rw-r--r-- | rest_framework/tests/serializer_bulk_update.py | 278 | ||||
| -rw-r--r-- | rest_framework/tests/serializer_nested.py | 246 | ||||
| -rw-r--r-- | rest_framework/tests/status.py | 13 |
15 files changed, 931 insertions, 207 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index cf005636..c86403d8 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -1,4 +1,4 @@ -__version__ = '2.2.4' +__version__ = '2.2.5' VERSION = __version__ # synonym diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index b4b73699..145d4295 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -2,14 +2,16 @@ Provides a set of pluggable authentication policies. """ from __future__ import unicode_literals +import base64 +from datetime import datetime + from django.contrib.auth import authenticate from django.core.exceptions import ImproperlyConfigured from rest_framework import exceptions, HTTP_HEADER_ENCODING from rest_framework.compat import CsrfViewMiddleware from rest_framework.compat import oauth, oauth_provider, oauth_provider_store -from rest_framework.compat import oauth2_provider, oauth2_provider_forms, oauth2_provider_backends +from rest_framework.compat import oauth2_provider, oauth2_provider_forms from rest_framework.authtoken.models import Token -import base64 def get_authorization_header(request): @@ -204,6 +206,9 @@ class OAuthAuthentication(BaseAuthentication): except oauth.Error as err: raise exceptions.AuthenticationFailed(err.message) + if not oauth_request: + return None + oauth_params = oauth_provider.consts.OAUTH_PARAMETERS_NAMES found = any(param for param in oauth_params if param in oauth_request) @@ -312,21 +317,15 @@ class OAuth2Authentication(BaseAuthentication): Authenticate the request, given the access token. """ - # Authenticate the client - oauth2_client_form = oauth2_provider_forms.ClientAuthForm(request.REQUEST) - if not oauth2_client_form.is_valid(): - raise exceptions.AuthenticationFailed('Client could not be validated') - client = oauth2_client_form.cleaned_data.get('client') - - # Retrieve the `OAuth2AccessToken` instance from the access_token - auth_backend = oauth2_provider_backends.AccessTokenBackend() - token = auth_backend.authenticate(access_token, client) - if token is None: + try: + token = oauth2_provider.models.AccessToken.objects.select_related('user') + # TODO: Change to timezone aware datetime when oauth2_provider add + # support to it. + token = token.get(token=access_token, expires__gt=datetime.now()) + except oauth2_provider.models.AccessToken.DoesNotExist: raise exceptions.AuthenticationFailed('Invalid token') - user = token.user - - if not user.is_active: + if not token.user.is_active: msg = 'User inactive or deleted: %s' % user.username raise exceptions.AuthenticationFailed(msg) diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 7b2ef738..6551723a 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -395,6 +395,37 @@ except ImportError: kw = dict((k, int(v)) for k, v in kw.iteritems() if v is not None) return datetime.datetime(**kw) + +# smart_urlquote is new on Django 1.4 +try: + from django.utils.html import smart_urlquote +except ImportError: + try: + from urllib.parse import quote, urlsplit, urlunsplit + except ImportError: # Python 2 + from urllib import quote + from urlparse import urlsplit, urlunsplit + + def smart_urlquote(url): + "Quotes a URL if it isn't already quoted." + # Handle IDN before quoting. + scheme, netloc, path, query, fragment = urlsplit(url) + try: + netloc = netloc.encode('idna').decode('ascii') # IDN -> ACE + except UnicodeError: # invalid domain part + pass + else: + url = urlunsplit((scheme, netloc, path, query, fragment)) + + # An URL is considered unquoted if it contains no % characters or + # contains a % not followed by two hexadecimal digits. See #9655. + if '%' not in url or unquoted_percents_re.search(url): + # See http://bugs.python.org/issue2637 + url = quote(force_bytes(url), safe=b'!*\'();:@&=+$,/?#[]~') + + return force_text(url) + + # Markdown is optional try: import markdown @@ -445,14 +476,12 @@ except ImportError: # OAuth 2 support is optional try: import provider.oauth2 as oauth2_provider - from provider.oauth2 import backends as oauth2_provider_backends from provider.oauth2 import models as oauth2_provider_models from provider.oauth2 import forms as oauth2_provider_forms from provider import scope as oauth2_provider_scope from provider import constants as oauth2_constants except ImportError: oauth2_provider = None - oauth2_provider_backends = None oauth2_provider_models = None oauth2_provider_forms = None oauth2_provider_scope = None diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 4b6931ad..f3496b53 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -494,7 +494,7 @@ class DateField(WritableField): } empty = None input_formats = api_settings.DATE_INPUT_FORMATS - format = api_settings.DATE_FORMAT + format = None def __init__(self, input_formats=None, format=None, *args, **kwargs): self.input_formats = input_formats if input_formats is not None else self.input_formats @@ -536,8 +536,8 @@ class DateField(WritableField): raise ValidationError(msg) def to_native(self, value): - if value is None: - return None + if value is None or self.format is None: + return value if isinstance(value, datetime.datetime): value = value.date() @@ -557,7 +557,7 @@ class DateTimeField(WritableField): } empty = None input_formats = api_settings.DATETIME_INPUT_FORMATS - format = api_settings.DATETIME_FORMAT + format = None def __init__(self, input_formats=None, format=None, *args, **kwargs): self.input_formats = input_formats if input_formats is not None else self.input_formats @@ -605,11 +605,14 @@ class DateTimeField(WritableField): raise ValidationError(msg) def to_native(self, value): - if value is None: - return None + if value is None or self.format is None: + return value if self.format.lower() == ISO_8601: - return value.isoformat() + ret = value.isoformat() + if ret.endswith('+00:00'): + ret = ret[:-6] + 'Z' + return ret return value.strftime(self.format) @@ -623,7 +626,7 @@ class TimeField(WritableField): } empty = None input_formats = api_settings.TIME_INPUT_FORMATS - format = api_settings.TIME_FORMAT + format = None def __init__(self, input_formats=None, format=None, *args, **kwargs): self.input_formats = input_formats if input_formats is not None else self.input_formats @@ -658,8 +661,8 @@ class TimeField(WritableField): raise ValidationError(msg) def to_native(self, value): - if value is None: - return None + if value is None or self.format is None: + return value if isinstance(value, datetime.datetime): value = value.time() diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 6fea46fa..413fa0d2 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -55,6 +55,6 @@ class DjangoFilterBackend(BaseFilterBackend): filter_class = self.get_filter_class(view) if filter_class: - return filter_class(request.QUERY_PARAMS, queryset=queryset) + return filter_class(request.QUERY_PARAMS, queryset=queryset).qs return queryset diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 4fe857a6..1b2b0821 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -20,6 +20,25 @@ from rest_framework.relations import * from rest_framework.fields import * +class NestedValidationError(ValidationError): + """ + The default ValidationError behavior is to stringify each item in the list + if the messages are a list of error messages. + + In the case of nested serializers, where the parent has many children, + then the child's `serializer.errors` will be a list of dicts. In the case + of a single child, the `serializer.errors` will be a dict. + + We need to override the default behavior to get properly nested error dicts. + """ + + def __init__(self, message): + if isinstance(message, dict): + self.messages = [message] + else: + self.messages = message + + class DictWithMetadata(dict): """ A dict-like object, that can have additional properties attached. @@ -98,7 +117,7 @@ class SerializerOptions(object): self.exclude = getattr(meta, 'exclude', ()) -class BaseSerializer(Field): +class BaseSerializer(WritableField): """ This is the Serializer implementation. We need to implement it as `BaseSerializer` due to metaclass magicks. @@ -110,13 +129,15 @@ class BaseSerializer(Field): _dict_class = SortedDictWithMetadata def __init__(self, instance=None, data=None, files=None, - context=None, partial=False, many=None, source=None): - super(BaseSerializer, self).__init__(source=source) + context=None, partial=False, many=None, + allow_add_remove=False, **kwargs): + super(BaseSerializer, self).__init__(**kwargs) self.opts = self._options_class(self.Meta) self.parent = None self.root = None self.partial = partial self.many = many + self.allow_add_remove = allow_add_remove self.context = context or {} @@ -128,6 +149,13 @@ class BaseSerializer(Field): self._data = None self._files = None self._errors = None + self._deleted = None + + if many and instance is not None and not hasattr(instance, '__iter__'): + raise ValueError('instance should be a queryset or other iterable with many=True') + + if allow_add_remove and not many: + raise ValueError('allow_add_remove should only be used for bulk updates, but you have not set many=True') ##### # Methods to determine which fields to use when (de)serializing objects. @@ -296,40 +324,91 @@ class BaseSerializer(Field): def field_to_native(self, obj, field_name): """ - Override default so that we can apply ModelSerializer as a nested - field to relationships. + Override default so that the serializer can be used as a nested field + across relationships. """ if self.source == '*': return self.to_native(obj) try: - if self.source: - for component in self.source.split('.'): - obj = getattr(obj, component) - if is_simple_callable(obj): - obj = obj() - else: - obj = getattr(obj, field_name) - if is_simple_callable(obj): - obj = obj() + source = self.source or field_name + value = obj + + for component in source.split('.'): + value = get_component(value, component) + if value is None: + break except ObjectDoesNotExist: return None - # If the object has an "all" method, assume it's a relationship - if is_simple_callable(getattr(obj, 'all', None)): - return [self.to_native(item) for item in obj.all()] + if is_simple_callable(getattr(value, 'all', None)): + return [self.to_native(item) for item in value.all()] - if obj is None: + if value is None: return None if self.many is not None: many = self.many else: - many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict, six.text_type)) + many = hasattr(value, '__iter__') and not isinstance(value, (Page, dict, six.text_type)) if many: - return [self.to_native(item) for item in obj] - return self.to_native(obj) + return [self.to_native(item) for item in value] + return self.to_native(value) + + def field_from_native(self, data, files, field_name, into): + """ + Override default so that the serializer can be used as a writable + nested field across relationships. + """ + if self.read_only: + return + + try: + value = data[field_name] + except KeyError: + if self.default is not None and not self.partial: + # Note: partial updates shouldn't set defaults + value = copy.deepcopy(self.default) + else: + if self.required: + raise ValidationError(self.error_messages['required']) + return + + # Set the serializer object if it exists + obj = getattr(self.parent.object, field_name) if self.parent.object else None + + if value in (None, ''): + into[(self.source or field_name)] = None + else: + kwargs = { + 'instance': obj, + 'data': value, + 'context': self.context, + 'partial': self.partial, + 'many': self.many + } + serializer = self.__class__(**kwargs) + + if serializer.is_valid(): + into[self.source or field_name] = serializer.object + else: + # Propagate errors up to our parent + raise NestedValidationError(serializer.errors) + + def get_identity(self, data): + """ + This hook is required for bulk update. + It is used to determine the canonical identity of a given object. + + Note that the data has not been validated at this point, so we need + to make sure that we catch any cases of incorrect datatypes being + passed to this method. + """ + try: + return data.get('id', None) + except AttributeError: + return None @property def errors(self): @@ -352,10 +431,37 @@ class BaseSerializer(Field): if many: ret = [] errors = [] - for item in data: - ret.append(self.from_native(item, None)) - errors.append(self._errors) - self._errors = any(errors) and errors or [] + update = self.object is not None + + if update: + # If this is a bulk update we need to map all the objects + # to a canonical identity so we can determine which + # individual object is being updated for each item in the + # incoming data + objects = self.object + identities = [self.get_identity(self.to_native(obj)) for obj in objects] + identity_to_objects = dict(zip(identities, objects)) + + if hasattr(data, '__iter__') and not isinstance(data, (dict, six.text_type)): + for item in data: + if update: + # Determine which object we're updating + identity = self.get_identity(item) + self.object = identity_to_objects.pop(identity, None) + if self.object is None and not self.allow_add_remove: + ret.append(None) + errors.append({'non_field_errors': ['Cannot create a new item, only existing items may be updated.']}) + continue + + ret.append(self.from_native(item, None)) + errors.append(self._errors) + + if update: + self._deleted = identity_to_objects.values() + + self._errors = any(errors) and errors or [] + else: + self._errors = {'non_field_errors': ['Expected a list of items.']} else: ret = self.from_native(data, files) @@ -394,6 +500,9 @@ class BaseSerializer(Field): def save_object(self, obj, **kwargs): obj.save(**kwargs) + def delete_object(self, obj): + obj.delete() + def save(self, **kwargs): """ Save the deserialized object and return it. @@ -402,6 +511,10 @@ class BaseSerializer(Field): [self.save_object(item, **kwargs) for item in self.object] else: self.save_object(self.object, **kwargs) + + if self.allow_add_remove and self._deleted: + [self.delete_object(item) for item in self._deleted] + return self.object @@ -584,33 +697,43 @@ class ModelSerializer(Serializer): """ Restore the model instance. """ - self.m2m_data = {} - self.related_data = {} + m2m_data = {} + related_data = {} + meta = self.opts.model._meta - # Reverse fk relations - for (obj, model) in self.opts.model._meta.get_all_related_objects_with_model(): + # Reverse fk or one-to-one relations + for (obj, model) in meta.get_all_related_objects_with_model(): field_name = obj.field.related_query_name() if field_name in attrs: - self.related_data[field_name] = attrs.pop(field_name) + related_data[field_name] = attrs.pop(field_name) # Reverse m2m relations - for (obj, model) in self.opts.model._meta.get_all_related_m2m_objects_with_model(): + for (obj, model) in meta.get_all_related_m2m_objects_with_model(): field_name = obj.field.related_query_name() if field_name in attrs: - self.m2m_data[field_name] = attrs.pop(field_name) + m2m_data[field_name] = attrs.pop(field_name) # Forward m2m relations - for field in self.opts.model._meta.many_to_many: + for field in meta.many_to_many: if field.name in attrs: - self.m2m_data[field.name] = attrs.pop(field.name) + m2m_data[field.name] = attrs.pop(field.name) + # Update an existing instance... if instance is not None: for key, val in attrs.items(): setattr(instance, key, val) + # ...or create a new instance else: instance = self.opts.model(**attrs) + # Any relations that cannot be set until we've + # saved the model get hidden away on these + # private attributes, so we can deal with them + # at the point of save. + instance._related_data = related_data + instance._m2m_data = m2m_data + return instance def from_native(self, data, files): @@ -627,15 +750,15 @@ class ModelSerializer(Serializer): """ obj.save(**kwargs) - if getattr(self, 'm2m_data', None): - for accessor_name, object_list in self.m2m_data.items(): - setattr(self.object, accessor_name, object_list) - self.m2m_data = {} + if getattr(obj, '_m2m_data', None): + for accessor_name, object_list in obj._m2m_data.items(): + setattr(obj, accessor_name, object_list) + del(obj._m2m_data) - if getattr(self, 'related_data', None): - for accessor_name, object_list in self.related_data.items(): - setattr(self.object, accessor_name, object_list) - self.related_data = {} + if getattr(obj, '_related_data', None): + for accessor_name, related in obj._related_data.items(): + setattr(obj, accessor_name, related) + del(obj._related_data) class HyperlinkedModelSerializerOptions(ModelSerializerOptions): @@ -690,3 +813,13 @@ class HyperlinkedModelSerializer(ModelSerializer): 'many': to_many } return HyperlinkedRelatedField(**kwargs) + + def get_identity(self, data): + """ + This hook is required for bulk update. + We need to override the default, to use the url as the identity. + """ + try: + return data.get('url', None) + except AttributeError: + return None diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index c21ddcd7..b6ab2de3 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -4,11 +4,8 @@ from django.core.urlresolvers import reverse, NoReverseMatch from django.http import QueryDict from django.utils.html import escape from django.utils.safestring import SafeData, mark_safe -from rest_framework.compat import urlparse -from rest_framework.compat import force_text -from rest_framework.compat import six -import re -import string +from rest_framework.compat import urlparse, force_text, six, smart_urlquote +import re, string register = template.Library() @@ -112,22 +109,6 @@ def replace_query_param(url, key, val): class_re = re.compile(r'(?<=class=["\'])(.*)(?=["\'])') -# Bunch of stuff cloned from urlize -LEADING_PUNCTUATION = ['(', '<', '<', '"', "'"] -TRAILING_PUNCTUATION = ['.', ',', ')', '>', '\n', '>', '"', "'"] -DOTS = ['·', '*', '\xe2\x80\xa2', '•', '•', '•'] -unencoded_ampersands_re = re.compile(r'&(?!(\w+|#\d+);)') -word_split_re = re.compile(r'(\s+)') -punctuation_re = re.compile('^(?P<lead>(?:%s)*)(?P<middle>.*?)(?P<trail>(?:%s)*)$' % \ - ('|'.join([re.escape(x) for x in LEADING_PUNCTUATION]), - '|'.join([re.escape(x) for x in TRAILING_PUNCTUATION]))) -simple_email_re = re.compile(r'^\S+@[a-zA-Z0-9._-]+\.[a-zA-Z0-9._-]+$') -link_target_attribute_re = re.compile(r'(<a [^>]*?)target=[^\s>]+') -html_gunk_re = re.compile(r'(?:<br clear="all">|<i><\/i>|<b><\/b>|<em><\/em>|<strong><\/strong>|<\/?smallcaps>|<\/?uppercase>)', re.IGNORECASE) -hard_coded_bullets_re = re.compile(r'((?:<p>(?:%s).*?[a-zA-Z].*?</p>\s*)+)' % '|'.join([re.escape(x) for x in DOTS]), re.DOTALL) -trailing_empty_content_re = re.compile(r'(?:<p>(?: |\s|<br \/>)*?</p>\s*)+\Z') - - # And the template tags themselves... @register.simple_tag @@ -195,15 +176,25 @@ def add_class(value, css_class): return value +# Bunch of stuff cloned from urlize +TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)', '"', "'"] +WRAPPING_PUNCTUATION = [('(', ')'), ('<', '>'), ('[', ']'), ('<', '>'), + ('"', '"'), ("'", "'")] +word_split_re = re.compile(r'(\s+)') +simple_url_re = re.compile(r'^https?://\w', re.IGNORECASE) +simple_url_2_re = re.compile(r'^www\.|^(?!http)\w[^@]+\.(com|edu|gov|int|mil|net|org)$', re.IGNORECASE) +simple_email_re = re.compile(r'^\S+@\S+\.\S+$') + + @register.filter def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=True): """ Converts any URLs in text into clickable links. - Works on http://, https://, www. links and links ending in .org, .net or - .com. Links can have trailing punctuation (periods, commas, close-parens) - and leading punctuation (opening parens) and it'll still do the right - thing. + Works on http://, https://, www. links, and also on links ending in one of + the original seven gTLDs (.com, .edu, .gov, .int, .mil, .net, and .org). + Links can have trailing punctuation (periods, commas, close-parens) and + leading punctuation (opening parens) and it'll still do the right thing. If trim_url_limit is not None, the URLs in link text longer than this limit will truncated to trim_url_limit-3 characters and appended with an elipsis. @@ -216,24 +207,41 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru trim_url = lambda x, limit=trim_url_limit: limit is not None and (len(x) > limit and ('%s...' % x[:max(0, limit - 3)])) or x safe_input = isinstance(text, SafeData) words = word_split_re.split(force_text(text)) - nofollow_attr = nofollow and ' rel="nofollow"' or '' for i, word in enumerate(words): match = None if '.' in word or '@' in word or ':' in word: - match = punctuation_re.match(word) - if match: - lead, middle, trail = match.groups() + # Deal with punctuation. + lead, middle, trail = '', word, '' + for punctuation in TRAILING_PUNCTUATION: + if middle.endswith(punctuation): + middle = middle[:-len(punctuation)] + trail = punctuation + trail + for opening, closing in WRAPPING_PUNCTUATION: + if middle.startswith(opening): + middle = middle[len(opening):] + lead = lead + opening + # Keep parentheses at the end only if they're balanced. + if (middle.endswith(closing) + and middle.count(closing) == middle.count(opening) + 1): + middle = middle[:-len(closing)] + trail = closing + trail + # Make URL we want to point to. url = None - if middle.startswith('http://') or middle.startswith('https://'): - url = middle - elif middle.startswith('www.') or ('@' not in middle and \ - middle and middle[0] in string.ascii_letters + string.digits and \ - (middle.endswith('.org') or middle.endswith('.net') or middle.endswith('.com'))): - url = 'http://%s' % middle - elif '@' in middle and not ':' in middle and simple_email_re.match(middle): - url = 'mailto:%s' % middle + nofollow_attr = ' rel="nofollow"' if nofollow else '' + if simple_url_re.match(middle): + url = smart_urlquote(middle) + elif simple_url_2_re.match(middle): + url = smart_urlquote('http://%s' % middle) + elif not ':' in middle and simple_email_re.match(middle): + local, domain = middle.rsplit('@', 1) + try: + domain = domain.encode('idna').decode('ascii') + except UnicodeError: + continue + url = 'mailto:%s@%s' % (local, domain) nofollow_attr = '' + # Make link. if url: trimmed = trim_url(middle) @@ -251,4 +259,4 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru words[i] = mark_safe(word) elif autoescape: words[i] = escape(word) - return mark_safe(''.join(words)) + return ''.join(words) diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py index b663ca48..8e6d3e51 100644 --- a/rest_framework/tests/authentication.py +++ b/rest_framework/tests/authentication.py @@ -466,17 +466,13 @@ class OAuth2Tests(TestCase): def _create_authorization_header(self, token=None): return "Bearer {0}".format(token or self.access_token.token) - def _client_credentials_params(self): - return {'client_id': self.CLIENT_ID, 'client_secret': self.CLIENT_SECRET} - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') def test_get_form_with_wrong_authorization_header_token_type_failing(self): """Ensure that a wrong token type lead to the correct HTTP error status code""" auth = "Wrong token-type-obsviously" response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 401) - params = self._client_credentials_params() - response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 401) @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') @@ -485,8 +481,7 @@ class OAuth2Tests(TestCase): auth = "Bearer wrong token format" response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 401) - params = self._client_credentials_params() - response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 401) @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') @@ -495,33 +490,21 @@ class OAuth2Tests(TestCase): auth = "Bearer wrong-token" response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 401) - params = self._client_credentials_params() - response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) - self.assertEqual(response.status_code, 401) - - @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') - def test_get_form_with_wrong_client_data_failing_auth(self): - """Ensure GETing form over OAuth with incorrect client credentials fails""" - auth = self._create_authorization_header() - params = self._client_credentials_params() - params['client_id'] += 'a' - response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 401) @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') def test_get_form_passing_auth(self): """Ensure GETing form over OAuth with correct client credentials succeed""" auth = self._create_authorization_header() - params = self._client_credentials_params() - response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') def test_post_form_passing_auth(self): """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF""" auth = self._create_authorization_header() - params = self._client_credentials_params() - response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') @@ -529,16 +512,14 @@ class OAuth2Tests(TestCase): """Ensure POSTing when there is no OAuth access token in db fails""" self.access_token.delete() auth = self._create_authorization_header() - params = self._client_credentials_params() - response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') def test_post_form_with_refresh_token_failing_auth(self): """Ensure POSTing with refresh token instead of access token fails""" auth = self._create_authorization_header(token=self.refresh_token.token) - params = self._client_credentials_params() - response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') @@ -547,8 +528,7 @@ class OAuth2Tests(TestCase): self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10) # 10 seconds late self.access_token.save() auth = self._create_authorization_header() - params = self._client_credentials_params() - response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) self.assertIn('Invalid token', response.content) @@ -559,10 +539,9 @@ class OAuth2Tests(TestCase): read_only_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['read'] read_only_access_token.save() auth = self._create_authorization_header(token=read_only_access_token.token) - params = self._client_credentials_params() - response = self.csrf_client.get('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.get('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) - response = self.csrf_client.post('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') @@ -572,6 +551,5 @@ class OAuth2Tests(TestCase): read_write_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['write'] read_write_access_token.save() auth = self._create_authorization_header(token=read_write_access_token.token) - params = self._client_credentials_params() - response = self.csrf_client.post('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py index fd6de779..19c663d8 100644 --- a/rest_framework/tests/fields.py +++ b/rest_framework/tests/fields.py @@ -153,12 +153,22 @@ class DateFieldTest(TestCase): def test_to_native(self): """ - Make sure to_native() returns isoformat as default. + Make sure to_native() returns datetime as default. """ f = serializers.DateField() result_1 = f.to_native(datetime.date(1984, 7, 31)) + self.assertEqual(datetime.date(1984, 7, 31), result_1) + + def test_to_native_iso(self): + """ + Make sure to_native() with 'iso-8601' returns iso formated date. + """ + f = serializers.DateField(format='iso-8601') + + result_1 = f.to_native(datetime.date(1984, 7, 31)) + self.assertEqual('1984-07-31', result_1) def test_to_native_custom_format(self): @@ -289,6 +299,22 @@ class DateTimeFieldTest(TestCase): result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59)) result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) + self.assertEqual(datetime.datetime(1984, 7, 31), result_1) + self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_2) + self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_3) + self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_4) + + def test_to_native_iso(self): + """ + Make sure to_native() with format=iso-8601 returns iso formatted datetime. + """ + f = serializers.DateTimeField(format='iso-8601') + + result_1 = f.to_native(datetime.datetime(1984, 7, 31)) + result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31)) + result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59)) + result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) + self.assertEqual('1984-07-31T00:00:00', result_1) self.assertEqual('1984-07-31T04:31:00', result_2) self.assertEqual('1984-07-31T04:31:59', result_3) @@ -419,13 +445,26 @@ class TimeFieldTest(TestCase): def test_to_native(self): """ - Make sure to_native() returns isoformat as default. + Make sure to_native() returns time object as default. """ f = serializers.TimeField() result_1 = f.to_native(datetime.time(4, 31)) result_2 = f.to_native(datetime.time(4, 31, 59)) result_3 = f.to_native(datetime.time(4, 31, 59, 200)) + self.assertEqual(datetime.time(4, 31), result_1) + self.assertEqual(datetime.time(4, 31, 59), result_2) + self.assertEqual(datetime.time(4, 31, 59, 200), result_3) + + def test_to_native_iso(self): + """ + Make sure to_native() with format='iso-8601' returns iso formatted time. + """ + f = serializers.TimeField(format='iso-8601') + result_1 = f.to_native(datetime.time(4, 31)) + result_2 = f.to_native(datetime.time(4, 31, 59)) + result_3 = f.to_native(datetime.time(4, 31, 59, 200)) + self.assertEqual('04:31:00', result_1) self.assertEqual('04:31:59', result_2) self.assertEqual('04:31:59.000200', result_3) diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py index fe92e0bc..1a71558c 100644 --- a/rest_framework/tests/filterset.py +++ b/rest_framework/tests/filterset.py @@ -1,11 +1,12 @@ from __future__ import unicode_literals import datetime from decimal import Decimal +from django.core.urlresolvers import reverse from django.test import TestCase from django.test.client import RequestFactory from django.utils import unittest from rest_framework import generics, status, filters -from rest_framework.compat import django_filters +from rest_framework.compat import django_filters, patterns, url from rest_framework.tests.models import FilterableItem, BasicModel factory = RequestFactory() @@ -46,12 +47,21 @@ if django_filters: filter_class = MisconfiguredFilter filter_backend = filters.DjangoFilterBackend + class FilterClassDetailView(generics.RetrieveAPIView): + model = FilterableItem + filter_class = SeveralFieldsFilter + filter_backend = filters.DjangoFilterBackend + + urlpatterns = patterns('', + url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'), + url(r'^$', FilterClassRootView.as_view(), name='root-view'), + ) -class IntegrationTestFiltering(TestCase): - """ - Integration tests for filtered list views. - """ +class CommonFilteringTestCase(TestCase): + def _serialize_object(self, obj): + return {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} + def setUp(self): """ Create 10 FilterableItem instances. @@ -65,10 +75,16 @@ class IntegrationTestFiltering(TestCase): self.objects = FilterableItem.objects self.data = [ - {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date.isoformat()} + self._serialize_object(obj) for obj in self.objects.all() ] + +class IntegrationTestFiltering(CommonFilteringTestCase): + """ + Integration tests for filtered list views. + """ + @unittest.skipUnless(django_filters, 'django-filters not installed') def test_get_filtered_fields_root_view(self): """ @@ -95,7 +111,7 @@ class IntegrationTestFiltering(TestCase): request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22' response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) - expected_data = [f for f in self.data if datetime.datetime.strptime(f['date'], '%Y-%m-%d').date() == search_date] + expected_data = [f for f in self.data if f['date'] == search_date] self.assertEqual(response.data, expected_data) @unittest.skipUnless(django_filters, 'django-filters not installed') @@ -125,7 +141,7 @@ class IntegrationTestFiltering(TestCase): request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02' response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) - expected_data = [f for f in self.data if datetime.datetime.strptime(f['date'], '%Y-%m-%d').date() > search_date] + expected_data = [f for f in self.data if f['date'] > search_date] self.assertEqual(response.data, expected_data) # Tests that the text filter set with 'icontains' in the filter class works. @@ -142,8 +158,7 @@ class IntegrationTestFiltering(TestCase): request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date)) response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) - expected_data = [f for f in self.data if - datetime.datetime.strptime(f['date'], '%Y-%m-%d').date() > search_date and + expected_data = [f for f in self.data if f['date'] > search_date and f['decimal'] < search_decimal] self.assertEqual(response.data, expected_data) @@ -168,3 +183,50 @@ class IntegrationTestFiltering(TestCase): request = factory.get('/?integer=%s' % search_integer) response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) + + +class IntegrationTestDetailFiltering(CommonFilteringTestCase): + """ + Integration tests for filtered detail views. + """ + urls = 'rest_framework.tests.filterset' + + def _get_url(self, item): + return reverse('detail-view', kwargs=dict(pk=item.pk)) + + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_get_filtered_detail_view(self): + """ + GET requests to filtered RetrieveAPIView that have a filter_class set + should return filtered results. + """ + item = self.objects.all()[0] + data = self._serialize_object(item) + + # Basic test with no filter. + response = self.client.get(self._get_url(item)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, data) + + # Tests that the decimal filter set that should fail. + search_decimal = Decimal('4.25') + high_item = self.objects.filter(decimal__gt=search_decimal)[0] + response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(high_item), param=search_decimal)) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + # Tests that the decimal filter set that should succeed. + search_decimal = Decimal('4.25') + low_item = self.objects.filter(decimal__lt=search_decimal)[0] + low_item_data = self._serialize_object(low_item) + response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, low_item_data) + + # Tests that multiple filters works. + search_decimal = Decimal('5.25') + search_date = datetime.date(2012, 10, 2) + valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0] + valid_item_data = self._serialize_object(valid_item) + response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date)) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, valid_item_data) diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py index 1a2d68a6..6b8ef02f 100644 --- a/rest_framework/tests/pagination.py +++ b/rest_framework/tests/pagination.py @@ -102,7 +102,7 @@ class IntegrationTestPaginationAndFiltering(TestCase): self.objects = FilterableItem.objects self.data = [ - {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date.isoformat()} + {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} for obj in self.objects.all() ] @@ -129,16 +129,6 @@ class IntegrationTestPaginationAndFiltering(TestCase): view = FilterFieldsRootView.as_view() EXPECTED_NUM_QUERIES = 2 - if django.VERSION < (1, 4): - # On Django 1.3 we need to use django-filter 0.5.4 - # - # The filter objects there don't expose a `.count()` method, - # which means we only make a single query *but* it's a single - # query across *all* of the queryset, instead of a COUNT and then - # a SELECT with a LIMIT. - # - # Although this is fewer queries, it's actually a regression. - EXPECTED_NUM_QUERIES = 1 request = factory.get('/?decimal=15.20') with self.assertNumQueries(EXPECTED_NUM_QUERIES): diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py index beb372c2..05217f35 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/serializer.py @@ -112,7 +112,7 @@ class BasicTests(TestCase): self.expected = { 'email': 'tom@example.com', 'content': 'Happy new year!', - 'created': '2012-01-01T00:00:00', + 'created': datetime.datetime(2012, 1, 1), 'sub_comment': 'And Merry Christmas!' } self.person_data = {'name': 'dwight', 'age': 35} @@ -261,34 +261,6 @@ class ValidationTests(TestCase): self.assertEqual(serializer.is_valid(), True) self.assertEqual(serializer.errors, {}) - def test_bad_type_data_is_false(self): - """ - Data of the wrong type is not valid. - """ - data = ['i am', 'a', 'list'] - serializer = CommentSerializer(self.comment, data=data, many=True) - self.assertEqual(serializer.is_valid(), False) - self.assertTrue(isinstance(serializer.errors, list)) - - self.assertEqual( - serializer.errors, - [ - {'non_field_errors': ['Invalid data']}, - {'non_field_errors': ['Invalid data']}, - {'non_field_errors': ['Invalid data']} - ] - ) - - data = 'and i am a string' - serializer = CommentSerializer(self.comment, data=data) - self.assertEqual(serializer.is_valid(), False) - self.assertEqual(serializer.errors, {'non_field_errors': ['Invalid data']}) - - data = 42 - serializer = CommentSerializer(self.comment, data=data) - self.assertEqual(serializer.is_valid(), False) - self.assertEqual(serializer.errors, {'non_field_errors': ['Invalid data']}) - def test_cross_field_validation(self): class CommentSerializerWithCrossFieldValidator(CommentSerializer): diff --git a/rest_framework/tests/serializer_bulk_update.py b/rest_framework/tests/serializer_bulk_update.py new file mode 100644 index 00000000..8b0ded1a --- /dev/null +++ b/rest_framework/tests/serializer_bulk_update.py @@ -0,0 +1,278 @@ +""" +Tests to cover bulk create and update using serializers. +""" +from __future__ import unicode_literals +from django.test import TestCase +from rest_framework import serializers + + +class BulkCreateSerializerTests(TestCase): + """ + Creating multiple instances using serializers. + """ + + def setUp(self): + class BookSerializer(serializers.Serializer): + id = serializers.IntegerField() + title = serializers.CharField(max_length=100) + author = serializers.CharField(max_length=100) + + self.BookSerializer = BookSerializer + + def test_bulk_create_success(self): + """ + Correct bulk update serialization should return the input data. + """ + + data = [ + { + 'id': 0, + 'title': 'The electric kool-aid acid test', + 'author': 'Tom Wolfe' + }, { + 'id': 1, + 'title': 'If this is a man', + 'author': 'Primo Levi' + }, { + 'id': 2, + 'title': 'The wind-up bird chronicle', + 'author': 'Haruki Murakami' + } + ] + + serializer = self.BookSerializer(data=data, many=True) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.object, data) + + def test_bulk_create_errors(self): + """ + Correct bulk update serialization should return the input data. + """ + + data = [ + { + 'id': 0, + 'title': 'The electric kool-aid acid test', + 'author': 'Tom Wolfe' + }, { + 'id': 1, + 'title': 'If this is a man', + 'author': 'Primo Levi' + }, { + 'id': 'foo', + 'title': 'The wind-up bird chronicle', + 'author': 'Haruki Murakami' + } + ] + expected_errors = [ + {}, + {}, + {'id': ['Enter a whole number.']} + ] + + serializer = self.BookSerializer(data=data, many=True) + self.assertEqual(serializer.is_valid(), False) + self.assertEqual(serializer.errors, expected_errors) + + def test_invalid_list_datatype(self): + """ + Data containing list of incorrect data type should return errors. + """ + data = ['foo', 'bar', 'baz'] + serializer = self.BookSerializer(data=data, many=True) + self.assertEqual(serializer.is_valid(), False) + + expected_errors = [ + {'non_field_errors': ['Invalid data']}, + {'non_field_errors': ['Invalid data']}, + {'non_field_errors': ['Invalid data']} + ] + + self.assertEqual(serializer.errors, expected_errors) + + def test_invalid_single_datatype(self): + """ + Data containing a single incorrect data type should return errors. + """ + data = 123 + serializer = self.BookSerializer(data=data, many=True) + self.assertEqual(serializer.is_valid(), False) + + expected_errors = {'non_field_errors': ['Expected a list of items.']} + + self.assertEqual(serializer.errors, expected_errors) + + def test_invalid_single_object(self): + """ + Data containing only a single object, instead of a list of objects + should return errors. + """ + data = { + 'id': 0, + 'title': 'The electric kool-aid acid test', + 'author': 'Tom Wolfe' + } + serializer = self.BookSerializer(data=data, many=True) + self.assertEqual(serializer.is_valid(), False) + + expected_errors = {'non_field_errors': ['Expected a list of items.']} + + self.assertEqual(serializer.errors, expected_errors) + + +class BulkUpdateSerializerTests(TestCase): + """ + Updating multiple instances using serializers. + """ + + def setUp(self): + class Book(object): + """ + A data type that can be persisted to a mock storage backend + with `.save()` and `.delete()`. + """ + object_map = {} + + def __init__(self, id, title, author): + self.id = id + self.title = title + self.author = author + + def save(self): + Book.object_map[self.id] = self + + def delete(self): + del Book.object_map[self.id] + + class BookSerializer(serializers.Serializer): + id = serializers.IntegerField() + title = serializers.CharField(max_length=100) + author = serializers.CharField(max_length=100) + + def restore_object(self, attrs, instance=None): + if instance: + instance.id = attrs['id'] + instance.title = attrs['title'] + instance.author = attrs['author'] + return instance + return Book(**attrs) + + self.Book = Book + self.BookSerializer = BookSerializer + + data = [ + { + 'id': 0, + 'title': 'The electric kool-aid acid test', + 'author': 'Tom Wolfe' + }, { + 'id': 1, + 'title': 'If this is a man', + 'author': 'Primo Levi' + }, { + 'id': 2, + 'title': 'The wind-up bird chronicle', + 'author': 'Haruki Murakami' + } + ] + + for item in data: + book = Book(item['id'], item['title'], item['author']) + book.save() + + def books(self): + """ + Return all the objects in the mock storage backend. + """ + return self.Book.object_map.values() + + def test_bulk_update_success(self): + """ + Correct bulk update serialization should return the input data. + """ + data = [ + { + 'id': 0, + 'title': 'The electric kool-aid acid test', + 'author': 'Tom Wolfe' + }, { + 'id': 2, + 'title': 'Kafka on the shore', + 'author': 'Haruki Murakami' + } + ] + serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.data, data) + serializer.save() + new_data = self.BookSerializer(self.books(), many=True).data + + self.assertEqual(data, new_data) + + def test_bulk_update_and_create(self): + """ + Bulk update serialization may also include created items. + """ + data = [ + { + 'id': 0, + 'title': 'The electric kool-aid acid test', + 'author': 'Tom Wolfe' + }, { + 'id': 3, + 'title': 'Kafka on the shore', + 'author': 'Haruki Murakami' + } + ] + serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.data, data) + serializer.save() + new_data = self.BookSerializer(self.books(), many=True).data + self.assertEqual(data, new_data) + + def test_bulk_update_invalid_create(self): + """ + Bulk update serialization without allow_add_remove may not create items. + """ + data = [ + { + 'id': 0, + 'title': 'The electric kool-aid acid test', + 'author': 'Tom Wolfe' + }, { + 'id': 3, + 'title': 'Kafka on the shore', + 'author': 'Haruki Murakami' + } + ] + expected_errors = [ + {}, + {'non_field_errors': ['Cannot create a new item, only existing items may be updated.']} + ] + serializer = self.BookSerializer(self.books(), data=data, many=True) + self.assertEqual(serializer.is_valid(), False) + self.assertEqual(serializer.errors, expected_errors) + + def test_bulk_update_error(self): + """ + Incorrect bulk update serialization should return error data. + """ + data = [ + { + 'id': 0, + 'title': 'The electric kool-aid acid test', + 'author': 'Tom Wolfe' + }, { + 'id': 'foo', + 'title': 'Kafka on the shore', + 'author': 'Haruki Murakami' + } + ] + expected_errors = [ + {}, + {'id': ['Enter a whole number.']} + ] + serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) + self.assertEqual(serializer.is_valid(), False) + self.assertEqual(serializer.errors, expected_errors) diff --git a/rest_framework/tests/serializer_nested.py b/rest_framework/tests/serializer_nested.py new file mode 100644 index 00000000..6a29c652 --- /dev/null +++ b/rest_framework/tests/serializer_nested.py @@ -0,0 +1,246 @@ +""" +Tests to cover nested serializers. + +Doesn't cover model serializers. +""" +from __future__ import unicode_literals +from django.test import TestCase +from rest_framework import serializers + + +class WritableNestedSerializerBasicTests(TestCase): + """ + Tests for deserializing nested entities. + Basic tests that use serializers that simply restore to dicts. + """ + + def setUp(self): + class TrackSerializer(serializers.Serializer): + order = serializers.IntegerField() + title = serializers.CharField(max_length=100) + duration = serializers.IntegerField() + + class AlbumSerializer(serializers.Serializer): + album_name = serializers.CharField(max_length=100) + artist = serializers.CharField(max_length=100) + tracks = TrackSerializer(many=True) + + self.AlbumSerializer = AlbumSerializer + + def test_nested_validation_success(self): + """ + Correct nested serialization should return the input data. + """ + + data = { + 'album_name': 'Discovery', + 'artist': 'Daft Punk', + 'tracks': [ + {'order': 1, 'title': 'One More Time', 'duration': 235}, + {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, + {'order': 3, 'title': 'Digital Love', 'duration': 239} + ] + } + + serializer = self.AlbumSerializer(data=data) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.object, data) + + def test_nested_validation_error(self): + """ + Incorrect nested serialization should return appropriate error data. + """ + + data = { + 'album_name': 'Discovery', + 'artist': 'Daft Punk', + 'tracks': [ + {'order': 1, 'title': 'One More Time', 'duration': 235}, + {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, + {'order': 3, 'title': 'Digital Love', 'duration': 'foobar'} + ] + } + expected_errors = { + 'tracks': [ + {}, + {}, + {'duration': ['Enter a whole number.']} + ] + } + + serializer = self.AlbumSerializer(data=data) + self.assertEqual(serializer.is_valid(), False) + self.assertEqual(serializer.errors, expected_errors) + + def test_many_nested_validation_error(self): + """ + Incorrect nested serialization should return appropriate error data + when multiple entities are being deserialized. + """ + + data = [ + { + 'album_name': 'Russian Red', + 'artist': 'I Love Your Glasses', + 'tracks': [ + {'order': 1, 'title': 'Cigarettes', 'duration': 121}, + {'order': 2, 'title': 'No Past Land', 'duration': 198}, + {'order': 3, 'title': 'They Don\'t Believe', 'duration': 191} + ] + }, + { + 'album_name': 'Discovery', + 'artist': 'Daft Punk', + 'tracks': [ + {'order': 1, 'title': 'One More Time', 'duration': 235}, + {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, + {'order': 3, 'title': 'Digital Love', 'duration': 'foobar'} + ] + } + ] + expected_errors = [ + {}, + { + 'tracks': [ + {}, + {}, + {'duration': ['Enter a whole number.']} + ] + } + ] + + serializer = self.AlbumSerializer(data=data) + self.assertEqual(serializer.is_valid(), False) + self.assertEqual(serializer.errors, expected_errors) + + +class WritableNestedSerializerObjectTests(TestCase): + """ + Tests for deserializing nested entities. + These tests use serializers that restore to concrete objects. + """ + + def setUp(self): + # Couple of concrete objects that we're going to deserialize into + class Track(object): + def __init__(self, order, title, duration): + self.order, self.title, self.duration = order, title, duration + + def __eq__(self, other): + return ( + self.order == other.order and + self.title == other.title and + self.duration == other.duration + ) + + class Album(object): + def __init__(self, album_name, artist, tracks): + self.album_name, self.artist, self.tracks = album_name, artist, tracks + + def __eq__(self, other): + return ( + self.album_name == other.album_name and + self.artist == other.artist and + self.tracks == other.tracks + ) + + # And their corresponding serializers + class TrackSerializer(serializers.Serializer): + order = serializers.IntegerField() + title = serializers.CharField(max_length=100) + duration = serializers.IntegerField() + + def restore_object(self, attrs, instance=None): + return Track(attrs['order'], attrs['title'], attrs['duration']) + + class AlbumSerializer(serializers.Serializer): + album_name = serializers.CharField(max_length=100) + artist = serializers.CharField(max_length=100) + tracks = TrackSerializer(many=True) + + def restore_object(self, attrs, instance=None): + return Album(attrs['album_name'], attrs['artist'], attrs['tracks']) + + self.Album, self.Track = Album, Track + self.AlbumSerializer = AlbumSerializer + + def test_nested_validation_success(self): + """ + Correct nested serialization should return a restored object + that corresponds to the input data. + """ + + data = { + 'album_name': 'Discovery', + 'artist': 'Daft Punk', + 'tracks': [ + {'order': 1, 'title': 'One More Time', 'duration': 235}, + {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, + {'order': 3, 'title': 'Digital Love', 'duration': 239} + ] + } + expected_object = self.Album( + album_name='Discovery', + artist='Daft Punk', + tracks=[ + self.Track(order=1, title='One More Time', duration=235), + self.Track(order=2, title='Aerodynamic', duration=184), + self.Track(order=3, title='Digital Love', duration=239), + ] + ) + + serializer = self.AlbumSerializer(data=data) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.object, expected_object) + + def test_many_nested_validation_success(self): + """ + Correct nested serialization should return multiple restored objects + that corresponds to the input data when multiple objects are + being deserialized. + """ + + data = [ + { + 'album_name': 'Russian Red', + 'artist': 'I Love Your Glasses', + 'tracks': [ + {'order': 1, 'title': 'Cigarettes', 'duration': 121}, + {'order': 2, 'title': 'No Past Land', 'duration': 198}, + {'order': 3, 'title': 'They Don\'t Believe', 'duration': 191} + ] + }, + { + 'album_name': 'Discovery', + 'artist': 'Daft Punk', + 'tracks': [ + {'order': 1, 'title': 'One More Time', 'duration': 235}, + {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, + {'order': 3, 'title': 'Digital Love', 'duration': 239} + ] + } + ] + expected_object = [ + self.Album( + album_name='Russian Red', + artist='I Love Your Glasses', + tracks=[ + self.Track(order=1, title='Cigarettes', duration=121), + self.Track(order=2, title='No Past Land', duration=198), + self.Track(order=3, title='They Don\'t Believe', duration=191), + ] + ), + self.Album( + album_name='Discovery', + artist='Daft Punk', + tracks=[ + self.Track(order=1, title='One More Time', duration=235), + self.Track(order=2, title='Aerodynamic', duration=184), + self.Track(order=3, title='Digital Love', duration=239), + ] + ) + ] + + serializer = self.AlbumSerializer(data=data) + self.assertEqual(serializer.is_valid(), True) + self.assertEqual(serializer.object, expected_object) diff --git a/rest_framework/tests/status.py b/rest_framework/tests/status.py deleted file mode 100644 index e1644a6b..00000000 --- a/rest_framework/tests/status.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Tests for the status module""" -from __future__ import unicode_literals -from django.test import TestCase -from rest_framework import status - - -class TestStatus(TestCase): - """Simple sanity test to check the status module""" - - def test_status(self): - """Ensure the status module is present and correct.""" - self.assertEqual(200, status.HTTP_200_OK) - self.assertEqual(404, status.HTTP_404_NOT_FOUND) |
