diff options
| author | Tom Christie | 2013-03-30 15:41:38 +0000 | 
|---|---|---|
| committer | Tom Christie | 2013-03-30 15:41:38 +0000 | 
| commit | b4945f476c5e18be60429441abc0671bf7b193ec (patch) | |
| tree | a16524b93f1ec4e775c380005cb2b8cf33e90054 /rest_framework | |
| parent | 922ee61d8611b41e2944b6503af736b1790abe83 (diff) | |
| parent | 399ac70b831d782b7d774950b59f3b2066ab86f7 (diff) | |
| download | django-rest-framework-b4945f476c5e18be60429441abc0671bf7b193ec.tar.bz2 | |
Merge branch 'master' into resources-routers
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/__init__.py | 2 | ||||
| -rw-r--r-- | rest_framework/authentication.py | 29 | ||||
| -rw-r--r-- | rest_framework/compat.py | 33 | ||||
| -rw-r--r-- | rest_framework/fields.py | 23 | ||||
| -rw-r--r-- | rest_framework/filters.py | 2 | ||||
| -rw-r--r-- | rest_framework/serializers.py | 217 | ||||
| -rw-r--r-- | rest_framework/templatetags/rest_framework.py | 84 | ||||
| -rw-r--r-- | rest_framework/tests/authentication.py | 44 | ||||
| -rw-r--r-- | rest_framework/tests/fields.py | 43 | ||||
| -rw-r--r-- | rest_framework/tests/filterset.py | 82 | ||||
| -rw-r--r-- | rest_framework/tests/pagination.py | 12 | ||||
| -rw-r--r-- | rest_framework/tests/serializer.py | 30 | ||||
| -rw-r--r-- | rest_framework/tests/serializer_bulk_update.py | 278 | ||||
| -rw-r--r-- | rest_framework/tests/serializer_nested.py | 246 | ||||
| -rw-r--r-- | rest_framework/tests/status.py | 13 | 
15 files changed, 931 insertions, 207 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index cf005636..c86403d8 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -1,4 +1,4 @@ -__version__ = '2.2.4' +__version__ = '2.2.5'  VERSION = __version__  # synonym diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index b4b73699..145d4295 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -2,14 +2,16 @@  Provides a set of pluggable authentication policies.  """  from __future__ import unicode_literals +import base64 +from datetime import datetime +  from django.contrib.auth import authenticate  from django.core.exceptions import ImproperlyConfigured  from rest_framework import exceptions, HTTP_HEADER_ENCODING  from rest_framework.compat import CsrfViewMiddleware  from rest_framework.compat import oauth, oauth_provider, oauth_provider_store -from rest_framework.compat import oauth2_provider, oauth2_provider_forms, oauth2_provider_backends +from rest_framework.compat import oauth2_provider, oauth2_provider_forms  from rest_framework.authtoken.models import Token -import base64  def get_authorization_header(request): @@ -204,6 +206,9 @@ class OAuthAuthentication(BaseAuthentication):          except oauth.Error as err:              raise exceptions.AuthenticationFailed(err.message) +        if not oauth_request: +            return None +          oauth_params = oauth_provider.consts.OAUTH_PARAMETERS_NAMES          found = any(param for param in oauth_params if param in oauth_request) @@ -312,21 +317,15 @@ class OAuth2Authentication(BaseAuthentication):          Authenticate the request, given the access token.          """ -        # Authenticate the client -        oauth2_client_form = oauth2_provider_forms.ClientAuthForm(request.REQUEST) -        if not oauth2_client_form.is_valid(): -            raise exceptions.AuthenticationFailed('Client could not be validated') -        client = oauth2_client_form.cleaned_data.get('client') - -        # Retrieve the `OAuth2AccessToken` instance from the access_token -        auth_backend = oauth2_provider_backends.AccessTokenBackend() -        token = auth_backend.authenticate(access_token, client) -        if token is None: +        try: +            token = oauth2_provider.models.AccessToken.objects.select_related('user') +            # TODO: Change to timezone aware datetime when oauth2_provider add +            # support to it. +            token = token.get(token=access_token, expires__gt=datetime.now()) +        except oauth2_provider.models.AccessToken.DoesNotExist:              raise exceptions.AuthenticationFailed('Invalid token') -        user = token.user - -        if not user.is_active: +        if not token.user.is_active:              msg = 'User inactive or deleted: %s' % user.username              raise exceptions.AuthenticationFailed(msg) diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 7b2ef738..6551723a 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -395,6 +395,37 @@ except ImportError:              kw = dict((k, int(v)) for k, v in kw.iteritems() if v is not None)              return datetime.datetime(**kw) + +# smart_urlquote is new on Django 1.4 +try: +    from django.utils.html import smart_urlquote +except ImportError: +    try: +        from urllib.parse import quote, urlsplit, urlunsplit +    except ImportError:     # Python 2 +        from urllib import quote +        from urlparse import urlsplit, urlunsplit + +    def smart_urlquote(url): +        "Quotes a URL if it isn't already quoted." +        # Handle IDN before quoting. +        scheme, netloc, path, query, fragment = urlsplit(url) +        try: +            netloc = netloc.encode('idna').decode('ascii') # IDN -> ACE +        except UnicodeError: # invalid domain part +            pass +        else: +            url = urlunsplit((scheme, netloc, path, query, fragment)) + +        # An URL is considered unquoted if it contains no % characters or +        # contains a % not followed by two hexadecimal digits. See #9655. +        if '%' not in url or unquoted_percents_re.search(url): +            # See http://bugs.python.org/issue2637 +            url = quote(force_bytes(url), safe=b'!*\'();:@&=+$,/?#[]~') + +        return force_text(url) + +  # Markdown is optional  try:      import markdown @@ -445,14 +476,12 @@ except ImportError:  # OAuth 2 support is optional  try:      import provider.oauth2 as oauth2_provider -    from provider.oauth2 import backends as oauth2_provider_backends      from provider.oauth2 import models as oauth2_provider_models      from provider.oauth2 import forms as oauth2_provider_forms      from provider import scope as oauth2_provider_scope      from provider import constants as oauth2_constants  except ImportError:      oauth2_provider = None -    oauth2_provider_backends = None      oauth2_provider_models = None      oauth2_provider_forms = None      oauth2_provider_scope = None diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 4b6931ad..f3496b53 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -494,7 +494,7 @@ class DateField(WritableField):      }      empty = None      input_formats = api_settings.DATE_INPUT_FORMATS -    format = api_settings.DATE_FORMAT +    format = None      def __init__(self, input_formats=None, format=None, *args, **kwargs):          self.input_formats = input_formats if input_formats is not None else self.input_formats @@ -536,8 +536,8 @@ class DateField(WritableField):          raise ValidationError(msg)      def to_native(self, value): -        if value is None: -            return None +        if value is None or self.format is None: +            return value          if isinstance(value, datetime.datetime):              value = value.date() @@ -557,7 +557,7 @@ class DateTimeField(WritableField):      }      empty = None      input_formats = api_settings.DATETIME_INPUT_FORMATS -    format = api_settings.DATETIME_FORMAT +    format = None      def __init__(self, input_formats=None, format=None, *args, **kwargs):          self.input_formats = input_formats if input_formats is not None else self.input_formats @@ -605,11 +605,14 @@ class DateTimeField(WritableField):          raise ValidationError(msg)      def to_native(self, value): -        if value is None: -            return None +        if value is None or self.format is None: +            return value          if self.format.lower() == ISO_8601: -            return value.isoformat() +            ret = value.isoformat() +            if ret.endswith('+00:00'): +                ret = ret[:-6] + 'Z' +            return ret          return value.strftime(self.format) @@ -623,7 +626,7 @@ class TimeField(WritableField):      }      empty = None      input_formats = api_settings.TIME_INPUT_FORMATS -    format = api_settings.TIME_FORMAT +    format = None      def __init__(self, input_formats=None, format=None, *args, **kwargs):          self.input_formats = input_formats if input_formats is not None else self.input_formats @@ -658,8 +661,8 @@ class TimeField(WritableField):          raise ValidationError(msg)      def to_native(self, value): -        if value is None: -            return None +        if value is None or self.format is None: +            return value          if isinstance(value, datetime.datetime):              value = value.time() diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 6fea46fa..413fa0d2 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -55,6 +55,6 @@ class DjangoFilterBackend(BaseFilterBackend):          filter_class = self.get_filter_class(view)          if filter_class: -            return filter_class(request.QUERY_PARAMS, queryset=queryset) +            return filter_class(request.QUERY_PARAMS, queryset=queryset).qs          return queryset diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 4fe857a6..1b2b0821 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -20,6 +20,25 @@ from rest_framework.relations import *  from rest_framework.fields import * +class NestedValidationError(ValidationError): +    """ +    The default ValidationError behavior is to stringify each item in the list +    if the messages are a list of error messages. + +    In the case of nested serializers, where the parent has many children, +    then the child's `serializer.errors` will be a list of dicts.  In the case +    of a single child, the `serializer.errors` will be a dict. + +    We need to override the default behavior to get properly nested error dicts. +    """ + +    def __init__(self, message): +        if isinstance(message, dict): +            self.messages = [message] +        else: +            self.messages = message + +  class DictWithMetadata(dict):      """      A dict-like object, that can have additional properties attached. @@ -98,7 +117,7 @@ class SerializerOptions(object):          self.exclude = getattr(meta, 'exclude', ()) -class BaseSerializer(Field): +class BaseSerializer(WritableField):      """      This is the Serializer implementation.      We need to implement it as `BaseSerializer` due to metaclass magicks. @@ -110,13 +129,15 @@ class BaseSerializer(Field):      _dict_class = SortedDictWithMetadata      def __init__(self, instance=None, data=None, files=None, -                 context=None, partial=False, many=None, source=None): -        super(BaseSerializer, self).__init__(source=source) +                 context=None, partial=False, many=None, +                 allow_add_remove=False, **kwargs): +        super(BaseSerializer, self).__init__(**kwargs)          self.opts = self._options_class(self.Meta)          self.parent = None          self.root = None          self.partial = partial          self.many = many +        self.allow_add_remove = allow_add_remove          self.context = context or {} @@ -128,6 +149,13 @@ class BaseSerializer(Field):          self._data = None          self._files = None          self._errors = None +        self._deleted = None + +        if many and instance is not None and not hasattr(instance, '__iter__'): +            raise ValueError('instance should be a queryset or other iterable with many=True') + +        if allow_add_remove and not many: +            raise ValueError('allow_add_remove should only be used for bulk updates, but you have not set many=True')      #####      # Methods to determine which fields to use when (de)serializing objects. @@ -296,40 +324,91 @@ class BaseSerializer(Field):      def field_to_native(self, obj, field_name):          """ -        Override default so that we can apply ModelSerializer as a nested -        field to relationships. +        Override default so that the serializer can be used as a nested field +        across relationships.          """          if self.source == '*':              return self.to_native(obj)          try: -            if self.source: -                for component in self.source.split('.'): -                    obj = getattr(obj, component) -                    if is_simple_callable(obj): -                        obj = obj() -            else: -                obj = getattr(obj, field_name) -                if is_simple_callable(obj): -                    obj = obj() +            source = self.source or field_name +            value = obj + +            for component in source.split('.'): +                value = get_component(value, component) +                if value is None: +                    break          except ObjectDoesNotExist:              return None -        # If the object has an "all" method, assume it's a relationship -        if is_simple_callable(getattr(obj, 'all', None)): -            return [self.to_native(item) for item in obj.all()] +        if is_simple_callable(getattr(value, 'all', None)): +            return [self.to_native(item) for item in value.all()] -        if obj is None: +        if value is None:              return None          if self.many is not None:              many = self.many          else: -            many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict, six.text_type)) +            many = hasattr(value, '__iter__') and not isinstance(value, (Page, dict, six.text_type))          if many: -            return [self.to_native(item) for item in obj] -        return self.to_native(obj) +            return [self.to_native(item) for item in value] +        return self.to_native(value) + +    def field_from_native(self, data, files, field_name, into): +        """ +        Override default so that the serializer can be used as a writable +        nested field across relationships. +        """ +        if self.read_only: +            return + +        try: +            value = data[field_name] +        except KeyError: +            if self.default is not None and not self.partial: +                # Note: partial updates shouldn't set defaults +                value = copy.deepcopy(self.default) +            else: +                if self.required: +                    raise ValidationError(self.error_messages['required']) +                return + +        # Set the serializer object if it exists +        obj = getattr(self.parent.object, field_name) if self.parent.object else None + +        if value in (None, ''): +            into[(self.source or field_name)] = None +        else: +            kwargs = { +                'instance': obj, +                'data': value, +                'context': self.context, +                'partial': self.partial, +                'many': self.many +            } +            serializer = self.__class__(**kwargs) + +            if serializer.is_valid(): +                into[self.source or field_name] = serializer.object +            else: +                # Propagate errors up to our parent +                raise NestedValidationError(serializer.errors) + +    def get_identity(self, data): +        """ +        This hook is required for bulk update. +        It is used to determine the canonical identity of a given object. + +        Note that the data has not been validated at this point, so we need +        to make sure that we catch any cases of incorrect datatypes being +        passed to this method. +        """ +        try: +            return data.get('id', None) +        except AttributeError: +            return None      @property      def errors(self): @@ -352,10 +431,37 @@ class BaseSerializer(Field):              if many:                  ret = []                  errors = [] -                for item in data: -                    ret.append(self.from_native(item, None)) -                    errors.append(self._errors) -                self._errors = any(errors) and errors or [] +                update = self.object is not None + +                if update: +                    # If this is a bulk update we need to map all the objects +                    # to a canonical identity so we can determine which +                    # individual object is being updated for each item in the +                    # incoming data +                    objects = self.object +                    identities = [self.get_identity(self.to_native(obj)) for obj in objects] +                    identity_to_objects = dict(zip(identities, objects)) + +                if hasattr(data, '__iter__') and not isinstance(data, (dict, six.text_type)): +                    for item in data: +                        if update: +                            # Determine which object we're updating +                            identity = self.get_identity(item) +                            self.object = identity_to_objects.pop(identity, None) +                            if self.object is None and not self.allow_add_remove: +                                ret.append(None) +                                errors.append({'non_field_errors': ['Cannot create a new item, only existing items may be updated.']}) +                                continue + +                        ret.append(self.from_native(item, None)) +                        errors.append(self._errors) + +                    if update: +                        self._deleted = identity_to_objects.values() + +                    self._errors = any(errors) and errors or [] +                else: +                    self._errors = {'non_field_errors': ['Expected a list of items.']}              else:                  ret = self.from_native(data, files) @@ -394,6 +500,9 @@ class BaseSerializer(Field):      def save_object(self, obj, **kwargs):          obj.save(**kwargs) +    def delete_object(self, obj): +        obj.delete() +      def save(self, **kwargs):          """          Save the deserialized object and return it. @@ -402,6 +511,10 @@ class BaseSerializer(Field):              [self.save_object(item, **kwargs) for item in self.object]          else:              self.save_object(self.object, **kwargs) + +        if self.allow_add_remove and self._deleted: +            [self.delete_object(item) for item in self._deleted] +          return self.object @@ -584,33 +697,43 @@ class ModelSerializer(Serializer):          """          Restore the model instance.          """ -        self.m2m_data = {} -        self.related_data = {} +        m2m_data = {} +        related_data = {} +        meta = self.opts.model._meta -        # Reverse fk relations -        for (obj, model) in self.opts.model._meta.get_all_related_objects_with_model(): +        # Reverse fk or one-to-one relations +        for (obj, model) in meta.get_all_related_objects_with_model():              field_name = obj.field.related_query_name()              if field_name in attrs: -                self.related_data[field_name] = attrs.pop(field_name) +                related_data[field_name] = attrs.pop(field_name)          # Reverse m2m relations -        for (obj, model) in self.opts.model._meta.get_all_related_m2m_objects_with_model(): +        for (obj, model) in meta.get_all_related_m2m_objects_with_model():              field_name = obj.field.related_query_name()              if field_name in attrs: -                self.m2m_data[field_name] = attrs.pop(field_name) +                m2m_data[field_name] = attrs.pop(field_name)          # Forward m2m relations -        for field in self.opts.model._meta.many_to_many: +        for field in meta.many_to_many:              if field.name in attrs: -                self.m2m_data[field.name] = attrs.pop(field.name) +                m2m_data[field.name] = attrs.pop(field.name) +        # Update an existing instance...          if instance is not None:              for key, val in attrs.items():                  setattr(instance, key, val) +        # ...or create a new instance          else:              instance = self.opts.model(**attrs) +        # Any relations that cannot be set until we've +        # saved the model get hidden away on these +        # private attributes, so we can deal with them +        # at the point of save. +        instance._related_data = related_data +        instance._m2m_data = m2m_data +          return instance      def from_native(self, data, files): @@ -627,15 +750,15 @@ class ModelSerializer(Serializer):          """          obj.save(**kwargs) -        if getattr(self, 'm2m_data', None): -            for accessor_name, object_list in self.m2m_data.items(): -                setattr(self.object, accessor_name, object_list) -            self.m2m_data = {} +        if getattr(obj, '_m2m_data', None): +            for accessor_name, object_list in obj._m2m_data.items(): +                setattr(obj, accessor_name, object_list) +            del(obj._m2m_data) -        if getattr(self, 'related_data', None): -            for accessor_name, object_list in self.related_data.items(): -                setattr(self.object, accessor_name, object_list) -            self.related_data = {} +        if getattr(obj, '_related_data', None): +            for accessor_name, related in obj._related_data.items(): +                setattr(obj, accessor_name, related) +            del(obj._related_data)  class HyperlinkedModelSerializerOptions(ModelSerializerOptions): @@ -690,3 +813,13 @@ class HyperlinkedModelSerializer(ModelSerializer):              'many': to_many          }          return HyperlinkedRelatedField(**kwargs) + +    def get_identity(self, data): +        """ +        This hook is required for bulk update. +        We need to override the default, to use the url as the identity. +        """ +        try: +            return data.get('url', None) +        except AttributeError: +            return None diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index c21ddcd7..b6ab2de3 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -4,11 +4,8 @@ from django.core.urlresolvers import reverse, NoReverseMatch  from django.http import QueryDict  from django.utils.html import escape  from django.utils.safestring import SafeData, mark_safe -from rest_framework.compat import urlparse -from rest_framework.compat import force_text -from rest_framework.compat import six -import re -import string +from rest_framework.compat import urlparse, force_text, six, smart_urlquote +import re, string  register = template.Library() @@ -112,22 +109,6 @@ def replace_query_param(url, key, val):  class_re = re.compile(r'(?<=class=["\'])(.*)(?=["\'])') -# Bunch of stuff cloned from urlize -LEADING_PUNCTUATION = ['(', '<', '<', '"', "'"] -TRAILING_PUNCTUATION = ['.', ',', ')', '>', '\n', '>', '"', "'"] -DOTS = ['·', '*', '\xe2\x80\xa2', '•', '•', '•'] -unencoded_ampersands_re = re.compile(r'&(?!(\w+|#\d+);)') -word_split_re = re.compile(r'(\s+)') -punctuation_re = re.compile('^(?P<lead>(?:%s)*)(?P<middle>.*?)(?P<trail>(?:%s)*)$' % \ -    ('|'.join([re.escape(x) for x in LEADING_PUNCTUATION]), -    '|'.join([re.escape(x) for x in TRAILING_PUNCTUATION]))) -simple_email_re = re.compile(r'^\S+@[a-zA-Z0-9._-]+\.[a-zA-Z0-9._-]+$') -link_target_attribute_re = re.compile(r'(<a [^>]*?)target=[^\s>]+') -html_gunk_re = re.compile(r'(?:<br clear="all">|<i><\/i>|<b><\/b>|<em><\/em>|<strong><\/strong>|<\/?smallcaps>|<\/?uppercase>)', re.IGNORECASE) -hard_coded_bullets_re = re.compile(r'((?:<p>(?:%s).*?[a-zA-Z].*?</p>\s*)+)' % '|'.join([re.escape(x) for x in DOTS]), re.DOTALL) -trailing_empty_content_re = re.compile(r'(?:<p>(?: |\s|<br \/>)*?</p>\s*)+\Z') - -  # And the template tags themselves...  @register.simple_tag @@ -195,15 +176,25 @@ def add_class(value, css_class):      return value +# Bunch of stuff cloned from urlize +TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)', '"', "'"] +WRAPPING_PUNCTUATION = [('(', ')'), ('<', '>'), ('[', ']'), ('<', '>'), +                        ('"', '"'), ("'", "'")] +word_split_re = re.compile(r'(\s+)') +simple_url_re = re.compile(r'^https?://\w', re.IGNORECASE) +simple_url_2_re = re.compile(r'^www\.|^(?!http)\w[^@]+\.(com|edu|gov|int|mil|net|org)$', re.IGNORECASE) +simple_email_re = re.compile(r'^\S+@\S+\.\S+$') + +  @register.filter  def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=True):      """      Converts any URLs in text into clickable links. -    Works on http://, https://, www. links and links ending in .org, .net or -    .com. Links can have trailing punctuation (periods, commas, close-parens) -    and leading punctuation (opening parens) and it'll still do the right -    thing. +    Works on http://, https://, www. links, and also on links ending in one of +    the original seven gTLDs (.com, .edu, .gov, .int, .mil, .net, and .org). +    Links can have trailing punctuation (periods, commas, close-parens) and +    leading punctuation (opening parens) and it'll still do the right thing.      If trim_url_limit is not None, the URLs in link text longer than this limit      will truncated to trim_url_limit-3 characters and appended with an elipsis. @@ -216,24 +207,41 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru      trim_url = lambda x, limit=trim_url_limit: limit is not None and (len(x) > limit and ('%s...' % x[:max(0, limit - 3)])) or x      safe_input = isinstance(text, SafeData)      words = word_split_re.split(force_text(text)) -    nofollow_attr = nofollow and ' rel="nofollow"' or ''      for i, word in enumerate(words):          match = None          if '.' in word or '@' in word or ':' in word: -            match = punctuation_re.match(word) -        if match: -            lead, middle, trail = match.groups() +            # Deal with punctuation. +            lead, middle, trail = '', word, '' +            for punctuation in TRAILING_PUNCTUATION: +                if middle.endswith(punctuation): +                    middle = middle[:-len(punctuation)] +                    trail = punctuation + trail +            for opening, closing in WRAPPING_PUNCTUATION: +                if middle.startswith(opening): +                    middle = middle[len(opening):] +                    lead = lead + opening +                # Keep parentheses at the end only if they're balanced. +                if (middle.endswith(closing) +                    and middle.count(closing) == middle.count(opening) + 1): +                    middle = middle[:-len(closing)] +                    trail = closing + trail +              # Make URL we want to point to.              url = None -            if middle.startswith('http://') or middle.startswith('https://'): -                url = middle -            elif middle.startswith('www.') or ('@' not in middle and \ -                    middle and middle[0] in string.ascii_letters + string.digits and \ -                    (middle.endswith('.org') or middle.endswith('.net') or middle.endswith('.com'))): -                url = 'http://%s' % middle -            elif '@' in middle and not ':' in middle and simple_email_re.match(middle): -                url = 'mailto:%s' % middle +            nofollow_attr = ' rel="nofollow"' if nofollow else '' +            if simple_url_re.match(middle): +                url = smart_urlquote(middle) +            elif simple_url_2_re.match(middle): +                url = smart_urlquote('http://%s' % middle) +            elif not ':' in middle and simple_email_re.match(middle): +                local, domain = middle.rsplit('@', 1) +                try: +                    domain = domain.encode('idna').decode('ascii') +                except UnicodeError: +                    continue +                url = 'mailto:%s@%s' % (local, domain)                  nofollow_attr = '' +              # Make link.              if url:                  trimmed = trim_url(middle) @@ -251,4 +259,4 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru              words[i] = mark_safe(word)          elif autoescape:              words[i] = escape(word) -    return mark_safe(''.join(words)) +    return ''.join(words) diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py index b663ca48..8e6d3e51 100644 --- a/rest_framework/tests/authentication.py +++ b/rest_framework/tests/authentication.py @@ -466,17 +466,13 @@ class OAuth2Tests(TestCase):      def _create_authorization_header(self, token=None):          return "Bearer {0}".format(token or self.access_token.token) -    def _client_credentials_params(self): -        return {'client_id': self.CLIENT_ID, 'client_secret': self.CLIENT_SECRET} -      @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')      def test_get_form_with_wrong_authorization_header_token_type_failing(self):          """Ensure that a wrong token type lead to the correct HTTP error status code"""          auth = "Wrong token-type-obsviously"          response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)          self.assertEqual(response.status_code, 401) -        params = self._client_credentials_params() -        response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) +        response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)          self.assertEqual(response.status_code, 401)      @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') @@ -485,8 +481,7 @@ class OAuth2Tests(TestCase):          auth = "Bearer wrong token format"          response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)          self.assertEqual(response.status_code, 401) -        params = self._client_credentials_params() -        response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) +        response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)          self.assertEqual(response.status_code, 401)      @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') @@ -495,33 +490,21 @@ class OAuth2Tests(TestCase):          auth = "Bearer wrong-token"          response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)          self.assertEqual(response.status_code, 401) -        params = self._client_credentials_params() -        response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) -        self.assertEqual(response.status_code, 401) - -    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') -    def test_get_form_with_wrong_client_data_failing_auth(self): -        """Ensure GETing form over OAuth with incorrect client credentials fails""" -        auth = self._create_authorization_header() -        params = self._client_credentials_params() -        params['client_id'] += 'a' -        response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) +        response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)          self.assertEqual(response.status_code, 401)      @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')      def test_get_form_passing_auth(self):          """Ensure GETing form over OAuth with correct client credentials succeed"""          auth = self._create_authorization_header() -        params = self._client_credentials_params() -        response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) +        response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)          self.assertEqual(response.status_code, 200)      @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')      def test_post_form_passing_auth(self):          """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""          auth = self._create_authorization_header() -        params = self._client_credentials_params() -        response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) +        response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)          self.assertEqual(response.status_code, 200)      @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') @@ -529,16 +512,14 @@ class OAuth2Tests(TestCase):          """Ensure POSTing when there is no OAuth access token in db fails"""          self.access_token.delete()          auth = self._create_authorization_header() -        params = self._client_credentials_params() -        response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) +        response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)          self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))      @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')      def test_post_form_with_refresh_token_failing_auth(self):          """Ensure POSTing with refresh token instead of access token fails"""          auth = self._create_authorization_header(token=self.refresh_token.token) -        params = self._client_credentials_params() -        response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) +        response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)          self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))      @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') @@ -547,8 +528,7 @@ class OAuth2Tests(TestCase):          self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10)  # 10 seconds late          self.access_token.save()          auth = self._create_authorization_header() -        params = self._client_credentials_params() -        response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) +        response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)          self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))          self.assertIn('Invalid token', response.content) @@ -559,10 +539,9 @@ class OAuth2Tests(TestCase):          read_only_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['read']          read_only_access_token.save()          auth = self._create_authorization_header(token=read_only_access_token.token) -        params = self._client_credentials_params() -        response = self.csrf_client.get('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth) +        response = self.csrf_client.get('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)          self.assertEqual(response.status_code, 200) -        response = self.csrf_client.post('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth) +        response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)          self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)      @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') @@ -572,6 +551,5 @@ class OAuth2Tests(TestCase):          read_write_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['write']          read_write_access_token.save()          auth = self._create_authorization_header(token=read_write_access_token.token) -        params = self._client_credentials_params() -        response = self.csrf_client.post('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth) +        response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)          self.assertEqual(response.status_code, 200) diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py index fd6de779..19c663d8 100644 --- a/rest_framework/tests/fields.py +++ b/rest_framework/tests/fields.py @@ -153,12 +153,22 @@ class DateFieldTest(TestCase):      def test_to_native(self):          """ -        Make sure to_native() returns isoformat as default. +        Make sure to_native() returns datetime as default.          """          f = serializers.DateField()          result_1 = f.to_native(datetime.date(1984, 7, 31)) +        self.assertEqual(datetime.date(1984, 7, 31), result_1) + +    def test_to_native_iso(self): +        """ +        Make sure to_native() with 'iso-8601' returns iso formated date. +        """ +        f = serializers.DateField(format='iso-8601') + +        result_1 = f.to_native(datetime.date(1984, 7, 31)) +          self.assertEqual('1984-07-31', result_1)      def test_to_native_custom_format(self): @@ -289,6 +299,22 @@ class DateTimeFieldTest(TestCase):          result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59))          result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) +        self.assertEqual(datetime.datetime(1984, 7, 31), result_1) +        self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31), result_2) +        self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59), result_3) +        self.assertEqual(datetime.datetime(1984, 7, 31, 4, 31, 59, 200), result_4) + +    def test_to_native_iso(self): +        """ +        Make sure to_native() with format=iso-8601 returns iso formatted datetime. +        """ +        f = serializers.DateTimeField(format='iso-8601') + +        result_1 = f.to_native(datetime.datetime(1984, 7, 31)) +        result_2 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31)) +        result_3 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59)) +        result_4 = f.to_native(datetime.datetime(1984, 7, 31, 4, 31, 59, 200)) +          self.assertEqual('1984-07-31T00:00:00', result_1)          self.assertEqual('1984-07-31T04:31:00', result_2)          self.assertEqual('1984-07-31T04:31:59', result_3) @@ -419,13 +445,26 @@ class TimeFieldTest(TestCase):      def test_to_native(self):          """ -        Make sure to_native() returns isoformat as default. +        Make sure to_native() returns time object as default.          """          f = serializers.TimeField()          result_1 = f.to_native(datetime.time(4, 31))          result_2 = f.to_native(datetime.time(4, 31, 59))          result_3 = f.to_native(datetime.time(4, 31, 59, 200)) +        self.assertEqual(datetime.time(4, 31), result_1) +        self.assertEqual(datetime.time(4, 31, 59), result_2) +        self.assertEqual(datetime.time(4, 31, 59, 200), result_3) + +    def test_to_native_iso(self): +        """ +        Make sure to_native() with format='iso-8601' returns iso formatted time. +        """ +        f = serializers.TimeField(format='iso-8601') +        result_1 = f.to_native(datetime.time(4, 31)) +        result_2 = f.to_native(datetime.time(4, 31, 59)) +        result_3 = f.to_native(datetime.time(4, 31, 59, 200)) +          self.assertEqual('04:31:00', result_1)          self.assertEqual('04:31:59', result_2)          self.assertEqual('04:31:59.000200', result_3) diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py index fe92e0bc..1a71558c 100644 --- a/rest_framework/tests/filterset.py +++ b/rest_framework/tests/filterset.py @@ -1,11 +1,12 @@  from __future__ import unicode_literals  import datetime  from decimal import Decimal +from django.core.urlresolvers import reverse  from django.test import TestCase  from django.test.client import RequestFactory  from django.utils import unittest  from rest_framework import generics, status, filters -from rest_framework.compat import django_filters +from rest_framework.compat import django_filters, patterns, url  from rest_framework.tests.models import FilterableItem, BasicModel  factory = RequestFactory() @@ -46,12 +47,21 @@ if django_filters:          filter_class = MisconfiguredFilter          filter_backend = filters.DjangoFilterBackend +    class FilterClassDetailView(generics.RetrieveAPIView): +        model = FilterableItem +        filter_class = SeveralFieldsFilter +        filter_backend = filters.DjangoFilterBackend + +    urlpatterns = patterns('', +        url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'), +        url(r'^$', FilterClassRootView.as_view(), name='root-view'), +    ) -class IntegrationTestFiltering(TestCase): -    """ -    Integration tests for filtered list views. -    """ +class CommonFilteringTestCase(TestCase): +    def _serialize_object(self, obj): +        return {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} +          def setUp(self):          """          Create 10 FilterableItem instances. @@ -65,10 +75,16 @@ class IntegrationTestFiltering(TestCase):          self.objects = FilterableItem.objects          self.data = [ -            {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date.isoformat()} +            self._serialize_object(obj)              for obj in self.objects.all()          ] + +class IntegrationTestFiltering(CommonFilteringTestCase): +    """ +    Integration tests for filtered list views. +    """ +      @unittest.skipUnless(django_filters, 'django-filters not installed')      def test_get_filtered_fields_root_view(self):          """ @@ -95,7 +111,7 @@ class IntegrationTestFiltering(TestCase):          request = factory.get('/?date=%s' % search_date)  # search_date str: '2012-09-22'          response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK) -        expected_data = [f for f in self.data if datetime.datetime.strptime(f['date'], '%Y-%m-%d').date() == search_date] +        expected_data = [f for f in self.data if f['date'] == search_date]          self.assertEqual(response.data, expected_data)      @unittest.skipUnless(django_filters, 'django-filters not installed') @@ -125,7 +141,7 @@ class IntegrationTestFiltering(TestCase):          request = factory.get('/?date=%s' % search_date)  # search_date str: '2012-10-02'          response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK) -        expected_data = [f for f in self.data if datetime.datetime.strptime(f['date'], '%Y-%m-%d').date() > search_date] +        expected_data = [f for f in self.data if f['date'] > search_date]          self.assertEqual(response.data, expected_data)          # Tests that the text filter set with 'icontains' in the filter class works. @@ -142,8 +158,7 @@ class IntegrationTestFiltering(TestCase):          request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date))          response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK) -        expected_data = [f for f in self.data if -                         datetime.datetime.strptime(f['date'], '%Y-%m-%d').date() > search_date and +        expected_data = [f for f in self.data if f['date'] > search_date and                           f['decimal'] < search_decimal]          self.assertEqual(response.data, expected_data) @@ -168,3 +183,50 @@ class IntegrationTestFiltering(TestCase):          request = factory.get('/?integer=%s' % search_integer)          response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK) + + +class IntegrationTestDetailFiltering(CommonFilteringTestCase): +    """ +    Integration tests for filtered detail views. +    """ +    urls = 'rest_framework.tests.filterset' +     +    def _get_url(self, item): +        return reverse('detail-view', kwargs=dict(pk=item.pk)) + +    @unittest.skipUnless(django_filters, 'django-filters not installed') +    def test_get_filtered_detail_view(self): +        """ +        GET requests to filtered RetrieveAPIView that have a filter_class set +        should return filtered results. +        """ +        item = self.objects.all()[0] +        data = self._serialize_object(item) + +        # Basic test with no filter. +        response = self.client.get(self._get_url(item)) +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, data) + +        # Tests that the decimal filter set that should fail. +        search_decimal = Decimal('4.25') +        high_item = self.objects.filter(decimal__gt=search_decimal)[0] +        response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(high_item), param=search_decimal)) +        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + +        # Tests that the decimal filter set that should succeed. +        search_decimal = Decimal('4.25') +        low_item = self.objects.filter(decimal__lt=search_decimal)[0] +        low_item_data = self._serialize_object(low_item) +        response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal)) +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, low_item_data) +         +        # Tests that multiple filters works. +        search_decimal = Decimal('5.25') +        search_date = datetime.date(2012, 10, 2) +        valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0] +        valid_item_data = self._serialize_object(valid_item) +        response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date)) +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data, valid_item_data) diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py index 1a2d68a6..6b8ef02f 100644 --- a/rest_framework/tests/pagination.py +++ b/rest_framework/tests/pagination.py @@ -102,7 +102,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):          self.objects = FilterableItem.objects          self.data = [ -            {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date.isoformat()} +            {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}              for obj in self.objects.all()          ] @@ -129,16 +129,6 @@ class IntegrationTestPaginationAndFiltering(TestCase):          view = FilterFieldsRootView.as_view()          EXPECTED_NUM_QUERIES = 2 -        if django.VERSION < (1, 4): -            # On Django 1.3 we need to use django-filter 0.5.4 -            # -            # The filter objects there don't expose a `.count()` method, -            # which means we only make a single query *but* it's a single -            # query across *all* of the queryset, instead of a COUNT and then -            # a SELECT with a LIMIT. -            # -            # Although this is fewer queries, it's actually a regression. -            EXPECTED_NUM_QUERIES = 1          request = factory.get('/?decimal=15.20')          with self.assertNumQueries(EXPECTED_NUM_QUERIES): diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py index beb372c2..05217f35 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/serializer.py @@ -112,7 +112,7 @@ class BasicTests(TestCase):          self.expected = {              'email': 'tom@example.com',              'content': 'Happy new year!', -            'created': '2012-01-01T00:00:00', +            'created': datetime.datetime(2012, 1, 1),              'sub_comment': 'And Merry Christmas!'          }          self.person_data = {'name': 'dwight', 'age': 35} @@ -261,34 +261,6 @@ class ValidationTests(TestCase):          self.assertEqual(serializer.is_valid(), True)          self.assertEqual(serializer.errors, {}) -    def test_bad_type_data_is_false(self): -        """ -        Data of the wrong type is not valid. -        """ -        data = ['i am', 'a', 'list'] -        serializer = CommentSerializer(self.comment, data=data, many=True) -        self.assertEqual(serializer.is_valid(), False) -        self.assertTrue(isinstance(serializer.errors, list)) - -        self.assertEqual( -            serializer.errors, -            [ -                {'non_field_errors': ['Invalid data']}, -                {'non_field_errors': ['Invalid data']}, -                {'non_field_errors': ['Invalid data']} -            ] -        ) - -        data = 'and i am a string' -        serializer = CommentSerializer(self.comment, data=data) -        self.assertEqual(serializer.is_valid(), False) -        self.assertEqual(serializer.errors, {'non_field_errors': ['Invalid data']}) - -        data = 42 -        serializer = CommentSerializer(self.comment, data=data) -        self.assertEqual(serializer.is_valid(), False) -        self.assertEqual(serializer.errors, {'non_field_errors': ['Invalid data']}) -      def test_cross_field_validation(self):          class CommentSerializerWithCrossFieldValidator(CommentSerializer): diff --git a/rest_framework/tests/serializer_bulk_update.py b/rest_framework/tests/serializer_bulk_update.py new file mode 100644 index 00000000..8b0ded1a --- /dev/null +++ b/rest_framework/tests/serializer_bulk_update.py @@ -0,0 +1,278 @@ +""" +Tests to cover bulk create and update using serializers. +""" +from __future__ import unicode_literals +from django.test import TestCase +from rest_framework import serializers + + +class BulkCreateSerializerTests(TestCase): +    """ +    Creating multiple instances using serializers. +    """ + +    def setUp(self): +        class BookSerializer(serializers.Serializer): +            id = serializers.IntegerField() +            title = serializers.CharField(max_length=100) +            author = serializers.CharField(max_length=100) + +        self.BookSerializer = BookSerializer + +    def test_bulk_create_success(self): +        """ +        Correct bulk update serialization should return the input data. +        """ + +        data = [ +            { +                'id': 0, +                'title': 'The electric kool-aid acid test', +                'author': 'Tom Wolfe' +            }, { +                'id': 1, +                'title': 'If this is a man', +                'author': 'Primo Levi' +            }, { +                'id': 2, +                'title': 'The wind-up bird chronicle', +                'author': 'Haruki Murakami' +            } +        ] + +        serializer = self.BookSerializer(data=data, many=True) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(serializer.object, data) + +    def test_bulk_create_errors(self): +        """ +        Correct bulk update serialization should return the input data. +        """ + +        data = [ +            { +                'id': 0, +                'title': 'The electric kool-aid acid test', +                'author': 'Tom Wolfe' +            }, { +                'id': 1, +                'title': 'If this is a man', +                'author': 'Primo Levi' +            }, { +                'id': 'foo', +                'title': 'The wind-up bird chronicle', +                'author': 'Haruki Murakami' +            } +        ] +        expected_errors = [ +            {}, +            {}, +            {'id': ['Enter a whole number.']} +        ] + +        serializer = self.BookSerializer(data=data, many=True) +        self.assertEqual(serializer.is_valid(), False) +        self.assertEqual(serializer.errors, expected_errors) + +    def test_invalid_list_datatype(self): +        """ +        Data containing list of incorrect data type should return errors. +        """ +        data = ['foo', 'bar', 'baz'] +        serializer = self.BookSerializer(data=data, many=True) +        self.assertEqual(serializer.is_valid(), False) + +        expected_errors = [ +                {'non_field_errors': ['Invalid data']}, +                {'non_field_errors': ['Invalid data']}, +                {'non_field_errors': ['Invalid data']} +        ] + +        self.assertEqual(serializer.errors, expected_errors) + +    def test_invalid_single_datatype(self): +        """ +        Data containing a single incorrect data type should return errors. +        """ +        data = 123 +        serializer = self.BookSerializer(data=data, many=True) +        self.assertEqual(serializer.is_valid(), False) + +        expected_errors = {'non_field_errors': ['Expected a list of items.']} + +        self.assertEqual(serializer.errors, expected_errors) + +    def test_invalid_single_object(self): +        """ +        Data containing only a single object, instead of a list of objects +        should return errors. +        """ +        data = { +            'id': 0, +            'title': 'The electric kool-aid acid test', +            'author': 'Tom Wolfe' +        } +        serializer = self.BookSerializer(data=data, many=True) +        self.assertEqual(serializer.is_valid(), False) + +        expected_errors = {'non_field_errors': ['Expected a list of items.']} + +        self.assertEqual(serializer.errors, expected_errors) + + +class BulkUpdateSerializerTests(TestCase): +    """ +    Updating multiple instances using serializers. +    """ + +    def setUp(self): +        class Book(object): +            """ +            A data type that can be persisted to a mock storage backend +            with `.save()` and `.delete()`. +            """ +            object_map = {} + +            def __init__(self, id, title, author): +                self.id = id +                self.title = title +                self.author = author + +            def save(self): +                Book.object_map[self.id] = self + +            def delete(self): +                del Book.object_map[self.id] + +        class BookSerializer(serializers.Serializer): +            id = serializers.IntegerField() +            title = serializers.CharField(max_length=100) +            author = serializers.CharField(max_length=100) + +            def restore_object(self, attrs, instance=None): +                if instance: +                    instance.id = attrs['id'] +                    instance.title = attrs['title'] +                    instance.author = attrs['author'] +                    return instance +                return Book(**attrs) + +        self.Book = Book +        self.BookSerializer = BookSerializer + +        data = [ +            { +                'id': 0, +                'title': 'The electric kool-aid acid test', +                'author': 'Tom Wolfe' +            }, { +                'id': 1, +                'title': 'If this is a man', +                'author': 'Primo Levi' +            }, { +                'id': 2, +                'title': 'The wind-up bird chronicle', +                'author': 'Haruki Murakami' +            } +        ] + +        for item in data: +            book = Book(item['id'], item['title'], item['author']) +            book.save() + +    def books(self): +        """ +        Return all the objects in the mock storage backend. +        """ +        return self.Book.object_map.values() + +    def test_bulk_update_success(self): +        """ +        Correct bulk update serialization should return the input data. +        """ +        data = [ +            { +                'id': 0, +                'title': 'The electric kool-aid acid test', +                'author': 'Tom Wolfe' +            }, { +                'id': 2, +                'title': 'Kafka on the shore', +                'author': 'Haruki Murakami' +            } +        ] +        serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(serializer.data, data) +        serializer.save() +        new_data = self.BookSerializer(self.books(), many=True).data + +        self.assertEqual(data, new_data) + +    def test_bulk_update_and_create(self): +        """ +        Bulk update serialization may also include created items. +        """ +        data = [ +            { +                'id': 0, +                'title': 'The electric kool-aid acid test', +                'author': 'Tom Wolfe' +            }, { +                'id': 3, +                'title': 'Kafka on the shore', +                'author': 'Haruki Murakami' +            } +        ] +        serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(serializer.data, data) +        serializer.save() +        new_data = self.BookSerializer(self.books(), many=True).data +        self.assertEqual(data, new_data) + +    def test_bulk_update_invalid_create(self): +        """ +        Bulk update serialization without allow_add_remove may not create items. +        """ +        data = [ +            { +                'id': 0, +                'title': 'The electric kool-aid acid test', +                'author': 'Tom Wolfe' +            }, { +                'id': 3, +                'title': 'Kafka on the shore', +                'author': 'Haruki Murakami' +            } +        ] +        expected_errors = [ +            {}, +            {'non_field_errors': ['Cannot create a new item, only existing items may be updated.']} +        ] +        serializer = self.BookSerializer(self.books(), data=data, many=True) +        self.assertEqual(serializer.is_valid(), False) +        self.assertEqual(serializer.errors, expected_errors) + +    def test_bulk_update_error(self): +        """ +        Incorrect bulk update serialization should return error data. +        """ +        data = [ +            { +                'id': 0, +                'title': 'The electric kool-aid acid test', +                'author': 'Tom Wolfe' +            }, { +                'id': 'foo', +                'title': 'Kafka on the shore', +                'author': 'Haruki Murakami' +            } +        ] +        expected_errors = [ +            {}, +            {'id': ['Enter a whole number.']} +        ] +        serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) +        self.assertEqual(serializer.is_valid(), False) +        self.assertEqual(serializer.errors, expected_errors) diff --git a/rest_framework/tests/serializer_nested.py b/rest_framework/tests/serializer_nested.py new file mode 100644 index 00000000..6a29c652 --- /dev/null +++ b/rest_framework/tests/serializer_nested.py @@ -0,0 +1,246 @@ +""" +Tests to cover nested serializers. + +Doesn't cover model serializers. +""" +from __future__ import unicode_literals +from django.test import TestCase +from rest_framework import serializers + + +class WritableNestedSerializerBasicTests(TestCase): +    """ +    Tests for deserializing nested entities. +    Basic tests that use serializers that simply restore to dicts. +    """ + +    def setUp(self): +        class TrackSerializer(serializers.Serializer): +            order = serializers.IntegerField() +            title = serializers.CharField(max_length=100) +            duration = serializers.IntegerField() + +        class AlbumSerializer(serializers.Serializer): +            album_name = serializers.CharField(max_length=100) +            artist = serializers.CharField(max_length=100) +            tracks = TrackSerializer(many=True) + +        self.AlbumSerializer = AlbumSerializer + +    def test_nested_validation_success(self): +        """ +        Correct nested serialization should return the input data. +        """ + +        data = { +            'album_name': 'Discovery', +            'artist': 'Daft Punk', +            'tracks': [ +                {'order': 1, 'title': 'One More Time', 'duration': 235}, +                {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, +                {'order': 3, 'title': 'Digital Love', 'duration': 239} +            ] +        } + +        serializer = self.AlbumSerializer(data=data) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(serializer.object, data) + +    def test_nested_validation_error(self): +        """ +        Incorrect nested serialization should return appropriate error data. +        """ + +        data = { +            'album_name': 'Discovery', +            'artist': 'Daft Punk', +            'tracks': [ +                {'order': 1, 'title': 'One More Time', 'duration': 235}, +                {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, +                {'order': 3, 'title': 'Digital Love', 'duration': 'foobar'} +            ] +        } +        expected_errors = { +            'tracks': [ +                {}, +                {}, +                {'duration': ['Enter a whole number.']} +            ] +        } + +        serializer = self.AlbumSerializer(data=data) +        self.assertEqual(serializer.is_valid(), False) +        self.assertEqual(serializer.errors, expected_errors) + +    def test_many_nested_validation_error(self): +        """ +        Incorrect nested serialization should return appropriate error data +        when multiple entities are being deserialized. +        """ + +        data = [ +            { +                'album_name': 'Russian Red', +                'artist': 'I Love Your Glasses', +                'tracks': [ +                    {'order': 1, 'title': 'Cigarettes', 'duration': 121}, +                    {'order': 2, 'title': 'No Past Land', 'duration': 198}, +                    {'order': 3, 'title': 'They Don\'t Believe', 'duration': 191} +                ] +            }, +            { +                'album_name': 'Discovery', +                'artist': 'Daft Punk', +                'tracks': [ +                    {'order': 1, 'title': 'One More Time', 'duration': 235}, +                    {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, +                    {'order': 3, 'title': 'Digital Love', 'duration': 'foobar'} +                ] +            } +        ] +        expected_errors = [ +            {}, +            { +                'tracks': [ +                    {}, +                    {}, +                    {'duration': ['Enter a whole number.']} +                ] +            } +        ] + +        serializer = self.AlbumSerializer(data=data) +        self.assertEqual(serializer.is_valid(), False) +        self.assertEqual(serializer.errors, expected_errors) + + +class WritableNestedSerializerObjectTests(TestCase): +    """ +    Tests for deserializing nested entities. +    These tests use serializers that restore to concrete objects. +    """ + +    def setUp(self): +        # Couple of concrete objects that we're going to deserialize into +        class Track(object): +            def __init__(self, order, title, duration): +                self.order, self.title, self.duration = order, title, duration + +            def __eq__(self, other): +                return ( +                    self.order == other.order and +                    self.title == other.title and +                    self.duration == other.duration +                ) + +        class Album(object): +            def __init__(self, album_name, artist, tracks): +                self.album_name, self.artist, self.tracks = album_name, artist, tracks + +            def __eq__(self, other): +                return ( +                    self.album_name == other.album_name and +                    self.artist == other.artist and +                    self.tracks == other.tracks +                ) + +        # And their corresponding serializers +        class TrackSerializer(serializers.Serializer): +            order = serializers.IntegerField() +            title = serializers.CharField(max_length=100) +            duration = serializers.IntegerField() + +            def restore_object(self, attrs, instance=None): +                return Track(attrs['order'], attrs['title'], attrs['duration']) + +        class AlbumSerializer(serializers.Serializer): +            album_name = serializers.CharField(max_length=100) +            artist = serializers.CharField(max_length=100) +            tracks = TrackSerializer(many=True) + +            def restore_object(self, attrs, instance=None): +                return Album(attrs['album_name'], attrs['artist'], attrs['tracks']) + +        self.Album, self.Track = Album, Track +        self.AlbumSerializer = AlbumSerializer + +    def test_nested_validation_success(self): +        """ +        Correct nested serialization should return a restored object +        that corresponds to the input data. +        """ + +        data = { +            'album_name': 'Discovery', +            'artist': 'Daft Punk', +            'tracks': [ +                {'order': 1, 'title': 'One More Time', 'duration': 235}, +                {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, +                {'order': 3, 'title': 'Digital Love', 'duration': 239} +            ] +        } +        expected_object = self.Album( +            album_name='Discovery', +            artist='Daft Punk', +            tracks=[ +                self.Track(order=1, title='One More Time', duration=235), +                self.Track(order=2, title='Aerodynamic', duration=184), +                self.Track(order=3, title='Digital Love', duration=239), +            ] +        ) + +        serializer = self.AlbumSerializer(data=data) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(serializer.object, expected_object) + +    def test_many_nested_validation_success(self): +        """ +        Correct nested serialization should return multiple restored objects +        that corresponds to the input data when multiple objects are +        being deserialized. +        """ + +        data = [ +            { +                'album_name': 'Russian Red', +                'artist': 'I Love Your Glasses', +                'tracks': [ +                    {'order': 1, 'title': 'Cigarettes', 'duration': 121}, +                    {'order': 2, 'title': 'No Past Land', 'duration': 198}, +                    {'order': 3, 'title': 'They Don\'t Believe', 'duration': 191} +                ] +            }, +            { +                'album_name': 'Discovery', +                'artist': 'Daft Punk', +                'tracks': [ +                    {'order': 1, 'title': 'One More Time', 'duration': 235}, +                    {'order': 2, 'title': 'Aerodynamic', 'duration': 184}, +                    {'order': 3, 'title': 'Digital Love', 'duration': 239} +                ] +            } +        ] +        expected_object = [ +            self.Album( +                album_name='Russian Red', +                artist='I Love Your Glasses', +                tracks=[ +                    self.Track(order=1, title='Cigarettes', duration=121), +                    self.Track(order=2, title='No Past Land', duration=198), +                    self.Track(order=3, title='They Don\'t Believe', duration=191), +                ] +            ), +            self.Album( +                album_name='Discovery', +                artist='Daft Punk', +                tracks=[ +                    self.Track(order=1, title='One More Time', duration=235), +                    self.Track(order=2, title='Aerodynamic', duration=184), +                    self.Track(order=3, title='Digital Love', duration=239), +                ] +            ) +        ] + +        serializer = self.AlbumSerializer(data=data) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(serializer.object, expected_object) diff --git a/rest_framework/tests/status.py b/rest_framework/tests/status.py deleted file mode 100644 index e1644a6b..00000000 --- a/rest_framework/tests/status.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Tests for the status module""" -from __future__ import unicode_literals -from django.test import TestCase -from rest_framework import status - - -class TestStatus(TestCase): -    """Simple sanity test to check the status module""" - -    def test_status(self): -        """Ensure the status module is present and correct.""" -        self.assertEqual(200, status.HTTP_200_OK) -        self.assertEqual(404, status.HTTP_404_NOT_FOUND)  | 
