diff options
Diffstat (limited to 'rest_framework')
44 files changed, 1351 insertions, 301 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index 151ba832..f9882c57 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -1,3 +1,3 @@ -__version__ = '2.1.14' +__version__ = '2.1.17'  VERSION = __version__  # synonym diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index c50bf944..76ee4bd6 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -23,34 +23,47 @@ class BaseAuthentication(object):          """          raise NotImplementedError(".authenticate() must be overridden.") +    def authenticate_header(self, request): +        """ +        Return a string to be used as the value of the `WWW-Authenticate` +        header in a `401 Unauthenticated` response, or `None` if the +        authentication scheme should return `403 Permission Denied` responses. +        """ +        pass +  class BasicAuthentication(BaseAuthentication):      """      HTTP Basic authentication against username/password.      """ +    www_authenticate_realm = 'api'      def authenticate(self, request):          """          Returns a `User` if a correct username and password have been supplied          using HTTP Basic authentication.  Otherwise returns `None`.          """ -        if 'HTTP_AUTHORIZATION' in request.META: -            auth = request.META['HTTP_AUTHORIZATION'].split() -            if len(auth) == 2 and auth[0].lower() == "basic": -                try: -                    encoding = api_settings.HTTP_HEADER_ENCODING -                    b = base64.b64decode(auth[1].encode(encoding)) -                    auth_parts = b.decode(encoding).partition(':') -                except TypeError: -                    return None - -                try: -                    userid = smart_text(auth_parts[0]) -                    password = smart_text(auth_parts[2]) -                except DjangoUnicodeDecodeError: -                    return None - -                return self.authenticate_credentials(userid, password) +        auth = request.META.get('HTTP_AUTHORIZATION', '').split() + +        if not auth or auth[0].lower() != "basic": +            return None + +        if len(auth) != 2: +            raise exceptions.AuthenticationFailed('Invalid basic header') + +        encoding = api_settings.HTTP_HEADER_ENCODING +        try: +            auth_parts = base64.b64decode(auth[1].encode(encoding)).partition(':') +        except TypeError: +            raise exceptions.AuthenticationFailed('Invalid basic header') + +        try: +            userid = smart_text(auth_parts[0]) +            password = smart_text(auth_parts[2]) +        except DjangoUnicodeDecodeError: +            raise exceptions.AuthenticationFailed('Invalid basic header') + +        return self.authenticate_credentials(userid, password)      def authenticate_credentials(self, userid, password):          """ @@ -59,6 +72,10 @@ class BasicAuthentication(BaseAuthentication):          user = authenticate(username=userid, password=password)          if user is not None and user.is_active:              return (user, None) +        raise exceptions.AuthenticationFailed('Invalid username/password') + +    def authenticate_header(self, request): +        return 'Basic realm="%s"' % self.www_authenticate_realm  class SessionAuthentication(BaseAuthentication): @@ -78,7 +95,7 @@ class SessionAuthentication(BaseAuthentication):          # Unauthenticated, CSRF validation not required          if not user or not user.is_active: -            return +            return None          # Enforce CSRF validation for session based authentication.          class CSRFCheck(CsrfViewMiddleware): @@ -89,7 +106,7 @@ class SessionAuthentication(BaseAuthentication):          reason = CSRFCheck().process_view(http_request, None, (), {})          if reason:              # CSRF failed, bail with explicit error message -            raise exceptions.PermissionDenied('CSRF Failed: %s' % reason) +            raise exceptions.AuthenticationFailed('CSRF Failed: %s' % reason)          # CSRF passed with authenticated user          return (user, None) @@ -116,14 +133,26 @@ class TokenAuthentication(BaseAuthentication):      def authenticate(self, request):          auth = request.META.get('HTTP_AUTHORIZATION', '').split() -        if len(auth) == 2 and auth[0].lower() == "token": -            key = auth[1] -            try: -                token = self.model.objects.get(key=key) -            except self.model.DoesNotExist: -                return None +        if not auth or auth[0].lower() != "token": +            return None + +        if len(auth) != 2: +            raise exceptions.AuthenticationFailed('Invalid token header') + +        return self.authenticate_credentials(auth[1]) + +    def authenticate_credentials(self, key): +        try: +            token = self.model.objects.get(key=key) +        except self.model.DoesNotExist: +            raise exceptions.AuthenticationFailed('Invalid token') + +        if token.user.is_active: +            return (token.user, token) +        raise exceptions.AuthenticationFailed('User inactive or deleted') + +    def authenticate_header(self, request): +        return 'Token' -            if token.user.is_active: -                return (token.user, token)  # TODO: OAuthAuthentication diff --git a/rest_framework/authtoken/views.py b/rest_framework/authtoken/views.py index d318c723..7c03cb76 100644 --- a/rest_framework/authtoken/views.py +++ b/rest_framework/authtoken/views.py @@ -12,10 +12,11 @@ class ObtainAuthToken(APIView):      permission_classes = ()      parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,)      renderer_classes = (renderers.JSONRenderer,) +    serializer_class = AuthTokenSerializer      model = Token      def post(self, request): -        serializer = AuthTokenSerializer(data=request.DATA) +        serializer = self.serializer_class(data=request.DATA)          if serializer.is_valid():              token, created = Token.objects.get_or_create(user=serializer.object['user'])              return Response({'token': token.key}) diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 5924cd6d..ef11b85b 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -126,6 +126,12 @@ else:              update_wrapper(view, cls.dispatch, assigned=())              return view +# Taken from @markotibold's attempt at supporting PATCH. +# https://github.com/markotibold/django-rest-framework/tree/patch +http_method_names = set(View.http_method_names) +http_method_names.add('patch') +View.http_method_names = list(http_method_names)  # PATCH method is not implemented by Django +  # PUT, DELETE do not require CSRF until 1.4.  They should.  Make it better.  if django.VERSION >= (1, 4):      from django.middleware.csrf import CsrfViewMiddleware diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index 1b710a03..7a4103e1 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -1,4 +1,5 @@  from rest_framework.views import APIView +import types  def api_view(http_method_names): @@ -23,6 +24,14 @@ def api_view(http_method_names):          #         pass          #     WrappedAPIView.__doc__ = func.doc    <--- Not possible to do this +        # api_view applied without (method_names) +        assert not(isinstance(http_method_names, types.FunctionType)), \ +            '@api_view missing list of allowed HTTP methods' + +        # api_view applied with eg. string instead of list of strings +        assert isinstance(http_method_names, (list, tuple)), \ +            '@api_view expected a list of strings, recieved %s' % type(http_method_names).__name__ +          allowed_methods = set(http_method_names) | set(('options',))          WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods] diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index 89479deb..d635351c 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -23,6 +23,22 @@ class ParseError(APIException):          self.detail = detail or self.default_detail +class AuthenticationFailed(APIException): +    status_code = status.HTTP_401_UNAUTHORIZED +    default_detail = 'Incorrect authentication credentials.' + +    def __init__(self, detail=None): +        self.detail = detail or self.default_detail + + +class NotAuthenticated(APIException): +    status_code = status.HTTP_401_UNAUTHORIZED +    default_detail = 'Authentication credentials were not provided.' + +    def __init__(self, detail=None): +        self.detail = detail or self.default_detail + +  class PermissionDenied(APIException):      status_code = status.HTTP_403_FORBIDDEN      default_detail = 'You do not have permission to perform this action.' diff --git a/rest_framework/fields.py b/rest_framework/fields.py index adea5bf5..a66e1d7c 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1,4 +1,3 @@ -  from __future__ import unicode_literals  import copy @@ -185,11 +184,13 @@ class WritableField(Field):          try:              if self._use_files: +                files = files or {}                  native = files[field_name]              else:                  native = data[field_name]          except KeyError: -            if self.default is not None: +            if self.default is not None and not self.root.partial: +                # Note: partial updates shouldn't set defaults                  native = self.default              else:                  if self.required: diff --git a/rest_framework/generics.py b/rest_framework/generics.py index dd8dfcf8..19f2b704 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -47,14 +47,16 @@ class GenericAPIView(views.APIView):          return serializer_class -    def get_serializer(self, instance=None, data=None, files=None): +    def get_serializer(self, instance=None, data=None, +                       files=None, partial=False):          """          Return the serializer instance that should be used for validating and          deserializing input, and for serializing output.          """          serializer_class = self.get_serializer_class()          context = self.get_serializer_context() -        return serializer_class(instance, data=data, files=files, context=context) +        return serializer_class(instance, data=data, files=files, +                                partial=partial, context=context)  class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView): @@ -171,6 +173,10 @@ class UpdateAPIView(mixins.UpdateModelMixin,      def put(self, request, *args, **kwargs):          return self.update(request, *args, **kwargs) +    def patch(self, request, *args, **kwargs): +        kwargs['partial'] = True +        return self.update(request, *args, **kwargs) +  class ListCreateAPIView(mixins.ListModelMixin,                          mixins.CreateModelMixin, @@ -185,6 +191,23 @@ class ListCreateAPIView(mixins.ListModelMixin,          return self.create(request, *args, **kwargs) +class RetrieveUpdateAPIView(mixins.RetrieveModelMixin, +                            mixins.UpdateModelMixin, +                            SingleObjectAPIView): +    """ +    Concrete view for retrieving, updating a model instance. +    """ +    def get(self, request, *args, **kwargs): +        return self.retrieve(request, *args, **kwargs) + +    def put(self, request, *args, **kwargs): +        return self.update(request, *args, **kwargs) + +    def patch(self, request, *args, **kwargs): +        kwargs['partial'] = True +        return self.update(request, *args, **kwargs) + +  class RetrieveDestroyAPIView(mixins.RetrieveModelMixin,                               mixins.DestroyModelMixin,                               SingleObjectAPIView): @@ -211,5 +234,9 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,      def put(self, request, *args, **kwargs):          return self.update(request, *args, **kwargs) +    def patch(self, request, *args, **kwargs): +        kwargs['partial'] = True +        return self.update(request, *args, **kwargs) +      def delete(self, request, *args, **kwargs):          return self.destroy(request, *args, **kwargs) diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 503376ce..acaf8a71 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -18,11 +18,14 @@ class CreateModelMixin(object):      """      def create(self, request, *args, **kwargs):          serializer = self.get_serializer(data=request.DATA, files=request.FILES) +          if serializer.is_valid():              self.pre_save(serializer.object)              self.object = serializer.save()              headers = self.get_success_headers(serializer.data) -            return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) +            return Response(serializer.data, status=status.HTTP_201_CREATED, +                            headers=headers) +          return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)      def get_success_headers(self, data): @@ -84,20 +87,21 @@ class UpdateModelMixin(object):      Should be mixed in with `SingleObjectBaseView`.      """      def update(self, request, *args, **kwargs): +        partial = kwargs.pop('partial', False)          try:              self.object = self.get_object() -            created = False +            success_status_code = status.HTTP_200_OK          except Http404:              self.object = None -            created = True +            success_status_code = status.HTTP_201_CREATED -        serializer = self.get_serializer(self.object, data=request.DATA, files=request.FILES) +        serializer = self.get_serializer(self.object, data=request.DATA, +                                         files=request.FILES, partial=partial)          if serializer.is_valid():              self.pre_save(serializer.object)              self.object = serializer.save() -            status_code = created and status.HTTP_201_CREATED or status.HTTP_200_OK -            return Response(serializer.data, status=status_code) +            return Response(serializer.data, status=success_status_code)          return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) @@ -117,7 +121,8 @@ class UpdateModelMixin(object):          # Ensure we clean the attributes so that we don't eg return integer          # pk using a string representation, as provided by the url conf kwarg. -        obj.full_clean() +        if hasattr(obj, 'full_clean'): +            obj.full_clean()  class DestroyModelMixin(object): diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index d241ade7..92d41e0e 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -34,6 +34,17 @@ class PreviousPageField(serializers.Field):          return replace_query_param(url, self.page_field, page) +class DefaultObjectSerializer(serializers.Field): +    """ +    If no object serializer is specified, then this serializer will be applied +    as the default. +    """ + +    def __init__(self, source=None, context=None): +        # Note: Swallow context kwarg - only required for eg. ModelSerializer. +        super(DefaultObjectSerializer, self).__init__(source=source) + +  class PaginationSerializerOptions(serializers.SerializerOptions):      """      An object that stores the options that may be provided to a @@ -44,7 +55,7 @@ class PaginationSerializerOptions(serializers.SerializerOptions):      def __init__(self, meta):          super(PaginationSerializerOptions, self).__init__(meta)          self.object_serializer_class = getattr(meta, 'object_serializer_class', -                                               serializers.Field) +                                               DefaultObjectSerializer)  class BasePaginationSerializer(serializers.Serializer): @@ -62,14 +73,13 @@ class BasePaginationSerializer(serializers.Serializer):          super(BasePaginationSerializer, self).__init__(*args, **kwargs)          results_field = self.results_field          object_serializer = self.opts.object_serializer_class -        self.fields[results_field] = object_serializer(source='object_list') -    def to_native(self, obj): -        """ -        Prevent default behaviour of iterating over elements, and serializing -        each in turn. -        """ -        return self.convert_object(obj) +        if 'context' in kwargs: +            context_kwarg = {'context': kwargs['context']} +        else: +            context_kwarg = {} + +        self.fields[results_field] = object_serializer(source='object_list', **context_kwarg)  class PaginationSerializer(BasePaginationSerializer): diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index 7c01006a..4a2b34a5 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -8,12 +8,12 @@ on the request, such as form content or json encoded data.  from django.http import QueryDict  from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser  from django.http.multipartparser import MultiPartParserError -from django.utils import simplejson as json  from rest_framework.compat import yaml, ETParseError  from rest_framework.exceptions import ParseError  from rest_framework.compat import six  from xml.etree import ElementTree as ET  from xml.parsers.expat import ExpatError +import json  import datetime  import decimal diff --git a/rest_framework/relations.py b/rest_framework/relations.py index b7a6e0c1..c4f854ef 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -6,6 +6,7 @@ from django.core.urlresolvers import resolve, get_script_prefix  from django import forms  from django.forms import widgets  from django.forms.models import ModelChoiceIterator +from django.utils.translation import ugettext_lazy as _  from rest_framework.fields import Field, WritableField  from rest_framework.reverse import reverse  from rest_framework.compat import urlparse @@ -103,7 +104,13 @@ class RelatedField(WritableField):      ### Regular serializer stuff...      def field_to_native(self, obj, field_name): -        value = getattr(obj, self.source or field_name) +        try: +            value = getattr(obj, self.source or field_name) +        except ObjectDoesNotExist: +            return None + +        if value is None: +            return None          return self.to_native(value)      def field_from_native(self, data, files, field_name, into): @@ -144,7 +151,7 @@ class ManyRelatedMixin(object):              value = data.getlist(self.source or field_name)          except:              # Non-form data -            value = data.get(self.source or field_name) +            value = data.get(self.source or field_name, [])          else:              if value == ['']:                  value = [] @@ -171,6 +178,11 @@ class PrimaryKeyRelatedField(RelatedField):      default_read_only = False      form_field_class = forms.ChoiceField +    default_error_messages = { +        'does_not_exist': _("Invalid pk '%s' - object does not exist."), +        'incorrect_type': _('Incorrect type.  Expected pk value, received %s.'), +    } +      # TODO: Remove these field hacks...      def prepare_value(self, obj):          return self.to_native(obj.pk) @@ -196,7 +208,11 @@ class PrimaryKeyRelatedField(RelatedField):          try:              return self.queryset.get(pk=data)          except ObjectDoesNotExist: -            msg = "Invalid pk '%s' - object does not exist." % smart_text(data) +            msg = self.error_messages['does_not_exist'] % smart_text(data) +            raise ValidationError(msg) +        except (TypeError, ValueError): +            received = type(data).__name__ +            msg = self.error_messages['incorrect_type'] % received              raise ValidationError(msg)      def field_to_native(self, obj, field_name): @@ -205,7 +221,10 @@ class PrimaryKeyRelatedField(RelatedField):              pk = obj.serializable_value(self.source or field_name)          except AttributeError:              # RelatedObject (reverse relationship) -            obj = getattr(obj, self.source or field_name) +            try: +                obj = getattr(obj, self.source or field_name) +            except ObjectDoesNotExist: +                return None              return self.to_native(obj.pk)          # Forward relationship          return self.to_native(pk) @@ -218,6 +237,11 @@ class ManyPrimaryKeyRelatedField(ManyRelatedField):      default_read_only = False      form_field_class = forms.MultipleChoiceField +    default_error_messages = { +        'does_not_exist': _("Invalid pk '%s' - object does not exist."), +        'incorrect_type': _('Incorrect type.  Expected pk value, received %s.'), +    } +      def prepare_value(self, obj):          return self.to_native(obj.pk) @@ -252,7 +276,11 @@ class ManyPrimaryKeyRelatedField(ManyRelatedField):          try:              return self.queryset.get(pk=data)          except ObjectDoesNotExist: -            msg = "Invalid pk '%s' - object does not exist." % smart_text(data) +            msg = self.error_messages['does_not_exist'] % smart_text(data) +            raise ValidationError(msg) +        except (TypeError, ValueError): +            received = type(data).__name__ +            msg = self.error_messages['incorrect_type'] % received              raise ValidationError(msg)  ### Slug relationships @@ -262,6 +290,11 @@ class SlugRelatedField(RelatedField):      default_read_only = False      form_field_class = forms.ChoiceField +    default_error_messages = { +        'does_not_exist': _("Object with %s=%s does not exist."), +        'invalid': _('Invalid value.'), +    } +      def __init__(self, *args, **kwargs):          self.slug_field = kwargs.pop('slug_field', None)          assert self.slug_field, 'slug_field is required' @@ -277,8 +310,11 @@ class SlugRelatedField(RelatedField):          try:              return self.queryset.get(**{self.slug_field: data})          except ObjectDoesNotExist: -            raise ValidationError('Object with %s=%s does not exist.' % +            raise ValidationError(self.error_messages['does_not_exist'] %                                    (self.slug_field, unicode(data))) +        except (TypeError, ValueError): +            msg = self.error_messages['invalid'] +            raise ValidationError(msg)  class ManySlugRelatedField(ManyRelatedMixin, SlugRelatedField): @@ -297,6 +333,14 @@ class HyperlinkedRelatedField(RelatedField):      default_read_only = False      form_field_class = forms.ChoiceField +    default_error_messages = { +        'no_match': _('Invalid hyperlink - No URL match'), +        'incorrect_match': _('Invalid hyperlink - Incorrect URL match'), +        'configuration_error': _('Invalid hyperlink due to configuration error'), +        'does_not_exist': _("Invalid hyperlink - object does not exist."), +        'incorrect_type': _('Incorrect type.  Expected url string, received %s.'), +    } +      def __init__(self, *args, **kwargs):          try:              self.view_name = kwargs.pop('view_name') @@ -333,21 +377,21 @@ class HyperlinkedRelatedField(RelatedField):          slug = getattr(obj, self.slug_field, None)          if not slug: -            raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name) +            raise Exception('Could not resolve URL for field using view name "%s"' % view_name)          kwargs = {self.slug_url_kwarg: slug}          try: -            return reverse(self.view_name, kwargs=kwargs, request=request, format=format) +            return reverse(view_name, kwargs=kwargs, request=request, format=format)          except:              pass          kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug}          try: -            return reverse(self.view_name, kwargs=kwargs, request=request, format=format) +            return reverse(view_name, kwargs=kwargs, request=request, format=format)          except:              pass -        raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name) +        raise Exception('Could not resolve URL for field using view name "%s"' % view_name)      def from_native(self, value):          # Convert URL -> model instance pk @@ -355,7 +399,13 @@ class HyperlinkedRelatedField(RelatedField):          if self.queryset is None:              raise Exception('Writable related fields must include a `queryset` argument') -        if value.startswith('http:') or value.startswith('https:'): +        try: +            http_prefix = value.startswith('http:') or value.startswith('https:') +        except AttributeError: +            msg = self.error_messages['incorrect_type'] +            raise ValidationError(msg % type(value).__name__) + +        if http_prefix:              # If needed convert absolute URLs to relative path              value = urlparse.urlparse(value).path              prefix = get_script_prefix() @@ -365,10 +415,10 @@ class HyperlinkedRelatedField(RelatedField):          try:              match = resolve(value)          except: -            raise ValidationError('Invalid hyperlink - No URL match') +            raise ValidationError(self.error_messages['no_match']) -        if match.url_name != self.view_name: -            raise ValidationError('Invalid hyperlink - Incorrect URL match') +        if match.view_name != self.view_name: +            raise ValidationError(self.error_messages['incorrect_match'])          pk = match.kwargs.get(self.pk_url_kwarg, None)          slug = match.kwargs.get(self.slug_url_kwarg, None) @@ -380,14 +430,18 @@ class HyperlinkedRelatedField(RelatedField):          elif slug is not None:              slug_field = self.get_slug_field()              queryset = self.queryset.filter(**{slug_field: slug}) -        # If none of those are defined, it's an error. +        # If none of those are defined, it's probably a configuation error.          else: -            raise ValidationError('Invalid hyperlink') +            raise ValidationError(self.error_messages['configuration_error'])          try:              obj = queryset.get()          except ObjectDoesNotExist: -            raise ValidationError('Invalid hyperlink - object does not exist.') +            raise ValidationError(self.error_messages['does_not_exist']) +        except (TypeError, ValueError): +            msg = self.error_messages['incorrect_type'] +            raise ValidationError(msg % type(value).__name__) +          return obj @@ -410,6 +464,7 @@ class HyperlinkedIdentityField(Field):          # TODO: Make view_name mandatory, and have the          # HyperlinkedModelSerializer set it on-the-fly          self.view_name = kwargs.pop('view_name', None) +        # Optionally the format of the target hyperlink may be specified          self.format = kwargs.pop('format', None)          self.slug_field = kwargs.pop('slug_field', self.slug_field) @@ -421,9 +476,22 @@ class HyperlinkedIdentityField(Field):      def field_to_native(self, obj, field_name):          request = self.context.get('request', None) -        format = self.format or self.context.get('format', None) +        format = self.context.get('format', None)          view_name = self.view_name or self.parent.opts.view_name          kwargs = {self.pk_url_kwarg: obj.pk} + +        # By default use whatever format is given for the current context +        # unless the target is a different type to the source. +        # +        # Eg. Consider a HyperlinkedIdentityField pointing from a json +        # representation to an html property of that representation... +        # +        # '/snippets/1/' should link to '/snippets/1/highlight/' +        # ...but... +        # '/snippets/1/.json' should link to '/snippets/1/highlight/.html' +        if format and self.format and self.format != format: +            format = self.format +          try:              return reverse(view_name, kwargs=kwargs, request=request, format=format)          except: @@ -432,18 +500,18 @@ class HyperlinkedIdentityField(Field):          slug = getattr(obj, self.slug_field, None)          if not slug: -            raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name) +            raise Exception('Could not resolve URL for field using view name "%s"' % view_name)          kwargs = {self.slug_url_kwarg: slug}          try: -            return reverse(self.view_name, kwargs=kwargs, request=request, format=format) +            return reverse(view_name, kwargs=kwargs, request=request, format=format)          except:              pass          kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug}          try: -            return reverse(self.view_name, kwargs=kwargs, request=request, format=format) +            return reverse(view_name, kwargs=kwargs, request=request, format=format)          except:              pass -        raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name) +        raise Exception('Could not resolve URL for field using view name "%s"' % view_name) diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 54930167..b3ee0690 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -10,10 +10,10 @@ from __future__ import unicode_literals  import copy  import string +import json  from django import forms  from django.http.multipartparser import parse_header  from django.template import RequestContext, loader, Template -from django.utils import simplejson as json  from rest_framework.compat import yaml  from rest_framework.exceptions import ConfigurationError  from rest_framework.settings import api_settings diff --git a/rest_framework/request.py b/rest_framework/request.py index 048a1c41..23e1da87 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -86,6 +86,7 @@ class Request(object):          self._method = Empty          self._content_type = Empty          self._stream = Empty +        self._authenticator = None          if self.parser_context is None:              self.parser_context = {} @@ -166,7 +167,7 @@ class Request(object):          by the authentication classes provided to the request.          """          if not hasattr(self, '_user'): -            self._user, self._auth = self._authenticate() +            self._authenticator, self._user, self._auth = self._authenticate()          return self._user      @user.setter @@ -185,7 +186,7 @@ class Request(object):          request, such as an authentication token.          """          if not hasattr(self, '_auth'): -            self._user, self._auth = self._authenticate() +            self._authenticator, self._user, self._auth = self._authenticate()          return self._auth      @auth.setter @@ -196,6 +197,14 @@ class Request(object):          """          self._auth = value +    @property +    def successful_authenticator(self): +        """ +        Return the instance of the authentication instance class that was used +        to authenticate the request, or `None`. +        """ +        return self._authenticator +      def _load_data_and_files(self):          """          Parses the request content into self.DATA and self.FILES. @@ -299,21 +308,23 @@ class Request(object):      def _authenticate(self):          """ -        Attempt to authenticate the request using each authentication instance in turn. -        Returns a two-tuple of (user, authtoken). +        Attempt to authenticate the request using each authentication instance +        in turn. +        Returns a three-tuple of (authenticator, user, authtoken).          """          for authenticator in self.authenticators:              user_auth_tuple = authenticator.authenticate(self)              if not user_auth_tuple is None: -                return user_auth_tuple +                user, auth = user_auth_tuple +                return (authenticator, user, auth)          return self._not_authenticated()      def _not_authenticated(self):          """ -        Return a two-tuple of (user, authtoken), representing an -        unauthenticated request. +        Return a three-tuple of (authenticator, user, authtoken), representing +        an unauthenticated request. -        By default this will be (AnonymousUser, None). +        By default this will be (None, AnonymousUser, None).          """          if api_settings.UNAUTHENTICATED_USER:              user = api_settings.UNAUTHENTICATED_USER() @@ -325,7 +336,7 @@ class Request(object):          else:              auth = None -        return (user, auth) +        return (None, user, auth)      def __getattr__(self, attr):          """ diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 663f166b..3d3bcb3c 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -2,6 +2,7 @@ import copy  import datetime  import types  from decimal import Decimal +from django.core.paginator import Page  from django.db import models  from django.forms import widgets  from django.utils.datastructures import SortedDict @@ -209,6 +210,11 @@ class BaseSerializer(Field):          Converts a dictionary of data into a dictionary of deserialized fields.          """          reverted_data = {} + +        if data is not None and not isinstance(data, dict): +            self._errors['non_field_errors'] = [u'Invalid data'] +            return None +          for field_name, field in self.fields.items():              field.initialize(parent=self, field_name=field_name)              try: @@ -223,6 +229,8 @@ class BaseSerializer(Field):          Run `validate_<fieldname>()` and `validate()` methods on the serializer          """          for field_name, field in self.fields.items(): +            if field_name in self._errors: +                continue              try:                  validate_method = getattr(self, 'validate_%s' % field_name, None)                  if validate_method: @@ -267,7 +275,11 @@ class BaseSerializer(Field):          """          Serialize objects -> primitives.          """ -        if hasattr(obj, '__iter__'): +        # Note: At the moment we have an ugly hack to determine if we should +        # walk over iterables.  At some point, serializers will require an +        # explicit `many=True` in order to iterate over a set, and this hack +        # will disappear. +        if hasattr(obj, '__iter__') and not isinstance(obj, Page):              return [self.convert_object(item) for item in obj]          return self.convert_object(obj) @@ -277,7 +289,7 @@ class BaseSerializer(Field):          """          if hasattr(data, '__iter__') and not isinstance(data, dict):              # TODO: error data when deserializing lists -            return (self.from_native(item) for item in data) +            return [self.from_native(item, None) for item in data]          self._errors = {}          if data is not None or files is not None: @@ -294,15 +306,21 @@ class BaseSerializer(Field):          Override default so that we can apply ModelSerializer as a nested          field to relationships.          """ -        if self.source: -            for component in self.source.split('.'): -                obj = getattr(obj, component) +        if self.source == '*': +            return self.to_native(obj) + +        try: +            if self.source: +                for component in self.source.split('.'): +                    obj = getattr(obj, component) +                    if is_simple_callable(obj): +                        obj = obj() +            else: +                obj = getattr(obj, field_name)                  if is_simple_callable(obj):                      obj = obj() -        else: -            obj = getattr(obj, field_name) -            if is_simple_callable(obj): -                obj = value() +        except ObjectDoesNotExist: +            return None          # If the object has an "all" method, assume it's a relationship          if is_simple_callable(getattr(obj, 'all', None)): @@ -408,7 +426,7 @@ class ModelSerializer(Serializer):          """          Returns a default instance of the pk field.          """ -        return Field() +        return self.get_field(model_field)      def get_nested_field(self, model_field):          """ @@ -426,7 +444,7 @@ class ModelSerializer(Serializer):          # TODO: filter queryset using:          # .using(db).complex_filter(self.rel.limit_choices_to)          kwargs = { -            'null': model_field.null, +            'null': model_field.null or model_field.blank,              'queryset': model_field.rel.to._default_manager          } @@ -445,11 +463,14 @@ class ModelSerializer(Serializer):          if model_field.null or model_field.blank:              kwargs['required'] = False +        if isinstance(model_field, models.AutoField) or not model_field.editable: +            kwargs['read_only'] = True +          if model_field.has_default():              kwargs['required'] = False              kwargs['default'] = model_field.get_default() -        if model_field.__class__ == models.TextField: +        if issubclass(model_field.__class__, models.TextField):              kwargs['widget'] = widgets.Textarea          # TODO: TypedChoiceField? @@ -458,6 +479,7 @@ class ModelSerializer(Serializer):              return ChoiceField(**kwargs)          field_mapping = { +            models.AutoField: IntegerField,              models.FloatField: FloatField,              models.IntegerField: IntegerField,              models.PositiveIntegerField: IntegerField, @@ -492,6 +514,22 @@ class ModelSerializer(Serializer):                  exclusions.remove(field_name)          return exclusions +    def full_clean(self, instance): +        """ +        Perform Django's full_clean, and populate the `errors` dictionary +        if any validation errors occur. + +        Note that we don't perform this inside the `.restore_object()` method, +        so that subclasses can override `.restore_object()`, and still get +        the full_clean validation checking. +        """ +        try: +            instance.full_clean(exclude=self.get_validation_exclusions()) +        except ValidationError, err: +            self._errors = err.message_dict +            return None +        return instance +      def restore_object(self, attrs, instance=None):          """          Restore the model instance. @@ -531,13 +569,21 @@ class ModelSerializer(Serializer):          return instance -    def save(self, save_m2m=True): +    def from_native(self, data, files): +        """ +        Override the default method to also include model field validation. +        """ +        instance = super(ModelSerializer, self).from_native(data, files) +        if instance: +            return self.full_clean(instance) + +    def save(self):          """          Save the deserialized object and return it.          """          self.object.save() -        if getattr(self, 'm2m_data', None) and save_m2m: +        if getattr(self, 'm2m_data', None):              for accessor_name, object_list in self.m2m_data.items():                  setattr(self.object, accessor_name, object_list)              self.m2m_data = {} diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 2358d188..13d03e62 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -119,9 +119,8 @@ def import_from_string(val, setting_name):          module_path, class_name = '.'.join(parts[:-1]), parts[-1]          module = importlib.import_module(module_path)          return getattr(module, class_name) -    except: -        raise -        msg = "Could not import '%s' for API setting '%s'" % (val, setting_name) +    except ImportError as e: +        msg = "Could not import '%s' for API setting '%s'. %s: %s." % (val, setting_name, e.__class__.__name__, e)          raise ImportError(msg) diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html index 42e49cb9..092bf2e4 100644 --- a/rest_framework/templates/rest_framework/base.html +++ b/rest_framework/templates/rest_framework/base.html @@ -112,7 +112,7 @@              <div class="request-info">                  <pre class="prettyprint"><b>{{ request.method }}</b> {{ request.get_full_path }}</pre> -            <div> +            </div>              <div class="response-info">                  <pre class="prettyprint"><div class="meta nocode"><b>HTTP {{ response.status_code }} {{ response.status_text }}</b>{% autoescape off %}  {% for key, val in response.items %}<b>{{ key }}:</b> <span class="lit">{{ val|urlize_quoted_links }}</span> diff --git a/rest_framework/templates/rest_framework/login.html b/rest_framework/templates/rest_framework/login.html index 6e2bd8d4..e10ce20f 100644 --- a/rest_framework/templates/rest_framework/login.html +++ b/rest_framework/templates/rest_framework/login.html @@ -25,14 +25,14 @@                      <form action="{% url 'rest_framework:login' %}" class=" form-inline" method="post">                          {% csrf_token %}                          <div id="div_id_username" class="clearfix control-group"> -                            <div class="controls" style="height: 30px"> -                                <Label class="span4" style="margin-top: 3px">Username:</label> +                            <div class="controls"> +                                <Label class="span4">Username:</label>                                  <input style="height: 25px" type="text" name="username" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_username">                              </div>                          </div>                          <div id="div_id_password" class="clearfix control-group"> -                            <div class="controls" style="height: 30px"> -                                <Label class="span4" style="margin-top: 3px">Password:</label> +                            <div class="controls"> +                                <Label class="span4">Password:</label>                                  <input style="height: 25px" type="password" name="password" maxlength="100" autocapitalize="off" autocorrect="off" class="textinput textInput" id="id_password">                              </div>                          </div> diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index 4205e57c..cbafbe0e 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -27,7 +27,7 @@ register = template.Library()  # conflicts with this rest_framework template tag module.  try:  # Django 1.5+ -    from django.contrib.staticfiles.templatetags import StaticFilesNode +    from django.contrib.staticfiles.templatetags.staticfiles import StaticFilesNode      @register.tag('static')      def do_static(parser, token): diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py index 8c0bfc47..ba2042cb 100644 --- a/rest_framework/tests/authentication.py +++ b/rest_framework/tests/authentication.py @@ -1,14 +1,13 @@  from django.contrib.auth.models import User  from django.http import HttpResponse  from django.test import Client, TestCase -from django.utils import simplejson as json -  from rest_framework import permissions  from rest_framework.authtoken.models import Token -from rest_framework.authentication import TokenAuthentication +from rest_framework.authentication import TokenAuthentication, BasicAuthentication, SessionAuthentication  from rest_framework.compat import patterns  from rest_framework.views import APIView +import json  import base64 @@ -21,10 +20,10 @@ class MockView(APIView):      def put(self, request):          return HttpResponse({'a': 1, 'b': 2, 'c': 3}) -MockView.authentication_classes += (TokenAuthentication,) -  urlpatterns = patterns('', -    (r'^$', MockView.as_view()), +    (r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])), +    (r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])), +    (r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),      (r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),  ) @@ -42,25 +41,26 @@ class BasicAuthTests(TestCase):      def test_post_form_passing_basic_auth(self):          """Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF""" -        auth = 'Basic ' + base64.encodestring(('%s:%s' % (self.username, self.password)).encode('iso-8859-1')).strip().decode('iso-8859-1') -        response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) +        auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).encode('iso-8859-1').strip().decode('iso-8859-1') +        response = self.csrf_client.post('/basic/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)          self.assertEqual(response.status_code, 200)      def test_post_json_passing_basic_auth(self):          """Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF""" -        auth = 'Basic ' + base64.encodestring(('%s:%s' % (self.username, self.password)).encode('iso-8859-1')).strip().decode('iso-8859-1') -        response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) +        auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).encode('iso-8859-1').strip().decode('iso-8859-1') +        response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)          self.assertEqual(response.status_code, 200)      def test_post_form_failing_basic_auth(self):          """Ensure POSTing form over basic auth without correct credentials fails""" -        response = self.csrf_client.post('/', {'example': 'example'}) -        self.assertEqual(response.status_code, 403) +        response = self.csrf_client.post('/basic/', {'example': 'example'}) +        self.assertEqual(response.status_code, 401)      def test_post_json_failing_basic_auth(self):          """Ensure POSTing json over basic auth without correct credentials fails""" -        response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json') -        self.assertEqual(response.status_code, 403) +        response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json') +        self.assertEqual(response.status_code, 401) +        self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"')  class SessionAuthTests(TestCase): @@ -83,7 +83,7 @@ class SessionAuthTests(TestCase):          Ensure POSTing form over session authentication without CSRF token fails.          """          self.csrf_client.login(username=self.username, password=self.password) -        response = self.csrf_client.post('/', {'example': 'example'}) +        response = self.csrf_client.post('/session/', {'example': 'example'})          self.assertEqual(response.status_code, 403)      def test_post_form_session_auth_passing(self): @@ -91,7 +91,7 @@ class SessionAuthTests(TestCase):          Ensure POSTing form over session authentication with logged in user and CSRF token passes.          """          self.non_csrf_client.login(username=self.username, password=self.password) -        response = self.non_csrf_client.post('/', {'example': 'example'}) +        response = self.non_csrf_client.post('/session/', {'example': 'example'})          self.assertEqual(response.status_code, 200)      def test_put_form_session_auth_passing(self): @@ -99,14 +99,14 @@ class SessionAuthTests(TestCase):          Ensure PUTting form over session authentication with logged in user and CSRF token passes.          """          self.non_csrf_client.login(username=self.username, password=self.password) -        response = self.non_csrf_client.put('/', {'example': 'example'}) +        response = self.non_csrf_client.put('/session/', {'example': 'example'})          self.assertEqual(response.status_code, 200)      def test_post_form_session_auth_failing(self):          """          Ensure POSTing form over session authentication without logged in user fails.          """ -        response = self.csrf_client.post('/', {'example': 'example'}) +        response = self.csrf_client.post('/session/', {'example': 'example'})          self.assertEqual(response.status_code, 403) @@ -127,24 +127,24 @@ class TokenAuthTests(TestCase):      def test_post_form_passing_token_auth(self):          """Ensure POSTing json over token auth with correct credentials passes and does not require CSRF"""          auth = "Token " + self.key -        response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) +        response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)          self.assertEqual(response.status_code, 200)      def test_post_json_passing_token_auth(self):          """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF"""          auth = "Token " + self.key -        response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) +        response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)          self.assertEqual(response.status_code, 200)      def test_post_form_failing_token_auth(self):          """Ensure POSTing form over token auth without correct credentials fails""" -        response = self.csrf_client.post('/', {'example': 'example'}) -        self.assertEqual(response.status_code, 403) +        response = self.csrf_client.post('/token/', {'example': 'example'}) +        self.assertEqual(response.status_code, 401)      def test_post_json_failing_token_auth(self):          """Ensure POSTing json over token auth without correct credentials fails""" -        response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json') -        self.assertEqual(response.status_code, 403) +        response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json') +        self.assertEqual(response.status_code, 401)      def test_token_has_auto_assigned_key_if_none_provided(self):          """Ensure creating a token with no key will auto-assign a key""" diff --git a/rest_framework/tests/decorators.py b/rest_framework/tests/decorators.py index 8079c8cb..82f912e9 100644 --- a/rest_framework/tests/decorators.py +++ b/rest_framework/tests/decorators.py @@ -1,5 +1,4 @@  from django.test import TestCase -from django.test.client import RequestFactory  from rest_framework import status  from rest_framework.response import Response  from rest_framework.renderers import JSONRenderer @@ -17,6 +16,8 @@ from rest_framework.decorators import (      permission_classes,  ) +from rest_framework.tests.utils import RequestFactory +  class DecoratorTestCase(TestCase): @@ -27,13 +28,27 @@ class DecoratorTestCase(TestCase):          response.request = request          return APIView.finalize_response(self, request, response, *args, **kwargs) -    def test_wrap_view(self): +    def test_api_view_incorrect(self): +        """ +        If @api_view is not applied correct, we should raise an assertion. +        """ -        @api_view(['GET']) +        @api_view          def view(request): -            return Response({}) +            return Response() + +        request = self.factory.get('/') +        self.assertRaises(AssertionError, view, request) -        self.assertTrue(isinstance(view.cls_instance, APIView)) +    def test_api_view_incorrect_arguments(self): +        """ +        If @api_view is missing arguments, we should raise an assertion. +        """ + +        with self.assertRaises(AssertionError): +            @api_view('GET') +            def view(request): +                return Response()      def test_calling_method(self): @@ -63,6 +78,20 @@ class DecoratorTestCase(TestCase):          response = view(request)          self.assertEqual(response.status_code, 405) +    def test_calling_patch_method(self): + +        @api_view(['GET', 'PATCH']) +        def view(request): +            return Response({}) + +        request = self.factory.patch('/') +        response = view(request) +        self.assertEqual(response.status_code, 200) + +        request = self.factory.post('/') +        response = view(request) +        self.assertEqual(response.status_code, 405) +      def test_renderer_classes(self):          @api_view(['GET']) diff --git a/rest_framework/tests/extras/__init__.py b/rest_framework/tests/extras/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/rest_framework/tests/extras/__init__.py diff --git a/rest_framework/tests/extras/bad_import.py b/rest_framework/tests/extras/bad_import.py new file mode 100644 index 00000000..68263d94 --- /dev/null +++ b/rest_framework/tests/extras/bad_import.py @@ -0,0 +1 @@ +raise ValueError diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py new file mode 100644 index 00000000..8068272d --- /dev/null +++ b/rest_framework/tests/fields.py @@ -0,0 +1,49 @@ +""" +General serializer field tests. +""" + +from django.db import models +from django.test import TestCase +from rest_framework import serializers + + +class TimestampedModel(models.Model): +    added = models.DateTimeField(auto_now_add=True) +    updated = models.DateTimeField(auto_now=True) + + +class CharPrimaryKeyModel(models.Model): +    id = models.CharField(max_length=20, primary_key=True) + + +class TimestampedModelSerializer(serializers.ModelSerializer): +    class Meta: +        model = TimestampedModel + + +class CharPrimaryKeyModelSerializer(serializers.ModelSerializer): +    class Meta: +        model = CharPrimaryKeyModel + + +class ReadOnlyFieldTests(TestCase): +    def test_auto_now_fields_read_only(self): +        """ +        auto_now and auto_now_add fields should be read_only by default. +        """ +        serializer = TimestampedModelSerializer() +        self.assertEquals(serializer.fields['added'].read_only, True) + +    def test_auto_pk_fields_read_only(self): +        """ +        AutoField fields should be read_only by default. +        """ +        serializer = TimestampedModelSerializer() +        self.assertEquals(serializer.fields['id'].read_only, True) + +    def test_non_auto_pk_fields_not_read_only(self): +        """ +        PK fields other than AutoField fields should not be read_only by default. +        """ +        serializer = CharPrimaryKeyModelSerializer() +        self.assertEquals(serializer.fields['id'].read_only, False) diff --git a/rest_framework/tests/files.py b/rest_framework/tests/files.py index ca6bc905..0434f900 100644 --- a/rest_framework/tests/files.py +++ b/rest_framework/tests/files.py @@ -26,7 +26,6 @@ class UploadedFileSerializer(serializers.Serializer):  class FileSerializerTests(TestCase): -      def test_create(self):          now = datetime.datetime.now()          file = BytesIO(six.b('stuff')) @@ -38,3 +37,16 @@ class FileSerializerTests(TestCase):          self.assertEquals(serializer.object.created, uploaded_file.created)          self.assertEquals(serializer.object.file, uploaded_file.file)          self.assertFalse(serializer.object is uploaded_file) + +    def test_creation_failure(self): +        """ +        Passing files=None should result in an ValidationError + +        Regression test for: +        https://github.com/tomchristie/django-rest-framework/issues/542 +        """ +        now = datetime.datetime.now() + +        serializer = UploadedFileSerializer(data={'created': now}) +        self.assertFalse(serializer.is_valid()) +        self.assertIn('file', serializer.errors) diff --git a/rest_framework/tests/genericrelations.py b/rest_framework/tests/genericrelations.py index ba29dbed..72070a1a 100644 --- a/rest_framework/tests/genericrelations.py +++ b/rest_framework/tests/genericrelations.py @@ -1,27 +1,63 @@  from __future__ import unicode_literals +from django.contrib.contenttypes.models import ContentType +from django.contrib.contenttypes.generic import GenericRelation, GenericForeignKey +from django.db import models  from django.test import TestCase  from rest_framework import serializers -from rest_framework.tests.models import * + + +class Tag(models.Model): +    """ +    Tags have a descriptive slug, and are attached to an arbitrary object. +    """ +    tag = models.SlugField() +    content_type = models.ForeignKey(ContentType) +    object_id = models.PositiveIntegerField() +    tagged_item = GenericForeignKey('content_type', 'object_id') + +    def __unicode__(self): +        return self.tag + + +class Bookmark(models.Model): +    """ +    A URL bookmark that may have multiple tags attached. +    """ +    url = models.URLField() +    tags = GenericRelation(Tag) + +    def __unicode__(self): +        return 'Bookmark: %s' % self.url + + +class Note(models.Model): +    """ +    A textual note that may have multiple tags attached. +    """ +    text = models.TextField() +    tags = GenericRelation(Tag) + +    def __unicode__(self): +        return 'Note: %s' % self.text  class TestGenericRelations(TestCase):      def setUp(self): -        bookmark = Bookmark(url='https://www.djangoproject.com/') -        bookmark.save() -        django = Tag(tag_name='django') -        django.save() -        python = Tag(tag_name='python') -        python.save() -        t1 = TaggedItem(content_object=bookmark, tag=django) -        t1.save() -        t2 = TaggedItem(content_object=bookmark, tag=python) -        t2.save() -        self.bookmark = bookmark - -    def test_reverse_generic_relation(self): +        self.bookmark = Bookmark.objects.create(url='https://www.djangoproject.com/') +        Tag.objects.create(tagged_item=self.bookmark, tag='django') +        Tag.objects.create(tagged_item=self.bookmark, tag='python') +        self.note = Note.objects.create(text='Remember the milk') +        Tag.objects.create(tagged_item=self.note, tag='reminder') + +    def test_generic_relation(self): +        """ +        Test a relationship that spans a GenericRelation field. +        IE. A reverse generic relationship. +        """ +          class BookmarkSerializer(serializers.ModelSerializer): -            tags = serializers.ManyRelatedField(source='tags') +            tags = serializers.ManyRelatedField()              class Meta:                  model = Bookmark @@ -33,3 +69,33 @@ class TestGenericRelations(TestCase):              'url': 'https://www.djangoproject.com/'          }          self.assertEquals(serializer.data, expected) + +    def test_generic_fk(self): +        """ +        Test a relationship that spans a GenericForeignKey field. +        IE. A forward generic relationship. +        """ + +        class TagSerializer(serializers.ModelSerializer): +            tagged_item = serializers.RelatedField() + +            class Meta: +                model = Tag +                exclude = ('id', 'content_type', 'object_id') + +        serializer = TagSerializer(Tag.objects.all()) +        expected = [ +        { +            'tag': u'django', +            'tagged_item': u'Bookmark: https://www.djangoproject.com/' +        }, +        { +            'tag': u'python', +            'tagged_item': u'Bookmark: https://www.djangoproject.com/' +        }, +        { +            'tag': u'reminder', +            'tagged_item': u'Note: Remember the milk' +        } +        ] +        self.assertEquals(serializer.data, expected) diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py index 215de0c4..fd01312a 100644 --- a/rest_framework/tests/generics.py +++ b/rest_framework/tests/generics.py @@ -1,12 +1,11 @@  from __future__ import unicode_literals -  from django.db import models  from django.test import TestCase -from django.test.client import RequestFactory -from django.utils import simplejson as json  from rest_framework import generics, serializers, status +from rest_framework.tests.utils import RequestFactory  from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel  from rest_framework.compat import six +import json  factory = RequestFactory() @@ -183,6 +182,20 @@ class TestInstanceView(TestCase):          updated = self.objects.get(id=1)          self.assertEquals(updated.text, 'foobar') +    def test_patch_instance_view(self): +        """ +        PATCH requests to RetrieveUpdateDestroyAPIView should update an object. +        """ +        content = {'text': 'foobar'} +        request = factory.patch('/1', json.dumps(content), +                              content_type='application/json') + +        response = self.view(request, pk=1).render() +        self.assertEquals(response.status_code, status.HTTP_200_OK) +        self.assertEquals(response.data, {'id': 1, 'text': 'foobar'}) +        updated = self.objects.get(id=1) +        self.assertEquals(updated.text, 'foobar') +      def test_delete_instance_view(self):          """          DELETE requests to RetrieveUpdateDestroyAPIView should delete an object. diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py index ee4d8e57..c6a8224b 100644 --- a/rest_framework/tests/hyperlinkedserializers.py +++ b/rest_framework/tests/hyperlinkedserializers.py @@ -1,6 +1,6 @@ +import json  from django.test import TestCase  from django.test.client import RequestFactory -from django.utils import simplejson as json  from rest_framework import generics, status, serializers  from rest_framework.compat import patterns, url  from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, Album, Photo, OptionalRelationModel diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py index 0759650a..9ab15328 100644 --- a/rest_framework/tests/models.py +++ b/rest_framework/tests/models.py @@ -71,6 +71,7 @@ class SlugBasedModel(RESTFrameworkModel):  class DefaultValueModel(RESTFrameworkModel):      text = models.CharField(default='foobar', max_length=100) +    extra = models.CharField(blank=True, null=True, max_length=100)  class CallableDefaultValueModel(RESTFrameworkModel): @@ -85,27 +86,6 @@ class ReadOnlyManyToManyModel(RESTFrameworkModel):      text = models.CharField(max_length=100, default='anchor')      rel = models.ManyToManyField(Anchor) -# Models to test generic relations - - -class Tag(RESTFrameworkModel): -    tag_name = models.SlugField() - - -class TaggedItem(RESTFrameworkModel): -    tag = models.ForeignKey(Tag, related_name='items') -    content_type = models.ForeignKey(ContentType) -    object_id = models.PositiveIntegerField() -    content_object = GenericForeignKey('content_type', 'object_id') - -    def __unicode__(self): -        return self.tag.tag_name - - -class Bookmark(RESTFrameworkModel): -    url = models.URLField() -    tags = GenericRelation(TaggedItem) -  # Model to test filtering.  class FilterableItem(RESTFrameworkModel): @@ -176,3 +156,42 @@ class OptionalRelationModel(RESTFrameworkModel):  # Model for RegexField  class Book(RESTFrameworkModel):      isbn = models.CharField(max_length=13) + + +# Models for relations tests +# ManyToMany +class ManyToManyTarget(RESTFrameworkModel): +    name = models.CharField(max_length=100) + + +class ManyToManySource(RESTFrameworkModel): +    name = models.CharField(max_length=100) +    targets = models.ManyToManyField(ManyToManyTarget, related_name='sources') + + +# ForeignKey +class ForeignKeyTarget(RESTFrameworkModel): +    name = models.CharField(max_length=100) + + +class ForeignKeySource(RESTFrameworkModel): +    name = models.CharField(max_length=100) +    target = models.ForeignKey(ForeignKeyTarget, related_name='sources') + + +# Nullable ForeignKey +class NullableForeignKeySource(RESTFrameworkModel): +    name = models.CharField(max_length=100) +    target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True, +                               related_name='nullable_sources') + + +# OneToOne +class OneToOneTarget(RESTFrameworkModel): +    name = models.CharField(max_length=100) + + +class NullableOneToOneSource(RESTFrameworkModel): +    name = models.CharField(max_length=100) +    target = models.OneToOneField(OneToOneTarget, null=True, blank=True, +                                  related_name='nullable_source') diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py index 81d297a1..697dfb5b 100644 --- a/rest_framework/tests/pagination.py +++ b/rest_framework/tests/pagination.py @@ -181,10 +181,10 @@ class UnitTestPagination(TestCase):          """          Ensure context gets passed through to the object serializer.          """ -        serializer = PassOnContextPaginationSerializer(self.first_page) +        serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'})          serializer.data          results = serializer.fields[serializer.results_field] -        self.assertTrue(serializer.context is results.context) +        self.assertEquals(serializer.context, results.context)  class TestUnpaginated(TestCase): @@ -252,6 +252,8 @@ class TestCustomPaginateByParam(TestCase):          self.assertEquals(response.data['results'], self.data[:5]) +### Tests for context in pagination serializers +  class CustomField(serializers.Field):      def to_native(self, value):          if not 'view' in self.context: @@ -262,6 +264,11 @@ class CustomField(serializers.Field):  class BasicModelSerializer(serializers.Serializer):      text = CustomField() +    def __init__(self, *args, **kwargs): +        super(BasicModelSerializer, self).__init__(*args, **kwargs) +        if not 'view' in self.context: +            raise RuntimeError("context isn't getting passed into serializer init") +  class TestContextPassedToCustomField(TestCase):      def setUp(self): @@ -279,3 +286,39 @@ class TestContextPassedToCustomField(TestCase):          self.assertEquals(response.status_code, status.HTTP_200_OK) + +### Tests for custom pagination serializers + +class LinksSerializer(serializers.Serializer): +    next = pagination.NextPageField(source='*') +    prev = pagination.PreviousPageField(source='*') + + +class CustomPaginationSerializer(pagination.BasePaginationSerializer): +    links = LinksSerializer(source='*')  # Takes the page object as the source +    total_results = serializers.Field(source='paginator.count') + +    results_field = 'objects' + + +class TestCustomPaginationSerializer(TestCase): +    def setUp(self): +        objects = ['john', 'paul', 'george', 'ringo'] +        paginator = Paginator(objects, 2) +        self.page = paginator.page(1) + +    def test_custom_pagination_serializer(self): +        request = RequestFactory().get('/foobar') +        serializer = CustomPaginationSerializer( +            instance=self.page, +            context={'request': request} +        ) +        expected = { +            'links': { +                'next': 'http://testserver/foobar?page=2', +                'prev': None +            }, +            'total_results': 4, +            'objects': ['john', 'paul'] +        } +        self.assertEquals(serializer.data, expected) diff --git a/rest_framework/tests/relations.py b/rest_framework/tests/relations.py new file mode 100644 index 00000000..edc85f9e --- /dev/null +++ b/rest_framework/tests/relations.py @@ -0,0 +1,47 @@ +""" +General tests for relational fields. +""" + +from django.db import models +from django.test import TestCase +from rest_framework import serializers + + +class NullModel(models.Model): +    pass + + +class FieldTests(TestCase): +    def test_pk_related_field_with_empty_string(self): +        """ +        Regression test for #446 + +        https://github.com/tomchristie/django-rest-framework/issues/446 +        """ +        field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all()) +        self.assertRaises(serializers.ValidationError, field.from_native, '') +        self.assertRaises(serializers.ValidationError, field.from_native, []) + +    def test_hyperlinked_related_field_with_empty_string(self): +        field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='') +        self.assertRaises(serializers.ValidationError, field.from_native, '') +        self.assertRaises(serializers.ValidationError, field.from_native, []) + +    def test_slug_related_field_with_empty_string(self): +        field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk') +        self.assertRaises(serializers.ValidationError, field.from_native, '') +        self.assertRaises(serializers.ValidationError, field.from_native, []) + + +class TestManyRelateMixin(TestCase): +    def test_missing_many_to_many_related_field(self): +        ''' +        Regression test for #632 + +        https://github.com/tomchristie/django-rest-framework/pull/632 +        ''' +        field = serializers.ManyRelatedField(read_only=False) + +        into = {} +        field.field_from_native({}, None, 'field_name', into) +        self.assertEqual(into['field_name'], []) diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/relations_hyperlink.py index 407c04e0..b4ad3166 100644 --- a/rest_framework/tests/relations_hyperlink.py +++ b/rest_framework/tests/relations_hyperlink.py @@ -1,9 +1,9 @@  from __future__ import unicode_literals -from django.db import models  from django.test import TestCase  from rest_framework import serializers  from rest_framework.compat import patterns, url +from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource  def dummy_view(request, pk): @@ -15,20 +15,11 @@ urlpatterns = patterns('',      url(r'^foreignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeysource-detail'),      url(r'^foreignkeytarget/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeytarget-detail'),      url(r'^nullableforeignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableforeignkeysource-detail'), +    url(r'^onetoonetarget/(?P<pk>[0-9]+)/$', dummy_view, name='onetoonetarget-detail'), +    url(r'^nullableonetoonesource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableonetoonesource-detail'),  ) -# ManyToMany - -class ManyToManyTarget(models.Model): -    name = models.CharField(max_length=100) - - -class ManyToManySource(models.Model): -    name = models.CharField(max_length=100) -    targets = models.ManyToManyField(ManyToManyTarget, related_name='sources') - -  class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer):      sources = serializers.ManyHyperlinkedRelatedField(view_name='manytomanysource-detail') @@ -41,17 +32,6 @@ class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer):          model = ManyToManySource -# ForeignKey - -class ForeignKeyTarget(models.Model): -    name = models.CharField(max_length=100) - - -class ForeignKeySource(models.Model): -    name = models.CharField(max_length=100) -    target = models.ForeignKey(ForeignKeyTarget, related_name='sources') - -  class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer):      sources = serializers.ManyHyperlinkedRelatedField(view_name='foreignkeysource-detail') @@ -65,16 +45,17 @@ class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):  # Nullable ForeignKey +class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): +    class Meta: +        model = NullableForeignKeySource -class NullableForeignKeySource(models.Model): -    name = models.CharField(max_length=100) -    target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True, -                               related_name='nullable_sources') +# OneToOne +class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer): +    nullable_source = serializers.HyperlinkedRelatedField(view_name='nullableonetoonesource-detail') -class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):      class Meta: -        model = NullableForeignKeySource +        model = OneToOneTarget  # TODO: Add test that .data cannot be accessed prior to .is_valid @@ -236,6 +217,13 @@ class HyperlinkedForeignKeyTests(TestCase):          ]          self.assertEquals(serializer.data, expected) +    def test_foreign_key_update_incorrect_type(self): +        data = {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': 2} +        instance = ForeignKeySource.objects.get(pk=1) +        serializer = ForeignKeySourceSerializer(instance, data=data) +        self.assertFalse(serializer.is_valid()) +        self.assertEquals(serializer.errors, {'target': [u'Incorrect type.  Expected url string, received int.']}) +      def test_reverse_foreign_key_update(self):          data = {'url': '/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']}          instance = ForeignKeyTarget.objects.get(pk=2) @@ -248,7 +236,7 @@ class HyperlinkedForeignKeyTests(TestCase):          expected = [              {'url': '/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/2/', '/foreignkeysource/3/']},              {'url': '/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, -        ]         +        ]          self.assertEquals(new_serializer.data, expected)          serializer.save() @@ -434,3 +422,24 @@ class HyperlinkedNullableForeignKeyTests(TestCase):      #         {'id': 2, 'name': 'target-2', 'sources': []},      #     ]      #     self.assertEquals(serializer.data, expected) + + +class HyperlinkedNullableOneToOneTests(TestCase): +    urls = 'rest_framework.tests.relations_hyperlink' + +    def setUp(self): +        target = OneToOneTarget(name='target-1') +        target.save() +        new_target = OneToOneTarget(name='target-2') +        new_target.save() +        source = NullableOneToOneSource(name='source-1', target=target) +        source.save() + +    def test_reverse_foreign_key_retrieve_with_null(self): +        queryset = OneToOneTarget.objects.all() +        serializer = NullableOneToOneTargetSerializer(queryset) +        expected = [ +            {'url': '/onetoonetarget/1/', 'name': u'target-1', 'nullable_source': '/nullableonetoonesource/1/'}, +            {'url': '/onetoonetarget/2/', 'name': u'target-2', 'nullable_source': None}, +        ] +        self.assertEquals(serializer.data, expected) diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py index 442cbebe..e81f0e42 100644 --- a/rest_framework/tests/relations_nested.py +++ b/rest_framework/tests/relations_nested.py @@ -1,19 +1,7 @@  from __future__ import unicode_literals - -from django.db import models  from django.test import TestCase  from rest_framework import serializers - - -# ForeignKey - -class ForeignKeyTarget(models.Model): -    name = models.CharField(max_length=100) - - -class ForeignKeySource(models.Model): -    name = models.CharField(max_length=100) -    target = models.ForeignKey(ForeignKeyTarget, related_name='sources') +from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource  class ForeignKeySourceSerializer(serializers.ModelSerializer): @@ -34,20 +22,24 @@ class ForeignKeyTargetSerializer(serializers.ModelSerializer):          model = ForeignKeyTarget -# Nullable ForeignKey - -class NullableForeignKeySource(models.Model): -    name = models.CharField(max_length=100) -    target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True, -                               related_name='nullable_sources') - -  class NullableForeignKeySourceSerializer(serializers.ModelSerializer):      class Meta:          depth = 1          model = NullableForeignKeySource +class NullableOneToOneSourceSerializer(serializers.ModelSerializer): +    class Meta: +        model = NullableOneToOneSource + + +class NullableOneToOneTargetSerializer(serializers.ModelSerializer): +    nullable_source = NullableOneToOneSourceSerializer() + +    class Meta: +        model = OneToOneTarget + +  class ReverseForeignKeyTests(TestCase):      def setUp(self):          target = ForeignKeyTarget(name='target-1') @@ -102,3 +94,22 @@ class NestedNullableForeignKeyTests(TestCase):              {'id': 3, 'name': 'source-3', 'target': None},          ]          self.assertEquals(serializer.data, expected) + + +class NestedNullableOneToOneTests(TestCase): +    def setUp(self): +        target = OneToOneTarget(name='target-1') +        target.save() +        new_target = OneToOneTarget(name='target-2') +        new_target.save() +        source = NullableOneToOneSource(name='source-1', target=target) +        source.save() + +    def test_reverse_foreign_key_retrieve_with_null(self): +        queryset = OneToOneTarget.objects.all() +        serializer = NullableOneToOneTargetSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'target-1', 'nullable_source': {'id': 1, 'name': u'source-1', 'target': 1}}, +            {'id': 2, 'name': u'target-2', 'nullable_source': None}, +        ] +        self.assertEquals(serializer.data, expected) diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/relations_pk.py index a04c5c80..4d00795a 100644 --- a/rest_framework/tests/relations_pk.py +++ b/rest_framework/tests/relations_pk.py @@ -3,17 +3,7 @@ from __future__ import unicode_literals  from django.db import models  from django.test import TestCase  from rest_framework import serializers - - -# ManyToMany - -class ManyToManyTarget(models.Model): -    name = models.CharField(max_length=100) - - -class ManyToManySource(models.Model): -    name = models.CharField(max_length=100) -    targets = models.ManyToManyField(ManyToManyTarget, related_name='sources') +from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource  class ManyToManyTargetSerializer(serializers.ModelSerializer): @@ -28,17 +18,6 @@ class ManyToManySourceSerializer(serializers.ModelSerializer):          model = ManyToManySource -# ForeignKey - -class ForeignKeyTarget(models.Model): -    name = models.CharField(max_length=100) - - -class ForeignKeySource(models.Model): -    name = models.CharField(max_length=100) -    target = models.ForeignKey(ForeignKeyTarget, related_name='sources') - -  class ForeignKeyTargetSerializer(serializers.ModelSerializer):      sources = serializers.ManyPrimaryKeyRelatedField() @@ -51,17 +30,17 @@ class ForeignKeySourceSerializer(serializers.ModelSerializer):          model = ForeignKeySource -# Nullable ForeignKey +class NullableForeignKeySourceSerializer(serializers.ModelSerializer): +    class Meta: +        model = NullableForeignKeySource -class NullableForeignKeySource(models.Model): -    name = models.CharField(max_length=100) -    target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True, -                               related_name='nullable_sources') +# OneToOne +class NullableOneToOneTargetSerializer(serializers.ModelSerializer): +    nullable_source = serializers.PrimaryKeyRelatedField() -class NullableForeignKeySourceSerializer(serializers.ModelSerializer):      class Meta: -        model = NullableForeignKeySource +        model = OneToOneTarget  # TODO: Add test that .data cannot be accessed prior to .is_valid @@ -218,6 +197,13 @@ class PKForeignKeyTests(TestCase):          ]          self.assertEquals(serializer.data, expected) +    def test_foreign_key_update_incorrect_type(self): +        data = {'id': 1, 'name': u'source-1', 'target': 'foo'} +        instance = ForeignKeySource.objects.get(pk=1) +        serializer = ForeignKeySourceSerializer(instance, data=data) +        self.assertFalse(serializer.is_valid()) +        self.assertEquals(serializer.errors, {'target': [u'Incorrect type.  Expected pk value, received str.']}) +      def test_reverse_foreign_key_update(self):          data = {'id': 2, 'name': 'target-2', 'sources': [1, 3]}          instance = ForeignKeyTarget.objects.get(pk=2) @@ -230,7 +216,7 @@ class PKForeignKeyTests(TestCase):          expected = [              {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},              {'id': 2, 'name': 'target-2', 'sources': []}, -        ]         +        ]          self.assertEquals(new_serializer.data, expected)          serializer.save() @@ -414,3 +400,22 @@ class PKNullableForeignKeyTests(TestCase):      #         {'id': 2, 'name': 'target-2', 'sources': []},      #     ]      #     self.assertEquals(serializer.data, expected) + + +class PKNullableOneToOneTests(TestCase): +    def setUp(self): +        target = OneToOneTarget(name='target-1') +        target.save() +        new_target = OneToOneTarget(name='target-2') +        new_target.save() +        source = NullableOneToOneSource(name='source-1', target=target) +        source.save() + +    def test_reverse_foreign_key_retrieve_with_null(self): +        queryset = OneToOneTarget.objects.all() +        serializer = NullableOneToOneTargetSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'target-1', 'nullable_source': 1}, +            {'id': 2, 'name': u'target-2', 'nullable_source': None}, +        ] +        self.assertEquals(serializer.data, expected) diff --git a/rest_framework/tests/relations_slug.py b/rest_framework/tests/relations_slug.py new file mode 100644 index 00000000..37ccc75e --- /dev/null +++ b/rest_framework/tests/relations_slug.py @@ -0,0 +1,257 @@ +from django.test import TestCase +from rest_framework import serializers +from rest_framework.tests.models import NullableForeignKeySource, ForeignKeySource, ForeignKeyTarget + + +class ForeignKeyTargetSerializer(serializers.ModelSerializer): +    sources = serializers.ManySlugRelatedField(slug_field='name') + +    class Meta: +        model = ForeignKeyTarget + + +class ForeignKeySourceSerializer(serializers.ModelSerializer): +    target = serializers.SlugRelatedField(slug_field='name') + +    class Meta: +        model = ForeignKeySource + + +class NullableForeignKeySourceSerializer(serializers.ModelSerializer): +    target = serializers.SlugRelatedField(slug_field='name', null=True) + +    class Meta: +        model = NullableForeignKeySource + + +# TODO: M2M Tests, FKTests (Non-nulable), One2One +class PKForeignKeyTests(TestCase): +    def setUp(self): +        target = ForeignKeyTarget(name='target-1') +        target.save() +        new_target = ForeignKeyTarget(name='target-2') +        new_target.save() +        for idx in range(1, 4): +            source = ForeignKeySource(name='source-%d' % idx, target=target) +            source.save() + +    def test_foreign_key_retrieve(self): +        queryset = ForeignKeySource.objects.all() +        serializer = ForeignKeySourceSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'source-1', 'target': 'target-1'}, +            {'id': 2, 'name': u'source-2', 'target': 'target-1'}, +            {'id': 3, 'name': u'source-3', 'target': 'target-1'} +        ] +        self.assertEquals(serializer.data, expected) + +    def test_reverse_foreign_key_retrieve(self): +        queryset = ForeignKeyTarget.objects.all() +        serializer = ForeignKeyTargetSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, +            {'id': 2, 'name': u'target-2', 'sources': []}, +        ] +        self.assertEquals(serializer.data, expected) + +    def test_foreign_key_update(self): +        data = {'id': 1, 'name': u'source-1', 'target': 'target-2'} +        instance = ForeignKeySource.objects.get(pk=1) +        serializer = ForeignKeySourceSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        self.assertEquals(serializer.data, data) +        serializer.save() + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = ForeignKeySource.objects.all() +        serializer = ForeignKeySourceSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'source-1', 'target': 'target-2'}, +            {'id': 2, 'name': u'source-2', 'target': 'target-1'}, +            {'id': 3, 'name': u'source-3', 'target': 'target-1'} +        ] +        self.assertEquals(serializer.data, expected) + +    def test_foreign_key_update_incorrect_type(self): +        data = {'id': 1, 'name': u'source-1', 'target': 123} +        instance = ForeignKeySource.objects.get(pk=1) +        serializer = ForeignKeySourceSerializer(instance, data=data) +        self.assertFalse(serializer.is_valid()) +        self.assertEquals(serializer.errors, {'target': [u'Object with name=123 does not exist.']}) + +    def test_reverse_foreign_key_update(self): +        data = {'id': 2, 'name': u'target-2', 'sources': ['source-1', 'source-3']} +        instance = ForeignKeyTarget.objects.get(pk=2) +        serializer = ForeignKeyTargetSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        # We shouldn't have saved anything to the db yet since save +        # hasn't been called. +        queryset = ForeignKeyTarget.objects.all() +        new_serializer = ForeignKeyTargetSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, +            {'id': 2, 'name': u'target-2', 'sources': []}, +        ] +        self.assertEquals(new_serializer.data, expected) + +        serializer.save() +        self.assertEquals(serializer.data, data) + +        # Ensure target 2 is update, and everything else is as expected +        queryset = ForeignKeyTarget.objects.all() +        serializer = ForeignKeyTargetSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'target-1', 'sources': ['source-2']}, +            {'id': 2, 'name': u'target-2', 'sources': ['source-1', 'source-3']}, +        ] +        self.assertEquals(serializer.data, expected) + +    def test_foreign_key_create(self): +        data = {'id': 4, 'name': u'source-4', 'target': 'target-2'} +        serializer = ForeignKeySourceSerializer(data=data) +        serializer.is_valid() +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEquals(serializer.data, data) +        self.assertEqual(obj.name, u'source-4') + +        # Ensure source 4 is added, and everything else is as expected +        queryset = ForeignKeySource.objects.all() +        serializer = ForeignKeySourceSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'source-1', 'target': 'target-1'}, +            {'id': 2, 'name': u'source-2', 'target': 'target-1'}, +            {'id': 3, 'name': u'source-3', 'target': 'target-1'}, +            {'id': 4, 'name': u'source-4', 'target': 'target-2'}, +        ] +        self.assertEquals(serializer.data, expected) + +    def test_reverse_foreign_key_create(self): +        data = {'id': 3, 'name': u'target-3', 'sources': ['source-1', 'source-3']} +        serializer = ForeignKeyTargetSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEquals(serializer.data, data) +        self.assertEqual(obj.name, u'target-3') + +        # Ensure target 3 is added, and everything else is as expected +        queryset = ForeignKeyTarget.objects.all() +        serializer = ForeignKeyTargetSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'target-1', 'sources': ['source-2']}, +            {'id': 2, 'name': u'target-2', 'sources': []}, +            {'id': 3, 'name': u'target-3', 'sources': ['source-1', 'source-3']}, +        ] +        self.assertEquals(serializer.data, expected) + +    def test_foreign_key_update_with_invalid_null(self): +        data = {'id': 1, 'name': u'source-1', 'target': None} +        instance = ForeignKeySource.objects.get(pk=1) +        serializer = ForeignKeySourceSerializer(instance, data=data) +        self.assertFalse(serializer.is_valid()) +        self.assertEquals(serializer.errors, {'target': [u'Value may not be null']}) + + +class SlugNullableForeignKeyTests(TestCase): +    def setUp(self): +        target = ForeignKeyTarget(name='target-1') +        target.save() +        for idx in range(1, 4): +            if idx == 3: +                target = None +            source = NullableForeignKeySource(name='source-%d' % idx, target=target) +            source.save() + +    def test_foreign_key_retrieve_with_null(self): +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'source-1', 'target': 'target-1'}, +            {'id': 2, 'name': u'source-2', 'target': 'target-1'}, +            {'id': 3, 'name': u'source-3', 'target': None}, +        ] +        self.assertEquals(serializer.data, expected) + +    def test_foreign_key_create_with_valid_null(self): +        data = {'id': 4, 'name': u'source-4', 'target': None} +        serializer = NullableForeignKeySourceSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEquals(serializer.data, data) +        self.assertEqual(obj.name, u'source-4') + +        # Ensure source 4 is created, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'source-1', 'target': 'target-1'}, +            {'id': 2, 'name': u'source-2', 'target': 'target-1'}, +            {'id': 3, 'name': u'source-3', 'target': None}, +            {'id': 4, 'name': u'source-4', 'target': None} +        ] +        self.assertEquals(serializer.data, expected) + +    def test_foreign_key_create_with_valid_emptystring(self): +        """ +        The emptystring should be interpreted as null in the context +        of relationships. +        """ +        data = {'id': 4, 'name': u'source-4', 'target': ''} +        expected_data = {'id': 4, 'name': u'source-4', 'target': None} +        serializer = NullableForeignKeySourceSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEquals(serializer.data, expected_data) +        self.assertEqual(obj.name, u'source-4') + +        # Ensure source 4 is created, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'source-1', 'target': 'target-1'}, +            {'id': 2, 'name': u'source-2', 'target': 'target-1'}, +            {'id': 3, 'name': u'source-3', 'target': None}, +            {'id': 4, 'name': u'source-4', 'target': None} +        ] +        self.assertEquals(serializer.data, expected) + +    def test_foreign_key_update_with_valid_null(self): +        data = {'id': 1, 'name': u'source-1', 'target': None} +        instance = NullableForeignKeySource.objects.get(pk=1) +        serializer = NullableForeignKeySourceSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        self.assertEquals(serializer.data, data) +        serializer.save() + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'source-1', 'target': None}, +            {'id': 2, 'name': u'source-2', 'target': 'target-1'}, +            {'id': 3, 'name': u'source-3', 'target': None} +        ] +        self.assertEquals(serializer.data, expected) + +    def test_foreign_key_update_with_valid_emptystring(self): +        """ +        The emptystring should be interpreted as null in the context +        of relationships. +        """ +        data = {'id': 1, 'name': u'source-1', 'target': ''} +        expected_data = {'id': 1, 'name': u'source-1', 'target': None} +        instance = NullableForeignKeySource.objects.get(pk=1) +        serializer = NullableForeignKeySourceSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        self.assertEquals(serializer.data, expected_data) +        serializer.save() + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'source-1', 'target': None}, +            {'id': 2, 'name': u'source-2', 'target': 'target-1'}, +            {'id': 3, 'name': u'source-3', 'target': None} +        ] +        self.assertEquals(serializer.data, expected) diff --git a/rest_framework/tests/request.py b/rest_framework/tests/request.py index 7d4575bb..92b1bfd8 100644 --- a/rest_framework/tests/request.py +++ b/rest_framework/tests/request.py @@ -1,12 +1,12 @@  """  Tests for content parsing, and form-overloaded content parsing.  """ +import json  from django.contrib.auth.models import User  from django.contrib.auth import authenticate, login, logout  from django.contrib.sessions.middleware import SessionMiddleware  from django.test import TestCase, Client  from django.test.client import RequestFactory -from django.utils import simplejson as json  from rest_framework import status  from rest_framework.authentication import SessionAuthentication  from rest_framework.compat import patterns diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py index 6ce7de31..a00626b5 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/serializer.py @@ -56,6 +56,19 @@ class ActionItemSerializer(serializers.ModelSerializer):          model = ActionItem +class ActionItemSerializerCustomRestore(serializers.ModelSerializer): + +    class Meta: +        model = ActionItem + +    def restore_object(self, data, instance=None): +        if instance is None: +            return ActionItem(**data) +        for key, val in data.items(): +            setattr(instance, key, val) +        return instance + +  class PersonSerializer(serializers.ModelSerializer):      info = serializers.Field(source='info') @@ -71,6 +84,7 @@ class AlbumsSerializer(serializers.ModelSerializer):          model = Album          fields = ['title']  # lists are also valid options +  class PositiveIntegerAsChoiceSerializer(serializers.ModelSerializer):      class Meta:          model = HasPositiveIntegerAsChoice @@ -163,7 +177,6 @@ class BasicTests(TestCase):          """          Attempting to update fields set as read_only should have no effect.          """ -          serializer = PersonSerializer(self.person, data={'name': 'dwight', 'age': 99})          self.assertEquals(serializer.is_valid(), True)          instance = serializer.save() @@ -184,8 +197,7 @@ class ValidationTests(TestCase):              'content': 'x' * 1001,              'created': datetime.datetime(2012, 1, 1)          } -        self.actionitem = ActionItem(title='Some to do item', -        ) +        self.actionitem = ActionItem(title='Some to do item',)      def test_create(self):          serializer = CommentSerializer(data=self.data) @@ -217,30 +229,24 @@ class ValidationTests(TestCase):          self.assertEquals(serializer.is_valid(), True)          self.assertEquals(serializer.errors, {}) -    def test_field_validation(self): - -        class CommentSerializerWithFieldValidator(CommentSerializer): - -            def validate_content(self, attrs, source): -                value = attrs[source] -                if "test" not in value: -                    raise serializers.ValidationError("Test not in value") -                return attrs - -        data = { -            'email': 'tom@example.com', -            'content': 'A test comment', -            'created': datetime.datetime(2012, 1, 1) -        } - -        serializer = CommentSerializerWithFieldValidator(data=data) -        self.assertTrue(serializer.is_valid()) +    def test_bad_type_data_is_false(self): +        """ +        Data of the wrong type is not valid. +        """ +        data = ['i am', 'a', 'list'] +        serializer = CommentSerializer(self.comment, data=data) +        self.assertEquals(serializer.is_valid(), False) +        self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']}) -        data['content'] = 'This should not validate' +        data = 'and i am a string' +        serializer = CommentSerializer(self.comment, data=data) +        self.assertEquals(serializer.is_valid(), False) +        self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']}) -        serializer = CommentSerializerWithFieldValidator(data=data) -        self.assertFalse(serializer.is_valid()) -        self.assertEquals(serializer.errors, {'content': ['Test not in value']}) +        data = 42 +        serializer = CommentSerializer(self.comment, data=data) +        self.assertEquals(serializer.is_valid(), False) +        self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']})      def test_cross_field_validation(self): @@ -282,6 +288,20 @@ class ValidationTests(TestCase):          self.assertEquals(serializer.is_valid(), False)          self.assertEquals(serializer.errors, {'title': ['Ensure this value has at most 200 characters (it has 201).']}) +    def test_modelserializer_max_length_exceeded_with_custom_restore(self): +        """ +        When overriding ModelSerializer.restore_object, validation tests should still apply. +        Regression test for #623. + +        https://github.com/tomchristie/django-rest-framework/pull/623 +        """ +        data = { +            'title': 'x' * 201, +        } +        serializer = ActionItemSerializerCustomRestore(data=data) +        self.assertEquals(serializer.is_valid(), False) +        self.assertEquals(serializer.errors, {'title': [u'Ensure this value has at most 200 characters (it has 201).']}) +      def test_default_modelfield_max_length_exceeded(self):          data = {              'title': 'Testing "info" field...', @@ -292,12 +312,69 @@ class ValidationTests(TestCase):          self.assertEquals(serializer.errors, {'info': ['Ensure this value has at most 12 characters (it has 13).']}) +class CustomValidationTests(TestCase): +    class CommentSerializerWithFieldValidator(CommentSerializer): + +        def validate_email(self, attrs, source): +            value = attrs[source] + +            return attrs + +        def validate_content(self, attrs, source): +            value = attrs[source] +            if "test" not in value: +                raise serializers.ValidationError("Test not in value") +            return attrs + +    def test_field_validation(self): +        data = { +            'email': 'tom@example.com', +            'content': 'A test comment', +            'created': datetime.datetime(2012, 1, 1) +        } + +        serializer = self.CommentSerializerWithFieldValidator(data=data) +        self.assertTrue(serializer.is_valid()) + +        data['content'] = 'This should not validate' + +        serializer = self.CommentSerializerWithFieldValidator(data=data) +        self.assertFalse(serializer.is_valid()) +        self.assertEquals(serializer.errors, {'content': [u'Test not in value']}) + +    def test_missing_data(self): +        """ +        Make sure that validate_content isn't called if the field is missing +        """ +        incomplete_data = { +            'email': 'tom@example.com', +            'created': datetime.datetime(2012, 1, 1) +        } +        serializer = self.CommentSerializerWithFieldValidator(data=incomplete_data) +        self.assertFalse(serializer.is_valid()) +        self.assertEquals(serializer.errors, {'content': [u'This field is required.']}) + +    def test_wrong_data(self): +        """ +        Make sure that validate_content isn't called if the field input is wrong +        """ +        wrong_data = { +            'email': 'not an email', +            'content': 'A test comment', +            'created': datetime.datetime(2012, 1, 1) +        } +        serializer = self.CommentSerializerWithFieldValidator(data=wrong_data) +        self.assertFalse(serializer.is_valid()) +        self.assertEquals(serializer.errors, {'email': [u'Enter a valid e-mail address.']}) + +  class PositiveIntegerAsChoiceTests(TestCase):      def test_positive_integer_in_json_is_correctly_parsed(self): -        data = {'some_integer':1} +        data = {'some_integer': 1}          serializer = PositiveIntegerAsChoiceSerializer(data=data)          self.assertEquals(serializer.is_valid(), True) +  class ModelValidationTests(TestCase):      def test_validate_unique(self):          """ @@ -342,7 +419,6 @@ class ModelValidationTests(TestCase):          self.assertTrue(photo_serializer.save()) -  class RegexValidationTest(TestCase):      def test_create_failed(self):          serializer = BookSerializer(data={'isbn': '1234567890'}) @@ -553,6 +629,21 @@ class DefaultValueTests(TestCase):          self.assertEquals(instance.pk, 1)          self.assertEquals(instance.text, 'overridden') +    def test_partial_update_default(self): +        """ Regression test for issue #532 """ +        data = {'text': 'overridden'} +        serializer = self.serializer_class(data=data, partial=True) +        self.assertEquals(serializer.is_valid(), True) +        instance = serializer.save() + +        data = {'extra': 'extra_value'} +        serializer = self.serializer_class(instance=instance, data=data, partial=True) +        self.assertEquals(serializer.is_valid(), True) +        instance = serializer.save() + +        self.assertEquals(instance.extra, 'extra_value') +        self.assertEquals(instance.text, 'overridden') +  class CallableDefaultValueTests(TestCase):      def setUp(self): diff --git a/rest_framework/tests/settings.py b/rest_framework/tests/settings.py new file mode 100644 index 00000000..0293fdc3 --- /dev/null +++ b/rest_framework/tests/settings.py @@ -0,0 +1,21 @@ +"""Tests for the settings module""" +from django.test import TestCase + +from rest_framework.settings import APISettings, DEFAULTS, IMPORT_STRINGS + + +class TestSettings(TestCase): +    """Tests relating to the api settings""" + +    def test_non_import_errors(self): +        """Make sure other errors aren't suppressed.""" +        settings = APISettings({'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.tests.extras.bad_import.ModelSerializer'}, DEFAULTS, IMPORT_STRINGS) +        with self.assertRaises(ValueError): +            settings.DEFAULT_MODEL_SERIALIZER_CLASS + +    def test_import_error_message_maintained(self): +        """Make sure real import errors are captured and raised sensibly.""" +        settings = APISettings({'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.tests.extras.not_here.ModelSerializer'}, DEFAULTS, IMPORT_STRINGS) +        with self.assertRaises(ImportError) as cm: +            settings.DEFAULT_MODEL_SERIALIZER_CLASS +        self.assertTrue('ImportError' in str(cm.exception)) diff --git a/rest_framework/tests/urlpatterns.py b/rest_framework/tests/urlpatterns.py new file mode 100644 index 00000000..43e8ef69 --- /dev/null +++ b/rest_framework/tests/urlpatterns.py @@ -0,0 +1,78 @@ +from collections import namedtuple + +from django.core import urlresolvers + +from django.test import TestCase +from django.test.client import RequestFactory + +from rest_framework.compat import patterns, url, include +from rest_framework.urlpatterns import format_suffix_patterns + + +# A container class for test paths for the test case +URLTestPath = namedtuple('URLTestPath', ['path', 'args', 'kwargs']) + + +def dummy_view(request, *args, **kwargs): +    pass + + +class FormatSuffixTests(TestCase): +    """ +    Tests `format_suffix_patterns` against different URLPatterns to ensure the URLs still resolve properly, including any captured parameters. +    """ +    def _resolve_urlpatterns(self, urlpatterns, test_paths): +        factory = RequestFactory() +        try: +            urlpatterns = format_suffix_patterns(urlpatterns) +        except: +            self.fail("Failed to apply `format_suffix_patterns` on  the supplied urlpatterns") +        resolver = urlresolvers.RegexURLResolver(r'^/', urlpatterns) +        for test_path in test_paths: +            request = factory.get(test_path.path) +            try: +                callback, callback_args, callback_kwargs = resolver.resolve(request.path_info) +            except: +                self.fail("Failed to resolve URL: %s" % request.path_info) +            self.assertEquals(callback_args, test_path.args) +            self.assertEquals(callback_kwargs, test_path.kwargs) + +    def test_format_suffix(self): +        urlpatterns = patterns( +            '', +            url(r'^test$', dummy_view), +        ) +        test_paths = [ +            URLTestPath('/test', (), {}), +            URLTestPath('/test.api', (), {'format': 'api'}), +            URLTestPath('/test.asdf', (), {'format': 'asdf'}), +        ] +        self._resolve_urlpatterns(urlpatterns, test_paths) + +    def test_default_args(self): +        urlpatterns = patterns( +            '', +            url(r'^test$', dummy_view, {'foo': 'bar'}), +        ) +        test_paths = [ +            URLTestPath('/test', (), {'foo': 'bar', }), +            URLTestPath('/test.api', (), {'foo': 'bar', 'format': 'api'}), +            URLTestPath('/test.asdf', (), {'foo': 'bar', 'format': 'asdf'}), +        ] +        self._resolve_urlpatterns(urlpatterns, test_paths) + +    def test_included_urls(self): +        nested_patterns = patterns( +            '', +            url(r'^path$', dummy_view) +        ) +        urlpatterns = patterns( +            '', +            url(r'^test/', include(nested_patterns), {'foo': 'bar'}), +        ) +        test_paths = [ +            URLTestPath('/test/path', (), {'foo': 'bar', }), +            URLTestPath('/test/path.api', (), {'foo': 'bar', 'format': 'api'}), +            URLTestPath('/test/path.asdf', (), {'foo': 'bar', 'format': 'asdf'}), +        ] +        self._resolve_urlpatterns(urlpatterns, test_paths) diff --git a/rest_framework/tests/utils.py b/rest_framework/tests/utils.py new file mode 100644 index 00000000..3906adb9 --- /dev/null +++ b/rest_framework/tests/utils.py @@ -0,0 +1,27 @@ +from django.test.client import RequestFactory, FakePayload +from django.test.client import MULTIPART_CONTENT +from urlparse import urlparse + + +class RequestFactory(RequestFactory): + +    def __init__(self, **defaults): +        super(RequestFactory, self).__init__(**defaults) + +    def patch(self, path, data={}, content_type=MULTIPART_CONTENT, +            **extra): +        "Construct a PATCH request." + +        patch_data = self._encode_data(data, content_type) + +        parsed = urlparse(path) +        r = { +            'CONTENT_LENGTH': len(patch_data), +            'CONTENT_TYPE':   content_type, +            'PATH_INFO':      self._get_path(parsed), +            'QUERY_STRING':   parsed[4], +            'REQUEST_METHOD': 'PATCH', +            'wsgi.input':     FakePayload(patch_data), +        } +        r.update(extra) +        return self.request(**r) diff --git a/rest_framework/tests/views.py b/rest_framework/tests/views.py index e51ca9f3..f2432516 100644 --- a/rest_framework/tests/views.py +++ b/rest_framework/tests/views.py @@ -20,7 +20,7 @@ class BasicView(APIView):          return Response({'method': 'POST', 'data': request.DATA}) -@api_view(['GET', 'POST', 'PUT']) +@api_view(['GET', 'POST', 'PUT', 'PATCH'])  def basic_view(request):      if request.method == 'GET':          return {'method': 'GET'} @@ -28,6 +28,8 @@ def basic_view(request):          return {'method': 'POST', 'data': request.DATA}      elif request.method == 'PUT':          return {'method': 'PUT', 'data': request.DATA} +    elif request.method == 'PATCH': +        return {'method': 'PATCH', 'data': request.DATA}  def sanitise_json_error(error_dict): diff --git a/rest_framework/urlpatterns.py b/rest_framework/urlpatterns.py index 143928c9..47789026 100644 --- a/rest_framework/urlpatterns.py +++ b/rest_framework/urlpatterns.py @@ -1,5 +1,35 @@ -from rest_framework.compat import url +from rest_framework.compat import url, include  from rest_framework.settings import api_settings +from django.core.urlresolvers import RegexURLResolver + + +def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required): +    ret = [] +    for urlpattern in urlpatterns: +        if isinstance(urlpattern, RegexURLResolver): +            # Set of included URL patterns +            regex = urlpattern.regex.pattern +            namespace = urlpattern.namespace +            app_name = urlpattern.app_name +            kwargs = urlpattern.default_kwargs +            # Add in the included patterns, after applying the suffixes +            patterns = apply_suffix_patterns(urlpattern.url_patterns, +                                             suffix_pattern, +                                             suffix_required) +            ret.append(url(regex, include(patterns, namespace, app_name), kwargs)) + +        else: +            # Regular URL pattern +            regex = urlpattern.regex.pattern.rstrip('$') + suffix_pattern +            view = urlpattern._callback or urlpattern._callback_str +            kwargs = urlpattern.default_args +            name = urlpattern.name +            # Add in both the existing and the new urlpattern +            if not suffix_required: +                ret.append(urlpattern) +            ret.append(url(regex, view, kwargs, name)) + +    return ret  def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None): @@ -28,15 +58,4 @@ def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None):      else:          suffix_pattern = r'\.(?P<%s>[a-z]+)$' % suffix_kwarg -    ret = [] -    for urlpattern in urlpatterns: -        # Form our complementing '.format' urlpattern -        regex = urlpattern.regex.pattern.rstrip('$') + suffix_pattern -        view = urlpattern._callback or urlpattern._callback_str -        kwargs = urlpattern.default_args -        name = urlpattern.name -        # Add in both the existing and the new urlpattern -        if not suffix_required: -            ret.append(urlpattern) -        ret.append(url(regex, view, kwargs, name)) -    return ret +    return apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required) diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py index 2d1fb353..7afe100a 100644 --- a/rest_framework/utils/encoders.py +++ b/rest_framework/utils/encoders.py @@ -4,7 +4,7 @@ Helper classes for parsers.  import datetime  import decimal  import types -from django.utils import simplejson as json +import json  from django.utils.datastructures import SortedDict  from rest_framework.compat import timezone  from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata @@ -12,7 +12,7 @@ from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata  class JSONEncoder(json.JSONEncoder):      """ -    JSONEncoder subclass that knows how to encode date/time, +    JSONEncoder subclass that knows how to encode date/time/timedelta,      decimal types, and generators.      """      def default(self, o): @@ -34,6 +34,8 @@ class JSONEncoder(json.JSONEncoder):              if o.microsecond:                  r = r[:12]              return r +        elif isinstance(o, datetime.timedelta): +            return str(o.total_seconds())          elif isinstance(o, decimal.Decimal):              return str(o)          elif hasattr(o, '__iter__'): diff --git a/rest_framework/views.py b/rest_framework/views.py index 10bdd5a5..ac9b3385 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -148,6 +148,8 @@ class APIView(View):          """          If request is not permitted, determine what kind of exception to raise.          """ +        if not self.request.successful_authenticator: +            raise exceptions.NotAuthenticated()          raise exceptions.PermissionDenied()      def throttled(self, request, wait): @@ -156,6 +158,15 @@ class APIView(View):          """          raise exceptions.Throttled(wait) +    def get_authenticate_header(self, request): +        """ +        If a request is unauthenticated, determine the WWW-Authenticate +        header to use for 401 responses, if any. +        """ +        authenticators = self.get_authenticators() +        if authenticators: +            return authenticators[0].authenticate_header(request) +      def get_parser_context(self, http_request):          """          Returns a dict that is passed through to Parser.parse(), @@ -319,6 +330,16 @@ class APIView(View):              # Throttle wait header              self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait +        if isinstance(exc, (exceptions.NotAuthenticated, +                            exceptions.AuthenticationFailed)): +            # WWW-Authenticate header for 401 responses, else coerce to 403 +            auth_header = self.get_authenticate_header(self.request) + +            if auth_header: +                self.headers['WWW-Authenticate'] = auth_header +            else: +                exc.status_code = status.HTTP_403_FORBIDDEN +          if isinstance(exc, exceptions.APIException):              return Response({'detail': exc.detail},                              status=exc.status_code,  | 
