diff options
Diffstat (limited to 'rest_framework')
55 files changed, 1169 insertions, 232 deletions
| diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index f5483b9d..2d76b55d 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -8,10 +8,10 @@ ______ _____ _____ _____    __                                             _  """  __title__ = 'Django REST framework' -__version__ = '2.3.10' +__version__ = '2.3.13'  __author__ = 'Tom Christie'  __license__ = 'BSD 2-Clause' -__copyright__ = 'Copyright 2011-2013 Tom Christie' +__copyright__ = 'Copyright 2011-2014 Tom Christie'  # Version synonym  VERSION = __version__ diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index 1f8d37fa..cbc83574 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -7,6 +7,7 @@ import base64  from django.contrib.auth import authenticate  from django.core.exceptions import ImproperlyConfigured  from django.middleware.csrf import CsrfViewMiddleware +from django.conf import settings  from rest_framework import exceptions, HTTP_HEADER_ENCODING  from rest_framework.compat import oauth, oauth_provider, oauth_provider_store  from rest_framework.compat import oauth2_provider, provider_now, check_nonce @@ -291,6 +292,7 @@ class OAuth2Authentication(BaseAuthentication):      OAuth 2 authentication backend using `django-oauth2-provider`      """      www_authenticate_realm = 'api' +    allow_query_params_token = settings.DEBUG      def __init__(self, *args, **kwargs):          super(OAuth2Authentication, self).__init__(*args, **kwargs) @@ -308,7 +310,13 @@ class OAuth2Authentication(BaseAuthentication):          auth = get_authorization_header(request).split() -        if not auth or auth[0].lower() != b'bearer': +        if auth and auth[0].lower() == b'bearer': +            access_token = auth[1] +        elif 'access_token' in request.POST: +            access_token = request.POST['access_token'] +        elif 'access_token' in request.GET and self.allow_query_params_token: +            access_token = request.GET['access_token'] +        else:              return None          if len(auth) == 1: @@ -318,7 +326,7 @@ class OAuth2Authentication(BaseAuthentication):              msg = 'Invalid bearer header. Token string should not contain spaces.'              raise exceptions.AuthenticationFailed(msg) -        return self.authenticate_credentials(request, auth[1]) +        return self.authenticate_credentials(request, access_token)      def authenticate_credentials(self, request, access_token):          """ @@ -326,11 +334,11 @@ class OAuth2Authentication(BaseAuthentication):          """          try: -            token = oauth2_provider.models.AccessToken.objects.select_related('user') +            token = oauth2_provider.oauth2.models.AccessToken.objects.select_related('user')              # provider_now switches to timezone aware datetime when              # the oauth2_provider version supports to it.              token = token.get(token=access_token, expires__gt=provider_now()) -        except oauth2_provider.models.AccessToken.DoesNotExist: +        except oauth2_provider.oauth2.models.AccessToken.DoesNotExist:              raise exceptions.AuthenticationFailed('Invalid token')          user = token.user diff --git a/rest_framework/authtoken/models.py b/rest_framework/authtoken/models.py index 024f62bf..8eac2cc4 100644 --- a/rest_framework/authtoken/models.py +++ b/rest_framework/authtoken/models.py @@ -1,5 +1,5 @@ -import uuid -import hmac +import binascii +import os  from hashlib import sha1  from django.conf import settings  from django.db import models @@ -34,8 +34,7 @@ class Token(models.Model):          return super(Token, self).save(*args, **kwargs)      def generate_key(self): -        unique = uuid.uuid4() -        return hmac.new(unique.bytes, digestmod=sha1).hexdigest() +        return binascii.hexlify(os.urandom(20))      def __unicode__(self):          return self.key diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 45045c0f..a013a155 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -121,7 +121,7 @@ from django.test.client import RequestFactory as DjangoRequestFactory  from django.test.client import FakePayload  try:      # In 1.5 the test client uses force_bytes -    from django.utils.encoding import force_bytes_or_smart_bytes +    from django.utils.encoding import force_bytes as force_bytes_or_smart_bytes  except ImportError:      # In 1.4 the test client just uses smart_str      from django.utils.encoding import smart_str as force_bytes_or_smart_bytes @@ -216,13 +216,10 @@ except (ImportError, ImproperlyConfigured):  # OAuth 2 support is optional  try: -    import provider.oauth2 as oauth2_provider -    from provider.oauth2 import models as oauth2_provider_models -    from provider.oauth2 import forms as oauth2_provider_forms +    import provider as oauth2_provider      from provider import scope as oauth2_provider_scope      from provider import constants as oauth2_constants -    from provider import __version__ as provider_version -    if provider_version in ('0.2.3', '0.2.4'): +    if oauth2_provider.__version__ in ('0.2.3', '0.2.4'):          # 0.2.3 and 0.2.4 are supported version that do not support          # timezone aware datetimes          import datetime @@ -232,8 +229,6 @@ try:          from django.utils.timezone import now as provider_now  except ImportError:      oauth2_provider = None -    oauth2_provider_models = None -    oauth2_provider_forms = None      oauth2_provider_scope = None      oauth2_constants = None      provider_now = None @@ -251,3 +246,23 @@ if six.PY3:  else:      def is_non_str_iterable(obj):          return hasattr(obj, '__iter__') + + +try: +    from django.utils.encoding import python_2_unicode_compatible +except ImportError: +    def python_2_unicode_compatible(klass): +        """ +        A decorator that defines __unicode__ and __str__ methods under Python 2. +        Under Python 3 it does nothing. + +        To support Python 2 and 3 with a single code base, define a __str__ method +        returning text and apply this decorator to the class. +        """ +        if '__str__' not in klass.__dict__: +            raise ValueError("@python_2_unicode_compatible cannot be applied " +                             "to %s because it doesn't define __str__()." % +                             klass.__name__) +        klass.__unicode__ = klass.__str__ +        klass.__str__ = lambda self: self.__unicode__().encode('utf-8') +        return klass diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index 425a7214..5f774a9f 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -6,47 +6,42 @@ In addition Django's built in 403 and 404 exceptions are handled.  """  from __future__ import unicode_literals  from rest_framework import status +import math  class APIException(Exception):      """      Base class for REST framework exceptions. -    Subclasses should provide `.status_code` and `.detail` properties. +    Subclasses should provide `.status_code` and `.default_detail` properties.      """ -    pass +    status_code = status.HTTP_500_INTERNAL_SERVER_ERROR +    default_detail = '' +    def __init__(self, detail=None): +        self.detail = detail or self.default_detail + +    def __str__(self): +        return self.detail  class ParseError(APIException):      status_code = status.HTTP_400_BAD_REQUEST      default_detail = 'Malformed request.' -    def __init__(self, detail=None): -        self.detail = detail or self.default_detail -  class AuthenticationFailed(APIException):      status_code = status.HTTP_401_UNAUTHORIZED      default_detail = 'Incorrect authentication credentials.' -    def __init__(self, detail=None): -        self.detail = detail or self.default_detail -  class NotAuthenticated(APIException):      status_code = status.HTTP_401_UNAUTHORIZED      default_detail = 'Authentication credentials were not provided.' -    def __init__(self, detail=None): -        self.detail = detail or self.default_detail -  class PermissionDenied(APIException):      status_code = status.HTTP_403_FORBIDDEN      default_detail = 'You do not have permission to perform this action.' -    def __init__(self, detail=None): -        self.detail = detail or self.default_detail -  class MethodNotAllowed(APIException):      status_code = status.HTTP_405_METHOD_NOT_ALLOWED @@ -75,14 +70,14 @@ class UnsupportedMediaType(APIException):  class Throttled(APIException):      status_code = status.HTTP_429_TOO_MANY_REQUESTS -    default_detail = "Request was throttled." +    default_detail = 'Request was throttled.'      extra_detail = "Expected available in %d second%s."      def __init__(self, wait=None, detail=None): -        import math -        self.wait = wait and math.ceil(wait) or None -        if wait is not None: -            format = detail or self.default_detail + self.extra_detail -            self.detail = format % (self.wait, self.wait != 1 and 's' or '') -        else: +        if wait is None:              self.detail = detail or self.default_detail +            self.wait = None +        else: +            format = (detail or self.default_detail) + self.extra_detail +            self.detail = format % (wait, wait != 1 and 's' or '') +            self.wait = math.ceil(wait) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 16485b41..533de28c 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -166,7 +166,7 @@ class Field(object):          Called to set up a field prior to field_to_native or field_from_native.          parent - The parent serializer. -        model_field - The model field this field corresponds to, if one exists. +        field_name - The name of the field being initialized.          """          self.parent = parent          self.root = parent.root or parent @@ -248,6 +248,7 @@ class WritableField(Field):      """      Base for read/write fields.      """ +    write_only = False      default_validators = []      default_error_messages = {          'required': _('This field is required.'), @@ -257,13 +258,17 @@ class WritableField(Field):      default = None      def __init__(self, source=None, label=None, help_text=None, -                 read_only=False, required=None, +                 read_only=False, write_only=False, required=None,                   validators=[], error_messages=None, widget=None,                   default=None, blank=None):          super(WritableField, self).__init__(source=source, label=label, help_text=help_text)          self.read_only = read_only +        self.write_only = write_only + +        assert not (read_only and write_only), "Cannot set read_only=True and write_only=True" +          if required is None:              self.required = not(read_only)          else: @@ -291,6 +296,11 @@ class WritableField(Field):          result.validators = self.validators[:]          return result +    def get_default_value(self): +        if is_simple_callable(self.default): +            return self.default() +        return self.default +      def validate(self, value):          if value in validators.EMPTY_VALUES and self.required:              raise ValidationError(self.error_messages['required']) @@ -313,6 +323,11 @@ class WritableField(Field):          if errors:              raise ValidationError(errors) +    def field_to_native(self, obj, field_name): +        if self.write_only: +            return None +        return super(WritableField, self).field_to_native(obj, field_name) +      def field_from_native(self, data, files, field_name, into):          """          Given a dictionary and a field name, updates the dictionary `into`, @@ -334,10 +349,7 @@ class WritableField(Field):          except KeyError:              if self.default is not None and not self.partial:                  # Note: partial updates shouldn't set defaults -                if is_simple_callable(self.default): -                    native = self.default() -                else: -                    native = self.default +                native = self.get_default_value()              else:                  if self.required:                      raise ValidationError(self.error_messages['required']) @@ -465,7 +477,8 @@ class URLField(CharField):      type_label = 'url'      def __init__(self, **kwargs): -        kwargs['validators'] = [validators.URLValidator()] +        if not 'validators' in kwargs: +            kwargs['validators'] = [validators.URLValidator()]          super(URLField, self).__init__(**kwargs) diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 5c6a187c..96d15eb9 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -3,8 +3,10 @@ Provides generic filtering backends that can be used to filter the results  returned by list views.  """  from __future__ import unicode_literals +from django.core.exceptions import ImproperlyConfigured  from django.db import models  from rest_framework.compat import django_filters, six, guardian, get_model_name +from rest_framework.settings import api_settings  from functools import reduce  import operator @@ -68,7 +70,8 @@ class DjangoFilterBackend(BaseFilterBackend):  class SearchFilter(BaseFilterBackend): -    search_param = 'search'  # The URL query parameter used for the search. +    # The URL query parameter used for the search. +    search_param = api_settings.SEARCH_PARAM      def get_search_terms(self, request):          """ @@ -106,7 +109,9 @@ class SearchFilter(BaseFilterBackend):  class OrderingFilter(BaseFilterBackend): -    ordering_param = 'ordering'  # The URL query parameter used for the ordering. +    # The URL query parameter used for the ordering. +    ordering_param = api_settings.ORDERING_PARAM +    ordering_fields = None      def get_ordering(self, request):          """ @@ -122,17 +127,34 @@ class OrderingFilter(BaseFilterBackend):              return (ordering,)          return ordering -    def remove_invalid_fields(self, queryset, ordering): -        field_names = [field.name for field in queryset.model._meta.fields] -        field_names += queryset.query.aggregates.keys() -        return [term for term in ordering if term.lstrip('-') in field_names] +    def remove_invalid_fields(self, queryset, ordering, view): +        valid_fields = getattr(view, 'ordering_fields', self.ordering_fields) + +        if valid_fields is None: +            # Default to allowing filtering on serializer fields +            serializer_class = getattr(view, 'serializer_class') +            if serializer_class is None: +                msg = ("Cannot use %s on a view which does not have either a " +                       "'serializer_class' or 'ordering_fields' attribute.") +                raise ImproperlyConfigured(msg % self.__class__.__name__) +            valid_fields = [ +                field.source or field_name +                for field_name, field in serializer_class().fields.items() +                if not getattr(field, 'write_only', False) +            ] +        elif valid_fields == '__all__': +            # View explictly allows filtering on any model field +            valid_fields = [field.name for field in queryset.model._meta.fields] +            valid_fields += queryset.query.aggregates.keys() + +        return [term for term in ordering if term.lstrip('-') in valid_fields]      def filter_queryset(self, request, queryset, view):          ordering = self.get_ordering(request)          if ordering:              # Skip any incorrect parameters -            ordering = self.remove_invalid_fields(queryset, ordering) +            ordering = self.remove_invalid_fields(queryset, ordering, view)          if not ordering:              # Use 'ordering' attribute by default diff --git a/rest_framework/generics.py b/rest_framework/generics.py index bd33c01a..c3256844 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -352,7 +352,7 @@ class GenericAPIView(views.APIView):      def post_delete(self, obj):          """ -        Placeholder method for calling after saving an object. +        Placeholder method for calling after deleting an object.          """          pass diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index b62a4cc1..2cc87eef 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -11,6 +11,7 @@ from django.http import Http404  from rest_framework import status  from rest_framework.response import Response  from rest_framework.request import clone_request +from rest_framework.settings import api_settings  import warnings @@ -60,7 +61,7 @@ class CreateModelMixin(object):      def get_success_headers(self, data):          try: -            return {'Location': data['url']} +            return {'Location': data[api_settings.URL_FIELD_NAME]}          except (TypeError, KeyError):              return {} @@ -115,30 +116,27 @@ class UpdateModelMixin(object):          partial = kwargs.pop('partial', False)          self.object = self.get_object_or_none() -        if self.object is None: -            created = True -            save_kwargs = {'force_insert': True} -            success_status_code = status.HTTP_201_CREATED -        else: -            created = False -            save_kwargs = {'force_update': True} -            success_status_code = status.HTTP_200_OK -          serializer = self.get_serializer(self.object, data=request.DATA,                                           files=request.FILES, partial=partial) -        if serializer.is_valid(): -            try: -                self.pre_save(serializer.object) -            except ValidationError as err: -                # full_clean on model instance may be called in pre_save, so we -                # have to handle eventual errors. -                return Response(err.message_dict, status=status.HTTP_400_BAD_REQUEST) -            self.object = serializer.save(**save_kwargs) -            self.post_save(self.object, created=created) -            return Response(serializer.data, status=success_status_code) +        if not serializer.is_valid(): +            return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) -        return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) +        try: +            self.pre_save(serializer.object) +        except ValidationError as err: +            # full_clean on model instance may be called in pre_save, +            # so we have to handle eventual errors. +            return Response(err.message_dict, status=status.HTTP_400_BAD_REQUEST) + +        if self.object is None: +            self.object = serializer.save(force_insert=True) +            self.post_save(self.object, created=True) +            return Response(serializer.data, status=status.HTTP_201_CREATED) + +        self.object = serializer.save(force_update=True) +        self.post_save(self.object, created=False) +        return Response(serializer.data, status=status.HTTP_200_OK)      def partial_update(self, request, *args, **kwargs):          kwargs['partial'] = True diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 4785c009..3b234dd5 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -33,6 +33,7 @@ class RelatedField(WritableField):      many_widget = widgets.SelectMultiple      form_field_class = forms.ChoiceField      many_form_field_class = forms.MultipleChoiceField +    null_values = (None, '', 'None')      cache_choices = False      empty_label = None @@ -50,6 +51,8 @@ class RelatedField(WritableField):          super(RelatedField, self).__init__(*args, **kwargs)          if not self.required: +            # Accessed in ModelChoiceIterator django/forms/models.py:1034 +            # If set adds empty choice.              self.empty_label = BLANK_CHOICE_DASH[0][1]          self.queryset = queryset @@ -57,16 +60,11 @@ class RelatedField(WritableField):      def initialize(self, parent, field_name):          super(RelatedField, self).initialize(parent, field_name)          if self.queryset is None and not self.read_only: -            try: -                manager = getattr(self.parent.opts.model, self.source or field_name) -                if hasattr(manager, 'related'):  # Forward -                    self.queryset = manager.related.model._default_manager.all() -                else:  # Reverse -                    self.queryset = manager.field.rel.to._default_manager.all() -            except Exception: -                msg = ('Serializer related fields must include a `queryset`' + -                       ' argument or set `read_only=True') -                raise Exception(msg) +            manager = getattr(self.parent.opts.model, self.source or field_name) +            if hasattr(manager, 'related'):  # Forward +                self.queryset = manager.related.model._default_manager.all() +            else:  # Reverse +                self.queryset = manager.field.rel.to._default_manager.all()      ### We need this stuff to make form choices work... @@ -115,6 +113,14 @@ class RelatedField(WritableField):      choices = property(_get_choices, _set_choices) +    ### Default value handling + +    def get_default_value(self): +        default = super(RelatedField, self).get_default_value() +        if self.many and default is None: +            return [] +        return default +      ### Regular serializer stuff...      def field_to_native(self, obj, field_name): @@ -163,11 +169,11 @@ class RelatedField(WritableField):          except KeyError:              if self.partial:                  return -            value = [] if self.many else None +            value = self.get_default_value() -        if value in (None, '') and self.required: -            raise ValidationError(self.error_messages['required']) -        elif value in (None, ''): +        if value in self.null_values: +            if self.required: +                raise ValidationError(self.error_messages['required'])              into[(self.source or field_name)] = None          elif self.many:              into[(self.source or field_name)] = [self.from_native(item) for item in value] diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 2fdd3337..7a7da561 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -10,6 +10,7 @@ from __future__ import unicode_literals  import copy  import json +import django  from django import forms  from django.core.exceptions import ImproperlyConfigured  from django.http.multipartparser import parse_header @@ -145,7 +146,7 @@ class XMLRenderer(BaseRenderer):      def render(self, data, accepted_media_type=None, renderer_context=None):          """ -        Renders *obj* into serialized XML. +        Renders `data` into serialized XML.          """          if data is None:              return '' @@ -195,7 +196,7 @@ class YAMLRenderer(BaseRenderer):      def render(self, data, accepted_media_type=None, renderer_context=None):          """ -        Renders *obj* into serialized YAML. +        Renders `data` into serialized YAML.          """          assert yaml, 'YAMLRenderer requires pyyaml to be installed' @@ -426,7 +427,7 @@ class BrowsableAPIRenderer(BaseRenderer):                  files = request.FILES              except ParseError:                  data = None -                files = None         +                files = None          else:              data = None              files = None @@ -543,6 +544,14 @@ class BrowsableAPIRenderer(BaseRenderer):          raw_data_patch_form = self.get_raw_data_form(view, 'PATCH', request)          raw_data_put_or_patch_form = raw_data_put_form or raw_data_patch_form +        response_headers = dict(response.items()) +        renderer_content_type = '' +        if renderer: +            renderer_content_type = '%s' % renderer.media_type +            if renderer.charset: +                renderer_content_type += ' ;%s' % renderer.charset +        response_headers['Content-Type'] = renderer_content_type +          context = {              'content': self.get_content(renderer, data, accepted_media_type, renderer_context),              'view': view, @@ -554,6 +563,7 @@ class BrowsableAPIRenderer(BaseRenderer):              'breadcrumblist': self.get_breadcrumbs(request),              'allowed_methods': view.allowed_methods,              'available_formats': [renderer.format for renderer in view.renderer_classes], +            'response_headers': response_headers,              'put_form': self.get_rendered_html_form(view, 'PUT', request),              'post_form': self.get_rendered_html_form(view, 'POST', request), @@ -597,7 +607,7 @@ class MultiPartRenderer(BaseRenderer):      media_type = 'multipart/form-data; boundary=BoUnDaRyStRiNg'      format = 'multipart'      charset = 'utf-8' -    BOUNDARY = 'BoUnDaRyStRiNg' +    BOUNDARY = 'BoUnDaRyStRiNg' if django.VERSION >= (1, 5) else b'BoUnDaRyStRiNg'      def render(self, data, accepted_media_type=None, renderer_context=None):          return encode_multipart(self.BOUNDARY, data) diff --git a/rest_framework/request.py b/rest_framework/request.py index fcea2508..40467c03 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -223,7 +223,7 @@ class Request(object):      def user(self, value):          """          Sets the user on the current request. This is necessary to maintain -        compatilbility with django.contrib.auth where the user proprety is +        compatibility with django.contrib.auth where the user property is          set in the login and logout functions.          """          self._user = value @@ -279,10 +279,9 @@ class Request(object):          if not _hasattr(self, '_method'):              self._method = self._request.method -            if self._method == 'POST': -                # Allow X-HTTP-METHOD-OVERRIDE header -                self._method = self.META.get('HTTP_X_HTTP_METHOD_OVERRIDE', -                                             self._method) +            # Allow X-HTTP-METHOD-OVERRIDE header +            self._method = self.META.get('HTTP_X_HTTP_METHOD_OVERRIDE', +                                         self._method)      def _load_stream(self):          """ @@ -347,7 +346,7 @@ class Request(object):          media_type = self.content_type          if stream is None or media_type is None: -            empty_data = QueryDict('', self._request._encoding) +            empty_data = QueryDict('', encoding=self._request._encoding)              empty_files = MultiValueDict()              return (empty_data, empty_files) @@ -363,7 +362,7 @@ class Request(object):              # re-raise.  Ensures we don't simply repeat the error when              # attempting to render the browsable renderer response, or when              # logging the request or similar. -            self._data = QueryDict('', self._request._encoding) +            self._data = QueryDict('', encoding=self._request._encoding)              self._files = MultiValueDict()              raise diff --git a/rest_framework/runtests/runtests.py b/rest_framework/runtests/runtests.py index da36d23f..2daaae4e 100755 --- a/rest_framework/runtests/runtests.py +++ b/rest_framework/runtests/runtests.py @@ -26,6 +26,10 @@ def usage():  def main(): +    try: +        django.setup() +    except AttributeError: +        pass      TestRunner = get_runner(settings)      test_runner = TestRunner() diff --git a/rest_framework/runtests/settings.py b/rest_framework/runtests/settings.py index 12aa73e7..36283d8e 100644 --- a/rest_framework/runtests/settings.py +++ b/rest_framework/runtests/settings.py @@ -97,6 +97,9 @@ INSTALLED_APPS = (      'rest_framework',      'rest_framework.authtoken',      'rest_framework.tests', +    'rest_framework.tests.accounts', +    'rest_framework.tests.records', +    'rest_framework.tests.users',  )  # OAuth is optional and won't work if there is no oauth_provider & oauth2 diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index cbf73fc3..5b14e403 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -13,12 +13,15 @@ response content is handled by parsers and renderers.  from __future__ import unicode_literals  import copy  import datetime +import inspect  import types  from decimal import Decimal  from django.db import models  from django.forms import widgets  from django.utils.datastructures import SortedDict -from rest_framework.compat import six +from rest_framework.compat import get_concrete_model, six +from rest_framework.settings import api_settings +  # Note: We do the following so that users of the framework can use this style:  # @@ -31,6 +34,27 @@ from rest_framework.relations import *  from rest_framework.fields import * +def _resolve_model(obj): +    """ +    Resolve supplied `obj` to a Django model class. + +    `obj` must be a Django model class itself, or a string +    representation of one.  Useful in situtations like GH #1225 where +    Django may not have resolved a string-based reference to a model in +    another model's foreign key definition. + +    String representations should have the format: +        'appname.ModelName' +    """ +    if type(obj) == str and len(obj.split('.')) == 2: +        app_name, model_name = obj.split('.') +        return models.get_model(app_name, model_name) +    elif inspect.isclass(obj) and issubclass(obj, models.Model): +        return obj +    else: +        raise ValueError("{0} is not a Django model".format(obj)) + +  def pretty_name(name):      """Converts 'first_name' to 'First name'"""      if not name: @@ -325,12 +349,13 @@ class BaseSerializer(WritableField):              method = getattr(self, 'transform_%s' % field_name, None)              if callable(method):                  value = method(obj, value) -            ret[key] = value +            if not getattr(field, 'write_only', False): +                ret[key] = value              ret.fields[key] = self.augment_field(field, field_name, key, value)          return ret -    def from_native(self, data, files): +    def from_native(self, data, files=None):          """          Deserialize primitives -> objects.          """ @@ -360,6 +385,9 @@ class BaseSerializer(WritableField):          Override default so that the serializer can be used as a nested field          across relationships.          """ +        if self.write_only: +            return None +          if self.source == '*':              return self.to_native(obj) @@ -404,16 +432,6 @@ class BaseSerializer(WritableField):                      raise ValidationError(self.error_messages['required'])                  return -        # Set the serializer object if it exists -        obj = get_component(self.parent.object, self.source or field_name) if self.parent.object else None - -        # If we have a model manager or similar object then we need -        # to iterate through each instance. -        if (self.many and -            not hasattr(obj, '__iter__') and -            is_simple_callable(getattr(obj, 'all', None))): -            obj = obj.all() -          if self.source == '*':              if value:                  reverted_data = self.restore_fields(value, {}) @@ -423,6 +441,16 @@ class BaseSerializer(WritableField):              if value in (None, ''):                  into[(self.source or field_name)] = None              else: +                # Set the serializer object if it exists +                obj = get_component(self.parent.object, self.source or field_name) if self.parent.object else None + +                # If we have a model manager or similar object then we need +                # to iterate through each instance. +                if (self.many and +                    not hasattr(obj, '__iter__') and +                    is_simple_callable(getattr(obj, 'all', None))): +                    obj = obj.all() +                  kwargs = {                      'instance': obj,                      'data': value, @@ -467,7 +495,7 @@ class BaseSerializer(WritableField):              else:                  many = hasattr(data, '__iter__') and not isinstance(data, (Page, dict, six.text_type))                  if many: -                    warnings.warn('Implict list/queryset serialization is deprecated. ' +                    warnings.warn('Implicit list/queryset serialization is deprecated. '                                    'Use the `many=True` flag when instantiating the serializer.',                                    DeprecationWarning, stacklevel=3) @@ -524,7 +552,16 @@ class BaseSerializer(WritableField):          if self._data is None:              obj = self.object -            if self.many: +            if self.many is not None: +                many = self.many +            else: +                many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict)) +                if many: +                    warnings.warn('Implicit list/queryset serialization is deprecated. ' +                                  'Use the `many=True` flag when instantiating the serializer.', +                                  DeprecationWarning, stacklevel=2) + +            if many:                  self._data = [self.to_native(item) for item in obj]              else:                  self._data = self.to_native(obj) @@ -578,6 +615,7 @@ class ModelSerializerOptions(SerializerOptions):          super(ModelSerializerOptions, self).__init__(meta)          self.model = getattr(meta, 'model', None)          self.read_only_fields = getattr(meta, 'read_only_fields', ()) +        self.write_only_fields = getattr(meta, 'write_only_fields', ())  class ModelSerializer(Serializer): @@ -641,7 +679,7 @@ class ModelSerializer(Serializer):              if model_field.rel:                  to_many = isinstance(model_field,                                       models.fields.related.ManyToManyField) -                related_model = model_field.rel.to +                related_model = _resolve_model(model_field.rel.to)                  if to_many and not model_field.rel.through._meta.auto_created:                      has_through_model = True @@ -713,20 +751,38 @@ class ModelSerializer(Serializer):                      field.read_only = True                  ret[accessor_name] = field +         +        # Ensure that 'read_only_fields' is an iterable +        assert isinstance(self.opts.read_only_fields, (list, tuple)), '`read_only_fields` must be a list or tuple'  -        # Add the `read_only` flag to any fields that have bee specified +        # Add the `read_only` flag to any fields that have been specified          # in the `read_only_fields` option          for field_name in self.opts.read_only_fields: -            assert field_name not in self.base_fields.keys(), \ -                "field '%s' on serializer '%s' specified in " \ -                "`read_only_fields`, but also added " \ -                "as an explicit field.  Remove it from `read_only_fields`." % \ -                (field_name, self.__class__.__name__) -            assert field_name in ret, \ -                "Non-existant field '%s' specified in `read_only_fields` " \ -                "on serializer '%s'." % \ -                (field_name, self.__class__.__name__) +            assert field_name not in self.base_fields.keys(), ( +                "field '%s' on serializer '%s' specified in " +                "`read_only_fields`, but also added " +                "as an explicit field.  Remove it from `read_only_fields`." % +                (field_name, self.__class__.__name__)) +            assert field_name in ret, ( +                "Non-existant field '%s' specified in `read_only_fields` " +                "on serializer '%s'." % +                (field_name, self.__class__.__name__))              ret[field_name].read_only = True +         +        # Ensure that 'write_only_fields' is an iterable +        assert isinstance(self.opts.write_only_fields, (list, tuple)), '`write_only_fields` must be a list or tuple'  +         +        for field_name in self.opts.write_only_fields: +            assert field_name not in self.base_fields.keys(), ( +                "field '%s' on serializer '%s' specified in " +                "`write_only_fields`, but also added " +                "as an explicit field.  Remove it from `write_only_fields`." % +                (field_name, self.__class__.__name__)) +            assert field_name in ret, ( +                "Non-existant field '%s' specified in `write_only_fields` " +                "on serializer '%s'." % +                (field_name, self.__class__.__name__)) +            ret[field_name].write_only = True                      return ret @@ -829,7 +885,7 @@ class ModelSerializer(Serializer):          except KeyError:              return ModelField(model_field=model_field, **kwargs) -    def get_validation_exclusions(self): +    def get_validation_exclusions(self, instance=None):          """          Return a list of field names to exclude from model validation.          """ @@ -841,6 +897,7 @@ class ModelSerializer(Serializer):              field_name = field.source or field_name              if field_name in exclusions \                  and not field.read_only \ +                and (field.required or hasattr(instance, field_name)) \                  and not isinstance(field, Serializer):                  exclusions.remove(field_name)          return exclusions @@ -855,7 +912,7 @@ class ModelSerializer(Serializer):          the full_clean validation checking.          """          try: -            instance.full_clean(exclude=self.get_validation_exclusions()) +            instance.full_clean(exclude=self.get_validation_exclusions(instance))          except ValidationError as err:              self._errors = err.message_dict              return None @@ -883,7 +940,7 @@ class ModelSerializer(Serializer):                  m2m_data[field_name] = attrs.pop(field_name)          # Forward m2m relations -        for field in meta.many_to_many: +        for field in meta.many_to_many + meta.virtual_fields:              if field.name in attrs:                  m2m_data[field.name] = attrs.pop(field.name) @@ -979,6 +1036,7 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions):          super(HyperlinkedModelSerializerOptions, self).__init__(meta)          self.view_name = getattr(meta, 'view_name', None)          self.lookup_field = getattr(meta, 'lookup_field', None) +        self.url_field_name = getattr(meta, 'url_field_name', api_settings.URL_FIELD_NAME)  class HyperlinkedModelSerializer(ModelSerializer): @@ -997,13 +1055,13 @@ class HyperlinkedModelSerializer(ModelSerializer):          if self.opts.view_name is None:              self.opts.view_name = self._get_default_view_name(self.opts.model) -        if 'url' not in fields: +        if self.opts.url_field_name not in fields:              url_field = self._hyperlink_identify_field_class(                  view_name=self.opts.view_name,                  lookup_field=self.opts.lookup_field              )              ret = self._dict_class() -            ret['url'] = url_field +            ret[self.opts.url_field_name] = url_field              ret.update(fields)              fields = ret @@ -1039,7 +1097,7 @@ class HyperlinkedModelSerializer(ModelSerializer):          We need to override the default, to use the url as the identity.          """          try: -            return data.get('url', None) +            return data.get(self.opts.url_field_name, None)          except AttributeError:              return None diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 383de72e..189131f1 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -70,6 +70,10 @@ DEFAULTS = {      'PAGINATE_BY_PARAM': None,      'MAX_PAGINATE_BY': None, +    # Filtering +    'SEARCH_PARAM': 'search', +    'ORDERING_PARAM': 'ordering', +      # Authentication      'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',      'UNAUTHENTICATED_TOKEN': None, @@ -96,6 +100,7 @@ DEFAULTS = {      'URL_FORMAT_OVERRIDE': 'format',      'FORMAT_SUFFIX_KWARG': 'format', +    'URL_FIELD_NAME': 'url',      # Input and output formats      'DATE_INPUT_FORMATS': ( diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html index 42ede968..210741ed 100644 --- a/rest_framework/templates/rest_framework/base.html +++ b/rest_framework/templates/rest_framework/base.html @@ -34,7 +34,7 @@          <div class="navbar-inner">              <div class="container-fluid">                  <span href="/"> -                    {% block branding %}<a class='brand' href='http://django-rest-framework.org'>Django REST framework <span class="version">{{ version }}</span></a>{% endblock %} +                    {% block branding %}<a class='brand' rel="nofollow" href='http://www.django-rest-framework.org'>Django REST framework <span class="version">{{ version }}</span></a>{% endblock %}                  </span>                  <ul class="nav pull-right">                      {% block userlinks %} @@ -119,7 +119,7 @@              </div>              <div class="response-info">                  <pre class="prettyprint"><div class="meta nocode"><b>HTTP {{ response.status_code }} {{ response.status_text }}</b>{% autoescape off %} -{% for key, val in response.items %}<b>{{ key }}:</b> <span class="lit">{{ val|break_long_headers|urlize_quoted_links }}</span> +{% for key, val in response_headers.items %}<b>{{ key }}:</b> <span class="lit">{{ val|break_long_headers|urlize_quoted_links }}</span>  {% endfor %}  </div>{{ content|urlize_quoted_links }}</pre>{% endautoescape %}              </div> diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index 55f36149..a0f9c841 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -2,10 +2,12 @@ from __future__ import unicode_literals, absolute_import  from django import template  from django.core.urlresolvers import reverse, NoReverseMatch  from django.http import QueryDict -from django.utils.html import escape, smart_urlquote +from django.utils.encoding import iri_to_uri +from django.utils.html import escape  from django.utils.safestring import SafeData, mark_safe  from rest_framework.compat import urlparse, force_text, six -import re, string +from django.utils.html import smart_urlquote +import re  register = template.Library() @@ -61,7 +63,9 @@ def add_query_param(request, key, val):      """      Add a query parameter to the current request url, and return the new url.      """ -    return replace_query_param(request.get_full_path(), key, val) +    iri = request.get_full_path() +    uri = iri_to_uri(iri) +    return replace_query_param(uri, key, val)  @register.filter @@ -103,6 +107,17 @@ simple_url_2_re = re.compile(r'^www\.|^(?!http)\w[^@]+\.(com|edu|gov|int|mil|net  simple_email_re = re.compile(r'^\S+@\S+\.\S+$') +def smart_urlquote_wrapper(matched_url): +    """ +    Simple wrapper for smart_urlquote. ValueError("Invalid IPv6 URL") can +    be raised here, see issue #1386 +    """ +    try: +        return smart_urlquote(matched_url) +    except ValueError: +        return None + +  @register.filter  def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=True):      """ @@ -125,7 +140,6 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru      safe_input = isinstance(text, SafeData)      words = word_split_re.split(force_text(text))      for i, word in enumerate(words): -        match = None          if '.' in word or '@' in word or ':' in word:              # Deal with punctuation.              lead, middle, trail = '', word, '' @@ -147,9 +161,9 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru              url = None              nofollow_attr = ' rel="nofollow"' if nofollow else ''              if simple_url_re.match(middle): -                url = smart_urlquote(middle) +                url = smart_urlquote_wrapper(middle)              elif simple_url_2_re.match(middle): -                url = smart_urlquote('http://%s' % middle) +                url = smart_urlquote_wrapper('http://%s' % middle)              elif not ':' in middle and simple_email_re.match(middle):                  local, domain = middle.rsplit('@', 1)                  try: diff --git a/rest_framework/test.py b/rest_framework/test.py index 234d10a4..df5a5b3b 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -8,6 +8,7 @@ from django.conf import settings  from django.test.client import Client as DjangoClient  from django.test.client import ClientHandler  from django.test import testcases +from django.utils.http import urlencode  from rest_framework.settings import api_settings  from rest_framework.compat import RequestFactory as DjangoRequestFactory  from rest_framework.compat import force_bytes_or_smart_bytes, six @@ -71,6 +72,17 @@ class APIRequestFactory(DjangoRequestFactory):          return ret, content_type +    def get(self, path, data=None, **extra): +        r = { +            'QUERY_STRING': urlencode(data or {}, doseq=True), +        } +        # Fix to support old behavior where you have the arguments in the url +        # See #1461 +        if not data and '?' in path: +            r['QUERY_STRING'] = path.split('?')[1] +        r.update(extra) +        return self.generic('GET', path, **r) +      def post(self, path, data=None, format=None, content_type=None, **extra):          data, content_type = self._encode_data(data, format, content_type)          return self.generic('POST', path, data, content_type, **extra) diff --git a/rest_framework/tests/accounts/__init__.py b/rest_framework/tests/accounts/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/rest_framework/tests/accounts/__init__.py diff --git a/rest_framework/tests/accounts/models.py b/rest_framework/tests/accounts/models.py new file mode 100644 index 00000000..525e601b --- /dev/null +++ b/rest_framework/tests/accounts/models.py @@ -0,0 +1,8 @@ +from django.db import models + +from rest_framework.tests.users.models import User + + +class Account(models.Model): +    owner = models.ForeignKey(User, related_name='accounts_owned') +    admins = models.ManyToManyField(User, blank=True, null=True, related_name='accounts_administered') diff --git a/rest_framework/tests/accounts/serializers.py b/rest_framework/tests/accounts/serializers.py new file mode 100644 index 00000000..a27b9ca6 --- /dev/null +++ b/rest_framework/tests/accounts/serializers.py @@ -0,0 +1,11 @@ +from rest_framework import serializers + +from rest_framework.tests.accounts.models import Account +from rest_framework.tests.users.serializers import UserSerializer + + +class AccountSerializer(serializers.ModelSerializer): +    admins = UserSerializer(many=True) + +    class Meta: +        model = Account diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py index 32a726c0..6c8f2342 100644 --- a/rest_framework/tests/models.py +++ b/rest_framework/tests/models.py @@ -103,7 +103,7 @@ class BlogPostComment(RESTFrameworkModel):  class Album(RESTFrameworkModel):      title = models.CharField(max_length=100, unique=True) - +    ref = models.CharField(max_length=10, unique=True, null=True, blank=True)  class Photo(RESTFrameworkModel):      description = models.TextField() @@ -168,3 +168,10 @@ class NullableOneToOneSource(RESTFrameworkModel):  class BasicModelSerializer(serializers.ModelSerializer):      class Meta:          model = BasicModel + + +# Models to test filters +class FilterableItem(models.Model): +    text = models.CharField(max_length=100) +    decimal = models.DecimalField(max_digits=4, decimal_places=2) +    date = models.DateField() diff --git a/rest_framework/tests/records/__init__.py b/rest_framework/tests/records/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/rest_framework/tests/records/__init__.py diff --git a/rest_framework/tests/records/models.py b/rest_framework/tests/records/models.py new file mode 100644 index 00000000..76954807 --- /dev/null +++ b/rest_framework/tests/records/models.py @@ -0,0 +1,6 @@ +from django.db import models + + +class Record(models.Model): +    account = models.ForeignKey('accounts.Account', blank=True, null=True) +    owner = models.ForeignKey('users.User', blank=True, null=True) diff --git a/rest_framework/tests/serializers.py b/rest_framework/tests/serializers.py new file mode 100644 index 00000000..cc943c7d --- /dev/null +++ b/rest_framework/tests/serializers.py @@ -0,0 +1,8 @@ +from rest_framework import serializers + +from rest_framework.tests.models import NullableForeignKeySource + + +class NullableFKSourceSerializer(serializers.ModelSerializer): +    class Meta: +        model = NullableForeignKeySource diff --git a/rest_framework/tests/test_authentication.py b/rest_framework/tests/test_authentication.py index fb0bc694..6c14debb 100644 --- a/rest_framework/tests/test_authentication.py +++ b/rest_framework/tests/test_authentication.py @@ -4,6 +4,7 @@ from django.contrib.auth.models import User  from django.http import HttpResponse  from django.test import TestCase  from django.utils import unittest +from django.utils.http import urlencode  from rest_framework import HTTP_HEADER_ENCODING  from rest_framework import exceptions  from rest_framework import permissions @@ -19,7 +20,7 @@ from rest_framework.authentication import (      OAuth2Authentication  )  from rest_framework.authtoken.models import Token -from rest_framework.compat import oauth2_provider, oauth2_provider_models, oauth2_provider_scope +from rest_framework.compat import oauth2_provider, oauth2_provider_scope  from rest_framework.compat import oauth, oauth_provider  from rest_framework.test import APIRequestFactory, APIClient  from rest_framework.views import APIView @@ -53,10 +54,14 @@ urlpatterns = patterns('',          permission_classes=[permissions.TokenHasReadWriteScope]))  ) +class OAuth2AuthenticationDebug(OAuth2Authentication): +    allow_query_params_token = True +  if oauth2_provider is not None:      urlpatterns += patterns('',          url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')),          url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])), +        url(r'^oauth2-test-debug/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationDebug])),          url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication],              permission_classes=[permissions.TokenHasReadWriteScope])),      ) @@ -488,7 +493,7 @@ class OAuth2Tests(TestCase):          self.ACCESS_TOKEN = "access_token"          self.REFRESH_TOKEN = "refresh_token" -        self.oauth2_client = oauth2_provider_models.Client.objects.create( +        self.oauth2_client = oauth2_provider.oauth2.models.Client.objects.create(                  client_id=self.CLIENT_ID,                  client_secret=self.CLIENT_SECRET,                  redirect_uri='', @@ -497,12 +502,12 @@ class OAuth2Tests(TestCase):                  user=None,              ) -        self.access_token = oauth2_provider_models.AccessToken.objects.create( +        self.access_token = oauth2_provider.oauth2.models.AccessToken.objects.create(                  token=self.ACCESS_TOKEN,                  client=self.oauth2_client,                  user=self.user,              ) -        self.refresh_token = oauth2_provider_models.RefreshToken.objects.create( +        self.refresh_token = oauth2_provider.oauth2.models.RefreshToken.objects.create(                  user=self.user,                  access_token=self.access_token,                  client=self.oauth2_client @@ -546,6 +551,27 @@ class OAuth2Tests(TestCase):          self.assertEqual(response.status_code, 200)      @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') +    def test_post_form_passing_auth_url_transport(self): +        """Ensure GETing form over OAuth with correct client credentials in form data succeed""" +        response = self.csrf_client.post('/oauth2-test/', +                data={'access_token': self.access_token.token}) +        self.assertEqual(response.status_code, 200) + +    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') +    def test_get_form_passing_auth_url_transport(self): +        """Ensure GETing form over OAuth with correct client credentials in query succeed when DEBUG is True""" +        query = urlencode({'access_token': self.access_token.token}) +        response = self.csrf_client.get('/oauth2-test-debug/?%s' % query) +        self.assertEqual(response.status_code, 200) + +    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') +    def test_get_form_failing_auth_url_transport(self): +        """Ensure GETing form over OAuth with correct client credentials in query fails when DEBUG is False""" +        query = urlencode({'access_token': self.access_token.token}) +        response = self.csrf_client.get('/oauth2-test/?%s' % query) +        self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) + +    @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')      def test_post_form_passing_auth(self):          """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""          auth = self._create_authorization_header() diff --git a/rest_framework/tests/test_fields.py b/rest_framework/tests/test_fields.py index 5c96bce9..e127feef 100644 --- a/rest_framework/tests/test_fields.py +++ b/rest_framework/tests/test_fields.py @@ -860,7 +860,9 @@ class SlugFieldTests(TestCase):  class URLFieldTests(TestCase):      """ -    Tests for URLField attribute values +    Tests for URLField attribute values. + +    (Includes test for #1210, checking that validators can be overridden.)      """      class URLFieldModel(RESTFrameworkModel): @@ -902,6 +904,11 @@ class URLFieldTests(TestCase):          self.assertEqual(getattr(serializer.fields['url_field'],                           'max_length'), 20) +    def test_validators_can_be_overridden(self): +        url_field = serializers.URLField(validators=[]) +        validators = url_field.validators +        self.assertEqual([], validators, 'Passing `validators` kwarg should have overridden default validators') +  class FieldMetadata(TestCase):      def setUp(self): diff --git a/rest_framework/tests/test_filters.py b/rest_framework/tests/test_filters.py index 8a03a077..2aa6f81a 100644 --- a/rest_framework/tests/test_filters.py +++ b/rest_framework/tests/test_filters.py @@ -1,25 +1,21 @@  from __future__ import unicode_literals  import datetime  from decimal import Decimal -from django.conf.urls import patterns, url  from django.db import models  from django.core.urlresolvers import reverse  from django.test import TestCase  from django.utils import unittest +from django.conf.urls import patterns, url  from rest_framework import generics, serializers, status, filters  from rest_framework.compat import django_filters  from rest_framework.test import APIRequestFactory  from rest_framework.tests.models import BasicModel +from .models import FilterableItem +from .utils import temporary_setting  factory = APIRequestFactory() -class FilterableItem(models.Model): -    text = models.CharField(max_length=100) -    decimal = models.DecimalField(max_digits=4, decimal_places=2) -    date = models.DateField() - -  if django_filters:      # Basic filter on a list view.      class FilterFieldsRootView(generics.ListCreateAPIView): @@ -129,7 +125,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):          # Tests that the decimal filter works.          search_decimal = Decimal('2.25') -        request = factory.get('/?decimal=%s' % search_decimal) +        request = factory.get('/', {'decimal': '%s' % search_decimal})          response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK)          expected_data = [f for f in self.data if f['decimal'] == search_decimal] @@ -137,7 +133,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):          # Tests that the date filter works.          search_date = datetime.date(2012, 9, 22) -        request = factory.get('/?date=%s' % search_date)  # search_date str: '2012-09-22' +        request = factory.get('/', {'date': '%s' % search_date})  # search_date str: '2012-09-22'          response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK)          expected_data = [f for f in self.data if f['date'] == search_date] @@ -152,7 +148,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):          # Tests that the decimal filter works.          search_decimal = Decimal('2.25') -        request = factory.get('/?decimal=%s' % search_decimal) +        request = factory.get('/', {'decimal': '%s' % search_decimal})          response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK)          expected_data = [f for f in self.data if f['decimal'] == search_decimal] @@ -185,7 +181,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):          # Tests that the decimal filter set with 'lt' in the filter class works.          search_decimal = Decimal('4.25') -        request = factory.get('/?decimal=%s' % search_decimal) +        request = factory.get('/', {'decimal': '%s' % search_decimal})          response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK)          expected_data = [f for f in self.data if f['decimal'] < search_decimal] @@ -193,7 +189,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):          # Tests that the date filter set with 'gt' in the filter class works.          search_date = datetime.date(2012, 10, 2) -        request = factory.get('/?date=%s' % search_date)  # search_date str: '2012-10-02' +        request = factory.get('/', {'date': '%s' % search_date})  # search_date str: '2012-10-02'          response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK)          expected_data = [f for f in self.data if f['date'] > search_date] @@ -201,7 +197,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):          # Tests that the text filter set with 'icontains' in the filter class works.          search_text = 'ff' -        request = factory.get('/?text=%s' % search_text) +        request = factory.get('/', {'text': '%s' % search_text})          response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK)          expected_data = [f for f in self.data if search_text in f['text'].lower()] @@ -210,7 +206,10 @@ class IntegrationTestFiltering(CommonFilteringTestCase):          # Tests that multiple filters works.          search_decimal = Decimal('5.25')          search_date = datetime.date(2012, 10, 2) -        request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date)) +        request = factory.get('/', { +            'decimal': '%s' % (search_decimal,), +            'date': '%s' % (search_date,) +        })          response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK)          expected_data = [f for f in self.data if f['date'] > search_date and @@ -235,7 +234,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):          view = FilterFieldsRootView.as_view()          search_integer = 10 -        request = factory.get('/?integer=%s' % search_integer) +        request = factory.get('/', {'integer': '%s' % search_integer})          response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -266,14 +265,18 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase):          # Tests that the decimal filter set that should fail.          search_decimal = Decimal('4.25')          high_item = self.objects.filter(decimal__gt=search_decimal)[0] -        response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(high_item), param=search_decimal)) +        response = self.client.get( +            '{url}'.format(url=self._get_url(high_item)), +            {'decimal': '{param}'.format(param=search_decimal)})          self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)          # Tests that the decimal filter set that should succeed.          search_decimal = Decimal('4.25')          low_item = self.objects.filter(decimal__lt=search_decimal)[0]          low_item_data = self._serialize_object(low_item) -        response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal)) +        response = self.client.get( +            '{url}'.format(url=self._get_url(low_item)), +            {'decimal': '{param}'.format(param=search_decimal)})          self.assertEqual(response.status_code, status.HTTP_200_OK)          self.assertEqual(response.data, low_item_data) @@ -282,7 +285,11 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase):          search_date = datetime.date(2012, 10, 2)          valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0]          valid_item_data = self._serialize_object(valid_item) -        response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date)) +        response = self.client.get( +            '{url}'.format(url=self._get_url(valid_item)), { +                'decimal': '{decimal}'.format(decimal=search_decimal), +                'date': '{date}'.format(date=search_date) +            })          self.assertEqual(response.status_code, status.HTTP_200_OK)          self.assertEqual(response.data, valid_item_data) @@ -316,7 +323,7 @@ class SearchFilterTests(TestCase):              search_fields = ('title', 'text')          view = SearchListView.as_view() -        request = factory.get('?search=b') +        request = factory.get('/', {'search': 'b'})          response = view(request)          self.assertEqual(              response.data, @@ -333,7 +340,7 @@ class SearchFilterTests(TestCase):              search_fields = ('=title', 'text')          view = SearchListView.as_view() -        request = factory.get('?search=zzz') +        request = factory.get('/', {'search': 'zzz'})          response = view(request)          self.assertEqual(              response.data, @@ -349,7 +356,7 @@ class SearchFilterTests(TestCase):              search_fields = ('title', '^text')          view = SearchListView.as_view() -        request = factory.get('?search=b') +        request = factory.get('/', {'search': 'b'})          response = view(request)          self.assertEqual(              response.data, @@ -358,6 +365,24 @@ class SearchFilterTests(TestCase):              ]          ) +    def test_search_with_nonstandard_search_param(self): +        with temporary_setting('SEARCH_PARAM', 'query', module=filters): +            class SearchListView(generics.ListAPIView): +                model = SearchFilterModel +                filter_backends = (filters.SearchFilter,) +                search_fields = ('title', 'text') + +            view = SearchListView.as_view() +            request = factory.get('/', {'query': 'b'}) +            response = view(request) +            self.assertEqual( +                response.data, +                [ +                    {'id': 1, 'title': 'z', 'text': 'abc'}, +                    {'id': 2, 'title': 'zz', 'text': 'bcd'} +                ] +            ) +  class OrdringFilterModel(models.Model):      title = models.CharField(max_length=20) @@ -369,7 +394,6 @@ class OrderingFilterRelatedModel(models.Model):                                         related_name="relateds") -  class OrderingFilterTests(TestCase):      def setUp(self):          # Sequence of title/text is: @@ -395,9 +419,10 @@ class OrderingFilterTests(TestCase):              model = OrdringFilterModel              filter_backends = (filters.OrderingFilter,)              ordering = ('title',) +            ordering_fields = ('text',)          view = OrderingListView.as_view() -        request = factory.get('?ordering=text') +        request = factory.get('/', {'ordering': 'text'})          response = view(request)          self.assertEqual(              response.data, @@ -413,9 +438,10 @@ class OrderingFilterTests(TestCase):              model = OrdringFilterModel              filter_backends = (filters.OrderingFilter,)              ordering = ('title',) +            ordering_fields = ('text',)          view = OrderingListView.as_view() -        request = factory.get('?ordering=-text') +        request = factory.get('/', {'ordering': '-text'})          response = view(request)          self.assertEqual(              response.data, @@ -431,9 +457,10 @@ class OrderingFilterTests(TestCase):              model = OrdringFilterModel              filter_backends = (filters.OrderingFilter,)              ordering = ('title',) +            ordering_fields = ('text',)          view = OrderingListView.as_view() -        request = factory.get('?ordering=foobar') +        request = factory.get('/', {'ordering': 'foobar'})          response = view(request)          self.assertEqual(              response.data, @@ -449,6 +476,7 @@ class OrderingFilterTests(TestCase):              model = OrdringFilterModel              filter_backends = (filters.OrderingFilter,)              ordering = ('title',) +            oredering_fields = ('text',)          view = OrderingListView.as_view()          request = factory.get('') @@ -467,6 +495,7 @@ class OrderingFilterTests(TestCase):              model = OrdringFilterModel              filter_backends = (filters.OrderingFilter,)              ordering = 'title' +            ordering_fields = ('text',)          view = OrderingListView.as_view()          request = factory.get('') @@ -495,11 +524,12 @@ class OrderingFilterTests(TestCase):              model = OrdringFilterModel              filter_backends = (filters.OrderingFilter,)              ordering = 'title' +            ordering_fields = '__all__'              queryset = OrdringFilterModel.objects.all().annotate(                  models.Count("relateds"))          view = OrderingListView.as_view() -        request = factory.get('?ordering=relateds__count') +        request = factory.get('/', {'ordering': 'relateds__count'})          response = view(request)          self.assertEqual(              response.data, @@ -510,5 +540,122 @@ class OrderingFilterTests(TestCase):              ]          ) +    def test_ordering_with_nonstandard_ordering_param(self): +        with temporary_setting('ORDERING_PARAM', 'order', filters): +            class OrderingListView(generics.ListAPIView): +                model = OrdringFilterModel +                filter_backends = (filters.OrderingFilter,) +                ordering = ('title',) +                ordering_fields = ('text',) + +            view = OrderingListView.as_view() +            request = factory.get('/', {'order': 'text'}) +            response = view(request) +            self.assertEqual( +                response.data, +                [ +                    {'id': 1, 'title': 'zyx', 'text': 'abc'}, +                    {'id': 2, 'title': 'yxw', 'text': 'bcd'}, +                    {'id': 3, 'title': 'xwv', 'text': 'cde'}, +                ] +            ) + + +class SensitiveOrderingFilterModel(models.Model): +    username = models.CharField(max_length=20) +    password = models.CharField(max_length=100) + + +# Three different styles of serializer. +# All should allow ordering by username, but not by password. +class SensitiveDataSerializer1(serializers.ModelSerializer): +    username = serializers.CharField() +    class Meta: +        model = SensitiveOrderingFilterModel +        fields = ('id', 'username') + +class SensitiveDataSerializer2(serializers.ModelSerializer): +    username = serializers.CharField() +    password = serializers.CharField(write_only=True) + +    class Meta: +        model = SensitiveOrderingFilterModel +        fields = ('id', 'username', 'password') + + +class SensitiveDataSerializer3(serializers.ModelSerializer): +    user = serializers.CharField(source='username') + +    class Meta: +        model = SensitiveOrderingFilterModel +        fields = ('id', 'user') + + +class SensitiveOrderingFilterTests(TestCase): +    def setUp(self): +        for idx in range(3): +            username = {0: 'userA', 1: 'userB', 2: 'userC'}[idx] +            password = {0: 'passA', 1: 'passC', 2: 'passB'}[idx] +            SensitiveOrderingFilterModel(username=username, password=password).save() + +    def test_order_by_serializer_fields(self): +        for serializer_cls in [ +            SensitiveDataSerializer1, +            SensitiveDataSerializer2, +            SensitiveDataSerializer3 +        ]: +            class OrderingListView(generics.ListAPIView): +                queryset = SensitiveOrderingFilterModel.objects.all().order_by('username') +                filter_backends = (filters.OrderingFilter,) +                serializer_class = serializer_cls + +            view = OrderingListView.as_view() +            request = factory.get('/', {'ordering': '-username'}) +            response = view(request) + +            if serializer_cls == SensitiveDataSerializer3: +                username_field = 'user' +            else: +                username_field = 'username' + +            # Note: Inverse username ordering correctly applied. +            self.assertEqual( +                response.data, +                [ +                    {'id': 3, username_field: 'userC'}, +                    {'id': 2, username_field: 'userB'}, +                    {'id': 1, username_field: 'userA'}, +                ] +            ) + +    def test_cannot_order_by_non_serializer_fields(self): +        for serializer_cls in [ +            SensitiveDataSerializer1, +            SensitiveDataSerializer2, +            SensitiveDataSerializer3 +        ]: +            class OrderingListView(generics.ListAPIView): +                queryset = SensitiveOrderingFilterModel.objects.all().order_by('username') +                filter_backends = (filters.OrderingFilter,) +                serializer_class = serializer_cls + +            view = OrderingListView.as_view() +            request = factory.get('/', {'ordering': 'password'}) +            response = view(request) + +            if serializer_cls == SensitiveDataSerializer3: +                username_field = 'user' +            else: +                username_field = 'username' + +            # Note: The passwords are not in order.  Default ordering is used. +            self.assertEqual( +                response.data, +                [ +                    {'id': 1, username_field: 'userA'}, # PassB +                    {'id': 2, username_field: 'userB'}, # PassC +                    {'id': 3, username_field: 'userC'}, # PassA +                ] +            ) diff --git a/rest_framework/tests/test_genericrelations.py b/rest_framework/tests/test_genericrelations.py index c38bfb9f..fa09c9e6 100644 --- a/rest_framework/tests/test_genericrelations.py +++ b/rest_framework/tests/test_genericrelations.py @@ -4,8 +4,10 @@ from django.contrib.contenttypes.generic import GenericRelation, GenericForeignK  from django.db import models  from django.test import TestCase  from rest_framework import serializers +from rest_framework.compat import python_2_unicode_compatible +@python_2_unicode_compatible  class Tag(models.Model):      """      Tags have a descriptive slug, and are attached to an arbitrary object. @@ -15,10 +17,11 @@ class Tag(models.Model):      object_id = models.PositiveIntegerField()      tagged_item = GenericForeignKey('content_type', 'object_id') -    def __unicode__(self): +    def __str__(self):          return self.tag +@python_2_unicode_compatible  class Bookmark(models.Model):      """      A URL bookmark that may have multiple tags attached. @@ -26,10 +29,11 @@ class Bookmark(models.Model):      url = models.URLField()      tags = GenericRelation(Tag) -    def __unicode__(self): +    def __str__(self):          return 'Bookmark: %s' % self.url +@python_2_unicode_compatible  class Note(models.Model):      """      A textual note that may have multiple tags attached. @@ -37,7 +41,7 @@ class Note(models.Model):      text = models.TextField()      tags = GenericRelation(Tag) -    def __unicode__(self): +    def __str__(self):          return 'Note: %s' % self.text @@ -69,6 +73,35 @@ class TestGenericRelations(TestCase):          }          self.assertEqual(serializer.data, expected) +    def test_generic_nested_relation(self): +        """ +        Test saving a GenericRelation field via a nested serializer. +        """ + +        class TagSerializer(serializers.ModelSerializer): +            class Meta: +                model = Tag +                exclude = ('content_type', 'object_id') + +        class BookmarkSerializer(serializers.ModelSerializer): +            tags = TagSerializer() + +            class Meta: +                model = Bookmark +                exclude = ('id',) + +        data = { +            'url': 'https://docs.djangoproject.com/', +            'tags': [ +                {'tag': 'contenttypes'}, +                {'tag': 'genericrelations'}, +            ] +        } +        serializer = BookmarkSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        serializer.save() +        self.assertEqual(serializer.object.tags.count(), 2) +      def test_generic_fk(self):          """          Test a relationship that spans a GenericForeignKey field. diff --git a/rest_framework/tests/test_htmlrenderer.py b/rest_framework/tests/test_htmlrenderer.py index 6c570dfd..1cbca04c 100644 --- a/rest_framework/tests/test_htmlrenderer.py +++ b/rest_framework/tests/test_htmlrenderer.py @@ -50,7 +50,7 @@ class TemplateHTMLRendererTests(TestCase):          """          self.get_template = django.template.loader.get_template -        def get_template(template_name): +        def get_template(template_name, dirs=None):              if template_name == 'example.html':                  return Template("example: {{ object }}")              raise TemplateDoesNotExist(template_name) @@ -108,11 +108,13 @@ class TemplateHTMLRendererExceptionTests(TestCase):      def test_not_found_html_view_with_template(self):          response = self.client.get('/not_found')          self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) -        self.assertEqual(response.content, six.b("404: Not found")) +        self.assertTrue(response.content in ( +            six.b("404: Not found"), six.b("404 Not Found")))          self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')      def test_permission_denied_html_view_with_template(self):          response = self.client.get('/permission_denied')          self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) -        self.assertEqual(response.content, six.b("403: Permission denied")) +        self.assertTrue(response.content in ( +            six.b("403: Permission denied"), six.b("403 Forbidden")))          self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') diff --git a/rest_framework/tests/test_hyperlinkedserializers.py b/rest_framework/tests/test_hyperlinkedserializers.py index ea7f70f2..5fb1b47e 100644 --- a/rest_framework/tests/test_hyperlinkedserializers.py +++ b/rest_framework/tests/test_hyperlinkedserializers.py @@ -1,8 +1,9 @@  from __future__ import unicode_literals  import json -from django.conf.urls import patterns, url  from django.test import TestCase  from rest_framework import generics, status, serializers +from django.conf.urls import patterns, url +from rest_framework.settings import api_settings  from rest_framework.test import APIRequestFactory  from rest_framework.tests.models import (      Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, @@ -331,3 +332,48 @@ class TestOverriddenURLField(TestCase):              serializer.data,              {'title': 'New blog post', 'url': 'foo bar'}          ) + + +class TestURLFieldNameBySettings(TestCase): +    urls = 'rest_framework.tests.test_hyperlinkedserializers' + +    def setUp(self): +        self.saved_url_field_name = api_settings.URL_FIELD_NAME +        api_settings.URL_FIELD_NAME = 'global_url_field' + +        class Serializer(serializers.HyperlinkedModelSerializer): + +            class Meta: +                model = BlogPost +                fields = ('title', api_settings.URL_FIELD_NAME) + +        self.Serializer = Serializer +        self.obj = BlogPost.objects.create(title="New blog post") + +    def tearDown(self): +        api_settings.URL_FIELD_NAME = self.saved_url_field_name + +    def test_overridden_url_field_name(self): +        request = factory.get('/posts/') +        serializer = self.Serializer(self.obj, context={'request': request}) +        self.assertIn(api_settings.URL_FIELD_NAME, serializer.data) + + +class TestURLFieldNameByOptions(TestCase): +    urls = 'rest_framework.tests.test_hyperlinkedserializers' + +    def setUp(self): +        class Serializer(serializers.HyperlinkedModelSerializer): + +            class Meta: +                model = BlogPost +                fields = ('title', 'serializer_url_field') +                url_field_name = 'serializer_url_field' + +        self.Serializer = Serializer +        self.obj = BlogPost.objects.create(title="New blog post") + +    def test_overridden_url_field_name(self): +        request = factory.get('/posts/') +        serializer = self.Serializer(self.obj, context={'request': request}) +        self.assertIn(self.Serializer.Meta.url_field_name, serializer.data) diff --git a/rest_framework/tests/test_nullable_fields.py b/rest_framework/tests/test_nullable_fields.py new file mode 100644 index 00000000..4812530e --- /dev/null +++ b/rest_framework/tests/test_nullable_fields.py @@ -0,0 +1,30 @@ +from django.core.urlresolvers import reverse + +from django.conf.urls import patterns, url +from rest_framework.test import APITestCase +from rest_framework.tests.models import NullableForeignKeySource +from rest_framework.tests.serializers import NullableFKSourceSerializer +from rest_framework.tests.views import NullableFKSourceDetail + + +urlpatterns = patterns( +    '', +    url(r'^objects/(?P<pk>\d+)/$', NullableFKSourceDetail.as_view(), name='object-detail'), +) + + +class NullableForeignKeyTests(APITestCase): +    """ +    DRF should be able to handle nullable foreign keys when a test +    Client POST/PUT request is made with its own serialized object. +    """ +    urls = 'rest_framework.tests.test_nullable_fields' + +    def test_updating_object_with_null_fk(self): +        obj = NullableForeignKeySource(name='example', target=None) +        obj.save() +        serialized_data = NullableFKSourceSerializer(obj).data + +        response = self.client.put(reverse('object-detail', args=[obj.pk]), serialized_data) + +        self.assertEqual(response.data, serialized_data) diff --git a/rest_framework/tests/test_pagination.py b/rest_framework/tests/test_pagination.py index cadb515f..24c1ba39 100644 --- a/rest_framework/tests/test_pagination.py +++ b/rest_framework/tests/test_pagination.py @@ -9,14 +9,18 @@ from rest_framework import generics, status, pagination, filters, serializers  from rest_framework.compat import django_filters  from rest_framework.test import APIRequestFactory  from rest_framework.tests.models import BasicModel +from .models import FilterableItem  factory = APIRequestFactory() +# Helper function to split arguments out of an url +def split_arguments_from_url(url): +    if '?' not in url: +        return url -class FilterableItem(models.Model): -    text = models.CharField(max_length=100) -    decimal = models.DecimalField(max_digits=4, decimal_places=2) -    date = models.DateField() +    path, args = url.split('?') +    args = dict(r.split('=') for r in args.split('&')) +    return path, args  class RootView(generics.ListCreateAPIView): @@ -84,7 +88,7 @@ class IntegrationTestPagination(TestCase):          self.assertNotEqual(response.data['next'], None)          self.assertEqual(response.data['previous'], None) -        request = factory.get(response.data['next']) +        request = factory.get(*split_arguments_from_url(response.data['next']))          with self.assertNumQueries(2):              response = self.view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -93,7 +97,7 @@ class IntegrationTestPagination(TestCase):          self.assertNotEqual(response.data['next'], None)          self.assertNotEqual(response.data['previous'], None) -        request = factory.get(response.data['next']) +        request = factory.get(*split_arguments_from_url(response.data['next']))          with self.assertNumQueries(2):              response = self.view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -146,7 +150,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):          EXPECTED_NUM_QUERIES = 2 -        request = factory.get('/?decimal=15.20') +        request = factory.get('/', {'decimal': '15.20'})          with self.assertNumQueries(EXPECTED_NUM_QUERIES):              response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -155,7 +159,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):          self.assertNotEqual(response.data['next'], None)          self.assertEqual(response.data['previous'], None) -        request = factory.get(response.data['next']) +        request = factory.get(*split_arguments_from_url(response.data['next']))          with self.assertNumQueries(EXPECTED_NUM_QUERIES):              response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -164,7 +168,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):          self.assertEqual(response.data['next'], None)          self.assertNotEqual(response.data['previous'], None) -        request = factory.get(response.data['previous']) +        request = factory.get(*split_arguments_from_url(response.data['previous']))          with self.assertNumQueries(EXPECTED_NUM_QUERIES):              response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -191,7 +195,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):          view = BasicFilterFieldsRootView.as_view() -        request = factory.get('/?decimal=15.20') +        request = factory.get('/', {'decimal': '15.20'})          with self.assertNumQueries(2):              response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -200,7 +204,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):          self.assertNotEqual(response.data['next'], None)          self.assertEqual(response.data['previous'], None) -        request = factory.get(response.data['next']) +        request = factory.get(*split_arguments_from_url(response.data['next']))          with self.assertNumQueries(2):              response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -209,7 +213,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):          self.assertEqual(response.data['next'], None)          self.assertNotEqual(response.data['previous'], None) -        request = factory.get(response.data['previous']) +        request = factory.get(*split_arguments_from_url(response.data['previous']))          with self.assertNumQueries(2):              response = view(request).render()          self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -317,7 +321,7 @@ class TestCustomPaginateByParam(TestCase):          """          If paginate_by_param is set, the new kwarg should limit per view requests.          """ -        request = factory.get('/?page_size=5') +        request = factory.get('/', {'page_size': 5})          response = self.view(request).render()          self.assertEqual(response.data['count'], 13)          self.assertEqual(response.data['results'], self.data[:5]) @@ -345,7 +349,7 @@ class TestMaxPaginateByParam(TestCase):          """          If max_paginate_by is set, it should limit page size for the view.          """ -        request = factory.get('/?page_size=10') +        request = factory.get('/', data={'page_size': 10})          response = self.view(request).render()          self.assertEqual(response.data['count'], 13)          self.assertEqual(response.data['results'], self.data[:5]) diff --git a/rest_framework/tests/test_relations.py b/rest_framework/tests/test_relations.py index d19219c9..37ac826b 100644 --- a/rest_framework/tests/test_relations.py +++ b/rest_framework/tests/test_relations.py @@ -2,8 +2,10 @@  General tests for relational fields.  """  from __future__ import unicode_literals +from django import get_version  from django.db import models  from django.test import TestCase +from django.utils import unittest  from rest_framework import serializers  from rest_framework.tests.models import BlogPost @@ -98,3 +100,45 @@ class RelatedFieldSourceTests(TestCase):          obj = ClassWithQuerysetMethod()          value = field.field_to_native(obj, 'field_name')          self.assertEqual(value, ['BlogPost object']) + +    # Regression for #1129 +    def test_exception_for_incorect_fk(self): +        """ +        Check that the exception message are correct if the source field +        doesn't exist. +        """ +        from rest_framework.tests.models import ManyToManySource +        class Meta: +            model = ManyToManySource +        attrs = { +            'name': serializers.SlugRelatedField( +                slug_field='name', source='banzai'), +            'Meta': Meta, +        } + +        TestSerializer = type(str('TestSerializer'), +            (serializers.ModelSerializer,), attrs) +        with self.assertRaises(AttributeError): +            TestSerializer(data={'name': 'foo'}) + +@unittest.skipIf(get_version() < '1.6.0', 'Upstream behaviour changed in v1.6') +class RelatedFieldChoicesTests(TestCase): +    """ +    Tests for #1408 "Web browseable API doesn't have blank option on drop down list box" +    https://github.com/tomchristie/django-rest-framework/issues/1408 +    """ +    def test_blank_option_is_added_to_choice_if_required_equals_false(self): +        """ + +        """ +        post = BlogPost(title="Checking blank option is added") +        post.save() + +        queryset = BlogPost.objects.all() +        field = serializers.RelatedField(required=False, queryset=queryset) + +        choice_count = BlogPost.objects.count() +        widget_count = len(field.widget.choices) + +        self.assertEqual(widget_count, choice_count + 1, 'BLANK_CHOICE_DASH option should have been added') + diff --git a/rest_framework/tests/test_relations_nested.py b/rest_framework/tests/test_relations_nested.py index d393b0c3..4d9da489 100644 --- a/rest_framework/tests/test_relations_nested.py +++ b/rest_framework/tests/test_relations_nested.py @@ -3,9 +3,7 @@ from django.db import models  from django.test import TestCase  from rest_framework import serializers - -class OneToOneTarget(models.Model): -    name = models.CharField(max_length=100) +from .models import OneToOneTarget  class OneToOneSource(models.Model): diff --git a/rest_framework/tests/test_renderers.py b/rest_framework/tests/test_renderers.py index 9cb68233..460c02a9 100644 --- a/rest_framework/tests/test_renderers.py +++ b/rest_framework/tests/test_renderers.py @@ -4,6 +4,7 @@ from __future__ import unicode_literals  from decimal import Decimal  from django.conf.urls import patterns, url, include  from django.core.cache import cache +from django.db import models  from django.test import TestCase  from django.utils import unittest  from django.utils.translation import ugettext_lazy as _ @@ -35,6 +36,10 @@ expected_results = [  ] +class DummyTestModel(models.Model): +    name = models.CharField(max_length=42, default='') + +  class BasicRendererTests(TestCase):      def test_expected_results(self):          for value, renderer_cls, expected in expected_results: @@ -252,6 +257,18 @@ class RendererEndToEndTests(TestCase):          self.assertEqual(resp.get('Content-Type', None), None)          self.assertEqual(resp.status_code, status.HTTP_204_NO_CONTENT) +    def test_contains_headers_of_api_response(self): +        """ +        Issue #1437 + +        Test we display the headers of the API response and not those from the +        HTML response +        """ +        resp = self.client.get('/html1') +        self.assertContains(resp, '>GET, HEAD, OPTIONS<') +        self.assertContains(resp, '>application/json<') +        self.assertNotContains(resp, '>text/html; charset=utf-8<') +  _flat_repr = '{"foo": ["bar", "baz"]}'  _indented_repr = '{\n  "foo": [\n    "bar",\n    "baz"\n  ]\n}' @@ -277,6 +294,20 @@ class JSONRendererTests(TestCase):          ret = JSONRenderer().render(_('test'))          self.assertEqual(ret, b'"test"') +    def test_render_queryset_values(self): +        o = DummyTestModel.objects.create(name='dummy') +        qs = DummyTestModel.objects.values('id', 'name') +        ret = JSONRenderer().render(qs) +        data = json.loads(ret.decode('utf-8')) +        self.assertEquals(data, [{'id': o.id, 'name': o.name}]) + +    def test_render_queryset_values_list(self): +        o = DummyTestModel.objects.create(name='dummy') +        qs = DummyTestModel.objects.values_list('id', 'name') +        ret = JSONRenderer().render(qs) +        data = json.loads(ret.decode('utf-8')) +        self.assertEquals(data, [[o.id, o.name]]) +      def test_render_dict_abc_obj(self):          class Dict(MutableMapping):              def __init__(self): @@ -583,6 +614,10 @@ class CacheRenderTest(TestCase):          method = getattr(self.client, http_method)          resp = method(url)          del resp.client, resp.request +        try: +            del resp.wsgi_request +        except AttributeError: +            pass          return resp      def test_obj_pickling(self): diff --git a/rest_framework/tests/test_request.py b/rest_framework/tests/test_request.py index f07c31a3..e0da5fd4 100644 --- a/rest_framework/tests/test_request.py +++ b/rest_framework/tests/test_request.py @@ -68,6 +68,9 @@ class TestMethodOverloading(TestCase):          request = Request(factory.post('/', {'foo': 'bar'}, HTTP_X_HTTP_METHOD_OVERRIDE='DELETE'))          self.assertEqual(request.method, 'DELETE') +        request = Request(factory.get('/', {'foo': 'bar'}, HTTP_X_HTTP_METHOD_OVERRIDE='DELETE')) +        self.assertEqual(request.method, 'DELETE') +  class TestContentParsing(TestCase):      def test_standard_behaviour_determines_no_content_GET(self): diff --git a/rest_framework/tests/test_serializer.py b/rest_framework/tests/test_serializer.py index 6d9b85ee..a09bf6f5 100644 --- a/rest_framework/tests/test_serializer.py +++ b/rest_framework/tests/test_serializer.py @@ -3,6 +3,7 @@ from __future__ import unicode_literals  from django.db import models  from django.db.models.fields import BLANK_CHOICE_DASH  from django.test import TestCase +from django.utils import unittest  from django.utils.datastructures import MultiValueDict  from django.utils.translation import ugettext_lazy as _  from rest_framework import serializers, fields, relations @@ -12,6 +13,31 @@ from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, Acti  from rest_framework.tests.models import BasicModelSerializer  import datetime  import pickle +try: +    import PIL +except: +    PIL = None + + +if PIL is not None: +    class AMOAFModel(RESTFrameworkModel): +        char_field = models.CharField(max_length=1024, blank=True) +        comma_separated_integer_field = models.CommaSeparatedIntegerField(max_length=1024, blank=True) +        decimal_field = models.DecimalField(max_digits=64, decimal_places=32, blank=True) +        email_field = models.EmailField(max_length=1024, blank=True) +        file_field = models.FileField(upload_to='test', max_length=1024, blank=True) +        image_field = models.ImageField(upload_to='test', max_length=1024, blank=True) +        slug_field = models.SlugField(max_length=1024, blank=True) +        url_field = models.URLField(max_length=1024, blank=True) + +    class DVOAFModel(RESTFrameworkModel): +        positive_integer_field = models.PositiveIntegerField(blank=True) +        positive_small_integer_field = models.PositiveSmallIntegerField(blank=True) +        email_field = models.EmailField(blank=True) +        file_field = models.FileField(upload_to='test', blank=True) +        image_field = models.ImageField(upload_to='test', blank=True) +        slug_field = models.SlugField(blank=True) +        url_field = models.URLField(blank=True)  class SubComment(object): @@ -71,6 +97,15 @@ class ActionItemSerializer(serializers.ModelSerializer):      class Meta:          model = ActionItem +class ActionItemSerializerOptionalFields(serializers.ModelSerializer): +    """ +    Intended to test that fields with `required=False` are excluded from validation. +    """ +    title = serializers.CharField(required=False) + +    class Meta: +        model = ActionItem +        fields = ('title',)  class ActionItemSerializerCustomRestore(serializers.ModelSerializer): @@ -132,7 +167,7 @@ class AlbumsSerializer(serializers.ModelSerializer):      class Meta:          model = Album -        fields = ['title']  # lists are also valid options +        fields = ['title', 'ref']  # lists are also valid options  class PositiveIntegerAsChoiceSerializer(serializers.ModelSerializer): @@ -288,7 +323,13 @@ class BasicTests(TestCase):          serializer.save()          self.assertIsNotNone(serializer.data.get('id',None), 'Model is saved. `id` should be set.') - +    def test_fields_marked_as_not_required_are_excluded_from_validation(self): +        """ +        Check that fields with `required=False` are included in list of exclusions. +        """ +        serializer = ActionItemSerializerOptionalFields(self.actionitem) +        exclusions = serializer.get_validation_exclusions() +        self.assertTrue('title' in exclusions, '`title` field was marked `required=False` and should be excluded')  class DictStyleSerializer(serializers.Serializer): @@ -467,6 +508,32 @@ class ValidationTests(TestCase):          )          self.assertEqual(serializer.is_valid(), True) +    def test_writable_star_source_on_nested_serializer_with_parent_object(self): +        class TitleSerializer(serializers.Serializer): +            title = serializers.WritableField(source='title') + +        class AlbumSerializer(serializers.ModelSerializer): +            nested = TitleSerializer(source='*') + +            class Meta: +                model = Album +                fields = ('nested',) + +        class PhotoSerializer(serializers.ModelSerializer): +            album = AlbumSerializer(source='album') + +            class Meta: +                model = Photo +                fields = ('album', ) + +        photo = Photo(album=Album()) + +        data = {'album': {'nested': {'title': 'test'}}} + +        serializer = PhotoSerializer(photo, data=data) +        self.assertEqual(serializer.is_valid(), True) +        self.assertEqual(serializer.data, data) +      def test_writable_star_source_with_inner_source_fields(self):          """          Tests that a serializer with source="*" correctly expands the @@ -576,12 +643,15 @@ class ModelValidationTests(TestCase):          """          Just check if serializers.ModelSerializer handles unique checks via .full_clean()          """ -        serializer = AlbumsSerializer(data={'title': 'a'}) +        serializer = AlbumsSerializer(data={'title': 'a', 'ref': '1'})          serializer.is_valid()          serializer.save()          second_serializer = AlbumsSerializer(data={'title': 'a'})          self.assertFalse(second_serializer.is_valid()) -        self.assertEqual(second_serializer.errors,  {'title': ['Album with this Title already exists.']}) +        self.assertEqual(second_serializer.errors,  {'title': ['Album with this Title already exists.'],}) +        third_serializer = AlbumsSerializer(data=[{'title': 'b', 'ref': '1'}, {'title': 'c'}]) +        self.assertFalse(third_serializer.is_valid()) +        self.assertEqual(third_serializer.errors,  [{'ref': ['Album with this Ref already exists.']}, {}])      def test_foreign_key_is_null_with_partial(self):          """ @@ -865,6 +935,58 @@ class DefaultValueTests(TestCase):          self.assertEqual(instance.text, 'overridden') +class WritableFieldDefaultValueTests(TestCase): + +    def setUp(self): +        self.expected = {'default': 'value'} +        self.create_field = fields.WritableField + +    def test_get_default_value_with_noncallable(self): +        field = self.create_field(default=self.expected) +        got = field.get_default_value() +        self.assertEqual(got, self.expected) + +    def test_get_default_value_with_callable(self): +        field = self.create_field(default=lambda : self.expected) +        got = field.get_default_value() +        self.assertEqual(got, self.expected) + +    def test_get_default_value_when_not_required(self): +        field = self.create_field(default=self.expected, required=False) +        got = field.get_default_value() +        self.assertEqual(got, self.expected) + +    def test_get_default_value_returns_None(self): +        field = self.create_field() +        got = field.get_default_value() +        self.assertIsNone(got) + +    def test_get_default_value_returns_non_True_values(self): +        values = [None, '', False, 0, [], (), {}] # values that assumed as 'False' in the 'if' clause +        for expected in values: +            field = self.create_field(default=expected) +            got = field.get_default_value() +            self.assertEqual(got, expected) + + +class RelatedFieldDefaultValueTests(WritableFieldDefaultValueTests): + +    def setUp(self): +        self.expected = {'foo': 'bar'} +        self.create_field = relations.RelatedField + +    def test_get_default_value_returns_empty_list(self): +        field = self.create_field(many=True) +        got = field.get_default_value() +        self.assertListEqual(got, []) + +    def test_get_default_value_returns_expected(self): +        expected = [1, 2, 3] +        field = self.create_field(many=True, default=expected) +        got = field.get_default_value() +        self.assertListEqual(got, expected) + +  class CallableDefaultValueTests(TestCase):      def setUp(self):          class CallableDefaultValueSerializer(serializers.ModelSerializer): @@ -1492,19 +1614,10 @@ class ManyFieldHelpTextTest(TestCase):          self.assertEqual('Some help text.', rel_field.help_text) +@unittest.skipUnless(PIL is not None, 'PIL is not installed')  class AttributeMappingOnAutogeneratedFieldsTests(TestCase):      def setUp(self): -        class AMOAFModel(RESTFrameworkModel): -            char_field = models.CharField(max_length=1024, blank=True) -            comma_separated_integer_field = models.CommaSeparatedIntegerField(max_length=1024, blank=True) -            decimal_field = models.DecimalField(max_digits=64, decimal_places=32, blank=True) -            email_field = models.EmailField(max_length=1024, blank=True) -            file_field = models.FileField(max_length=1024, blank=True) -            image_field = models.ImageField(max_length=1024, blank=True) -            slug_field = models.SlugField(max_length=1024, blank=True) -            url_field = models.URLField(max_length=1024, blank=True) -            nullable_char_field = models.CharField(max_length=1024, blank=True, null=True)          class AMOAFSerializer(serializers.ModelSerializer):              class Meta: @@ -1581,17 +1694,10 @@ class AttributeMappingOnAutogeneratedFieldsTests(TestCase):          self.field_test('nullable_char_field') +@unittest.skipUnless(PIL is not None, 'PIL is not installed')  class DefaultValuesOnAutogeneratedFieldsTests(TestCase):      def setUp(self): -        class DVOAFModel(RESTFrameworkModel): -            positive_integer_field = models.PositiveIntegerField(blank=True) -            positive_small_integer_field = models.PositiveSmallIntegerField(blank=True) -            email_field = models.EmailField(blank=True) -            file_field = models.FileField(blank=True) -            image_field = models.ImageField(blank=True) -            slug_field = models.SlugField(blank=True) -            url_field = models.URLField(blank=True)          class DVOAFSerializer(serializers.ModelSerializer):              class Meta: @@ -1830,14 +1936,14 @@ class SerializerDefaultTrueBoolean(TestCase):          self.assertEqual(serializer.data['cat'], False)          self.assertEqual(serializer.data['dog'], False) -         +  class BoolenFieldTypeTest(TestCase):      '''      Ensure the various Boolean based model fields are rendered as the proper      field type -     +      ''' -     +      def setUp(self):          '''          Setup an ActionItemSerializer for BooleanTesting @@ -1853,11 +1959,11 @@ class BoolenFieldTypeTest(TestCase):          '''          bfield = self.serializer.get_fields()['done']          self.assertEqual(type(bfield), fields.BooleanField) -     +      def test_nullbooleanfield_type(self):          ''' -        Test that BooleanField is infered from models.NullBooleanField  -         +        Test that BooleanField is infered from models.NullBooleanField +          https://groups.google.com/forum/#!topic/django-rest-framework/D9mXEftpuQ8          '''          bfield = self.serializer.get_fields()['started'] diff --git a/rest_framework/tests/test_serializer_import.py b/rest_framework/tests/test_serializer_import.py new file mode 100644 index 00000000..9f30a7ff --- /dev/null +++ b/rest_framework/tests/test_serializer_import.py @@ -0,0 +1,19 @@ +from django.test import TestCase + +from rest_framework import serializers +from rest_framework.tests.accounts.serializers import AccountSerializer + + +class ImportingModelSerializerTests(TestCase): +    """ +    In some situations like, GH #1225, it is possible, especially in +    testing, to import a serializer who's related models have not yet +    been resolved by Django. `AccountSerializer` is an example of such +    a serializer (imported at the top of this file). +    """ +    def test_import_model_serializer(self): +        """ +        The serializer at the top of this file should have been +        imported successfully, and we should be able to instantiate it. +        """ +        self.assertIsInstance(AccountSerializer(), serializers.ModelSerializer) diff --git a/rest_framework/tests/test_serializer_nested.py b/rest_framework/tests/test_serializer_nested.py index 7114a060..6d69ffbd 100644 --- a/rest_framework/tests/test_serializer_nested.py +++ b/rest_framework/tests/test_serializer_nested.py @@ -345,4 +345,3 @@ class NestedModelSerializerUpdateTests(TestCase):          result = deserialize.object          result.save()          self.assertEqual(result.id, john.id) - diff --git a/rest_framework/tests/test_serializers.py b/rest_framework/tests/test_serializers.py new file mode 100644 index 00000000..082a400c --- /dev/null +++ b/rest_framework/tests/test_serializers.py @@ -0,0 +1,28 @@ +from django.db import models +from django.test import TestCase + +from rest_framework.serializers import _resolve_model +from rest_framework.tests.models import BasicModel + + +class ResolveModelTests(TestCase): +    """ +    `_resolve_model` should return a Django model class given the +    provided argument is a Django model class itself, or a properly +    formatted string representation of one. +    """ +    def test_resolve_django_model(self): +        resolved_model = _resolve_model(BasicModel) +        self.assertEqual(resolved_model, BasicModel) + +    def test_resolve_string_representation(self): +        resolved_model = _resolve_model('tests.BasicModel') +        self.assertEqual(resolved_model, BasicModel) + +    def test_resolve_non_django_model(self): +        with self.assertRaises(ValueError): +            _resolve_model(TestCase) + +    def test_resolve_improper_string_representation(self): +        with self.assertRaises(ValueError): +            _resolve_model('BasicModel') diff --git a/rest_framework/tests/test_templatetags.py b/rest_framework/tests/test_templatetags.py new file mode 100644 index 00000000..d4da0c23 --- /dev/null +++ b/rest_framework/tests/test_templatetags.py @@ -0,0 +1,51 @@ +# encoding: utf-8 +from __future__ import unicode_literals +from django.test import TestCase +from rest_framework.test import APIRequestFactory +from rest_framework.templatetags.rest_framework import add_query_param, urlize_quoted_links + +factory = APIRequestFactory() + + +class TemplateTagTests(TestCase): + +    def test_add_query_param_with_non_latin_charactor(self): +        # Ensure we don't double-escape non-latin characters +        # that are present in the querystring. +        # See #1314. +        request = factory.get("/", {'q': '查询'}) +        json_url = add_query_param(request, "format", "json") +        self.assertIn("q=%E6%9F%A5%E8%AF%A2", json_url) +        self.assertIn("format=json", json_url) + + +class Issue1386Tests(TestCase): +    """ +    Covers #1386 +    """ + +    def test_issue_1386(self): +        """ +        Test function urlize_quoted_links with different args +        """ +        correct_urls = [ +            "asdf.com", +            "asdf.net", +            "www.as_df.org", +            "as.d8f.ghj8.gov", +        ] +        for i in correct_urls: +            res = urlize_quoted_links(i) +            self.assertNotEqual(res, i) +            self.assertIn(i, res) + +        incorrect_urls = [ +            "mailto://asdf@fdf.com", +            "asdf.netnet", +        ] +        for i in incorrect_urls: +            res = urlize_quoted_links(i) +            self.assertEqual(i, res) + +        # example from issue #1386, this shouldn't raise an exception +        _ = urlize_quoted_links("asdf:[/p]zxcv.com") diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py index c08dd493..83ae8148 100644 --- a/rest_framework/tests/test_testing.py +++ b/rest_framework/tests/test_testing.py @@ -2,6 +2,8 @@  from __future__ import unicode_literals  from django.conf.urls import patterns, url +from io import BytesIO +  from django.contrib.auth.models import User  from django.test import TestCase  from rest_framework.decorators import api_view @@ -143,3 +145,20 @@ class TestAPIRequestFactory(TestCase):          force_authenticate(request, user=user)          response = view(request)          self.assertEqual(response.data['user'], 'example') + +    def test_upload_file(self): +        # This is a 1x1 black png +        simple_png = BytesIO(b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x06\x00\x00\x00\x1f\x15\xc4\x89\x00\x00\x00\rIDATx\x9cc````\x00\x00\x00\x05\x00\x01\xa5\xf6E@\x00\x00\x00\x00IEND\xaeB`\x82') +        simple_png.name = 'test.png' +        factory = APIRequestFactory() +        factory.post('/', data={'image': simple_png}) + +    def test_request_factory_url_arguments(self): +        """ +        This is a non regression test against #1461 +        """ +        factory = APIRequestFactory() +        request = factory.get('/view/?demo=test') +        self.assertEqual(dict(request.GET), {'demo': ['test']}) +        request = factory.get('/view/', {'demo': 'test'}) +        self.assertEqual(dict(request.GET), {'demo': ['test']}) diff --git a/rest_framework/tests/test_validation.py b/rest_framework/tests/test_validation.py index 124c874d..e13e4078 100644 --- a/rest_framework/tests/test_validation.py +++ b/rest_framework/tests/test_validation.py @@ -1,4 +1,5 @@  from __future__ import unicode_literals +from django.core.validators import MaxValueValidator  from django.db import models  from django.test import TestCase  from rest_framework import generics, serializers, status @@ -102,3 +103,46 @@ class TestAvoidValidation(TestCase):          self.assertFalse(serializer.is_valid())          self.assertDictEqual(serializer.errors,                               {'non_field_errors': ['Invalid data']}) + + +# regression tests for issue: 1493 + +class ValidationMaxValueValidatorModel(models.Model): +    number_value = models.PositiveIntegerField(validators=[MaxValueValidator(100)]) + + +class ValidationMaxValueValidatorModelSerializer(serializers.ModelSerializer): +    class Meta: +        model = ValidationMaxValueValidatorModel + + +class UpdateMaxValueValidationModel(generics.RetrieveUpdateDestroyAPIView): +    model = ValidationMaxValueValidatorModel +    serializer_class = ValidationMaxValueValidatorModelSerializer + + +class TestMaxValueValidatorValidation(TestCase): + +    def test_max_value_validation_serializer_success(self): +        serializer = ValidationMaxValueValidatorModelSerializer(data={'number_value': 99}) +        self.assertTrue(serializer.is_valid()) + +    def test_max_value_validation_serializer_fails(self): +        serializer = ValidationMaxValueValidatorModelSerializer(data={'number_value': 101}) +        self.assertFalse(serializer.is_valid()) +        self.assertDictEqual({'number_value': ['Ensure this value is less than or equal to 100.']}, serializer.errors) + +    def test_max_value_validation_success(self): +        obj = ValidationMaxValueValidatorModel.objects.create(number_value=100) +        request = factory.patch('/{0}'.format(obj.pk), {'number_value': 98}, format='json') +        view = UpdateMaxValueValidationModel().as_view() +        response = view(request, pk=obj.pk).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) + +    def test_max_value_validation_fail(self): +        obj = ValidationMaxValueValidatorModel.objects.create(number_value=100) +        request = factory.patch('/{0}'.format(obj.pk), {'number_value': 101}, format='json') +        view = UpdateMaxValueValidationModel().as_view() +        response = view(request, pk=obj.pk).render() +        self.assertEqual(response.content, b'{"number_value": ["Ensure this value is less than or equal to 100."]}') +        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) diff --git a/rest_framework/tests/test_write_only_fields.py b/rest_framework/tests/test_write_only_fields.py new file mode 100644 index 00000000..aabb18d6 --- /dev/null +++ b/rest_framework/tests/test_write_only_fields.py @@ -0,0 +1,42 @@ +from django.db import models +from django.test import TestCase +from rest_framework import serializers + + +class ExampleModel(models.Model): +    email = models.EmailField(max_length=100) +    password = models.CharField(max_length=100) + + +class WriteOnlyFieldTests(TestCase): +    def test_write_only_fields(self): +        class ExampleSerializer(serializers.Serializer): +            email = serializers.EmailField() +            password = serializers.CharField(write_only=True) + +        data = { +            'email': 'foo@example.com', +            'password': '123' +        } +        serializer = ExampleSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        self.assertEquals(serializer.object, data) +        self.assertEquals(serializer.data, {'email': 'foo@example.com'}) + +    def test_write_only_fields_meta(self): +        class ExampleSerializer(serializers.ModelSerializer): +            class Meta: +                model = ExampleModel +                fields = ('email', 'password') +                write_only_fields = ('password',) + +        data = { +            'email': 'foo@example.com', +            'password': '123' +        } +        serializer = ExampleSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        self.assertTrue(isinstance(serializer.object, ExampleModel)) +        self.assertEquals(serializer.object.email, data['email']) +        self.assertEquals(serializer.object.password, data['password']) +        self.assertEquals(serializer.data, {'email': 'foo@example.com'}) diff --git a/rest_framework/tests/users/__init__.py b/rest_framework/tests/users/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/rest_framework/tests/users/__init__.py diff --git a/rest_framework/tests/users/models.py b/rest_framework/tests/users/models.py new file mode 100644 index 00000000..128bac90 --- /dev/null +++ b/rest_framework/tests/users/models.py @@ -0,0 +1,6 @@ +from django.db import models + + +class User(models.Model): +    account = models.ForeignKey('accounts.Account', blank=True, null=True, related_name='users') +    active_record = models.ForeignKey('records.Record', blank=True, null=True) diff --git a/rest_framework/tests/users/serializers.py b/rest_framework/tests/users/serializers.py new file mode 100644 index 00000000..da496554 --- /dev/null +++ b/rest_framework/tests/users/serializers.py @@ -0,0 +1,8 @@ +from rest_framework import serializers + +from rest_framework.tests.users.models import User + + +class UserSerializer(serializers.ModelSerializer): +    class Meta: +        model = User diff --git a/rest_framework/tests/utils.py b/rest_framework/tests/utils.py new file mode 100644 index 00000000..a8f2eb0b --- /dev/null +++ b/rest_framework/tests/utils.py @@ -0,0 +1,25 @@ +from contextlib import contextmanager +from rest_framework.compat import six +from rest_framework.settings import api_settings + + +@contextmanager +def temporary_setting(setting, value, module=None): +    """ +    Temporarily change value of setting for test. + +    Optionally reload given module, useful when module uses value of setting on +    import. +    """ +    original_value = getattr(api_settings, setting) +    setattr(api_settings, setting, value) + +    if module is not None: +        six.moves.reload_module(module) + +    yield + +    setattr(api_settings, setting, original_value) + +    if module is not None: +        six.moves.reload_module(module) diff --git a/rest_framework/tests/views.py b/rest_framework/tests/views.py new file mode 100644 index 00000000..3917b74a --- /dev/null +++ b/rest_framework/tests/views.py @@ -0,0 +1,8 @@ +from rest_framework import generics +from rest_framework.tests.models import NullableForeignKeySource +from rest_framework.tests.serializers import NullableFKSourceSerializer + + +class NullableFKSourceDetail(generics.RetrieveUpdateDestroyAPIView): +    model = NullableForeignKeySource +    model_serializer_class = NullableFKSourceSerializer diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py index c40f3065..fc24c92e 100644 --- a/rest_framework/throttling.py +++ b/rest_framework/throttling.py @@ -155,6 +155,8 @@ class SimpleRateThrottle(BaseThrottle):              remaining_duration = self.duration          available_requests = self.num_requests - len(self.history) + 1 +        if available_requests <= 0: +            return None          return remaining_duration / float(available_requests) diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py index 229b0b28..c125ac8a 100644 --- a/rest_framework/utils/encoders.py +++ b/rest_framework/utils/encoders.py @@ -3,6 +3,7 @@ Helper classes for parsers.  """  from __future__ import unicode_literals  from django.utils import timezone +from django.db.models.query import QuerySet  from django.utils.datastructures import SortedDict  from django.utils.functional import Promise  from rest_framework.compat import force_text @@ -43,6 +44,8 @@ class JSONEncoder(json.JSONEncoder):              return str(o.total_seconds())          elif isinstance(o, decimal.Decimal):              return str(o) +        elif isinstance(o, QuerySet): +            return list(o)          elif hasattr(o, 'tolist'):              return o.tolist()          elif hasattr(o, '__getitem__'): diff --git a/rest_framework/utils/mediatypes.py b/rest_framework/utils/mediatypes.py index c09c2933..92f99efd 100644 --- a/rest_framework/utils/mediatypes.py +++ b/rest_framework/utils/mediatypes.py @@ -74,7 +74,7 @@ class _MediaType(object):              return 0          elif self.sub_type == '*':              return 1 -        elif not self.params or self.params.keys() == ['q']: +        elif not self.params or list(self.params.keys()) == ['q']:              return 2          return 3 diff --git a/rest_framework/views.py b/rest_framework/views.py index e863af6d..a2668f2c 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -112,12 +112,13 @@ class APIView(View):      @property      def default_response_headers(self): -        # TODO: deprecate? -        # TODO: Only vary by accept if multiple renderers -        return { +        headers = {              'Allow': ', '.join(self.allowed_methods), -            'Vary': 'Accept'          } +        if len(self.renderer_classes) > 1: +            headers['Vary'] = 'Accept' +        return headers +      def http_method_not_allowed(self, request, *args, **kwargs):          """ @@ -130,7 +131,7 @@ class APIView(View):          """          If request is not permitted, determine what kind of exception to raise.          """ -        if not self.request.successful_authenticator: +        if not request.successful_authenticator:              raise exceptions.NotAuthenticated()          raise exceptions.PermissionDenied() @@ -294,7 +295,7 @@ class APIView(View):      # Dispatch methods -    def initialize_request(self, request, *args, **kargs): +    def initialize_request(self, request, *args, **kwargs):          """          Returns the initial request object.          """ | 
