aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
authorTom Christie2013-03-30 15:41:38 +0000
committerTom Christie2013-03-30 15:41:38 +0000
commitb4945f476c5e18be60429441abc0671bf7b193ec (patch)
treea16524b93f1ec4e775c380005cb2b8cf33e90054 /rest_framework
parent922ee61d8611b41e2944b6503af736b1790abe83 (diff)
parent399ac70b831d782b7d774950b59f3b2066ab86f7 (diff)
downloaddjango-rest-framework-b4945f476c5e18be60429441abc0671bf7b193ec.tar.bz2
Merge branch 'master' into resources-routers
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/__init__.py2
-rw-r--r--rest_framework/authentication.py29
-rw-r--r--rest_framework/compat.py33
-rw-r--r--rest_framework/fields.py23
-rw-r--r--rest_framework/filters.py2
-rw-r--r--rest_framework/serializers.py217
-rw-r--r--rest_framework/templatetags/rest_framework.py84
-rw-r--r--rest_framework/tests/authentication.py44
-rw-r--r--rest_framework/tests/fields.py43
-rw-r--r--rest_framework/tests/filterset.py82
-rw-r--r--rest_framework/tests/pagination.py12
-rw-r--r--rest_framework/tests/serializer.py30
-rw-r--r--rest_framework/tests/serializer_bulk_update.py278
-rw-r--r--rest_framework/tests/serializer_nested.py246
-rw-r--r--rest_framework/tests/status.py13
15 files changed, 931 insertions, 207 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py
index cf005636..c86403d8 100644
--- a/rest_framework/__init__.py
+++ b/rest_framework/__init__.py
@@ -1,4 +1,4 @@
-__version__ = '2.2.4'
+__version__ = '2.2.5'
VERSION = __version__ # synonym
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py
index b4b73699..145d4295 100644
--- a/rest_framework/authentication.py
+++ b/rest_framework/authentication.py
@@ -2,14 +2,16 @@
Provides a set of pluggable authentication policies.
"""
from __future__ import unicode_literals
+import base64
+from datetime import datetime
+
from django.contrib.auth import authenticate
from django.core.exceptions import ImproperlyConfigured
from rest_framework import exceptions, HTTP_HEADER_ENCODING
from rest_framework.compat import CsrfViewMiddleware
from rest_framework.compat import oauth, oauth_provider, oauth_provider_store
-from rest_framework.compat import oauth2_provider, oauth2_provider_forms, oauth2_provider_backends
+from rest_framework.compat import oauth2_provider, oauth2_provider_forms
from rest_framework.authtoken.models import Token
-import base64
def get_authorization_header(request):
@@ -204,6 +206,9 @@ class OAuthAuthentication(BaseAuthentication):
except oauth.Error as err:
raise exceptions.AuthenticationFailed(err.message)
+ if not oauth_request:
+ return None
+
oauth_params = oauth_provider.consts.OAUTH_PARAMETERS_NAMES
found = any(param for param in oauth_params if param in oauth_request)
@@ -312,21 +317,15 @@ class OAuth2Authentication(BaseAuthentication):
Authenticate the request, given the access token.
"""
- # Authenticate the client
- oauth2_client_form = oauth2_provider_forms.ClientAuthForm(request.REQUEST)
- if not oauth2_client_form.is_valid():
- raise exceptions.AuthenticationFailed('Client could not be validated')
- client = oauth2_client_form.cleaned_data.get('client')
-
- # Retrieve the `OAuth2AccessToken` instance from the access_token
- auth_backend = oauth2_provider_backends.AccessTokenBackend()
- token = auth_backend.authenticate(access_token, client)
- if token is None:
+ try:
+ token = oauth2_provider.models.AccessToken.objects.select_related('user')
+ # TODO: Change to timezone aware datetime when oauth2_provider add
+ # support to it.
+ token = token.get(token=access_token, expires__gt=datetime.now())
+ except oauth2_provider.models.AccessToken.DoesNotExist:
raise exceptions.AuthenticationFailed('Invalid token')
- user = token.user
-
- if not user.is_active:
+ if not token.user.is_active:
msg = 'User inactive or deleted: %s' % user.username
raise exceptions.AuthenticationFailed(msg)
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index 7b2ef738..6551723a 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -395,6 +395,37 @@ except ImportError:
kw = dict((k, int(v)) for k, v in kw.iteritems() if v is not None)
return datetime.datetime(**kw)
+
+# smart_urlquote is new on Django 1.4
+try:
+ from django.utils.html import smart_urlquote
+except ImportError:
+ try:
+ from urllib.parse import quote, urlsplit, urlunsplit
+ except ImportError: # Python 2
+ from urllib import quote
+ from urlparse import urlsplit, urlunsplit
+
+ def smart_urlquote(url):
+ "Quotes a URL if it isn't already quoted."
+ # Handle IDN before quoting.
+ scheme, netloc, path, query, fragment = urlsplit(url)
+ try:
+ netloc = netloc.encode('idna').decode('ascii') # IDN -> ACE
+ except UnicodeError: # invalid domain part
+ pass
+ else:
+ url = urlunsplit((scheme, netloc, path, query, fragment))
+
+ # An URL is considered unquoted if it contains no % characters or
+ # contains a % not followed by two hexadecimal digits. See #9655.
+ if '%' not in url or unquoted_percents_re.search(url):
+ # See http://bugs.python.org/issue2637
+ url = quote(force_bytes(url), safe=b'!*\'();:@&=+$,/?#[]~')
+
+ return force_text(url)
+
+
# Markdown is optional
try:
import markdown
@@ -445,14 +476,12 @@ except ImportError:
# OAuth 2 support is optional
try:
import provider.oauth2 as oauth2_provider
- from provider.oauth2 import backends as oauth2_provider_backends
from provider.oauth2 import models as oauth2_provider_models
from provider.oauth2 import forms as oauth2_provider_forms
from provider import scope as oauth2_provider_scope
from provider import constants as oauth2_constants
except ImportError:
oauth2_provider = None
- oauth2_provider_backends = None
oauth2_provider_models = None
oauth2_provider_forms = None
oauth2_provider_scope = None
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 4b6931ad..f3496b53 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -494,7 +494,7 @@ class DateField(WritableField):
}
empty = None
input_formats = api_settings.DATE_INPUT_FORMATS
- format = api_settings.DATE_FORMAT
+ format = None
def __init__(self, input_formats=None, format=None, *args, **kwargs):
self.input_formats = input_formats if input_formats is not None else self.input_formats
@@ -536,8 +536,8 @@ class DateField(WritableField):
raise ValidationError(msg)
def to_native(self, value):
- if value is None:
- return None
+ if value is None or self.format is None:
+ return value
if isinstance(value, datetime.datetime):
value = value.date()
@@ -557,7 +557,7 @@ class DateTimeField(WritableField):
}
empty = None
input_formats = api_settings.DATETIME_INPUT_FORMATS
- format = api_settings.DATETIME_FORMAT
+ format = None
def __init__(self, input_formats=None, format=None, *args, **kwargs):
self.input_formats = input_formats if input_formats is not None else self.input_formats
@@ -605,11 +605,14 @@ class DateTimeField(WritableField):
raise ValidationError(msg)
def to_native(self, value):
- if value is None:
- return None
+ if value is None or self.format is None:
+ return value
if self.format.lower() == ISO_8601:
- return value.isoformat()
+ ret = value.isoformat()
+ if ret.endswith('+00:00'):
+ ret = ret[:-6] + 'Z'
+ return ret
return value.strftime(self.format)
@@ -623,7 +626,7 @@ class TimeField(WritableField):
}
empty = None
input_formats = api_settings.TIME_INPUT_FORMATS
- format = api_settings.TIME_FORMAT
+ format = None
def __init__(self, input_formats=None, format=None, *args, **kwargs):
self.input_formats = input_formats if input_formats is not None else self.input_formats
@@ -658,8 +661,8 @@ class TimeField(WritableField):
raise ValidationError(msg)
def to_native(self, value):
- if value is None:
- return None
+ if value is None or self.format is None:
+ return value
if isinstance(value, datetime.datetime):
value = value.time()
diff --git a/rest_framework/filters.py b/rest_framework/filters.py
index 6fea46fa..413fa0d2 100644
--- a/rest_framework/filters.py
+++ b/rest_framework/filters.py
@@ -55,6 +55,6 @@ class DjangoFilterBackend(BaseFilterBackend):
filter_class = self.get_filter_class(view)
if filter_class:
- return filter_class(request.QUERY_PARAMS, queryset=queryset)
+ return filter_class(request.QUERY_PARAMS, queryset=queryset).qs
return queryset
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 4fe857a6..1b2b0821 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -20,6 +20,25 @@ from rest_framework.relations import *
from rest_framework.fields import *
+class NestedValidationError(ValidationError):
+ """
+ The default ValidationError behavior is to stringify each item in the list
+ if the messages are a list of error messages.
+
+ In the case of nested serializers, where the parent has many children,
+ then the child's `serializer.errors` will be a list of dicts. In the case
+ of a single child, the `serializer.errors` will be a dict.
+
+ We need to override the default behavior to get properly nested error dicts.
+ """
+
+ def __init__(self, message):
+ if isinstance(message, dict):
+ self.messages = [message]
+ else:
+ self.messages = message
+
+
class DictWithMetadata(dict):
"""
A dict-like object, that can have additional properties attached.
@@ -98,7 +117,7 @@ class SerializerOptions(object):
self.exclude = getattr(meta, 'exclude', ())
-class BaseSerializer(Field):
+class BaseSerializer(WritableField):
"""
This is the Serializer implementation.
We need to implement it as `BaseSerializer` due to metaclass magicks.
@@ -110,13 +129,15 @@ class BaseSerializer(Field):
_dict_class = SortedDictWithMetadata
def __init__(self, instance=None, data=None, files=None,
- context=None, partial=False, many=None, source=None):
- super(BaseSerializer, self).__init__(source=source)
+ context=None, partial=False, many=None,
+ allow_add_remove=False, **kwargs):
+ super(BaseSerializer, self).__init__(**kwargs)
self.opts = self._options_class(self.Meta)
self.parent = None
self.root = None
self.partial = partial
self.many = many
+ self.allow_add_remove = allow_add_remove
self.context = context or {}
@@ -128,6 +149,13 @@ class BaseSerializer(Field):
self._data = None
self._files = None
self._errors = None
+ self._deleted = None
+
+ if many and instance is not None and not hasattr(instance, '__iter__'):
+ raise ValueError('instance should be a queryset or other iterable with many=True')
+
+ if allow_add_remove and not many:
+ raise ValueError('allow_add_remove should only be used for bulk updates, but you have not set many=True')
#####
# Methods to determine which fields to use when (de)serializing objects.
@@ -296,40 +324,91 @@ class BaseSerializer(Field):
def field_to_native(self, obj, field_name):
"""
- Override default so that we can apply ModelSerializer as a nested
- field to relationships.
+ Override default so that the serializer can be used as a nested field
+ across relationships.
"""
if self.source == '*':
return self.to_native(obj)
try:
- if self.source:
- for component in self.source.split('.'):
- obj = getattr(obj, component)
- if is_simple_callable(obj):
- obj = obj()
- else:
- obj = getattr(obj, field_name)
- if is_simple_callable(obj):
- obj = obj()
+ source = self.source or field_name
+ value = obj
+
+ for component in source.split('.'):
+ value = get_component(value, component)
+ if value is None:
+ break
except ObjectDoesNotExist:
return None
- # If the object has an "all" method, assume it's a relationship
- if is_simple_callable(getattr(obj, 'all', None)):
- return [self.to_native(item) for item in obj.all()]
+ if is_simple_callable(getattr(value, 'all', None)):
+ return [self.to_native(item) for item in value.all()]
- if obj is None:
+ if value is None:
return None
if self.many is not None:
many = self.many
else:
- many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict, six.text_type))
+ many = hasattr(value, '__iter__') and not isinstance(value, (Page, dict, six.text_type))
if many:
- return [self.to_native(item) for item in obj]
- return self.to_native(obj)
+ return [self.to_native(item) for item in value]
+ return self.to_native(value)
+
+ def field_from_native(self, data, files, field_name, into):
+ """
+ Override default so that the serializer can be used as a writable
+ nested field across relationships.
+ """
+ if self.read_only:
+ return
+
+ try:
+ value = data[field_name]
+ except KeyError:
+ if self.default is not None and not self.partial:
+ # Note: partial updates shouldn't set defaults
+ value = copy.deepcopy(self.default)
+ else:
+ if self.required:
+ raise ValidationError(self.error_messages['required'])
+ return
+
+ # Set the serializer object if it exists
+ obj = getattr(self.parent.object, field_name) if self.parent.object else None
+
+ if value in (None, ''):
+ into[(self.source or field_name)] = None
+ else:
+ kwargs = {
+ 'instance': obj,
+ 'data': value,
+ 'context': self.context,
+ 'partial': self.partial,
+ 'many': self.many
+ }
+ serializer = self.__class__(**kwargs)
+
+ if serializer.is_valid():
+ into[self.source or field_name] = serializer.object
+ else:
+ # Propagate errors up to our parent
+ raise NestedValidationError(serializer.errors)
+
+ def get_identity(self, data):
+ """
+ This hook is required for bulk update.
+ It is used to determine the canonical identity of a given object.
+
+ Note that the data has not been validated at this point, so we need
+ to make sure that we catch any cases of incorrect datatypes being
+ passed to this method.
+ """
+ try:
+ return data.get('id', None)
+ except AttributeError:
+ return None
@property
def errors(self):
@@ -352,10 +431,37 @@ class BaseSerializer(Field):
if many:
ret = []
errors = []
- for item in data:
- ret.append(self.from_native(item, None))
- errors.append(self._errors)
- self._errors = any(errors) and errors or []
+ update = self.object is not None
+
+ if update:
+ # If this is a bulk update we need to map all the objects
+ # to a canonical identity so we can determine which
+ # individual object is being updated for each item in the
+ # incoming data
+ objects = self.object
+ identities = [self.get_identity(self.to_native(obj)) for obj in objects]
+ identity_to_objects = dict(zip(identities, objects))
+
+ if hasattr(data, '__iter__') and not isinstance(data, (dict, six.text_type)):
+ for item in data:
+ if update:
+ # Determine which object we're updating
+ identity = self.get_identity(item)
+ self.object = identity_to_objects.pop(identity, None)
+ if self.object is None and not self.allow_add_remove:
+ ret.append(None)
+ errors.append({'non_field_errors': ['Cannot create a new item, only existing items may be updated.']})
+ continue
+
+ ret.append(self.from_native(item, None))
+ errors.append(self._errors)
+
+ if update:
+ self._deleted = identity_to_objects.values()
+
+ self._errors = any(errors) and errors or []
+ else:
+ self._errors = {'non_field_errors': ['Expected a list of items.']}
else:
ret = self.from_native(data, files)
@@ -394,6 +500,9 @@ class BaseSerializer(Field):
def save_object(self, obj, **kwargs):
obj.save(**kwargs)
+ def delete_object(self, obj):
+ obj.delete()
+
def save(self, **kwargs):
"""
Save the deserialized object and return it.
@@ -402,6 +511,10 @@ class BaseSerializer(Field):
[self.save_object(item, **kwargs) for item in self.object]
else:
self.save_object(self.object, **kwargs)
+
+ if self.allow_add_remove and self._deleted:
+ [self.delete_object(item) for item in self._deleted]
+
return self.object
@@ -584,33 +697,43 @@ class ModelSerializer(Serializer):
"""
Restore the model instance.
"""
- self.m2m_data = {}
- self.related_data = {}
+ m2m_data = {}
+ related_data = {}
+ meta = self.opts.model._meta
- # Reverse fk relations
- for (obj, model) in self.opts.model._meta.get_all_related_objects_with_model():
+ # Reverse fk or one-to-one relations
+ for (obj, model) in meta.get_all_related_objects_with_model():
field_name = obj.field.related_query_name()
if field_name in attrs:
- self.related_data[field_name] = attrs.pop(field_name)
+ related_data[field_name] = attrs.pop(field_name)
# Reverse m2m relations
- for (obj, model) in self.opts.model._meta.get_all_related_m2m_objects_with_model():
+ for (obj, model) in meta.get_all_related_m2m_objects_with_model():
field_name = obj.field.related_query_name()
if field_name in attrs:
- self.m2m_data[field_name] = attrs.pop(field_name)
+ m2m_data[field_name] = attrs.pop(field_name)
# Forward m2m relations
- for field in self.opts.model._meta.many_to_many:
+ for field in meta.many_to_many:
if field.name in attrs:
- self.m2m_data[field.name] = attrs.pop(field.name)
+ m2m_data[field.name] = attrs.pop(field.name)
+ # Update an existing instance...
if instance is not None:
for key, val in attrs.items():
setattr(instance, key, val)
+ # ...or create a new instance
else:
instance = self.opts.model(**attrs)
+ # Any relations that cannot be set until we've
+ # saved the model get hidden away on these
+ # private attributes, so we can deal with them
+ # at the point of save.
+ instance._related_data = related_data
+ instance._m2m_data = m2m_data
+
return instance
def from_native(self, data, files):
@@ -627,15 +750,15 @@ class ModelSerializer(Serializer):
"""
obj.save(**kwargs)
- if getattr(self, 'm2m_data', None):
- for accessor_name, object_list in self.m2m_data.items():
- setattr(self.object, accessor_name, object_list)
- self.m2m_data = {}
+ if getattr(obj, '_m2m_data', None):
+ for accessor_name, object_list in obj._m2m_data.items():
+ setattr(obj, accessor_name, object_list)
+ del(obj._m2m_data)
- if getattr(self, 'related_data', None):
- for accessor_name, object_list in self.related_data.items():
- setattr(self.object, accessor_name, object_list)
- self.related_data = {}
+ if getattr(obj, '_related_data', None):
+ for accessor_name, related in obj._related_data.items():
+ setattr(obj, accessor_name, related)
+ del(obj._related_data)
class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
@@ -690,3 +813,13 @@ class HyperlinkedModelSerializer(ModelSerializer):
'many': to_many
}
return HyperlinkedRelatedField(**kwargs)
+
+ def get_identity(self, data):
+ """
+ This hook is required for bulk update.
+ We need to override the default, to use the url as the identity.
+ """
+ try:
+ return data.get('url', None)
+ except AttributeError:
+ return None
diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py
index c21ddcd7..b6ab2de3 100644
--- a/rest_framework/templatetags/rest_framework.py
+++ b/rest_framework/templatetags/rest_framework.py
@@ -4,11 +4,8 @@ from django.core.urlresolvers import reverse, NoReverseMatch
from django.http import QueryDict
from django.utils.html import escape
from django.utils.safestring import SafeData, mark_safe
-from rest_framework.compat import urlparse
-from rest_framework.compat import force_text
-from rest_framework.compat import six
-import re
-import string
+from rest_framework.compat import urlparse, force_text, six, smart_urlquote
+import re, string
register = template.Library()
@@ -112,22 +109,6 @@ def replace_query_param(url, key, val):
class_re = re.compile(r'(?<=class=["\'])(.*)(?=["\'])')
-# Bunch of stuff cloned from urlize
-LEADING_PUNCTUATION = ['(', '<', '&lt;', '"', "'"]
-TRAILING_PUNCTUATION = ['.', ',', ')', '>', '\n', '&gt;', '"', "'"]
-DOTS = ['&middot;', '*', '\xe2\x80\xa2', '&#149;', '&bull;', '&#8226;']
-unencoded_ampersands_re = re.compile(r'&(?!(\w+|#\d+);)')
-word_split_re = re.compile(r'(\s+)')
-punctuation_re = re.compile('^(?P<lead>(?:%s)*)(?P<middle>.*?)(?P<trail>(?:%s)*)$' % \
- ('|'.join([re.escape(x) for x in LEADING_PUNCTUATION]),
- '|'.join([re.escape(x) for x in TRAILING_PUNCTUATION])))
-simple_email_re = re.compile(r'^\S+@[a-zA-Z0-9._-]+\.[a-zA-Z0-9._-]+$')
-link_target_attribute_re = re.compile(r'(<a [^>]*?)target=[^\s>]+')
-html_gunk_re = re.compile(r'(?:<br clear="all">|<i><\/i>|<b><\/b>|<em><\/em>|<strong><\/strong>|<\/?smallcaps>|<\/?uppercase>)', re.IGNORECASE)
-hard_coded_bullets_re = re.compile(r'((?:<p>(?:%s).*?[a-zA-Z].*?</p>\s*)+)' % '|'.join([re.escape(x) for x in DOTS]), re.DOTALL)
-trailing_empty_content_re = re.compile(r'(?:<p>(?:&nbsp;|\s|<br \/>)*?</p>\s*)+\Z')
-
-
# And the template tags themselves...
@register.simple_tag
@@ -195,15 +176,25 @@ def add_class(value, css_class):
return value
+# Bunch of stuff cloned from urlize
+TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)', '"', "'"]
+WRAPPING_PUNCTUATION = [('(', ')'), ('<', '>'), ('[', ']'), ('&lt;', '&gt;'),
+ ('"', '"'), ("'", "'")]
+word_split_re = re.compile(r'(\s+)')
+simple_url_re = re.compile(r'^https?://\w', re.IGNORECASE)
+simple_url_2_re = re.compile(r'^www\.|^(?!http)\w[^@]+\.(com|edu|gov|int|mil|net|org)$', re.IGNORECASE)
+simple_email_re = re.compile(r'^\S+@\S+\.\S+$')
+
+
@register.filter
def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=True):
"""
Converts any URLs in text into clickable links.
- Works on http://, https://, www. links and links ending in .org, .net or
- .com. Links can have trailing punctuation (periods, commas, close-parens)
- and leading punctuation (opening parens) and it'll still do the right
- thing.
+ Works on http://, https://, www. links, and also on links ending in one of
+ the original seven gTLDs (.com, .edu, .gov, .int, .mil, .net, and .org).
+ Links can have trailing punctuation (periods, commas, close-parens) and
+ leading punctuation (opening parens) and it'll still do the right thing.
If trim_url_limit is not None, the URLs in link text longer than this limit
will truncated to trim_url_limit-3 characters and appended with an elipsis.
@@ -216,24 +207,41 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
trim_url = lambda x, limit=trim_url_limit: limit is not None and (len(x) > limit and ('%s...' % x[:max(0, limit - 3)])) or x
safe_input = isinstance(text, SafeData)
words = word_split_re.split(force_text(text))
- nofollow_attr = nofollow and ' rel="nofollow"' or ''
for i, word in enumerate(words):
match = None
if '.' in word or '@' in word or ':' in word:
- match = punctuation_re.match(word)
- if match:
- lead, middle, trail = match.groups()
+ # Deal with punctuation.
+ lead, middle, trail = '', word, ''
+ for punctuation in TRAILING_PUNCTUATION:
+ if middle.endswith(punctuation):
+ middle = middle[:-len(punctuation)]
+ trail = punctuation + trail
+ for opening, closing in WRAPPING_PUNCTUATION:
+ if middle.startswith(opening):
+ middle = middle[len(opening):]
+ lead = lead + opening
+ # Keep parentheses at the end only if they're balanced.
+ if (middle.endswith(closing)
+ and middle.count(closing) == middle.count(opening) + 1):
+ middle = middle[:-len(closing)]
+ trail = closing + trail
+
# Make URL we want to point to.
url = None
- if middle.startswith('http://') or middle.startswith('https://'):
- url = middle
- elif middle.startswith('www.') or ('@' not in middle and \
- middle and middle[0] in string.ascii_letters + string.digits and \
- (middle.endswith('.org') or middle.endswith('.net') or middle.endswith('.com'))):
- url = 'http://%s' % middle
- elif '@' in middle and not ':' in middle and simple_email_re.match(middle):
- url = 'mailto:%s' % middle
+ nofollow_attr = ' rel="nofollow"' if nofollow else ''
+ if simple_url_re.match(middle):
+ url = smart_urlquote(middle)
+ elif simple_url_2_re.match(middle):
+ url = smart_urlquote('http://%s' % middle)
+ elif not ':' in middle and simple_email_re.match(middle):
+ local, domain = middle.rsplit('@', 1)
+ try:
+ domain = domain.encode('idna').decode('ascii')
+ except UnicodeError:
+ continue
+ url = 'mailto:%s@%s' % (local, domain)
nofollow_attr = ''
+
# Make link.
if url:
trimmed = trim_url(middle)
@@ -251,4 +259,4 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
words[i] = mark_safe(word)
elif autoescape:
words[i] = escape(word)
- return mark_safe(''.join(words))
+ return ''.join(words)
diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py
index b663ca48..8e6d3e51 100644
--- a/rest_framework/tests/authentication.py
+++ b/rest_framework/tests/authentication.py
@@ -466,17 +466,13 @@ class OAuth2Tests(TestCase):
def _create_authorization_header(self, token=None):
return "Bearer {0}".format(token or self.access_token.token)
- def _client_credentials_params(self):
- return {'client_id': self.CLIENT_ID, 'client_secret': self.CLIENT_SECRET}
-
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_get_form_with_wrong_authorization_header_token_type_failing(self):
"""Ensure that a wrong token type lead to the correct HTTP error status code"""
auth = "Wrong token-type-obsviously"
response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 401)
- params = self._client_credentials_params()
- response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 401)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
@@ -485,8 +481,7 @@ class OAuth2Tests(TestCase):
auth = "Bearer wrong token format"
response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 401)
- params = self._client_credentials_params()
- response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 401)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
@@ -495,33 +490,21 @@ class OAuth2Tests(TestCase):
auth = "Bearer wrong-token"
response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 401)
- params = self._client_credentials_params()
- response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
- self.assertEqual(response.status_code, 401)
-
- @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
- def test_get_form_with_wrong_client_data_failing_auth(self):
- """Ensure GETing form over OAuth with incorrect client credentials fails"""
- auth = self._create_authorization_header()
- params = self._client_credentials_params()
- params['client_id'] += 'a'
- response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 401)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_get_form_passing_auth(self):
"""Ensure GETing form over OAuth with correct client credentials succeed"""
auth = self._create_authorization_header()
- params = self._client_credentials_params()
- response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_post_form_passing_auth(self):
"""Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
auth = self._create_authorization_header()
- params = self._client_credentials_params()
- response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
@@ -529,16 +512,14 @@ class OAuth2Tests(TestCase):
"""Ensure POSTing when there is no OAuth access token in db fails"""
self.access_token.delete()
auth = self._create_authorization_header()
- params = self._client_credentials_params()
- response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_post_form_with_refresh_token_failing_auth(self):
"""Ensure POSTing with refresh token instead of access token fails"""
auth = self._create_authorization_header(token=self.refresh_token.token)
- params = self._client_credentials_params()
- response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
@@ -547,8 +528,7 @@ class OAuth2Tests(TestCase):
self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10) # 10 seconds late
self.access_token.save()
auth = self._create_authorization_header()
- params = self._client_credentials_params()
- response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
self.assertIn('Invalid token', response.content)
@@ -559,10 +539,9 @@ class OAuth2Tests(TestCase):
read_only_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['read']
read_only_access_token.save()
auth = self._create_authorization_header(token=read_only_access_token.token)
- params = self._client_credentials_params()
- response = self.csrf_client.get('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.get('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
- response = self.csrf_client.post('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
@@ -572,6 +551,5 @@ class OAuth2Tests(TestCase):
read_write_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['write']
read_write_access_token.save()
auth = self._create_authorization_header(token=read_write_access_token.token)
- params = self._client_credentials_params()
- response = self.csrf_client.post('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py
index fd6de779..19c663d8 100644
--- a/rest_framework/tests/fields.py
+++ b/rest_framework/tests/fields.py
@@ -153,12 +153,22 @@ class DateFieldTest(TestCase):
def test_to_native(self):
"""
- Make sure to_native() returns isoformat as default.
+ Make sure to_native() returns datetime as default.
"""
f = serializers.DateField()
result_1 = f.to_native(datetime.date(1984, 7, 31))
+ self.assertEqual(datetime.date(1984, 7, 31), result_1)
+
+ def test_to_native_iso(self):
+ """
+ Make sure to_native() with 'iso-8601' returns iso formated date.
+ """
+ f = serializers.DateField(format='iso-8601')
+
+ result_1 = f.to_native(datetime.date(1984, 7, 31))
+
self.assertEqual('1984-07-31', result_1)
def test_to_native_custom_format(self):
@@ -289,6 +299,22 @@ class DateTimeFieldTest(TestCase):
result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+ self.assertEqual(datetime.datetime(1984, 7, 31), result_1)
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_2)
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_3)
+ self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_4)
+
+ def test_to_native_iso(self):
+ """
+ Make sure to_native() with format=iso-8601 returns iso formatted datetime.
+ """
+ f = serializers.DateTimeField(format='iso-8601')
+
+ result_1 = f.to_native(datetime.datetime(1984, 7, 31))
+ result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31))
+ result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59))
+ result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200))
+
self.assertEqual('1984-07-31T00:00:00', result_1)
self.assertEqual('1984-07-31T04:31:00', result_2)
self.assertEqual('1984-07-31T04:31:59', result_3)
@@ -419,13 +445,26 @@ class TimeFieldTest(TestCase):
def test_to_native(self):
"""
- Make sure to_native() returns isoformat as default.
+ Make sure to_native() returns time object as default.
"""
f = serializers.TimeField()
result_1 = f.to_native(datetime.time(4, 31))
result_2 = f.to_native(datetime.time(4, 31, 59))
result_3 = f.to_native(datetime.time(4, 31, 59, 200))
+ self.assertEqual(datetime.time(4, 31), result_1)
+ self.assertEqual(datetime.time(4, 31, 59), result_2)
+ self.assertEqual(datetime.time(4, 31, 59, 200), result_3)
+
+ def test_to_native_iso(self):
+ """
+ Make sure to_native() with format='iso-8601' returns iso formatted time.
+ """
+ f = serializers.TimeField(format='iso-8601')
+ result_1 = f.to_native(datetime.time(4, 31))
+ result_2 = f.to_native(datetime.time(4, 31, 59))
+ result_3 = f.to_native(datetime.time(4, 31, 59, 200))
+
self.assertEqual('04:31:00', result_1)
self.assertEqual('04:31:59', result_2)
self.assertEqual('04:31:59.000200', result_3)
diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py
index fe92e0bc..1a71558c 100644
--- a/rest_framework/tests/filterset.py
+++ b/rest_framework/tests/filterset.py
@@ -1,11 +1,12 @@
from __future__ import unicode_literals
import datetime
from decimal import Decimal
+from django.core.urlresolvers import reverse
from django.test import TestCase
from django.test.client import RequestFactory
from django.utils import unittest
from rest_framework import generics, status, filters
-from rest_framework.compat import django_filters
+from rest_framework.compat import django_filters, patterns, url
from rest_framework.tests.models import FilterableItem, BasicModel
factory = RequestFactory()
@@ -46,12 +47,21 @@ if django_filters:
filter_class = MisconfiguredFilter
filter_backend = filters.DjangoFilterBackend
+ class FilterClassDetailView(generics.RetrieveAPIView):
+ model = FilterableItem
+ filter_class = SeveralFieldsFilter
+ filter_backend = filters.DjangoFilterBackend
+
+ urlpatterns = patterns('',
+ url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'),
+ url(r'^$', FilterClassRootView.as_view(), name='root-view'),
+ )
-class IntegrationTestFiltering(TestCase):
- """
- Integration tests for filtered list views.
- """
+class CommonFilteringTestCase(TestCase):
+ def _serialize_object(self, obj):
+ return {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
+
def setUp(self):
"""
Create 10 FilterableItem instances.
@@ -65,10 +75,16 @@ class IntegrationTestFiltering(TestCase):
self.objects = FilterableItem.objects
self.data = [
- {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date.isoformat()}
+ self._serialize_object(obj)
for obj in self.objects.all()
]
+
+class IntegrationTestFiltering(CommonFilteringTestCase):
+ """
+ Integration tests for filtered list views.
+ """
+
@unittest.skipUnless(django_filters, 'django-filters not installed')
def test_get_filtered_fields_root_view(self):
"""
@@ -95,7 +111,7 @@ class IntegrationTestFiltering(TestCase):
request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22'
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if datetime.datetime.strptime(f['date'], '%Y-%m-%d').date() == search_date]
+ expected_data = [f for f in self.data if f['date'] == search_date]
self.assertEqual(response.data, expected_data)
@unittest.skipUnless(django_filters, 'django-filters not installed')
@@ -125,7 +141,7 @@ class IntegrationTestFiltering(TestCase):
request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02'
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if datetime.datetime.strptime(f['date'], '%Y-%m-%d').date() > search_date]
+ expected_data = [f for f in self.data if f['date'] > search_date]
self.assertEqual(response.data, expected_data)
# Tests that the text filter set with 'icontains' in the filter class works.
@@ -142,8 +158,7 @@ class IntegrationTestFiltering(TestCase):
request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date))
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
- expected_data = [f for f in self.data if
- datetime.datetime.strptime(f['date'], '%Y-%m-%d').date() > search_date and
+ expected_data = [f for f in self.data if f['date'] > search_date and
f['decimal'] < search_decimal]
self.assertEqual(response.data, expected_data)
@@ -168,3 +183,50 @@ class IntegrationTestFiltering(TestCase):
request = factory.get('/?integer=%s' % search_integer)
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+
+class IntegrationTestDetailFiltering(CommonFilteringTestCase):
+ """
+ Integration tests for filtered detail views.
+ """
+ urls = 'rest_framework.tests.filterset'
+
+ def _get_url(self, item):
+ return reverse('detail-view', kwargs=dict(pk=item.pk))
+
+ @unittest.skipUnless(django_filters, 'django-filters not installed')
+ def test_get_filtered_detail_view(self):
+ """
+ GET requests to filtered RetrieveAPIView that have a filter_class set
+ should return filtered results.
+ """
+ item = self.objects.all()[0]
+ data = self._serialize_object(item)
+
+ # Basic test with no filter.
+ response = self.client.get(self._get_url(item))
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, data)
+
+ # Tests that the decimal filter set that should fail.
+ search_decimal = Decimal('4.25')
+ high_item = self.objects.filter(decimal__gt=search_decimal)[0]
+ response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(high_item), param=search_decimal))
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
+ # Tests that the decimal filter set that should succeed.
+ search_decimal = Decimal('4.25')
+ low_item = self.objects.filter(decimal__lt=search_decimal)[0]
+ low_item_data = self._serialize_object(low_item)
+ response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal))
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, low_item_data)
+
+ # Tests that multiple filters works.
+ search_decimal = Decimal('5.25')
+ search_date = datetime.date(2012, 10, 2)
+ valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0]
+ valid_item_data = self._serialize_object(valid_item)
+ response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date))
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, valid_item_data)
diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py
index 1a2d68a6..6b8ef02f 100644
--- a/rest_framework/tests/pagination.py
+++ b/rest_framework/tests/pagination.py
@@ -102,7 +102,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
self.objects = FilterableItem.objects
self.data = [
- {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date.isoformat()}
+ {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
for obj in self.objects.all()
]
@@ -129,16 +129,6 @@ class IntegrationTestPaginationAndFiltering(TestCase):
view = FilterFieldsRootView.as_view()
EXPECTED_NUM_QUERIES = 2
- if django.VERSION < (1, 4):
- # On Django 1.3 we need to use django-filter 0.5.4
- #
- # The filter objects there don't expose a `.count()` method,
- # which means we only make a single query *but* it's a single
- # query across *all* of the queryset, instead of a COUNT and then
- # a SELECT with a LIMIT.
- #
- # Although this is fewer queries, it's actually a regression.
- EXPECTED_NUM_QUERIES = 1
request = factory.get('/?decimal=15.20')
with self.assertNumQueries(EXPECTED_NUM_QUERIES):
diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py
index beb372c2..05217f35 100644
--- a/rest_framework/tests/serializer.py
+++ b/rest_framework/tests/serializer.py
@@ -112,7 +112,7 @@ class BasicTests(TestCase):
self.expected = {
'email': 'tom@example.com',
'content': 'Happy new year!',
- 'created': '2012-01-01T00:00:00',
+ 'created': datetime.datetime(2012, 1, 1),
'sub_comment': 'And Merry Christmas!'
}
self.person_data = {'name': 'dwight', 'age': 35}
@@ -261,34 +261,6 @@ class ValidationTests(TestCase):
self.assertEqual(serializer.is_valid(), True)
self.assertEqual(serializer.errors, {})
- def test_bad_type_data_is_false(self):
- """
- Data of the wrong type is not valid.
- """
- data = ['i am', 'a', 'list']
- serializer = CommentSerializer(self.comment, data=data, many=True)
- self.assertEqual(serializer.is_valid(), False)
- self.assertTrue(isinstance(serializer.errors, list))
-
- self.assertEqual(
- serializer.errors,
- [
- {'non_field_errors': ['Invalid data']},
- {'non_field_errors': ['Invalid data']},
- {'non_field_errors': ['Invalid data']}
- ]
- )
-
- data = 'and i am a string'
- serializer = CommentSerializer(self.comment, data=data)
- self.assertEqual(serializer.is_valid(), False)
- self.assertEqual(serializer.errors, {'non_field_errors': ['Invalid data']})
-
- data = 42
- serializer = CommentSerializer(self.comment, data=data)
- self.assertEqual(serializer.is_valid(), False)
- self.assertEqual(serializer.errors, {'non_field_errors': ['Invalid data']})
-
def test_cross_field_validation(self):
class CommentSerializerWithCrossFieldValidator(CommentSerializer):
diff --git a/rest_framework/tests/serializer_bulk_update.py b/rest_framework/tests/serializer_bulk_update.py
new file mode 100644
index 00000000..8b0ded1a
--- /dev/null
+++ b/rest_framework/tests/serializer_bulk_update.py
@@ -0,0 +1,278 @@
+"""
+Tests to cover bulk create and update using serializers.
+"""
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework import serializers
+
+
+class BulkCreateSerializerTests(TestCase):
+ """
+ Creating multiple instances using serializers.
+ """
+
+ def setUp(self):
+ class BookSerializer(serializers.Serializer):
+ id = serializers.IntegerField()
+ title = serializers.CharField(max_length=100)
+ author = serializers.CharField(max_length=100)
+
+ self.BookSerializer = BookSerializer
+
+ def test_bulk_create_success(self):
+ """
+ Correct bulk update serialization should return the input data.
+ """
+
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 1,
+ 'title': 'If this is a man',
+ 'author': 'Primo Levi'
+ }, {
+ 'id': 2,
+ 'title': 'The wind-up bird chronicle',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+
+ serializer = self.BookSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, data)
+
+ def test_bulk_create_errors(self):
+ """
+ Correct bulk update serialization should return the input data.
+ """
+
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 1,
+ 'title': 'If this is a man',
+ 'author': 'Primo Levi'
+ }, {
+ 'id': 'foo',
+ 'title': 'The wind-up bird chronicle',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+ expected_errors = [
+ {},
+ {},
+ {'id': ['Enter a whole number.']}
+ ]
+
+ serializer = self.BookSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, expected_errors)
+
+ def test_invalid_list_datatype(self):
+ """
+ Data containing list of incorrect data type should return errors.
+ """
+ data = ['foo', 'bar', 'baz']
+ serializer = self.BookSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+
+ expected_errors = [
+ {'non_field_errors': ['Invalid data']},
+ {'non_field_errors': ['Invalid data']},
+ {'non_field_errors': ['Invalid data']}
+ ]
+
+ self.assertEqual(serializer.errors, expected_errors)
+
+ def test_invalid_single_datatype(self):
+ """
+ Data containing a single incorrect data type should return errors.
+ """
+ data = 123
+ serializer = self.BookSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+
+ expected_errors = {'non_field_errors': ['Expected a list of items.']}
+
+ self.assertEqual(serializer.errors, expected_errors)
+
+ def test_invalid_single_object(self):
+ """
+ Data containing only a single object, instead of a list of objects
+ should return errors.
+ """
+ data = {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }
+ serializer = self.BookSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+
+ expected_errors = {'non_field_errors': ['Expected a list of items.']}
+
+ self.assertEqual(serializer.errors, expected_errors)
+
+
+class BulkUpdateSerializerTests(TestCase):
+ """
+ Updating multiple instances using serializers.
+ """
+
+ def setUp(self):
+ class Book(object):
+ """
+ A data type that can be persisted to a mock storage backend
+ with `.save()` and `.delete()`.
+ """
+ object_map = {}
+
+ def __init__(self, id, title, author):
+ self.id = id
+ self.title = title
+ self.author = author
+
+ def save(self):
+ Book.object_map[self.id] = self
+
+ def delete(self):
+ del Book.object_map[self.id]
+
+ class BookSerializer(serializers.Serializer):
+ id = serializers.IntegerField()
+ title = serializers.CharField(max_length=100)
+ author = serializers.CharField(max_length=100)
+
+ def restore_object(self, attrs, instance=None):
+ if instance:
+ instance.id = attrs['id']
+ instance.title = attrs['title']
+ instance.author = attrs['author']
+ return instance
+ return Book(**attrs)
+
+ self.Book = Book
+ self.BookSerializer = BookSerializer
+
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 1,
+ 'title': 'If this is a man',
+ 'author': 'Primo Levi'
+ }, {
+ 'id': 2,
+ 'title': 'The wind-up bird chronicle',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+
+ for item in data:
+ book = Book(item['id'], item['title'], item['author'])
+ book.save()
+
+ def books(self):
+ """
+ Return all the objects in the mock storage backend.
+ """
+ return self.Book.object_map.values()
+
+ def test_bulk_update_success(self):
+ """
+ Correct bulk update serialization should return the input data.
+ """
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 2,
+ 'title': 'Kafka on the shore',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+ serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.data, data)
+ serializer.save()
+ new_data = self.BookSerializer(self.books(), many=True).data
+
+ self.assertEqual(data, new_data)
+
+ def test_bulk_update_and_create(self):
+ """
+ Bulk update serialization may also include created items.
+ """
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 3,
+ 'title': 'Kafka on the shore',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+ serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.data, data)
+ serializer.save()
+ new_data = self.BookSerializer(self.books(), many=True).data
+ self.assertEqual(data, new_data)
+
+ def test_bulk_update_invalid_create(self):
+ """
+ Bulk update serialization without allow_add_remove may not create items.
+ """
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 3,
+ 'title': 'Kafka on the shore',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+ expected_errors = [
+ {},
+ {'non_field_errors': ['Cannot create a new item, only existing items may be updated.']}
+ ]
+ serializer = self.BookSerializer(self.books(), data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, expected_errors)
+
+ def test_bulk_update_error(self):
+ """
+ Incorrect bulk update serialization should return error data.
+ """
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 'foo',
+ 'title': 'Kafka on the shore',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+ expected_errors = [
+ {},
+ {'id': ['Enter a whole number.']}
+ ]
+ serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, expected_errors)
diff --git a/rest_framework/tests/serializer_nested.py b/rest_framework/tests/serializer_nested.py
new file mode 100644
index 00000000..6a29c652
--- /dev/null
+++ b/rest_framework/tests/serializer_nested.py
@@ -0,0 +1,246 @@
+"""
+Tests to cover nested serializers.
+
+Doesn't cover model serializers.
+"""
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework import serializers
+
+
+class WritableNestedSerializerBasicTests(TestCase):
+ """
+ Tests for deserializing nested entities.
+ Basic tests that use serializers that simply restore to dicts.
+ """
+
+ def setUp(self):
+ class TrackSerializer(serializers.Serializer):
+ order = serializers.IntegerField()
+ title = serializers.CharField(max_length=100)
+ duration = serializers.IntegerField()
+
+ class AlbumSerializer(serializers.Serializer):
+ album_name = serializers.CharField(max_length=100)
+ artist = serializers.CharField(max_length=100)
+ tracks = TrackSerializer(many=True)
+
+ self.AlbumSerializer = AlbumSerializer
+
+ def test_nested_validation_success(self):
+ """
+ Correct nested serialization should return the input data.
+ """
+
+ data = {
+ 'album_name': 'Discovery',
+ 'artist': 'Daft Punk',
+ 'tracks': [
+ {'order': 1, 'title': 'One More Time', 'duration': 235},
+ {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
+ {'order': 3, 'title': 'Digital Love', 'duration': 239}
+ ]
+ }
+
+ serializer = self.AlbumSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, data)
+
+ def test_nested_validation_error(self):
+ """
+ Incorrect nested serialization should return appropriate error data.
+ """
+
+ data = {
+ 'album_name': 'Discovery',
+ 'artist': 'Daft Punk',
+ 'tracks': [
+ {'order': 1, 'title': 'One More Time', 'duration': 235},
+ {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
+ {'order': 3, 'title': 'Digital Love', 'duration': 'foobar'}
+ ]
+ }
+ expected_errors = {
+ 'tracks': [
+ {},
+ {},
+ {'duration': ['Enter a whole number.']}
+ ]
+ }
+
+ serializer = self.AlbumSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, expected_errors)
+
+ def test_many_nested_validation_error(self):
+ """
+ Incorrect nested serialization should return appropriate error data
+ when multiple entities are being deserialized.
+ """
+
+ data = [
+ {
+ 'album_name': 'Russian Red',
+ 'artist': 'I Love Your Glasses',
+ 'tracks': [
+ {'order': 1, 'title': 'Cigarettes', 'duration': 121},
+ {'order': 2, 'title': 'No Past Land', 'duration': 198},
+ {'order': 3, 'title': 'They Don\'t Believe', 'duration': 191}
+ ]
+ },
+ {
+ 'album_name': 'Discovery',
+ 'artist': 'Daft Punk',
+ 'tracks': [
+ {'order': 1, 'title': 'One More Time', 'duration': 235},
+ {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
+ {'order': 3, 'title': 'Digital Love', 'duration': 'foobar'}
+ ]
+ }
+ ]
+ expected_errors = [
+ {},
+ {
+ 'tracks': [
+ {},
+ {},
+ {'duration': ['Enter a whole number.']}
+ ]
+ }
+ ]
+
+ serializer = self.AlbumSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, expected_errors)
+
+
+class WritableNestedSerializerObjectTests(TestCase):
+ """
+ Tests for deserializing nested entities.
+ These tests use serializers that restore to concrete objects.
+ """
+
+ def setUp(self):
+ # Couple of concrete objects that we're going to deserialize into
+ class Track(object):
+ def __init__(self, order, title, duration):
+ self.order, self.title, self.duration = order, title, duration
+
+ def __eq__(self, other):
+ return (
+ self.order == other.order and
+ self.title == other.title and
+ self.duration == other.duration
+ )
+
+ class Album(object):
+ def __init__(self, album_name, artist, tracks):
+ self.album_name, self.artist, self.tracks = album_name, artist, tracks
+
+ def __eq__(self, other):
+ return (
+ self.album_name == other.album_name and
+ self.artist == other.artist and
+ self.tracks == other.tracks
+ )
+
+ # And their corresponding serializers
+ class TrackSerializer(serializers.Serializer):
+ order = serializers.IntegerField()
+ title = serializers.CharField(max_length=100)
+ duration = serializers.IntegerField()
+
+ def restore_object(self, attrs, instance=None):
+ return Track(attrs['order'], attrs['title'], attrs['duration'])
+
+ class AlbumSerializer(serializers.Serializer):
+ album_name = serializers.CharField(max_length=100)
+ artist = serializers.CharField(max_length=100)
+ tracks = TrackSerializer(many=True)
+
+ def restore_object(self, attrs, instance=None):
+ return Album(attrs['album_name'], attrs['artist'], attrs['tracks'])
+
+ self.Album, self.Track = Album, Track
+ self.AlbumSerializer = AlbumSerializer
+
+ def test_nested_validation_success(self):
+ """
+ Correct nested serialization should return a restored object
+ that corresponds to the input data.
+ """
+
+ data = {
+ 'album_name': 'Discovery',
+ 'artist': 'Daft Punk',
+ 'tracks': [
+ {'order': 1, 'title': 'One More Time', 'duration': 235},
+ {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
+ {'order': 3, 'title': 'Digital Love', 'duration': 239}
+ ]
+ }
+ expected_object = self.Album(
+ album_name='Discovery',
+ artist='Daft Punk',
+ tracks=[
+ self.Track(order=1, title='One More Time', duration=235),
+ self.Track(order=2, title='Aerodynamic', duration=184),
+ self.Track(order=3, title='Digital Love', duration=239),
+ ]
+ )
+
+ serializer = self.AlbumSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, expected_object)
+
+ def test_many_nested_validation_success(self):
+ """
+ Correct nested serialization should return multiple restored objects
+ that corresponds to the input data when multiple objects are
+ being deserialized.
+ """
+
+ data = [
+ {
+ 'album_name': 'Russian Red',
+ 'artist': 'I Love Your Glasses',
+ 'tracks': [
+ {'order': 1, 'title': 'Cigarettes', 'duration': 121},
+ {'order': 2, 'title': 'No Past Land', 'duration': 198},
+ {'order': 3, 'title': 'They Don\'t Believe', 'duration': 191}
+ ]
+ },
+ {
+ 'album_name': 'Discovery',
+ 'artist': 'Daft Punk',
+ 'tracks': [
+ {'order': 1, 'title': 'One More Time', 'duration': 235},
+ {'order': 2, 'title': 'Aerodynamic', 'duration': 184},
+ {'order': 3, 'title': 'Digital Love', 'duration': 239}
+ ]
+ }
+ ]
+ expected_object = [
+ self.Album(
+ album_name='Russian Red',
+ artist='I Love Your Glasses',
+ tracks=[
+ self.Track(order=1, title='Cigarettes', duration=121),
+ self.Track(order=2, title='No Past Land', duration=198),
+ self.Track(order=3, title='They Don\'t Believe', duration=191),
+ ]
+ ),
+ self.Album(
+ album_name='Discovery',
+ artist='Daft Punk',
+ tracks=[
+ self.Track(order=1, title='One More Time', duration=235),
+ self.Track(order=2, title='Aerodynamic', duration=184),
+ self.Track(order=3, title='Digital Love', duration=239),
+ ]
+ )
+ ]
+
+ serializer = self.AlbumSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.object, expected_object)
diff --git a/rest_framework/tests/status.py b/rest_framework/tests/status.py
deleted file mode 100644
index e1644a6b..00000000
--- a/rest_framework/tests/status.py
+++ /dev/null
@@ -1,13 +0,0 @@
-"""Tests for the status module"""
-from __future__ import unicode_literals
-from django.test import TestCase
-from rest_framework import status
-
-
-class TestStatus(TestCase):
- """Simple sanity test to check the status module"""
-
- def test_status(self):
- """Ensure the status module is present and correct."""
- self.assertEqual(200, status.HTTP_200_OK)
- self.assertEqual(404, status.HTTP_404_NOT_FOUND)